From 0c2a5116710a5728ca2aa31979e3086a032e9523 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Sat, 21 Jan 2023 18:07:36 +0100 Subject: [PATCH 01/25] wip SwapCrossAttnProcessor --- .../swap_cross_attention_processor.py | 160 ++++++++++++++++++ .../diffusion/cross_attention_control.py | 5 +- 2 files changed, 163 insertions(+), 2 deletions(-) create mode 100644 ldm/invoke/generator/swap_cross_attention_processor.py diff --git a/ldm/invoke/generator/swap_cross_attention_processor.py b/ldm/invoke/generator/swap_cross_attention_processor.py new file mode 100644 index 0000000000..e3aa5bc484 --- /dev/null +++ b/ldm/invoke/generator/swap_cross_attention_processor.py @@ -0,0 +1,160 @@ + +""" +# base implementation + +class CrossAttnProcessor: + def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) + + query = attn.to_q(hidden_states) + query = attn.head_to_batch_dim(query) + + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + +""" +import enum +from dataclasses import field, dataclass + +import torch + +from diffusers.models.cross_attention import CrossAttention, CrossAttnProcessor + +class AttentionType(enum.Enum): + SELF = 1 + TOKENS = 2 + +@dataclass +class SwapCrossAttnContext: + + cross_attention_types_to_do: list[AttentionType] + modified_text_embeddings: torch.Tensor + index_map: torch.Tensor # maps from original prompt token indices to the equivalent tokens in the modified prompt + mask: torch.Tensor # in the target space of the index_map + + def __int__(self, + cac_types_to_do: [AttentionType], + modified_text_embeddings: torch.Tensor, + index_map: torch.Tensor, + mask: torch.Tensor): + self.cross_attention_types_to_do = cac_types_to_do + self.modified_text_embeddings = modified_text_embeddings + self.index_map = index_map + self.mask = mask + + def wants_cross_attention_control(self, attn_type: AttentionType) -> bool: + return attn_type in self.cross_attention_types_to_do + + +class SwapCrossAttnProcessor(CrossAttnProcessor): + def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, + # kwargs + cross_attention_swap_context_provider: SwapCrossAttnContext=None): + + if cross_attention_swap_context_provider is None: + raise RuntimeError("a SwapCrossAttnContext instance must be passed via attention processor kwargs") + + attention_type = AttentionType.SELF if encoder_hidden_states is None else AttentionType.TOKENS + # if cross-attention control is not in play, just call through to the base implementation. + if not cross_attention_swap_context_provider.wants_cross_attention_control(attention_type): + return super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask) + + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) + + query = attn.to_q(hidden_states) + query = attn.head_to_batch_dim(query) + + # helper function + def get_attention_probs(embeddings): + this_key = attn.to_k(embeddings) + this_key = attn.head_to_batch_dim(this_key) + return attn.get_attention_scores(query, this_key, attention_mask) + + if attention_type == AttentionType.SELF: + # self attention has no remapping, it just bluntly copies the whole tensor + attention_probs = get_attention_probs(hidden_states) + value = attn.to_v(hidden_states) + else: + # tokens (cross) attention + # first, find attention probabilities for the "original" prompt + original_text_embeddings = encoder_hidden_states + original_attention_probs = get_attention_probs(original_text_embeddings) + + # then, find attention probabilities for the "modified" prompt + modified_text_embeddings = cross_attention_swap_context_provider.modified_text_embeddings + modified_attention_probs = get_attention_probs(modified_text_embeddings) + + # because the prompt modifications may result in token sequences shifted forwards or backwards, + # the original attention probabilities must be remapped to account for token index changes in the + # modified prompt + remapped_original_attention_probs = torch.index_select(original_attention_probs, -1, + cross_attention_swap_context_provider.index_map) + + # only some tokens taken from the original attention probabilities. this is controlled by the mask. + mask = cross_attention_swap_context_provider.mask + inverse_mask = 1 - mask + attention_probs = \ + remapped_original_attention_probs * mask + \ + modified_attention_probs * inverse_mask + + # for the "value" just use the modified text embeddings. + value = attn.to_v(modified_text_embeddings) + + value = attn.head_to_batch_dim(value) + + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +class P2PCrossAttentionProc: + + def __init__(self, head_size, upcast_attention, attn_maps_reweight): + super().__init__(head_size=head_size, upcast_attention=upcast_attention) + self.attn_maps_reweight = attn_maps_reweight + + def __call__(self, hidden_states, query_proj, key_proj, value_proj, encoder_hidden_states, modified_text_embeddings): + batch_size, sequence_length, _ = hidden_states.shape + query = query_proj(hidden_states) + + context = context if context is not None else hidden_states + attention_probs = [] + original_text_embeddings = encoder_hidden_states + for context in [original_text_embeddings, modified_text_embeddings]: + key = key_proj(original_text_embeddings) + value = self.value_proj(original_text_embeddings) + + query = self.head_to_batch_dim(query, self.head_size) + key = self.head_to_batch_dim(key, self.head_size) + value = self.head_to_batch_dim(value, self.head_size) + + attention_probs.append(self.get_attention_scores(query, key)) + + merged_probs = self.attn_maps_reweight * torch.cat(attention_probs) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = self.batch_to_head_dim(hidden_states) + return hidden_states + +proc = P2PCrossAttentionProc(unet.config.head_size, unet.config.upcast_attention, 0.6) diff --git a/ldm/models/diffusion/cross_attention_control.py b/ldm/models/diffusion/cross_attention_control.py index 03d5a5bcec..6eae824301 100644 --- a/ldm/models/diffusion/cross_attention_control.py +++ b/ldm/models/diffusion/cross_attention_control.py @@ -333,10 +333,10 @@ def setup_cross_attention_control(model, context: Context): indices[b0:b1] = indices_target[a0:a1] mask[b0:b1] = 1 - context.register_cross_attention_modules(model) + #context.register_cross_attention_modules(model) context.cross_attention_mask = mask.to(device) context.cross_attention_index_map = indices.to(device) - inject_attention_function(model, context) + #inject_attention_function(model, context) def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[str, InvokeAICrossAttentionMixin]]: @@ -445,6 +445,7 @@ def get_mem_free_total(device): return mem_free_total + class InvokeAIDiffusersCrossAttention(diffusers.models.attention.CrossAttention, InvokeAICrossAttentionMixin): def __init__(self, **kwargs): From bffe199ad72735b780eef4d184927e37463d6a09 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Sat, 21 Jan 2023 20:54:18 +0100 Subject: [PATCH 02/25] SwapCrossAttnProcessor working - tested on mac CPU (MPS doesn't work) --- ldm/invoke/generator/diffusers_pipeline.py | 13 +- .../swap_cross_attention_processor.py | 160 ---------------- .../diffusion/cross_attention_control.py | 174 +++++++++++++++++- .../diffusion/shared_invokeai_diffusion.py | 60 +++++- 4 files changed, 226 insertions(+), 181 deletions(-) diff --git a/ldm/invoke/generator/diffusers_pipeline.py b/ldm/invoke/generator/diffusers_pipeline.py index 54e9d555af..6f3cd14550 100644 --- a/ldm/invoke/generator/diffusers_pipeline.py +++ b/ldm/invoke/generator/diffusers_pipeline.py @@ -24,9 +24,6 @@ from ...models.diffusion import cross_attention_control from ...models.diffusion.cross_attention_map_saving import AttentionMapSaver from ...modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToEmbeddingsConverter -# monkeypatch diffusers CrossAttention 🙈 -# this is to make prompt2prompt and (future) attention maps work -attention.CrossAttention = cross_attention_control.InvokeAIDiffusersCrossAttention from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput @@ -295,7 +292,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): safety_checker=safety_checker, feature_extractor=feature_extractor, ) - self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward) + self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward, is_running_diffusers=True) use_full_precision = (precision == 'float32' or precision == 'autocast') self.textual_inversion_manager = TextualInversionManager(tokenizer=self.tokenizer, text_encoder=self.text_encoder, @@ -389,6 +386,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): attention_map_saver: Optional[AttentionMapSaver] = None self.invokeai_diffuser.remove_attention_map_saving() + for i, t in enumerate(self.progress_bar(timesteps)): batched_t.fill_(t) step_output = self.step(batched_t, latents, conditioning_data, @@ -447,7 +445,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): return step_output - def _unet_forward(self, latents, t, text_embeddings): + def _unet_forward(self, latents, t, text_embeddings, cross_attention_kwargs: Optional[dict[str,Any]] = None): """predict the noise residual""" if is_inpainting_model(self.unet) and latents.size(1) == 4: # Pad out normal non-inpainting inputs for an inpainting model. @@ -460,7 +458,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): initial_image_latents=torch.zeros_like(latents[:1], device=latents.device, dtype=latents.dtype) ).add_mask_channels(latents) - return self.unet(latents, t, encoder_hidden_states=text_embeddings).sample + return self.unet(sample=latents, + timestep=t, + encoder_hidden_states=text_embeddings, + cross_attention_kwargs=cross_attention_kwargs).sample def img2img_from_embeddings(self, init_image: Union[torch.FloatTensor, PIL.Image.Image], diff --git a/ldm/invoke/generator/swap_cross_attention_processor.py b/ldm/invoke/generator/swap_cross_attention_processor.py index e3aa5bc484..e69de29bb2 100644 --- a/ldm/invoke/generator/swap_cross_attention_processor.py +++ b/ldm/invoke/generator/swap_cross_attention_processor.py @@ -1,160 +0,0 @@ - -""" -# base implementation - -class CrossAttnProcessor: - def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): - batch_size, sequence_length, _ = hidden_states.shape - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) - - query = attn.to_q(hidden_states) - query = attn.head_to_batch_dim(query) - - encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) - - attention_probs = attn.get_attention_scores(query, key, attention_mask) - hidden_states = torch.bmm(attention_probs, value) - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - return hidden_states - -""" -import enum -from dataclasses import field, dataclass - -import torch - -from diffusers.models.cross_attention import CrossAttention, CrossAttnProcessor - -class AttentionType(enum.Enum): - SELF = 1 - TOKENS = 2 - -@dataclass -class SwapCrossAttnContext: - - cross_attention_types_to_do: list[AttentionType] - modified_text_embeddings: torch.Tensor - index_map: torch.Tensor # maps from original prompt token indices to the equivalent tokens in the modified prompt - mask: torch.Tensor # in the target space of the index_map - - def __int__(self, - cac_types_to_do: [AttentionType], - modified_text_embeddings: torch.Tensor, - index_map: torch.Tensor, - mask: torch.Tensor): - self.cross_attention_types_to_do = cac_types_to_do - self.modified_text_embeddings = modified_text_embeddings - self.index_map = index_map - self.mask = mask - - def wants_cross_attention_control(self, attn_type: AttentionType) -> bool: - return attn_type in self.cross_attention_types_to_do - - -class SwapCrossAttnProcessor(CrossAttnProcessor): - def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, - # kwargs - cross_attention_swap_context_provider: SwapCrossAttnContext=None): - - if cross_attention_swap_context_provider is None: - raise RuntimeError("a SwapCrossAttnContext instance must be passed via attention processor kwargs") - - attention_type = AttentionType.SELF if encoder_hidden_states is None else AttentionType.TOKENS - # if cross-attention control is not in play, just call through to the base implementation. - if not cross_attention_swap_context_provider.wants_cross_attention_control(attention_type): - return super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask) - - batch_size, sequence_length, _ = hidden_states.shape - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) - - query = attn.to_q(hidden_states) - query = attn.head_to_batch_dim(query) - - # helper function - def get_attention_probs(embeddings): - this_key = attn.to_k(embeddings) - this_key = attn.head_to_batch_dim(this_key) - return attn.get_attention_scores(query, this_key, attention_mask) - - if attention_type == AttentionType.SELF: - # self attention has no remapping, it just bluntly copies the whole tensor - attention_probs = get_attention_probs(hidden_states) - value = attn.to_v(hidden_states) - else: - # tokens (cross) attention - # first, find attention probabilities for the "original" prompt - original_text_embeddings = encoder_hidden_states - original_attention_probs = get_attention_probs(original_text_embeddings) - - # then, find attention probabilities for the "modified" prompt - modified_text_embeddings = cross_attention_swap_context_provider.modified_text_embeddings - modified_attention_probs = get_attention_probs(modified_text_embeddings) - - # because the prompt modifications may result in token sequences shifted forwards or backwards, - # the original attention probabilities must be remapped to account for token index changes in the - # modified prompt - remapped_original_attention_probs = torch.index_select(original_attention_probs, -1, - cross_attention_swap_context_provider.index_map) - - # only some tokens taken from the original attention probabilities. this is controlled by the mask. - mask = cross_attention_swap_context_provider.mask - inverse_mask = 1 - mask - attention_probs = \ - remapped_original_attention_probs * mask + \ - modified_attention_probs * inverse_mask - - # for the "value" just use the modified text embeddings. - value = attn.to_v(modified_text_embeddings) - - value = attn.head_to_batch_dim(value) - - hidden_states = torch.bmm(attention_probs, value) - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - return hidden_states - - -class P2PCrossAttentionProc: - - def __init__(self, head_size, upcast_attention, attn_maps_reweight): - super().__init__(head_size=head_size, upcast_attention=upcast_attention) - self.attn_maps_reweight = attn_maps_reweight - - def __call__(self, hidden_states, query_proj, key_proj, value_proj, encoder_hidden_states, modified_text_embeddings): - batch_size, sequence_length, _ = hidden_states.shape - query = query_proj(hidden_states) - - context = context if context is not None else hidden_states - attention_probs = [] - original_text_embeddings = encoder_hidden_states - for context in [original_text_embeddings, modified_text_embeddings]: - key = key_proj(original_text_embeddings) - value = self.value_proj(original_text_embeddings) - - query = self.head_to_batch_dim(query, self.head_size) - key = self.head_to_batch_dim(key, self.head_size) - value = self.head_to_batch_dim(value, self.head_size) - - attention_probs.append(self.get_attention_scores(query, key)) - - merged_probs = self.attn_maps_reweight * torch.cat(attention_probs) - hidden_states = torch.bmm(attention_probs, value) - hidden_states = self.batch_to_head_dim(hidden_states) - return hidden_states - -proc = P2PCrossAttentionProc(unet.config.head_size, unet.config.upcast_attention, 0.6) diff --git a/ldm/models/diffusion/cross_attention_control.py b/ldm/models/diffusion/cross_attention_control.py index 6eae824301..4b89b5bd56 100644 --- a/ldm/models/diffusion/cross_attention_control.py +++ b/ldm/models/diffusion/cross_attention_control.py @@ -9,6 +9,7 @@ from torch import nn from diffusers.models.unet_2d_condition import UNet2DConditionModel from ldm.invoke.devices import torch_dtype + # adapted from bloc97's CrossAttentionControl colab # https://github.com/bloc97/CrossAttentionControl @@ -304,11 +305,16 @@ class InvokeAICrossAttentionMixin: -def remove_cross_attention_control(model): - remove_attention_function(model) +def remove_cross_attention_control(model, is_running_diffusers: bool): + if is_running_diffusers: + unet = model + print("** need to know what cross attn processor to use by default, None in the following line is wrong") + unet.set_attn_processor(CrossAttnProcessor()) + else: + remove_attention_function(model) -def setup_cross_attention_control(model, context: Context): +def setup_cross_attention_control(model, context: Context, is_running_diffusers = False): """ Inject attention parameters and functions into the passed in model to enable cross attention editing. @@ -333,10 +339,16 @@ def setup_cross_attention_control(model, context: Context): indices[b0:b1] = indices_target[a0:a1] mask[b0:b1] = 1 - #context.register_cross_attention_modules(model) context.cross_attention_mask = mask.to(device) context.cross_attention_index_map = indices.to(device) - #inject_attention_function(model, context) + if is_running_diffusers: + unet = model + unet.set_attn_processor(SwapCrossAttnProcessor()) + else: + context.register_cross_attention_modules(model) + inject_attention_function(model, context) + + def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[str, InvokeAICrossAttentionMixin]]: @@ -461,3 +473,155 @@ class InvokeAIDiffusersCrossAttention(diffusers.models.attention.CrossAttention, hidden_states = self.reshape_batch_dim_to_heads(attention_result) return hidden_states + + + + +## 🧨diffusers implementation follows + + +""" +# base implementation + +class CrossAttnProcessor: + def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) + + query = attn.to_q(hidden_states) + query = attn.head_to_batch_dim(query) + + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + +""" +import enum +from dataclasses import field, dataclass + +import torch + +from diffusers.models.cross_attention import CrossAttention, CrossAttnProcessor +from ldm.models.diffusion.cross_attention_control import CrossAttentionType + + +@dataclass +class SwapCrossAttnContext: + modified_text_embeddings: torch.Tensor + index_map: torch.Tensor # maps from original prompt token indices to the equivalent tokens in the modified prompt + mask: torch.Tensor # in the target space of the index_map + cross_attention_types_to_do: list[CrossAttentionType] = field(default_factory=[]) + + def __int__(self, + cac_types_to_do: [CrossAttentionType], + modified_text_embeddings: torch.Tensor, + index_map: torch.Tensor, + mask: torch.Tensor): + self.cross_attention_types_to_do = cac_types_to_do + self.modified_text_embeddings = modified_text_embeddings + self.index_map = index_map + self.mask = mask + + def wants_cross_attention_control(self, attn_type: CrossAttentionType) -> bool: + return attn_type in self.cross_attention_types_to_do + + @classmethod + def make_mask_and_index_map(cls, edit_opcodes: list[tuple[str, int, int, int, int]], max_length: int) \ + -> tuple[torch.Tensor, torch.Tensor]: + + # mask=1 means use original prompt attention, mask=0 means use modified prompt attention + mask = torch.zeros(max_length) + indices_target = torch.arange(max_length, dtype=torch.long) + indices = torch.arange(max_length, dtype=torch.long) + for name, a0, a1, b0, b1 in edit_opcodes: + if b0 < max_length: + if name == "equal": + # these tokens remain the same as in the original prompt + indices[b0:b1] = indices_target[a0:a1] + mask[b0:b1] = 1 + + return mask, indices + + +class SwapCrossAttnProcessor(CrossAttnProcessor): + + def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, + # kwargs + swap_cross_attn_context: SwapCrossAttnContext=None): + + attention_type = CrossAttentionType.SELF if encoder_hidden_states is None else CrossAttentionType.TOKENS + + # if cross-attention control is not in play, just call through to the base implementation. + if swap_cross_attn_context is None or not swap_cross_attn_context.wants_cross_attention_control(attention_type): + #print(f"SwapCrossAttnContext for {attention_type} not active - passing request to superclass") + return super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask) + #else: + # print(f"SwapCrossAttnContext for {attention_type} active") + + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) + + query = attn.to_q(hidden_states) + query = attn.head_to_batch_dim(query) + + # helper function + def get_attention_probs(embeddings): + this_key = attn.to_k(embeddings) + this_key = attn.head_to_batch_dim(this_key) + return attn.get_attention_scores(query, this_key, attention_mask) + + if attention_type == CrossAttentionType.SELF: + # self attention has no remapping, it just bluntly copies the whole tensor + attention_probs = get_attention_probs(hidden_states) + value = attn.to_v(hidden_states) + else: + # tokens (cross) attention + # first, find attention probabilities for the "original" prompt + original_text_embeddings = encoder_hidden_states + original_attention_probs = get_attention_probs(original_text_embeddings) + + # then, find attention probabilities for the "modified" prompt + modified_text_embeddings = swap_cross_attn_context.modified_text_embeddings + modified_attention_probs = get_attention_probs(modified_text_embeddings) + + # because the prompt modifications may result in token sequences shifted forwards or backwards, + # the original attention probabilities must be remapped to account for token index changes in the + # modified prompt + remapped_original_attention_probs = torch.index_select(original_attention_probs, -1, + swap_cross_attn_context.index_map) + + # only some tokens taken from the original attention probabilities. this is controlled by the mask. + mask = swap_cross_attn_context.mask + inverse_mask = 1 - mask + attention_probs = \ + remapped_original_attention_probs * mask + \ + modified_attention_probs * inverse_mask + + # for the "value" just use the modified text embeddings. + value = attn.to_v(modified_text_embeddings) + + value = attn.head_to_batch_dim(value) + + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + diff --git a/ldm/models/diffusion/shared_invokeai_diffusion.py b/ldm/models/diffusion/shared_invokeai_diffusion.py index 3be6b10170..e4932f6ad8 100644 --- a/ldm/models/diffusion/shared_invokeai_diffusion.py +++ b/ldm/models/diffusion/shared_invokeai_diffusion.py @@ -1,14 +1,14 @@ import math from dataclasses import dataclass from math import ceil -from typing import Callable, Optional, Union +from typing import Callable, Optional, Union, Any import numpy as np import torch from ldm.models.diffusion.cross_attention_control import Arguments, \ remove_cross_attention_control, setup_cross_attention_control, Context, get_cross_attention_modules, \ - CrossAttentionType + CrossAttentionType, SwapCrossAttnContext from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver @@ -30,24 +30,28 @@ class InvokeAIDiffuserComponent: debug_thresholding = False + @dataclass class ExtraConditioningInfo: - def __init__(self, tokens_count_including_eos_bos:int, cross_attention_control_args: Optional[Arguments]): - self.tokens_count_including_eos_bos = tokens_count_including_eos_bos - self.cross_attention_control_args = cross_attention_control_args + + tokens_count_including_eos_bos: int + cross_attention_control_args: Optional[Arguments] = None @property def wants_cross_attention_control(self): return self.cross_attention_control_args is not None + def __init__(self, model, model_forward_callback: - Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor] - ): + Callable[[torch.Tensor, torch.Tensor, torch.Tensor, Optional[dict[str,Any]]], torch.Tensor], + is_running_diffusers: bool=False, + ): """ :param model: the unet model to pass through to cross attention control :param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning) """ self.conditioning = None self.model = model + self.is_running_diffusers = is_running_diffusers self.model_forward_callback = model_forward_callback self.cross_attention_control_context = None @@ -57,12 +61,14 @@ class InvokeAIDiffuserComponent: arguments=self.conditioning.cross_attention_control_args, step_count=step_count ) - setup_cross_attention_control(self.model, self.cross_attention_control_context) + setup_cross_attention_control(self.model, + self.cross_attention_control_context, + is_running_diffusers=self.is_running_diffusers) def remove_cross_attention_control(self): self.conditioning = None self.cross_attention_control_context = None - remove_cross_attention_control(self.model) + remove_cross_attention_control(self.model, is_running_diffusers=self.is_running_diffusers) def setup_attention_map_saving(self, saver: AttentionMapSaver): def callback(slice, dim, offset, slice_size, key): @@ -168,7 +174,41 @@ class InvokeAIDiffuserComponent: return unconditioned_next_x, conditioned_next_x - def apply_cross_attention_controlled_conditioning(self, x:torch.Tensor, sigma, unconditioning, conditioning, cross_attention_control_types_to_do): + def apply_cross_attention_controlled_conditioning(self, + x: torch.Tensor, + sigma, + unconditioning, + conditioning, + cross_attention_control_types_to_do): + if self.is_running_diffusers: + return self.apply_cross_attention_controlled_conditioning__diffusers(x, sigma, unconditioning, conditioning, cross_attention_control_types_to_do) + else: + return self.apply_cross_attention_controlled_conditioning__compvis(x, sigma, unconditioning, conditioning, cross_attention_control_types_to_do) + + def apply_cross_attention_controlled_conditioning__diffusers(self, + x: torch.Tensor, + sigma, + unconditioning, + conditioning, + cross_attention_control_types_to_do): + context: Context = self.cross_attention_control_context + + cross_attn_processor_context = SwapCrossAttnContext(modified_text_embeddings=context.arguments.edited_conditioning, + index_map=context.cross_attention_index_map, + mask=context.cross_attention_mask, + cross_attention_types_to_do=[]) + # no cross attention for unconditioning (negative prompt) + unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, + {"swap_cross_attn_context": cross_attn_processor_context}) + + # do requested cross attention types for conditioning (positive prompt) + cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do + conditioned_next_x = self.model_forward_callback(x, sigma, conditioning, + {"swap_cross_attn_context": cross_attn_processor_context}) + return unconditioned_next_x, conditioned_next_x + + + def apply_cross_attention_controlled_conditioning__compvis(self, x:torch.Tensor, sigma, unconditioning, conditioning, cross_attention_control_types_to_do): # print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do) # slower non-batched path (20% slower on mac MPS) # We are only interested in using attention maps for conditioned_next_x, but batching them with generation of From 313b206ff8f2a2099cbc5dae7050bcd6b2576590 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Sun, 22 Jan 2023 18:12:11 +0100 Subject: [PATCH 03/25] squash float16/float32 mismatch on linux --- ldm/models/diffusion/cross_attention_control.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ldm/models/diffusion/cross_attention_control.py b/ldm/models/diffusion/cross_attention_control.py index 4b89b5bd56..45294ac993 100644 --- a/ldm/models/diffusion/cross_attention_control.py +++ b/ldm/models/diffusion/cross_attention_control.py @@ -329,7 +329,7 @@ def setup_cross_attention_control(model, context: Context, is_running_diffusers # urgh. should this be hardcoded? max_length = 77 # mask=1 means use base prompt attention, mask=0 means use edited prompt attention - mask = torch.zeros(max_length) + mask = torch.zeros(max_length, dtype=torch_dtype()) indices_target = torch.arange(max_length, dtype=torch.long) indices = torch.arange(max_length, dtype=torch.long) for name, a0, a1, b0, b1 in context.arguments.edit_opcodes: @@ -338,7 +338,7 @@ def setup_cross_attention_control(model, context: Context, is_running_diffusers # these tokens have not been edited indices[b0:b1] = indices_target[a0:a1] mask[b0:b1] = 1 - +b context.cross_attention_mask = mask.to(device) context.cross_attention_index_map = indices.to(device) if is_running_diffusers: From c0610f7cb99341a73d0b554c8583ef027cf6db79 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Sun, 22 Jan 2023 18:19:01 +0100 Subject: [PATCH 04/25] pass missing value --- ldm/models/diffusion/cross_attention_control.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ldm/models/diffusion/cross_attention_control.py b/ldm/models/diffusion/cross_attention_control.py index 45294ac993..fcb9f52dde 100644 --- a/ldm/models/diffusion/cross_attention_control.py +++ b/ldm/models/diffusion/cross_attention_control.py @@ -329,7 +329,7 @@ def setup_cross_attention_control(model, context: Context, is_running_diffusers # urgh. should this be hardcoded? max_length = 77 # mask=1 means use base prompt attention, mask=0 means use edited prompt attention - mask = torch.zeros(max_length, dtype=torch_dtype()) + mask = torch.zeros(max_length, dtype=torch_dtype(device)) indices_target = torch.arange(max_length, dtype=torch.long) indices = torch.arange(max_length, dtype=torch.long) for name, a0, a1, b0, b1 in context.arguments.edit_opcodes: @@ -338,7 +338,7 @@ def setup_cross_attention_control(model, context: Context, is_running_diffusers # these tokens have not been edited indices[b0:b1] = indices_target[a0:a1] mask[b0:b1] = 1 -b + context.cross_attention_mask = mask.to(device) context.cross_attention_index_map = indices.to(device) if is_running_diffusers: From 63c6019f9281056ca79ec92ed9dadac00abe8837 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Tue, 24 Jan 2023 14:46:32 +0100 Subject: [PATCH 05/25] sliced attention processor wip (untested) --- .../swap_cross_attention_processor.py | 0 .../diffusion/cross_attention_control.py | 94 ++++++++++++++++++- 2 files changed, 93 insertions(+), 1 deletion(-) delete mode 100644 ldm/invoke/generator/swap_cross_attention_processor.py diff --git a/ldm/invoke/generator/swap_cross_attention_processor.py b/ldm/invoke/generator/swap_cross_attention_processor.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/ldm/models/diffusion/cross_attention_control.py b/ldm/models/diffusion/cross_attention_control.py index fcb9f52dde..b1c1cd63d9 100644 --- a/ldm/models/diffusion/cross_attention_control.py +++ b/ldm/models/diffusion/cross_attention_control.py @@ -514,7 +514,7 @@ from dataclasses import field, dataclass import torch -from diffusers.models.cross_attention import CrossAttention, CrossAttnProcessor +from diffusers.models.cross_attention import CrossAttention, CrossAttnProcessor, SlicedAttnProcessor from ldm.models.diffusion.cross_attention_control import CrossAttentionType @@ -625,3 +625,95 @@ class SwapCrossAttnProcessor(CrossAttnProcessor): return hidden_states + + +class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor): + + def __init__(self, slice_size = 1e6): + self.slice_count = slice_size + + # TODO: dynamically pick slice size based on memory conditions + + def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, + # kwargs + swap_cross_attn_context: SwapCrossAttnContext=None): + + attention_type = CrossAttentionType.SELF if encoder_hidden_states is None else CrossAttentionType.TOKENS + + # if cross-attention control is not in play, just call through to the base implementation. + if attention_type is CrossAttentionType.SELF or \ + swap_cross_attn_context is None or \ + not swap_cross_attn_context.wants_cross_attention_control(attention_type): + #print(f"SwapCrossAttnContext for {attention_type} not active - passing request to superclass") + return super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask) + #else: + # print(f"SwapCrossAttnContext for {attention_type} active") + + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) + + query = attn.to_q(hidden_states) + query = attn.head_to_batch_dim(query) + + original_text_embeddings = encoder_hidden_states + original_text_key = attn.to_k(original_text_embeddings) + original_text_key = attn.head_to_batch_dim(original_text_key) + modified_text_embeddings = swap_cross_attn_context.modified_text_embeddings + modified_text_key = attn.to_k(modified_text_embeddings) + modified_text_key = attn.head_to_batch_dim(original_text_key) + + # for the "value" just use the modified text embeddings. + value = attn.to_v(modified_text_embeddings) + value = attn.head_to_batch_dim(value) + + # compute slices and prepare output tensor + batch_size_attention = query.shape[0] + dim = query.shape[-1] + hidden_states = torch.zeros( + (batch_size_attention, sequence_length, dim // attn.heads), device=query.device, dtype=query.dtype + ) + + # do slices + for i in range(hidden_states.shape[0] // self.slice_size): + start_idx = i * self.slice_size + end_idx = min(hidden_states.shape[0], (i + 1) * self.slice_size) + + query_slice = query[start_idx:end_idx] + attention_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None + + # first, find attention probabilities for the "original" prompt + original_text_key_slice = original_text_key[start_idx:end_idx] + original_attention_probs_slice = attn.get_attention_scores(query_slice, original_text_key_slice, attention_mask_slice) + + # then, find attention probabilities for the "modified" prompt + modified_text_key_slice = modified_text_key[start_idx:end_idx] + modified_attention_probs_slice = attn.get_attention_scores(query_slice, modified_text_key_slice, attention_mask_slice) + + # because the prompt modifications may result in token sequences shifted forwards or backwards, + # the original attention probabilities must be remapped to account for token index changes in the + # modified prompt + remapped_original_attention_probs_slice = torch.index_select(original_attention_probs_slice, -1, + swap_cross_attn_context.index_map) + + # only some tokens taken from the original attention probabilities. this is controlled by the mask. + mask = swap_cross_attn_context.mask + inverse_mask = 1 - mask + attention_probs_slice = \ + remapped_original_attention_probs_slice * mask + \ + modified_attention_probs_slice * inverse_mask + + value_slice = value[start_idx:end_idx] + hidden_states_slice = torch.bmm(attention_probs_slice, value_slice) + + hidden_states[start_idx:end_idx] = hidden_states_slice + + + # done + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states From a4aea1540b906faee6687a819dc9a2d007ca9bfa Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Wed, 25 Jan 2023 14:51:08 +0100 Subject: [PATCH 06/25] more wip sliced attention (.swap doesn't work yet) --- ldm/invoke/generator/diffusers_pipeline.py | 60 +++++++++---------- .../diffusion/cross_attention_control.py | 25 ++++---- .../diffusion/shared_invokeai_diffusion.py | 39 +++++++++--- 3 files changed, 75 insertions(+), 49 deletions(-) diff --git a/ldm/invoke/generator/diffusers_pipeline.py b/ldm/invoke/generator/diffusers_pipeline.py index 6f3cd14550..e5ce403cb7 100644 --- a/ldm/invoke/generator/diffusers_pipeline.py +++ b/ldm/invoke/generator/diffusers_pipeline.py @@ -306,6 +306,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): if is_xformers_available() and not Globals.disable_xformers: self.enable_xformers_memory_efficient_attention() + else: + slice_size = 2 + self.enable_attention_slicing(slice_size=slice_size) def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int, conditioning_data: ConditioningData, @@ -370,43 +373,40 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): if additional_guidance is None: additional_guidance = [] extra_conditioning_info = conditioning_data.extra - if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: - self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, - step_count=len(self.scheduler.timesteps)) - else: - self.invokeai_diffuser.remove_cross_attention_control() + with self.invokeai_diffuser.custom_attention_context(extra_conditioning_info=extra_conditioning_info, + step_count=len(self.scheduler.timesteps), + do_attention_map_saving=False): - yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps, - latents=latents) + yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps, + latents=latents) - batch_size = latents.shape[0] - batched_t = torch.full((batch_size,), timesteps[0], - dtype=timesteps.dtype, device=self.unet.device) - latents = self.scheduler.add_noise(latents, noise, batched_t) + batch_size = latents.shape[0] + batched_t = torch.full((batch_size,), timesteps[0], + dtype=timesteps.dtype, device=self.unet.device) + latents = self.scheduler.add_noise(latents, noise, batched_t) - attention_map_saver: Optional[AttentionMapSaver] = None - self.invokeai_diffuser.remove_attention_map_saving() + attention_map_saver: Optional[AttentionMapSaver] = None - for i, t in enumerate(self.progress_bar(timesteps)): - batched_t.fill_(t) - step_output = self.step(batched_t, latents, conditioning_data, - step_index=i, - total_step_count=len(timesteps), - additional_guidance=additional_guidance) - latents = step_output.prev_sample - predicted_original = getattr(step_output, 'pred_original_sample', None) + for i, t in enumerate(self.progress_bar(timesteps)): + batched_t.fill_(t) + step_output = self.step(batched_t, latents, conditioning_data, + step_index=i, + total_step_count=len(timesteps), + additional_guidance=additional_guidance) + latents = step_output.prev_sample + predicted_original = getattr(step_output, 'pred_original_sample', None) - if i == len(timesteps)-1 and extra_conditioning_info is not None: - eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1 - attention_map_token_ids = range(1, eos_token_index) - attention_map_saver = AttentionMapSaver(token_ids=attention_map_token_ids, latents_shape=latents.shape[-2:]) - self.invokeai_diffuser.setup_attention_map_saving(attention_map_saver) + # TODO resuscitate attention map saving + #if i == len(timesteps)-1 and extra_conditioning_info is not None: + # eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1 + # attention_map_token_ids = range(1, eos_token_index) + # attention_map_saver = AttentionMapSaver(token_ids=attention_map_token_ids, latents_shape=latents.shape[-2:]) + # self.invokeai_diffuser.setup_attention_map_saving(attention_map_saver) - yield PipelineIntermediateState(run_id=run_id, step=i, timestep=int(t), latents=latents, - predicted_original=predicted_original, attention_map_saver=attention_map_saver) + yield PipelineIntermediateState(run_id=run_id, step=i, timestep=int(t), latents=latents, + predicted_original=predicted_original, attention_map_saver=attention_map_saver) - self.invokeai_diffuser.remove_attention_map_saving() - return latents, attention_map_saver + return latents, attention_map_saver @torch.inference_mode() def step(self, t: torch.Tensor, latents: torch.Tensor, diff --git a/ldm/models/diffusion/cross_attention_control.py b/ldm/models/diffusion/cross_attention_control.py index b1c1cd63d9..08c62060c9 100644 --- a/ldm/models/diffusion/cross_attention_control.py +++ b/ldm/models/diffusion/cross_attention_control.py @@ -7,6 +7,7 @@ import torch import diffusers from torch import nn from diffusers.models.unet_2d_condition import UNet2DConditionModel +from diffusers.models.cross_attention import AttnProcessor from ldm.invoke.devices import torch_dtype @@ -305,11 +306,10 @@ class InvokeAICrossAttentionMixin: -def remove_cross_attention_control(model, is_running_diffusers: bool): +def remove_cross_attention_control(model, is_running_diffusers: bool, restore_attention_processor: Optional[AttnProcessor]=None): if is_running_diffusers: unet = model - print("** need to know what cross attn processor to use by default, None in the following line is wrong") - unet.set_attn_processor(CrossAttnProcessor()) + unet.set_attn_processor(restore_attention_processor or CrossAttnProcessor()) else: remove_attention_function(model) @@ -343,10 +343,16 @@ def setup_cross_attention_control(model, context: Context, is_running_diffusers context.cross_attention_index_map = indices.to(device) if is_running_diffusers: unet = model - unet.set_attn_processor(SwapCrossAttnProcessor()) + old_attn_processors = unet.attn_processors + # try to re-use an existing slice size + default_slice_size = 4 + slice_size = next((p for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size) + unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size)) + return old_attn_processors else: context.register_cross_attention_modules(model) inject_attention_function(model, context) + return None @@ -509,13 +515,11 @@ class CrossAttnProcessor: return hidden_states """ -import enum from dataclasses import field, dataclass import torch -from diffusers.models.cross_attention import CrossAttention, CrossAttnProcessor, SlicedAttnProcessor -from ldm.models.diffusion.cross_attention_control import CrossAttentionType +from diffusers.models.cross_attention import CrossAttention, CrossAttnProcessor, SlicedAttnProcessor, AttnProcessor @dataclass @@ -523,7 +527,7 @@ class SwapCrossAttnContext: modified_text_embeddings: torch.Tensor index_map: torch.Tensor # maps from original prompt token indices to the equivalent tokens in the modified prompt mask: torch.Tensor # in the target space of the index_map - cross_attention_types_to_do: list[CrossAttentionType] = field(default_factory=[]) + cross_attention_types_to_do: list[CrossAttentionType] = field(default_factory=list) def __int__(self, cac_types_to_do: [CrossAttentionType], @@ -629,9 +633,6 @@ class SwapCrossAttnProcessor(CrossAttnProcessor): class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor): - def __init__(self, slice_size = 1e6): - self.slice_count = slice_size - # TODO: dynamically pick slice size based on memory conditions def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, @@ -660,7 +661,7 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor): original_text_key = attn.head_to_batch_dim(original_text_key) modified_text_embeddings = swap_cross_attn_context.modified_text_embeddings modified_text_key = attn.to_k(modified_text_embeddings) - modified_text_key = attn.head_to_batch_dim(original_text_key) + modified_text_key = attn.head_to_batch_dim(modified_text_key) # for the "value" just use the modified text embeddings. value = attn.to_v(modified_text_embeddings) diff --git a/ldm/models/diffusion/shared_invokeai_diffusion.py b/ldm/models/diffusion/shared_invokeai_diffusion.py index e4932f6ad8..0c91df9528 100644 --- a/ldm/models/diffusion/shared_invokeai_diffusion.py +++ b/ldm/models/diffusion/shared_invokeai_diffusion.py @@ -1,11 +1,13 @@ import math +from contextlib import contextmanager from dataclasses import dataclass from math import ceil -from typing import Callable, Optional, Union, Any +from typing import Callable, Optional, Union, Any, Dict import numpy as np import torch +from diffusers.models.cross_attention import AttnProcessor from ldm.models.diffusion.cross_attention_control import Arguments, \ remove_cross_attention_control, setup_cross_attention_control, Context, get_cross_attention_modules, \ CrossAttentionType, SwapCrossAttnContext @@ -55,20 +57,43 @@ class InvokeAIDiffuserComponent: self.model_forward_callback = model_forward_callback self.cross_attention_control_context = None - def setup_cross_attention_control(self, conditioning: ExtraConditioningInfo, step_count: int): + @contextmanager + def custom_attention_context(self, + extra_conditioning_info: Optional[ExtraConditioningInfo], + step_count: int, + do_attention_map_saving: bool): + do_swap = extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control + old_attn_processor = None + if do_swap: + old_attn_processor = self.setup_cross_attention_control(extra_conditioning_info, + step_count=step_count) + try: + yield None + finally: + self.remove_cross_attention_control(old_attn_processor) + # TODO resuscitate attention map saving + #self.remove_attention_map_saving() + + def setup_cross_attention_control(self, conditioning: ExtraConditioningInfo, step_count: int) -> Dict[str, AttnProcessor]: + """ + setup cross attention .swap control. for diffusers this replaces the attention processor, so + the previous attention processor is returned so that the caller can restore it later. + """ self.conditioning = conditioning self.cross_attention_control_context = Context( arguments=self.conditioning.cross_attention_control_args, step_count=step_count ) - setup_cross_attention_control(self.model, - self.cross_attention_control_context, - is_running_diffusers=self.is_running_diffusers) + return setup_cross_attention_control(self.model, + self.cross_attention_control_context, + is_running_diffusers=self.is_running_diffusers) - def remove_cross_attention_control(self): + def remove_cross_attention_control(self, restore_attention_processor: Optional['AttnProcessor']=None): self.conditioning = None self.cross_attention_control_context = None - remove_cross_attention_control(self.model, is_running_diffusers=self.is_running_diffusers) + remove_cross_attention_control(self.model, + is_running_diffusers=self.is_running_diffusers, + restore_attention_processor=restore_attention_processor) def setup_attention_map_saving(self, saver: AttentionMapSaver): def callback(slice, dim, offset, slice_size, key): From 1f5ad1b05edca2f76e9860776b495fd9fd5d0cfe Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Wed, 25 Jan 2023 21:38:27 +0100 Subject: [PATCH 07/25] sliced swap working --- .../diffusion/cross_attention_control.py | 48 +++++++++---------- 1 file changed, 22 insertions(+), 26 deletions(-) diff --git a/ldm/models/diffusion/cross_attention_control.py b/ldm/models/diffusion/cross_attention_control.py index 08c62060c9..8b5467e85f 100644 --- a/ldm/models/diffusion/cross_attention_control.py +++ b/ldm/models/diffusion/cross_attention_control.py @@ -346,7 +346,7 @@ def setup_cross_attention_control(model, context: Context, is_running_diffusers old_attn_processors = unet.attn_processors # try to re-use an existing slice size default_slice_size = 4 - slice_size = next((p for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size) + slice_size = next((p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size) unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size)) return old_attn_processors else: @@ -654,22 +654,23 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor): attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) query = attn.to_q(hidden_states) + dim = query.shape[-1] query = attn.head_to_batch_dim(query) original_text_embeddings = encoder_hidden_states - original_text_key = attn.to_k(original_text_embeddings) - original_text_key = attn.head_to_batch_dim(original_text_key) modified_text_embeddings = swap_cross_attn_context.modified_text_embeddings + original_text_key = attn.to_k(original_text_embeddings) modified_text_key = attn.to_k(modified_text_embeddings) - modified_text_key = attn.head_to_batch_dim(modified_text_key) + #original_value = attn.to_v(original_text_embeddings) + modified_value = attn.to_v(modified_text_embeddings) - # for the "value" just use the modified text embeddings. - value = attn.to_v(modified_text_embeddings) - value = attn.head_to_batch_dim(value) + original_text_key = attn.head_to_batch_dim(original_text_key) + modified_text_key = attn.head_to_batch_dim(modified_text_key) + #original_value = attn.head_to_batch_dim(original_value) + modified_value = attn.head_to_batch_dim(modified_value) # compute slices and prepare output tensor batch_size_attention = query.shape[0] - dim = query.shape[-1] hidden_states = torch.zeros( (batch_size_attention, sequence_length, dim // attn.heads), device=query.device, dtype=query.dtype ) @@ -677,36 +678,31 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor): # do slices for i in range(hidden_states.shape[0] // self.slice_size): start_idx = i * self.slice_size - end_idx = min(hidden_states.shape[0], (i + 1) * self.slice_size) + end_idx = (i + 1) * self.slice_size query_slice = query[start_idx:end_idx] - attention_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None + original_key_slice = original_text_key[start_idx:end_idx] + modified_key_slice = modified_text_key[start_idx:end_idx] + attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None - # first, find attention probabilities for the "original" prompt - original_text_key_slice = original_text_key[start_idx:end_idx] - original_attention_probs_slice = attn.get_attention_scores(query_slice, original_text_key_slice, attention_mask_slice) - - # then, find attention probabilities for the "modified" prompt - modified_text_key_slice = modified_text_key[start_idx:end_idx] - modified_attention_probs_slice = attn.get_attention_scores(query_slice, modified_text_key_slice, attention_mask_slice) + original_attn_slice = attn.get_attention_scores(query_slice, original_key_slice, attn_mask_slice) + modified_attn_slice = attn.get_attention_scores(query_slice, modified_key_slice, attn_mask_slice) # because the prompt modifications may result in token sequences shifted forwards or backwards, # the original attention probabilities must be remapped to account for token index changes in the # modified prompt - remapped_original_attention_probs_slice = torch.index_select(original_attention_probs_slice, -1, - swap_cross_attn_context.index_map) + remapped_original_attn_slice = torch.index_select(original_attn_slice, -1, + swap_cross_attn_context.index_map) # only some tokens taken from the original attention probabilities. this is controlled by the mask. mask = swap_cross_attn_context.mask inverse_mask = 1 - mask - attention_probs_slice = \ - remapped_original_attention_probs_slice * mask + \ - modified_attention_probs_slice * inverse_mask + attn_slice = \ + remapped_original_attn_slice * mask + \ + modified_attn_slice * inverse_mask - value_slice = value[start_idx:end_idx] - hidden_states_slice = torch.bmm(attention_probs_slice, value_slice) - - hidden_states[start_idx:end_idx] = hidden_states_slice + attn_slice = torch.bmm(attn_slice, modified_value[start_idx:end_idx]) + hidden_states[start_idx:end_idx] = attn_slice # done From 34a3f4a8203ad3fb602c0daad5704c0808a212df Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Wed, 25 Jan 2023 21:47:17 +0100 Subject: [PATCH 08/25] cleanup --- ldm/invoke/generator/diffusers_pipeline.py | 3 +- .../diffusion/cross_attention_control.py | 77 ++----------------- 2 files changed, 8 insertions(+), 72 deletions(-) diff --git a/ldm/invoke/generator/diffusers_pipeline.py b/ldm/invoke/generator/diffusers_pipeline.py index e5ce403cb7..a3d5ae3c07 100644 --- a/ldm/invoke/generator/diffusers_pipeline.py +++ b/ldm/invoke/generator/diffusers_pipeline.py @@ -307,9 +307,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): if is_xformers_available() and not Globals.disable_xformers: self.enable_xformers_memory_efficient_attention() else: - slice_size = 2 + slice_size = 4 # or 2, or 8. i chose this arbitrarily. self.enable_attention_slicing(slice_size=slice_size) + def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int, conditioning_data: ConditioningData, *, diff --git a/ldm/models/diffusion/cross_attention_control.py b/ldm/models/diffusion/cross_attention_control.py index 8b5467e85f..c248343040 100644 --- a/ldm/models/diffusion/cross_attention_control.py +++ b/ldm/models/diffusion/cross_attention_control.py @@ -560,77 +560,6 @@ class SwapCrossAttnContext: return mask, indices -class SwapCrossAttnProcessor(CrossAttnProcessor): - - def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, - # kwargs - swap_cross_attn_context: SwapCrossAttnContext=None): - - attention_type = CrossAttentionType.SELF if encoder_hidden_states is None else CrossAttentionType.TOKENS - - # if cross-attention control is not in play, just call through to the base implementation. - if swap_cross_attn_context is None or not swap_cross_attn_context.wants_cross_attention_control(attention_type): - #print(f"SwapCrossAttnContext for {attention_type} not active - passing request to superclass") - return super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask) - #else: - # print(f"SwapCrossAttnContext for {attention_type} active") - - batch_size, sequence_length, _ = hidden_states.shape - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) - - query = attn.to_q(hidden_states) - query = attn.head_to_batch_dim(query) - - # helper function - def get_attention_probs(embeddings): - this_key = attn.to_k(embeddings) - this_key = attn.head_to_batch_dim(this_key) - return attn.get_attention_scores(query, this_key, attention_mask) - - if attention_type == CrossAttentionType.SELF: - # self attention has no remapping, it just bluntly copies the whole tensor - attention_probs = get_attention_probs(hidden_states) - value = attn.to_v(hidden_states) - else: - # tokens (cross) attention - # first, find attention probabilities for the "original" prompt - original_text_embeddings = encoder_hidden_states - original_attention_probs = get_attention_probs(original_text_embeddings) - - # then, find attention probabilities for the "modified" prompt - modified_text_embeddings = swap_cross_attn_context.modified_text_embeddings - modified_attention_probs = get_attention_probs(modified_text_embeddings) - - # because the prompt modifications may result in token sequences shifted forwards or backwards, - # the original attention probabilities must be remapped to account for token index changes in the - # modified prompt - remapped_original_attention_probs = torch.index_select(original_attention_probs, -1, - swap_cross_attn_context.index_map) - - # only some tokens taken from the original attention probabilities. this is controlled by the mask. - mask = swap_cross_attn_context.mask - inverse_mask = 1 - mask - attention_probs = \ - remapped_original_attention_probs * mask + \ - modified_attention_probs * inverse_mask - - # for the "value" just use the modified text embeddings. - value = attn.to_v(modified_text_embeddings) - - value = attn.head_to_batch_dim(value) - - hidden_states = torch.bmm(attention_probs, value) - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - return hidden_states - - - class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor): # TODO: dynamically pick slice size based on memory conditions @@ -714,3 +643,9 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor): hidden_states = attn.to_out[1](hidden_states) return hidden_states + + +class SwapCrossAttnProcessor(SlicedSwapCrossAttnProcesser): + + def __init__(self): + super(SwapCrossAttnProcessor, self).__init__(slice_size=1e6) # big number so we never slice From 41aed57449a821f472c9a739ed11a318366a0b0c Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Wed, 25 Jan 2023 22:27:23 +0100 Subject: [PATCH 09/25] wip tracking down MPS slicing support --- ldm/invoke/generator/diffusers_pipeline.py | 10 ++- .../diffusion/cross_attention_control.py | 90 +++++++++++++++++-- 2 files changed, 92 insertions(+), 8 deletions(-) diff --git a/ldm/invoke/generator/diffusers_pipeline.py b/ldm/invoke/generator/diffusers_pipeline.py index a3d5ae3c07..a16cbe594b 100644 --- a/ldm/invoke/generator/diffusers_pipeline.py +++ b/ldm/invoke/generator/diffusers_pipeline.py @@ -307,8 +307,14 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): if is_xformers_available() and not Globals.disable_xformers: self.enable_xformers_memory_efficient_attention() else: - slice_size = 4 # or 2, or 8. i chose this arbitrarily. - self.enable_attention_slicing(slice_size=slice_size) + if torch.backends.mps.is_available(): + # until pytorch #91617 is fixed, slicing is borked on MPS + # https://github.com/pytorch/pytorch/issues/91617 + # fix is in https://github.com/kulinseth/pytorch/pull/222 but no idea when it will get merged to pytorch mainline. + pass + else: + slice_size = 4 # or 2, or 8. i chose this arbitrarily. + self.enable_attention_slicing(slice_size=slice_size) def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int, diff --git a/ldm/models/diffusion/cross_attention_control.py b/ldm/models/diffusion/cross_attention_control.py index c248343040..a1b680c411 100644 --- a/ldm/models/diffusion/cross_attention_control.py +++ b/ldm/models/diffusion/cross_attention_control.py @@ -344,10 +344,14 @@ def setup_cross_attention_control(model, context: Context, is_running_diffusers if is_running_diffusers: unet = model old_attn_processors = unet.attn_processors - # try to re-use an existing slice size - default_slice_size = 4 - slice_size = next((p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size) - unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size)) + if torch.backends.mps.is_available(): + # see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS + unet.set_attn_processor(SwapCrossAttnProcessor()) + else: + # try to re-use an existing slice size + default_slice_size = 4 + slice_size = next((p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size) + unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size)) return old_attn_processors else: context.register_cross_attention_modules(model) @@ -605,7 +609,7 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor): ) # do slices - for i in range(hidden_states.shape[0] // self.slice_size): + for i in range(max(1,hidden_states.shape[0] // self.slice_size)): start_idx = i * self.slice_size end_idx = (i + 1) * self.slice_size @@ -630,6 +634,8 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor): remapped_original_attn_slice * mask + \ modified_attn_slice * inverse_mask + del remapped_original_attn_slice, modified_attn_slice + attn_slice = torch.bmm(attn_slice, modified_value[start_idx:end_idx]) hidden_states[start_idx:end_idx] = attn_slice @@ -648,4 +654,76 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor): class SwapCrossAttnProcessor(SlicedSwapCrossAttnProcesser): def __init__(self): - super(SwapCrossAttnProcessor, self).__init__(slice_size=1e6) # big number so we never slice + super(SwapCrossAttnProcessor, self).__init__(slice_size=int(1e9)) + + # theoretically this class could simply inherit from SlicedSwapCrossAttnProcesser + # and consist wholly of an __init__ method that just calls super().__init__(slice_size=1000000000) + # - such a giant slice size would resolve to 'no slicing' at runtime. + # however, pytorch MPS is borked until https://github.com/kulinseth/pytorch/pull/222 is merged into + # mainline pytorch. so for now this has to be a full implementation. + + def no__call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, + # kwargs + swap_cross_attn_context: SwapCrossAttnContext=None): + + attention_type = CrossAttentionType.SELF if encoder_hidden_states is None else CrossAttentionType.TOKENS + + # if cross-attention control is not in play, just call through to the base implementation. + if attention_type == CrossAttentionType.SELF or \ + swap_cross_attn_context is None or \ + not swap_cross_attn_context.wants_cross_attention_control(attention_type): + #print(f"SwapCrossAttnContext for {attention_type} not active - passing request to superclass") + return super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask) + #else: + # print(f"SwapCrossAttnContext for {attention_type} active") + + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) + + query = attn.to_q(hidden_states) + query = attn.head_to_batch_dim(query) + + # helper function + def get_attention_probs(embeddings): + this_key = attn.to_k(embeddings) + this_key = attn.head_to_batch_dim(this_key) + return attn.get_attention_scores(query, this_key, attention_mask) + + # tokens (cross) attention + # first, find attention probabilities for the "original" prompt + original_text_embeddings = encoder_hidden_states + original_attention_probs = get_attention_probs(original_text_embeddings) + + # then, find attention probabilities for the "modified" prompt + modified_text_embeddings = swap_cross_attn_context.modified_text_embeddings + modified_attention_probs = get_attention_probs(modified_text_embeddings) + + # because the prompt modifications may result in token sequences shifted forwards or backwards, + # the original attention probabilities must be remapped to account for token index changes in the + # modified prompt + remapped_original_attention_probs = torch.index_select(original_attention_probs, -1, + swap_cross_attn_context.index_map).clone() + + # only some tokens taken from the original attention probabilities. this is controlled by the mask. + mask = swap_cross_attn_context.mask + inverse_mask = 1 - mask.clone() + attention_probs = \ + remapped_original_attention_probs * mask + \ + modified_attention_probs * inverse_mask + + # for the "value" just use the modified text embeddings. + value = attn.to_v(modified_text_embeddings) + + value = attn.head_to_batch_dim(value) + + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + From 95d147c5df4a4e70699fb8debd430f0955bb99b0 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Wed, 25 Jan 2023 23:03:30 +0100 Subject: [PATCH 10/25] MPS support: negatory --- .../diffusion/cross_attention_control.py | 73 +------------------ 1 file changed, 1 insertion(+), 72 deletions(-) diff --git a/ldm/models/diffusion/cross_attention_control.py b/ldm/models/diffusion/cross_attention_control.py index a1b680c411..9712ddf1bd 100644 --- a/ldm/models/diffusion/cross_attention_control.py +++ b/ldm/models/diffusion/cross_attention_control.py @@ -654,76 +654,5 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor): class SwapCrossAttnProcessor(SlicedSwapCrossAttnProcesser): def __init__(self): - super(SwapCrossAttnProcessor, self).__init__(slice_size=int(1e9)) - - # theoretically this class could simply inherit from SlicedSwapCrossAttnProcesser - # and consist wholly of an __init__ method that just calls super().__init__(slice_size=1000000000) - # - such a giant slice size would resolve to 'no slicing' at runtime. - # however, pytorch MPS is borked until https://github.com/kulinseth/pytorch/pull/222 is merged into - # mainline pytorch. so for now this has to be a full implementation. - - def no__call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, - # kwargs - swap_cross_attn_context: SwapCrossAttnContext=None): - - attention_type = CrossAttentionType.SELF if encoder_hidden_states is None else CrossAttentionType.TOKENS - - # if cross-attention control is not in play, just call through to the base implementation. - if attention_type == CrossAttentionType.SELF or \ - swap_cross_attn_context is None or \ - not swap_cross_attn_context.wants_cross_attention_control(attention_type): - #print(f"SwapCrossAttnContext for {attention_type} not active - passing request to superclass") - return super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask) - #else: - # print(f"SwapCrossAttnContext for {attention_type} active") - - batch_size, sequence_length, _ = hidden_states.shape - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) - - query = attn.to_q(hidden_states) - query = attn.head_to_batch_dim(query) - - # helper function - def get_attention_probs(embeddings): - this_key = attn.to_k(embeddings) - this_key = attn.head_to_batch_dim(this_key) - return attn.get_attention_scores(query, this_key, attention_mask) - - # tokens (cross) attention - # first, find attention probabilities for the "original" prompt - original_text_embeddings = encoder_hidden_states - original_attention_probs = get_attention_probs(original_text_embeddings) - - # then, find attention probabilities for the "modified" prompt - modified_text_embeddings = swap_cross_attn_context.modified_text_embeddings - modified_attention_probs = get_attention_probs(modified_text_embeddings) - - # because the prompt modifications may result in token sequences shifted forwards or backwards, - # the original attention probabilities must be remapped to account for token index changes in the - # modified prompt - remapped_original_attention_probs = torch.index_select(original_attention_probs, -1, - swap_cross_attn_context.index_map).clone() - - # only some tokens taken from the original attention probabilities. this is controlled by the mask. - mask = swap_cross_attn_context.mask - inverse_mask = 1 - mask.clone() - attention_probs = \ - remapped_original_attention_probs * mask + \ - modified_attention_probs * inverse_mask - - # for the "value" just use the modified text embeddings. - value = attn.to_v(modified_text_embeddings) - - value = attn.head_to_batch_dim(value) - - hidden_states = torch.bmm(attention_probs, value) - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - return hidden_states - + super(SwapCrossAttnProcessor, self).__init__(slice_size=int(1e9)) # massive slice size = don't slice From 5e7ed964d2524352c1ddf66b3bef75adb4a2fcbc Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Wed, 25 Jan 2023 23:49:38 +0100 Subject: [PATCH 11/25] wip updating docs --- docs/features/PROMPTS.md | 40 +++++++++---------- .../diffusion/cross_attention_control.py | 6 +-- 2 files changed, 21 insertions(+), 25 deletions(-) diff --git a/docs/features/PROMPTS.md b/docs/features/PROMPTS.md index 79166a9b94..5413cc5e55 100644 --- a/docs/features/PROMPTS.md +++ b/docs/features/PROMPTS.md @@ -239,28 +239,24 @@ Generate an image with a given prompt, record the seed of the image, and then use the `prompt2prompt` syntax to substitute words in the original prompt for words in a new prompt. This works for `img2img` as well. -- `a ("fluffy cat").swap("smiling dog") eating a hotdog`. - - quotes optional: `a (fluffy cat).swap(smiling dog) eating a hotdog`. - - for single word substitutions parentheses are also optional: - `a cat.swap(dog) eating a hotdog`. -- Supports options `s_start`, `s_end`, `t_start`, `t_end` (each 0-1) loosely - corresponding to bloc97's `prompt_edit_spatial_start/_end` and - `prompt_edit_tokens_start/_end` but with the math swapped to make it easier to - intuitively understand. - - Example usage:`a (cat).swap(dog, s_end=0.3) eating a hotdog` - the `s_end` - argument means that the "spatial" (self-attention) edit will stop having any - effect after 30% (=0.3) of the steps have been done, leaving Stable - Diffusion with 70% of the steps where it is free to decide for itself how to - reshape the cat-form into a dog form. - - The numbers represent a percentage through the step sequence where the edits - should happen. 0 means the start (noisy starting image), 1 is the end (final - image). - - For img2img, the step sequence does not start at 0 but instead at - (1-strength) - so if strength is 0.7, s_start and s_end must both be - greater than 0.3 (1-0.7) to have any effect. -- Convenience option `shape_freedom` (0-1) to specify how much "freedom" Stable - Diffusion should have to change the shape of the subject being swapped. - - `a (cat).swap(dog, shape_freedom=0.5) eating a hotdog`. +For example, consider the prompt `a cat.swap(dog) playing with a ball in the forest`. Normally, because of the word words interact with each other when doing a stable diffusion image generation, these two prompts would generate different compositions: + - `a cat playing with a ball in the forest` + - `a dog playing with a ball in the forest` + +| `a cat playing with a ball in the forest` | `a dog playing with a ball in the forest` | +| --- | --- | +| img | img | + + + - For multiple word swaps, use parentheses: `a (fluffy cat).swap(barking dog) playing with a ball in the forest`. + - To swap a comma, use quotes: `a ("fluffy, grey cat").swap("big, barking dog") playing with a ball in the forest`. +- Supports options `t_start` and `t_end` (each 0-1) loosely corresponding to bloc97's `prompt_edit_tokens_start/_end` but with the math swapped to make it easier to + intuitively understand. `t_start` and `t_end` are used to control on which steps cross-attention control should run. With the default values `t_start=0` and `t_end=1`, cross-attention control is active on every step of image generation. Other values can be used to turn cross-attention control off for part of the image generation process. + - For example, if doing a diffusion with 10 steps for the prompt is `a cat.swap(dog, t_start=0.3, t_end=1.0) playing with a ball in the forest`, the first 3 steps will be run as `a cat playing with a ball in the forest`, while the last 7 steps will run as `a dog playing with a ball in the forest`, but the pixels that represent `dog` will be locked to the pixels that would have represented `cat` if the `cat` prompt had been used instead. + - Conversely, for `a cat.swap(dog, t_start=0, t_end=0.7) playing with a ball in the forest`, the first 7 steps will run as `a dog playing with a ball in the forest` with the pixels that represent `dog` locked to the same pixels that would have represented `cat` if the `cat` prompt was being used instead. The final 3 steps will just run `a cat playing with a ball in the forest`. + > For img2img, the step sequence does not start at 0 but instead at `(1.0-strength)` - so if the img2img `strength` is `0.7`, `t_start` and `t_end` must both be greater than `0.3` (`1.0-0.7`) to have any effect. + +Prompt2prompt `.swap()` is not compatible with xformers, which will be temporarily disabled when doing a `.swap()` - so you should expect to use more VRAM and run slower that with xformers enabled. The `prompt2prompt` code is based off [bloc97's colab](https://github.com/bloc97/CrossAttentionControl). diff --git a/ldm/models/diffusion/cross_attention_control.py b/ldm/models/diffusion/cross_attention_control.py index 9712ddf1bd..99ef1d49bc 100644 --- a/ldm/models/diffusion/cross_attention_control.py +++ b/ldm/models/diffusion/cross_attention_control.py @@ -594,12 +594,12 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor): modified_text_embeddings = swap_cross_attn_context.modified_text_embeddings original_text_key = attn.to_k(original_text_embeddings) modified_text_key = attn.to_k(modified_text_embeddings) - #original_value = attn.to_v(original_text_embeddings) + original_value = attn.to_v(original_text_embeddings) modified_value = attn.to_v(modified_text_embeddings) original_text_key = attn.head_to_batch_dim(original_text_key) modified_text_key = attn.head_to_batch_dim(modified_text_key) - #original_value = attn.head_to_batch_dim(original_value) + original_value = attn.head_to_batch_dim(original_value) modified_value = attn.head_to_batch_dim(modified_value) # compute slices and prepare output tensor @@ -636,7 +636,7 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor): del remapped_original_attn_slice, modified_attn_slice - attn_slice = torch.bmm(attn_slice, modified_value[start_idx:end_idx]) + attn_slice = torch.bmm(attn_slice, original_value[start_idx:end_idx]) hidden_states[start_idx:end_idx] = attn_slice From 8ed8bf52d092326b532fac6d6f9db3f66f4f0291 Mon Sep 17 00:00:00 2001 From: damian Date: Thu, 26 Jan 2023 17:04:22 +0100 Subject: [PATCH 12/25] use 'auto' slice size --- ldm/invoke/generator/diffusers_pipeline.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ldm/invoke/generator/diffusers_pipeline.py b/ldm/invoke/generator/diffusers_pipeline.py index a16cbe594b..b7ad925c8c 100644 --- a/ldm/invoke/generator/diffusers_pipeline.py +++ b/ldm/invoke/generator/diffusers_pipeline.py @@ -313,8 +313,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): # fix is in https://github.com/kulinseth/pytorch/pull/222 but no idea when it will get merged to pytorch mainline. pass else: - slice_size = 4 # or 2, or 8. i chose this arbitrarily. - self.enable_attention_slicing(slice_size=slice_size) + self.enable_attention_slicing(slice_size='auto') def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int, From 729752620b99c6231759723376d88c3fb262fca0 Mon Sep 17 00:00:00 2001 From: damian Date: Thu, 26 Jan 2023 17:27:33 +0100 Subject: [PATCH 13/25] trying out JPPhoto's patch on vast.ai --- ldm/invoke/generator/diffusers_pipeline.py | 13 ++++++++++--- ldm/models/diffusion/shared_invokeai_diffusion.py | 5 ++--- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/ldm/invoke/generator/diffusers_pipeline.py b/ldm/invoke/generator/diffusers_pipeline.py index b7ad925c8c..971a6b4604 100644 --- a/ldm/invoke/generator/diffusers_pipeline.py +++ b/ldm/invoke/generator/diffusers_pipeline.py @@ -304,6 +304,13 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): textual_inversion_manager=self.textual_inversion_manager ) + self._enable_memory_efficient_attention() + + + def _enable_memory_efficient_attention(self): + """ + if xformers is available, use it, otherwise use sliced attention. + """ if is_xformers_available() and not Globals.disable_xformers: self.enable_xformers_memory_efficient_attention() else: @@ -315,7 +322,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): else: self.enable_attention_slicing(slice_size='auto') - def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int, conditioning_data: ConditioningData, *, @@ -360,6 +366,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): self.scheduler.set_timesteps(num_inference_steps, device=self.unet.device) timesteps = self.scheduler.timesteps infer_latents_from_embeddings = GeneratorToCallbackinator(self.generate_latents_from_embeddings, PipelineIntermediateState) + self._enable_memory_efficient_attention() result: PipelineIntermediateState = infer_latents_from_embeddings( latents, timesteps, conditioning_data, noise=noise, @@ -380,8 +387,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): additional_guidance = [] extra_conditioning_info = conditioning_data.extra with self.invokeai_diffuser.custom_attention_context(extra_conditioning_info=extra_conditioning_info, - step_count=len(self.scheduler.timesteps), - do_attention_map_saving=False): + step_count=len(self.scheduler.timesteps) + ): yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps, latents=latents) diff --git a/ldm/models/diffusion/shared_invokeai_diffusion.py b/ldm/models/diffusion/shared_invokeai_diffusion.py index 0c91df9528..10ad328575 100644 --- a/ldm/models/diffusion/shared_invokeai_diffusion.py +++ b/ldm/models/diffusion/shared_invokeai_diffusion.py @@ -60,13 +60,12 @@ class InvokeAIDiffuserComponent: @contextmanager def custom_attention_context(self, extra_conditioning_info: Optional[ExtraConditioningInfo], - step_count: int, - do_attention_map_saving: bool): + step_count: int): do_swap = extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control old_attn_processor = None if do_swap: old_attn_processor = self.setup_cross_attention_control(extra_conditioning_info, - step_count=step_count) + step_count=step_count) try: yield None finally: From fb312f9ed3bc2fd75942a7f6c877e1c5bcead7b5 Mon Sep 17 00:00:00 2001 From: damian Date: Thu, 26 Jan 2023 17:30:23 +0100 Subject: [PATCH 14/25] use the correct value - whoops --- ldm/models/diffusion/cross_attention_control.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ldm/models/diffusion/cross_attention_control.py b/ldm/models/diffusion/cross_attention_control.py index 99ef1d49bc..770c71f110 100644 --- a/ldm/models/diffusion/cross_attention_control.py +++ b/ldm/models/diffusion/cross_attention_control.py @@ -636,7 +636,7 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor): del remapped_original_attn_slice, modified_attn_slice - attn_slice = torch.bmm(attn_slice, original_value[start_idx:end_idx]) + attn_slice = torch.bmm(attn_slice, modified_value[start_idx:end_idx]) hidden_states[start_idx:end_idx] = attn_slice From c381788ab9cc6e6acdc9ef87bf3d2b714917ddb0 Mon Sep 17 00:00:00 2001 From: damian Date: Thu, 26 Jan 2023 17:44:27 +0100 Subject: [PATCH 15/25] don't restore None --- ldm/models/diffusion/shared_invokeai_diffusion.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ldm/models/diffusion/shared_invokeai_diffusion.py b/ldm/models/diffusion/shared_invokeai_diffusion.py index 10ad328575..1ecbd1c488 100644 --- a/ldm/models/diffusion/shared_invokeai_diffusion.py +++ b/ldm/models/diffusion/shared_invokeai_diffusion.py @@ -69,7 +69,8 @@ class InvokeAIDiffuserComponent: try: yield None finally: - self.remove_cross_attention_control(old_attn_processor) + if old_attn_processor is not None: + self.remove_cross_attention_control(old_attn_processor) # TODO resuscitate attention map saving #self.remove_attention_map_saving() From e090c0dc105f611ed9326c66dfd55098da0c2da7 Mon Sep 17 00:00:00 2001 From: damian Date: Thu, 26 Jan 2023 17:46:51 +0100 Subject: [PATCH 16/25] try without setting every time --- ldm/invoke/generator/diffusers_pipeline.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ldm/invoke/generator/diffusers_pipeline.py b/ldm/invoke/generator/diffusers_pipeline.py index 971a6b4604..bdd1e5c76b 100644 --- a/ldm/invoke/generator/diffusers_pipeline.py +++ b/ldm/invoke/generator/diffusers_pipeline.py @@ -366,7 +366,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): self.scheduler.set_timesteps(num_inference_steps, device=self.unet.device) timesteps = self.scheduler.timesteps infer_latents_from_embeddings = GeneratorToCallbackinator(self.generate_latents_from_embeddings, PipelineIntermediateState) - self._enable_memory_efficient_attention() result: PipelineIntermediateState = infer_latents_from_embeddings( latents, timesteps, conditioning_data, noise=noise, From 1bb5b4ab322da5a0021a8804687db6c68388cbe1 Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Fri, 27 Jan 2023 11:52:05 -0800 Subject: [PATCH 17/25] fix dimension errors when inpainting model is used with hires-fix --- ldm/invoke/generator/txt2img2img.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/ldm/invoke/generator/txt2img2img.py b/ldm/invoke/generator/txt2img2img.py index 47692a6bbb..4923e7daf5 100644 --- a/ldm/invoke/generator/txt2img2img.py +++ b/ldm/invoke/generator/txt2img2img.py @@ -3,10 +3,10 @@ ldm.invoke.generator.txt2img inherits from ldm.invoke.generator ''' import math -from diffusers.utils.logging import get_verbosity, set_verbosity, set_verbosity_error from typing import Callable, Optional import torch +from diffusers.utils.logging import get_verbosity, set_verbosity, set_verbosity_error from ldm.invoke.generator.base import Generator from ldm.invoke.generator.diffusers_pipeline import trim_to_multiple_of, StableDiffusionGeneratorPipeline, \ @@ -116,16 +116,20 @@ class Txt2Img2Img(Generator): scaled_height = height device = self.model.device + + channels = self.latent_channels + if channels == 9: + channels = 4 # we don't really want noise for all the mask channels if self.use_mps_noise or device.type == 'mps': return torch.randn([1, - self.latent_channels, + channels, scaled_height // self.downsampling_factor, scaled_width // self.downsampling_factor], dtype=self.torch_dtype(), device='cpu').to(device) else: return torch.randn([1, - self.latent_channels, + channels, scaled_height // self.downsampling_factor, scaled_width // self.downsampling_factor], dtype=self.torch_dtype(), From 09b6104bfd80212fe0ab528b9354b49094d1248b Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Fri, 27 Jan 2023 12:04:12 -0800 Subject: [PATCH 18/25] refactor(txt2img2img): factor out tensor shape --- ldm/invoke/generator/txt2img2img.py | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/ldm/invoke/generator/txt2img2img.py b/ldm/invoke/generator/txt2img2img.py index 4923e7daf5..058628ba1c 100644 --- a/ldm/invoke/generator/txt2img2img.py +++ b/ldm/invoke/generator/txt2img2img.py @@ -115,22 +115,13 @@ class Txt2Img2Img(Generator): scaled_width = width scaled_height = height - device = self.model.device - + device = self.model.device channels = self.latent_channels if channels == 9: channels = 4 # we don't really want noise for all the mask channels + shape = (1, channels, + scaled_height // self.downsampling_factor, scaled_width // self.downsampling_factor) if self.use_mps_noise or device.type == 'mps': - return torch.randn([1, - channels, - scaled_height // self.downsampling_factor, - scaled_width // self.downsampling_factor], - dtype=self.torch_dtype(), - device='cpu').to(device) + return torch.randn(shape, dtype=self.torch_dtype(), device='cpu').to(device) else: - return torch.randn([1, - channels, - scaled_height // self.downsampling_factor, - scaled_width // self.downsampling_factor], - dtype=self.torch_dtype(), - device=device) + return torch.randn(shape, dtype=self.torch_dtype(), device=device) From 07e03b31b7610a27d709b366a491e57085d9b393 Mon Sep 17 00:00:00 2001 From: Jonathan <34005131+JPPhoto@users.noreply.github.com> Date: Sun, 29 Jan 2023 12:27:01 -0600 Subject: [PATCH 19/25] Update --hires_fix (#2414) * Update --hires_fix Change `--hires_fix` to calculate initial width and height based on the model's resolution (if available) and with a minimum size. --- ldm/invoke/generator/txt2img2img.py | 31 ++++++++++++++++++++--------- 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/ldm/invoke/generator/txt2img2img.py b/ldm/invoke/generator/txt2img2img.py index 47692a6bbb..696ba03596 100644 --- a/ldm/invoke/generator/txt2img2img.py +++ b/ldm/invoke/generator/txt2img2img.py @@ -38,10 +38,6 @@ class Txt2Img2Img(Generator): uc, c, cfg_scale, extra_conditioning_info, threshold = ThresholdSettings(threshold, warmup=0.2) if threshold else None) .add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)) - scale_dim = min(width, height) - scale = 512 / scale_dim - - init_width, init_height = trim_to_multiple_of(scale * width, scale * height) def make_image(x_T): @@ -54,6 +50,10 @@ class Txt2Img2Img(Generator): # TODO: threshold = threshold, ) + # Get our initial generation width and height directly from the latent output so + # the message below is accurate. + init_width = first_pass_latent_output.size()[3] * self.downsampling_factor + init_height = first_pass_latent_output.size()[2] * self.downsampling_factor print( f"\n>> Interpolating from {init_width}x{init_height} to {width}x{height} using DDIM sampling" ) @@ -106,11 +106,24 @@ class Txt2Img2Img(Generator): def get_noise(self,width,height,scale = True): # print(f"Get noise: {width}x{height}") if scale: - trained_square = 512 * 512 - actual_square = width * height - scale = math.sqrt(trained_square / actual_square) - scaled_width = math.ceil(scale * width / 64) * 64 - scaled_height = math.ceil(scale * height / 64) * 64 + # Scale the input width and height for the initial generation + # Make their area equivalent to the model's resolution area (e.g. 512*512 = 262144), + # while keeping the minimum dimension at least 0.5 * resolution (e.g. 512*0.5 = 256) + + aspect = width / height + dimension = self.model.unet.config.sample_size * self.model.vae_scale_factor + min_dimension = math.floor(dimension * 0.5) + model_area = dimension * dimension # hardcoded for now since all models are trained on square images + + if aspect > 1.0: + init_height = max(min_dimension, math.sqrt(model_area / aspect)) + init_width = init_height * aspect + else: + init_width = max(min_dimension, math.sqrt(model_area * aspect)) + init_height = init_width / aspect + + scaled_width, scaled_height = trim_to_multiple_of(math.floor(init_width), math.floor(init_height)) + else: scaled_width = width scaled_height = height From d9ed0f60054b7b0fa69f79bae4844d679200ac5d Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sun, 29 Jan 2023 21:30:50 -0500 Subject: [PATCH 20/25] fix documentation of huggingface cache location (#2430) * fix documentation of huggingface cache location --------- Co-authored-by: Jonathan <34005131+JPPhoto@users.noreply.github.com> --- docs/CHANGELOG.md | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index e9461264db..491e3b76e4 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -52,12 +52,17 @@ introduces several changes you should know about. path: models/diffusers/hakurei-haifu-diffusion-1.4 ``` -2. The format of the models directory has changed to mimic the - HuggingFace cache directory. By default, diffusers models are - now automatically downloaded and retrieved from the directory - `ROOTDIR/models/diffusers`, while other models are stored in - the directory `ROOTDIR/models/hub`. This organization is the - same as that used by HuggingFace for its cache management. +2. In order of precedence, InvokeAI will now use HF_HOME, then + XDG_CACHE_HOME, then finally default to `ROOTDIR/models` to + store HuggingFace diffusers models. + + Consequently, the format of the models directory has changed to + mimic the HuggingFace cache directory. When HF_HOME and XDG_HOME + are not set, diffusers models are now automatically downloaded + and retrieved from the directory `ROOTDIR/models/diffusers`, + while other models are stored in the directory + `ROOTDIR/models/hub`. This organization is the same as that used + by HuggingFace for its cache management. This allows you to share diffusers and ckpt model files easily with other machine learning applications that use the HuggingFace @@ -66,7 +71,13 @@ introduces several changes you should know about. cache models in. To tell InvokeAI to use the standard HuggingFace cache directory, you would set HF_HOME like this (Linux/Mac): - `export HF_HOME=~/.cache/hugging_face` + `export HF_HOME=~/.cache/huggingface` + + Both HuggingFace and InvokeAI will fall back to the XDG_CACHE_HOME + environment variable if HF_HOME is not set; this path + takes precedence over `ROOTDIR/models` to allow for the same sharing + with other machine learning applications that use HuggingFace + libraries. 3. If you upgrade to InvokeAI 2.3.* from an earlier version, there will be a one-time migration from the old models directory format From 27ee939e4b3f0929687057f50f75009dde5d09d1 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Mon, 30 Jan 2023 14:50:57 +0100 Subject: [PATCH 21/25] with diffusers cac, always run the original prompt on the first step --- ldm/models/diffusion/cross_attention_control.py | 7 ++++++- ldm/models/diffusion/shared_invokeai_diffusion.py | 5 ++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/ldm/models/diffusion/cross_attention_control.py b/ldm/models/diffusion/cross_attention_control.py index 99ef1d49bc..cecbc869e9 100644 --- a/ldm/models/diffusion/cross_attention_control.py +++ b/ldm/models/diffusion/cross_attention_control.py @@ -108,7 +108,7 @@ class Context: return self.tokens_cross_attention_action == Context.Action.APPLY return False - def get_active_cross_attention_control_types_for_step(self, percent_through:float=None)\ + def get_active_cross_attention_control_types_for_step(self, percent_through:Optional[float]=None, step_size:Optional[float]=None)\ -> list[CrossAttentionType]: """ Should cross-attention control be applied on the given step? @@ -117,6 +117,11 @@ class Context: """ if percent_through is None: return [CrossAttentionType.SELF, CrossAttentionType.TOKENS] + if step_size is not None: + # adjust percent_through to ignore the first step + percent_through = (percent_through - step_size) / (1.0 - step_size) + if percent_through < 0: + return [] opts = self.arguments.edit_options to_control = [] diff --git a/ldm/models/diffusion/shared_invokeai_diffusion.py b/ldm/models/diffusion/shared_invokeai_diffusion.py index 0c91df9528..c4a571a21d 100644 --- a/ldm/models/diffusion/shared_invokeai_diffusion.py +++ b/ldm/models/diffusion/shared_invokeai_diffusion.py @@ -141,13 +141,16 @@ class InvokeAIDiffuserComponent: if step_index is not None and total_step_count is not None: # 🧨diffusers codepath percent_through = step_index / total_step_count # will never reach 1.0 - this is deliberate + step_size_percent = 1 / total_step_count else: # legacy compvis codepath # TODO remove when compvis codepath support is dropped if step_index is None and sigma is None: raise ValueError(f"Either step_index or sigma is required when doing cross attention control, but both are None.") percent_through = self.estimate_percent_through(step_index, sigma) - cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step(percent_through) + # legacy code path supports s_* so we don't need step_size_percent + step_size_percent = None + cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step(percent_through, step_size=step_size_percent) wants_cross_attention_control = (len(cross_attention_control_types_to_do) > 0) wants_hybrid_conditioning = isinstance(conditioning, dict) From 478c37953423984600d2c47b60dfffa46075b385 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Mon, 30 Jan 2023 15:30:01 +0100 Subject: [PATCH 22/25] for cac make t_start=0.1 the default --- ldm/invoke/prompt_parser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ldm/invoke/prompt_parser.py b/ldm/invoke/prompt_parser.py index 7729be78b3..420295c0b6 100644 --- a/ldm/invoke/prompt_parser.py +++ b/ldm/invoke/prompt_parser.py @@ -155,7 +155,7 @@ class CrossAttentionControlSubstitute(CrossAttentionControlledFragment): default_options = { 's_start': 0.0, 's_end': 0.2062994740159002, # ~= shape_freedom=0.5 - 't_start': 0.0, + 't_start': 0.1, 't_end': 1.0 } merged_options = default_options From 17d73d09c0abb1aa4541c91df885737cbc62651c Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Mon, 30 Jan 2023 15:38:03 +0100 Subject: [PATCH 23/25] Revert "with diffusers cac, always run the original prompt on the first step" This reverts commit 27ee939e4b3f0929687057f50f75009dde5d09d1. --- ldm/models/diffusion/cross_attention_control.py | 7 +------ ldm/models/diffusion/shared_invokeai_diffusion.py | 5 +---- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/ldm/models/diffusion/cross_attention_control.py b/ldm/models/diffusion/cross_attention_control.py index 4bb0bc4829..770c71f110 100644 --- a/ldm/models/diffusion/cross_attention_control.py +++ b/ldm/models/diffusion/cross_attention_control.py @@ -108,7 +108,7 @@ class Context: return self.tokens_cross_attention_action == Context.Action.APPLY return False - def get_active_cross_attention_control_types_for_step(self, percent_through:Optional[float]=None, step_size:Optional[float]=None)\ + def get_active_cross_attention_control_types_for_step(self, percent_through:float=None)\ -> list[CrossAttentionType]: """ Should cross-attention control be applied on the given step? @@ -117,11 +117,6 @@ class Context: """ if percent_through is None: return [CrossAttentionType.SELF, CrossAttentionType.TOKENS] - if step_size is not None: - # adjust percent_through to ignore the first step - percent_through = (percent_through - step_size) / (1.0 - step_size) - if percent_through < 0: - return [] opts = self.arguments.edit_options to_control = [] diff --git a/ldm/models/diffusion/shared_invokeai_diffusion.py b/ldm/models/diffusion/shared_invokeai_diffusion.py index 160d3ae6e5..1ecbd1c488 100644 --- a/ldm/models/diffusion/shared_invokeai_diffusion.py +++ b/ldm/models/diffusion/shared_invokeai_diffusion.py @@ -141,16 +141,13 @@ class InvokeAIDiffuserComponent: if step_index is not None and total_step_count is not None: # 🧨diffusers codepath percent_through = step_index / total_step_count # will never reach 1.0 - this is deliberate - step_size_percent = 1 / total_step_count else: # legacy compvis codepath # TODO remove when compvis codepath support is dropped if step_index is None and sigma is None: raise ValueError(f"Either step_index or sigma is required when doing cross attention control, but both are None.") percent_through = self.estimate_percent_through(step_index, sigma) - # legacy code path supports s_* so we don't need step_size_percent - step_size_percent = None - cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step(percent_through, step_size=step_size_percent) + cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step(percent_through) wants_cross_attention_control = (len(cross_attention_control_types_to_do) > 0) wants_hybrid_conditioning = isinstance(conditioning, dict) From d044d4c577a235b2528a3ea21b38dcba6b6371ec Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Mon, 30 Jan 2023 16:23:44 +0100 Subject: [PATCH 24/25] rename override/restore methods to better reflect what they actually do --- .../diffusion/cross_attention_control.py | 4 ++-- ldm/models/diffusion/ddim.py | 4 ++-- ldm/models/diffusion/ksampler.py | 4 ++-- ldm/models/diffusion/plms.py | 4 ++-- .../diffusion/shared_invokeai_diffusion.py | 24 +++++++++---------- 5 files changed, 20 insertions(+), 20 deletions(-) diff --git a/ldm/models/diffusion/cross_attention_control.py b/ldm/models/diffusion/cross_attention_control.py index 770c71f110..37f0ebfa1d 100644 --- a/ldm/models/diffusion/cross_attention_control.py +++ b/ldm/models/diffusion/cross_attention_control.py @@ -306,7 +306,7 @@ class InvokeAICrossAttentionMixin: -def remove_cross_attention_control(model, is_running_diffusers: bool, restore_attention_processor: Optional[AttnProcessor]=None): +def restore_default_cross_attention(model, is_running_diffusers: bool, restore_attention_processor: Optional[AttnProcessor]=None): if is_running_diffusers: unet = model unet.set_attn_processor(restore_attention_processor or CrossAttnProcessor()) @@ -314,7 +314,7 @@ def remove_cross_attention_control(model, is_running_diffusers: bool, restore_at remove_attention_function(model) -def setup_cross_attention_control(model, context: Context, is_running_diffusers = False): +def override_cross_attention(model, context: Context, is_running_diffusers = False): """ Inject attention parameters and functions into the passed in model to enable cross attention editing. diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py index e5a502f977..304009c1d3 100644 --- a/ldm/models/diffusion/ddim.py +++ b/ldm/models/diffusion/ddim.py @@ -19,9 +19,9 @@ class DDIMSampler(Sampler): all_timesteps_count = kwargs.get('all_timesteps_count', t_enc) if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: - self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = all_timesteps_count) + self.invokeai_diffuser.override_cross_attention(extra_conditioning_info, step_count = all_timesteps_count) else: - self.invokeai_diffuser.remove_cross_attention_control() + self.invokeai_diffuser.restore_default_cross_attention() # This is the central routine diff --git a/ldm/models/diffusion/ksampler.py b/ldm/models/diffusion/ksampler.py index 0038c481e8..f98ca8de21 100644 --- a/ldm/models/diffusion/ksampler.py +++ b/ldm/models/diffusion/ksampler.py @@ -43,9 +43,9 @@ class CFGDenoiser(nn.Module): extra_conditioning_info = kwargs.get('extra_conditioning_info', None) if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: - self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = t_enc) + self.invokeai_diffuser.override_cross_attention(extra_conditioning_info, step_count = t_enc) else: - self.invokeai_diffuser.remove_cross_attention_control() + self.invokeai_diffuser.restore_default_cross_attention() def forward(self, x, sigma, uncond, cond, cond_scale): diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py index 5124badcd1..9edd333780 100644 --- a/ldm/models/diffusion/plms.py +++ b/ldm/models/diffusion/plms.py @@ -21,9 +21,9 @@ class PLMSSampler(Sampler): all_timesteps_count = kwargs.get('all_timesteps_count', t_enc) if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: - self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = all_timesteps_count) + self.invokeai_diffuser.override_cross_attention(extra_conditioning_info, step_count = all_timesteps_count) else: - self.invokeai_diffuser.remove_cross_attention_control() + self.invokeai_diffuser.restore_default_cross_attention() # this is the essential routine diff --git a/ldm/models/diffusion/shared_invokeai_diffusion.py b/ldm/models/diffusion/shared_invokeai_diffusion.py index 1ecbd1c488..f37bec789e 100644 --- a/ldm/models/diffusion/shared_invokeai_diffusion.py +++ b/ldm/models/diffusion/shared_invokeai_diffusion.py @@ -9,7 +9,7 @@ import torch from diffusers.models.cross_attention import AttnProcessor from ldm.models.diffusion.cross_attention_control import Arguments, \ - remove_cross_attention_control, setup_cross_attention_control, Context, get_cross_attention_modules, \ + restore_default_cross_attention, override_cross_attention, Context, get_cross_attention_modules, \ CrossAttentionType, SwapCrossAttnContext from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver @@ -64,17 +64,17 @@ class InvokeAIDiffuserComponent: do_swap = extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control old_attn_processor = None if do_swap: - old_attn_processor = self.setup_cross_attention_control(extra_conditioning_info, - step_count=step_count) + old_attn_processor = self.override_cross_attention(extra_conditioning_info, + step_count=step_count) try: yield None finally: if old_attn_processor is not None: - self.remove_cross_attention_control(old_attn_processor) + self.restore_default_cross_attention(old_attn_processor) # TODO resuscitate attention map saving #self.remove_attention_map_saving() - def setup_cross_attention_control(self, conditioning: ExtraConditioningInfo, step_count: int) -> Dict[str, AttnProcessor]: + def override_cross_attention(self, conditioning: ExtraConditioningInfo, step_count: int) -> Dict[str, AttnProcessor]: """ setup cross attention .swap control. for diffusers this replaces the attention processor, so the previous attention processor is returned so that the caller can restore it later. @@ -84,16 +84,16 @@ class InvokeAIDiffuserComponent: arguments=self.conditioning.cross_attention_control_args, step_count=step_count ) - return setup_cross_attention_control(self.model, - self.cross_attention_control_context, - is_running_diffusers=self.is_running_diffusers) + return override_cross_attention(self.model, + self.cross_attention_control_context, + is_running_diffusers=self.is_running_diffusers) - def remove_cross_attention_control(self, restore_attention_processor: Optional['AttnProcessor']=None): + def restore_default_cross_attention(self, restore_attention_processor: Optional['AttnProcessor']=None): self.conditioning = None self.cross_attention_control_context = None - remove_cross_attention_control(self.model, - is_running_diffusers=self.is_running_diffusers, - restore_attention_processor=restore_attention_processor) + restore_default_cross_attention(self.model, + is_running_diffusers=self.is_running_diffusers, + restore_attention_processor=restore_attention_processor) def setup_attention_map_saving(self, saver: AttentionMapSaver): def callback(slice, dim, offset, slice_size, key): From fc8e3dbcd33adf92b6930284b32e9e56bf0d1e6c Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Tue, 31 Jan 2023 09:59:58 -0500 Subject: [PATCH 25/25] fix crash when editing name of model - fixes a spurious "unknown model name" error when trying to edit the short name of an existing model. - relaxes naming requirements to include the ':' and '/' characters in model names --- ldm/invoke/CLI.py | 4 ++-- ldm/invoke/model_manager.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ldm/invoke/CLI.py b/ldm/invoke/CLI.py index f5f0d5b86b..4a7b56eb5c 100644 --- a/ldm/invoke/CLI.py +++ b/ldm/invoke/CLI.py @@ -786,8 +786,8 @@ def _get_model_name(existing_names,completer,default_name:str='')->str: model_name = input(f'Short name for this model [{default_name}]: ').strip() if len(model_name)==0: model_name = default_name - if not re.match('^[\w._+-]+$',model_name): - print('** model name must contain only words, digits and the characters "._+-" **') + if not re.match('^[\w._+:/-]+$',model_name): + print('** model name must contain only words, digits and the characters "._+:/-" **') elif model_name != default_name and model_name in existing_names: print(f'** the name {model_name} is already in use. Pick another.') else: diff --git a/ldm/invoke/model_manager.py b/ldm/invoke/model_manager.py index 811e768764..dbc690ec54 100644 --- a/ldm/invoke/model_manager.py +++ b/ldm/invoke/model_manager.py @@ -125,7 +125,7 @@ class ModelManager(object): Set the default model. The change will not take effect until you call model_manager.commit() ''' - assert model_name in self.models,f"unknown model '{model_name}'" + assert model_name in self.model_names(), f"unknown model '{model_name}'" config = self.config for model in config: