diff --git a/docs/features/PROMPTS.md b/docs/features/PROMPTS.md index 79166a9b94..5413cc5e55 100644 --- a/docs/features/PROMPTS.md +++ b/docs/features/PROMPTS.md @@ -239,28 +239,24 @@ Generate an image with a given prompt, record the seed of the image, and then use the `prompt2prompt` syntax to substitute words in the original prompt for words in a new prompt. This works for `img2img` as well. -- `a ("fluffy cat").swap("smiling dog") eating a hotdog`. - - quotes optional: `a (fluffy cat).swap(smiling dog) eating a hotdog`. - - for single word substitutions parentheses are also optional: - `a cat.swap(dog) eating a hotdog`. -- Supports options `s_start`, `s_end`, `t_start`, `t_end` (each 0-1) loosely - corresponding to bloc97's `prompt_edit_spatial_start/_end` and - `prompt_edit_tokens_start/_end` but with the math swapped to make it easier to - intuitively understand. - - Example usage:`a (cat).swap(dog, s_end=0.3) eating a hotdog` - the `s_end` - argument means that the "spatial" (self-attention) edit will stop having any - effect after 30% (=0.3) of the steps have been done, leaving Stable - Diffusion with 70% of the steps where it is free to decide for itself how to - reshape the cat-form into a dog form. - - The numbers represent a percentage through the step sequence where the edits - should happen. 0 means the start (noisy starting image), 1 is the end (final - image). - - For img2img, the step sequence does not start at 0 but instead at - (1-strength) - so if strength is 0.7, s_start and s_end must both be - greater than 0.3 (1-0.7) to have any effect. -- Convenience option `shape_freedom` (0-1) to specify how much "freedom" Stable - Diffusion should have to change the shape of the subject being swapped. - - `a (cat).swap(dog, shape_freedom=0.5) eating a hotdog`. +For example, consider the prompt `a cat.swap(dog) playing with a ball in the forest`. Normally, because of the word words interact with each other when doing a stable diffusion image generation, these two prompts would generate different compositions: + - `a cat playing with a ball in the forest` + - `a dog playing with a ball in the forest` + +| `a cat playing with a ball in the forest` | `a dog playing with a ball in the forest` | +| --- | --- | +| img | img | + + + - For multiple word swaps, use parentheses: `a (fluffy cat).swap(barking dog) playing with a ball in the forest`. + - To swap a comma, use quotes: `a ("fluffy, grey cat").swap("big, barking dog") playing with a ball in the forest`. +- Supports options `t_start` and `t_end` (each 0-1) loosely corresponding to bloc97's `prompt_edit_tokens_start/_end` but with the math swapped to make it easier to + intuitively understand. `t_start` and `t_end` are used to control on which steps cross-attention control should run. With the default values `t_start=0` and `t_end=1`, cross-attention control is active on every step of image generation. Other values can be used to turn cross-attention control off for part of the image generation process. + - For example, if doing a diffusion with 10 steps for the prompt is `a cat.swap(dog, t_start=0.3, t_end=1.0) playing with a ball in the forest`, the first 3 steps will be run as `a cat playing with a ball in the forest`, while the last 7 steps will run as `a dog playing with a ball in the forest`, but the pixels that represent `dog` will be locked to the pixels that would have represented `cat` if the `cat` prompt had been used instead. + - Conversely, for `a cat.swap(dog, t_start=0, t_end=0.7) playing with a ball in the forest`, the first 7 steps will run as `a dog playing with a ball in the forest` with the pixels that represent `dog` locked to the same pixels that would have represented `cat` if the `cat` prompt was being used instead. The final 3 steps will just run `a cat playing with a ball in the forest`. + > For img2img, the step sequence does not start at 0 but instead at `(1.0-strength)` - so if the img2img `strength` is `0.7`, `t_start` and `t_end` must both be greater than `0.3` (`1.0-0.7`) to have any effect. + +Prompt2prompt `.swap()` is not compatible with xformers, which will be temporarily disabled when doing a `.swap()` - so you should expect to use more VRAM and run slower that with xformers enabled. The `prompt2prompt` code is based off [bloc97's colab](https://github.com/bloc97/CrossAttentionControl). diff --git a/ldm/invoke/CLI.py b/ldm/invoke/CLI.py index f5f0d5b86b..4a7b56eb5c 100644 --- a/ldm/invoke/CLI.py +++ b/ldm/invoke/CLI.py @@ -786,8 +786,8 @@ def _get_model_name(existing_names,completer,default_name:str='')->str: model_name = input(f'Short name for this model [{default_name}]: ').strip() if len(model_name)==0: model_name = default_name - if not re.match('^[\w._+-]+$',model_name): - print('** model name must contain only words, digits and the characters "._+-" **') + if not re.match('^[\w._+:/-]+$',model_name): + print('** model name must contain only words, digits and the characters "._+:/-" **') elif model_name != default_name and model_name in existing_names: print(f'** the name {model_name} is already in use. Pick another.') else: diff --git a/ldm/invoke/generator/diffusers_pipeline.py b/ldm/invoke/generator/diffusers_pipeline.py index 388d7a3342..1ef89913c5 100644 --- a/ldm/invoke/generator/diffusers_pipeline.py +++ b/ldm/invoke/generator/diffusers_pipeline.py @@ -24,9 +24,6 @@ from ...models.diffusion import cross_attention_control from ...models.diffusion.cross_attention_map_saving import AttentionMapSaver from ...modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToEmbeddingsConverter -# monkeypatch diffusers CrossAttention 🙈 -# this is to make prompt2prompt and (future) attention maps work -attention.CrossAttention = cross_attention_control.InvokeAIDiffusersCrossAttention from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput @@ -295,7 +292,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): safety_checker=safety_checker, feature_extractor=feature_extractor, ) - self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward) + self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward, is_running_diffusers=True) use_full_precision = (precision == 'float32' or precision == 'autocast') self.textual_inversion_manager = TextualInversionManager(tokenizer=self.tokenizer, text_encoder=self.text_encoder, @@ -307,8 +304,23 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): textual_inversion_manager=self.textual_inversion_manager ) + self._enable_memory_efficient_attention() + + + def _enable_memory_efficient_attention(self): + """ + if xformers is available, use it, otherwise use sliced attention. + """ if is_xformers_available() and not Globals.disable_xformers: self.enable_xformers_memory_efficient_attention() + else: + if torch.backends.mps.is_available(): + # until pytorch #91617 is fixed, slicing is borked on MPS + # https://github.com/pytorch/pytorch/issues/91617 + # fix is in https://github.com/kulinseth/pytorch/pull/222 but no idea when it will get merged to pytorch mainline. + pass + else: + self.enable_attention_slicing(slice_size='auto') def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int, conditioning_data: ConditioningData, @@ -373,42 +385,40 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): if additional_guidance is None: additional_guidance = [] extra_conditioning_info = conditioning_data.extra - if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: - self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, - step_count=len(self.scheduler.timesteps)) - else: - self.invokeai_diffuser.remove_cross_attention_control() + with self.invokeai_diffuser.custom_attention_context(extra_conditioning_info=extra_conditioning_info, + step_count=len(self.scheduler.timesteps) + ): - yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps, - latents=latents) + yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps, + latents=latents) - batch_size = latents.shape[0] - batched_t = torch.full((batch_size,), timesteps[0], - dtype=timesteps.dtype, device=self.unet.device) - latents = self.scheduler.add_noise(latents, noise, batched_t) + batch_size = latents.shape[0] + batched_t = torch.full((batch_size,), timesteps[0], + dtype=timesteps.dtype, device=self.unet.device) + latents = self.scheduler.add_noise(latents, noise, batched_t) - attention_map_saver: Optional[AttentionMapSaver] = None - self.invokeai_diffuser.remove_attention_map_saving() - for i, t in enumerate(self.progress_bar(timesteps)): - batched_t.fill_(t) - step_output = self.step(batched_t, latents, conditioning_data, - step_index=i, - total_step_count=len(timesteps), - additional_guidance=additional_guidance) - latents = step_output.prev_sample - predicted_original = getattr(step_output, 'pred_original_sample', None) + attention_map_saver: Optional[AttentionMapSaver] = None - if i == len(timesteps)-1 and extra_conditioning_info is not None: - eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1 - attention_map_token_ids = range(1, eos_token_index) - attention_map_saver = AttentionMapSaver(token_ids=attention_map_token_ids, latents_shape=latents.shape[-2:]) - self.invokeai_diffuser.setup_attention_map_saving(attention_map_saver) + for i, t in enumerate(self.progress_bar(timesteps)): + batched_t.fill_(t) + step_output = self.step(batched_t, latents, conditioning_data, + step_index=i, + total_step_count=len(timesteps), + additional_guidance=additional_guidance) + latents = step_output.prev_sample + predicted_original = getattr(step_output, 'pred_original_sample', None) - yield PipelineIntermediateState(run_id=run_id, step=i, timestep=int(t), latents=latents, - predicted_original=predicted_original, attention_map_saver=attention_map_saver) + # TODO resuscitate attention map saving + #if i == len(timesteps)-1 and extra_conditioning_info is not None: + # eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1 + # attention_map_token_ids = range(1, eos_token_index) + # attention_map_saver = AttentionMapSaver(token_ids=attention_map_token_ids, latents_shape=latents.shape[-2:]) + # self.invokeai_diffuser.setup_attention_map_saving(attention_map_saver) - self.invokeai_diffuser.remove_attention_map_saving() - return latents, attention_map_saver + yield PipelineIntermediateState(run_id=run_id, step=i, timestep=int(t), latents=latents, + predicted_original=predicted_original, attention_map_saver=attention_map_saver) + + return latents, attention_map_saver @torch.inference_mode() def step(self, t: torch.Tensor, latents: torch.Tensor, @@ -447,7 +457,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): return step_output - def _unet_forward(self, latents, t, text_embeddings): + def _unet_forward(self, latents, t, text_embeddings, cross_attention_kwargs: Optional[dict[str,Any]] = None): """predict the noise residual""" if is_inpainting_model(self.unet) and latents.size(1) == 4: # Pad out normal non-inpainting inputs for an inpainting model. @@ -460,7 +470,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): initial_image_latents=torch.zeros_like(latents[:1], device=latents.device, dtype=latents.dtype) ).add_mask_channels(latents) - return self.unet(latents, t, encoder_hidden_states=text_embeddings).sample + return self.unet(sample=latents, + timestep=t, + encoder_hidden_states=text_embeddings, + cross_attention_kwargs=cross_attention_kwargs).sample def img2img_from_embeddings(self, init_image: Union[torch.FloatTensor, PIL.Image.Image], @@ -531,6 +544,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): init_image = image_resized_to_grid_as_tensor(init_image.convert('RGB')) init_image = init_image.to(device=device, dtype=latents_dtype) + mask = mask.to(device=device, dtype=latents_dtype) if init_image.dim() == 3: init_image = init_image.unsqueeze(0) @@ -549,17 +563,22 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): if mask.dim() == 3: mask = mask.unsqueeze(0) - mask = tv_resize(mask, init_image_latents.shape[-2:], T.InterpolationMode.BILINEAR) \ + latent_mask = tv_resize(mask, init_image_latents.shape[-2:], T.InterpolationMode.BILINEAR) \ .to(device=device, dtype=latents_dtype) guidance: List[Callable] = [] if is_inpainting_model(self.unet): + # You'd think the inpainting model wouldn't be paying attention to the area it is going to repaint + # (that's why there's a mask!) but it seems to really want that blanked out. + masked_init_image = init_image * torch.where(mask < 0.5, 1, 0) + masked_latents = self.non_noised_latents_from_image(masked_init_image, device=device, dtype=latents_dtype) + # TODO: we should probably pass this in so we don't have to try/finally around setting it. self.invokeai_diffuser.model_forward_callback = \ - AddsMaskLatents(self._unet_forward, mask, init_image_latents) + AddsMaskLatents(self._unet_forward, latent_mask, masked_latents) else: - guidance.append(AddsMaskGuidance(mask, init_image_latents, self.scheduler, noise)) + guidance.append(AddsMaskGuidance(latent_mask, init_image_latents, self.scheduler, noise)) try: result_latents, result_attention_maps = self.latents_from_embeddings( diff --git a/ldm/invoke/generator/txt2img2img.py b/ldm/invoke/generator/txt2img2img.py index 696ba03596..1c398fb95d 100644 --- a/ldm/invoke/generator/txt2img2img.py +++ b/ldm/invoke/generator/txt2img2img.py @@ -3,10 +3,10 @@ ldm.invoke.generator.txt2img inherits from ldm.invoke.generator ''' import math -from diffusers.utils.logging import get_verbosity, set_verbosity, set_verbosity_error from typing import Callable, Optional import torch +from diffusers.utils.logging import get_verbosity, set_verbosity, set_verbosity_error from ldm.invoke.generator.base import Generator from ldm.invoke.generator.diffusers_pipeline import trim_to_multiple_of, StableDiffusionGeneratorPipeline, \ @@ -128,18 +128,13 @@ class Txt2Img2Img(Generator): scaled_width = width scaled_height = height - device = self.model.device + device = self.model.device + channels = self.latent_channels + if channels == 9: + channels = 4 # we don't really want noise for all the mask channels + shape = (1, channels, + scaled_height // self.downsampling_factor, scaled_width // self.downsampling_factor) if self.use_mps_noise or device.type == 'mps': - return torch.randn([1, - self.latent_channels, - scaled_height // self.downsampling_factor, - scaled_width // self.downsampling_factor], - dtype=self.torch_dtype(), - device='cpu').to(device) + return torch.randn(shape, dtype=self.torch_dtype(), device='cpu').to(device) else: - return torch.randn([1, - self.latent_channels, - scaled_height // self.downsampling_factor, - scaled_width // self.downsampling_factor], - dtype=self.torch_dtype(), - device=device) + return torch.randn(shape, dtype=self.torch_dtype(), device=device) diff --git a/ldm/invoke/model_manager.py b/ldm/invoke/model_manager.py index 811e768764..dbc690ec54 100644 --- a/ldm/invoke/model_manager.py +++ b/ldm/invoke/model_manager.py @@ -125,7 +125,7 @@ class ModelManager(object): Set the default model. The change will not take effect until you call model_manager.commit() ''' - assert model_name in self.models,f"unknown model '{model_name}'" + assert model_name in self.model_names(), f"unknown model '{model_name}'" config = self.config for model in config: diff --git a/ldm/invoke/prompt_parser.py b/ldm/invoke/prompt_parser.py index 7729be78b3..420295c0b6 100644 --- a/ldm/invoke/prompt_parser.py +++ b/ldm/invoke/prompt_parser.py @@ -155,7 +155,7 @@ class CrossAttentionControlSubstitute(CrossAttentionControlledFragment): default_options = { 's_start': 0.0, 's_end': 0.2062994740159002, # ~= shape_freedom=0.5 - 't_start': 0.0, + 't_start': 0.1, 't_end': 1.0 } merged_options = default_options diff --git a/ldm/models/diffusion/cross_attention_control.py b/ldm/models/diffusion/cross_attention_control.py index 03d5a5bcec..37f0ebfa1d 100644 --- a/ldm/models/diffusion/cross_attention_control.py +++ b/ldm/models/diffusion/cross_attention_control.py @@ -7,8 +7,10 @@ import torch import diffusers from torch import nn from diffusers.models.unet_2d_condition import UNet2DConditionModel +from diffusers.models.cross_attention import AttnProcessor from ldm.invoke.devices import torch_dtype + # adapted from bloc97's CrossAttentionControl colab # https://github.com/bloc97/CrossAttentionControl @@ -304,11 +306,15 @@ class InvokeAICrossAttentionMixin: -def remove_cross_attention_control(model): - remove_attention_function(model) +def restore_default_cross_attention(model, is_running_diffusers: bool, restore_attention_processor: Optional[AttnProcessor]=None): + if is_running_diffusers: + unet = model + unet.set_attn_processor(restore_attention_processor or CrossAttnProcessor()) + else: + remove_attention_function(model) -def setup_cross_attention_control(model, context: Context): +def override_cross_attention(model, context: Context, is_running_diffusers = False): """ Inject attention parameters and functions into the passed in model to enable cross attention editing. @@ -323,7 +329,7 @@ def setup_cross_attention_control(model, context: Context): # urgh. should this be hardcoded? max_length = 77 # mask=1 means use base prompt attention, mask=0 means use edited prompt attention - mask = torch.zeros(max_length) + mask = torch.zeros(max_length, dtype=torch_dtype(device)) indices_target = torch.arange(max_length, dtype=torch.long) indices = torch.arange(max_length, dtype=torch.long) for name, a0, a1, b0, b1 in context.arguments.edit_opcodes: @@ -333,10 +339,26 @@ def setup_cross_attention_control(model, context: Context): indices[b0:b1] = indices_target[a0:a1] mask[b0:b1] = 1 - context.register_cross_attention_modules(model) context.cross_attention_mask = mask.to(device) context.cross_attention_index_map = indices.to(device) - inject_attention_function(model, context) + if is_running_diffusers: + unet = model + old_attn_processors = unet.attn_processors + if torch.backends.mps.is_available(): + # see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS + unet.set_attn_processor(SwapCrossAttnProcessor()) + else: + # try to re-use an existing slice size + default_slice_size = 4 + slice_size = next((p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size) + unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size)) + return old_attn_processors + else: + context.register_cross_attention_modules(model) + inject_attention_function(model, context) + return None + + def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[str, InvokeAICrossAttentionMixin]]: @@ -445,6 +467,7 @@ def get_mem_free_total(device): return mem_free_total + class InvokeAIDiffusersCrossAttention(diffusers.models.attention.CrossAttention, InvokeAICrossAttentionMixin): def __init__(self, **kwargs): @@ -460,3 +483,176 @@ class InvokeAIDiffusersCrossAttention(diffusers.models.attention.CrossAttention, hidden_states = self.reshape_batch_dim_to_heads(attention_result) return hidden_states + + + + +## 🧨diffusers implementation follows + + +""" +# base implementation + +class CrossAttnProcessor: + def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) + + query = attn.to_q(hidden_states) + query = attn.head_to_batch_dim(query) + + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + +""" +from dataclasses import field, dataclass + +import torch + +from diffusers.models.cross_attention import CrossAttention, CrossAttnProcessor, SlicedAttnProcessor, AttnProcessor + + +@dataclass +class SwapCrossAttnContext: + modified_text_embeddings: torch.Tensor + index_map: torch.Tensor # maps from original prompt token indices to the equivalent tokens in the modified prompt + mask: torch.Tensor # in the target space of the index_map + cross_attention_types_to_do: list[CrossAttentionType] = field(default_factory=list) + + def __int__(self, + cac_types_to_do: [CrossAttentionType], + modified_text_embeddings: torch.Tensor, + index_map: torch.Tensor, + mask: torch.Tensor): + self.cross_attention_types_to_do = cac_types_to_do + self.modified_text_embeddings = modified_text_embeddings + self.index_map = index_map + self.mask = mask + + def wants_cross_attention_control(self, attn_type: CrossAttentionType) -> bool: + return attn_type in self.cross_attention_types_to_do + + @classmethod + def make_mask_and_index_map(cls, edit_opcodes: list[tuple[str, int, int, int, int]], max_length: int) \ + -> tuple[torch.Tensor, torch.Tensor]: + + # mask=1 means use original prompt attention, mask=0 means use modified prompt attention + mask = torch.zeros(max_length) + indices_target = torch.arange(max_length, dtype=torch.long) + indices = torch.arange(max_length, dtype=torch.long) + for name, a0, a1, b0, b1 in edit_opcodes: + if b0 < max_length: + if name == "equal": + # these tokens remain the same as in the original prompt + indices[b0:b1] = indices_target[a0:a1] + mask[b0:b1] = 1 + + return mask, indices + + +class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor): + + # TODO: dynamically pick slice size based on memory conditions + + def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, + # kwargs + swap_cross_attn_context: SwapCrossAttnContext=None): + + attention_type = CrossAttentionType.SELF if encoder_hidden_states is None else CrossAttentionType.TOKENS + + # if cross-attention control is not in play, just call through to the base implementation. + if attention_type is CrossAttentionType.SELF or \ + swap_cross_attn_context is None or \ + not swap_cross_attn_context.wants_cross_attention_control(attention_type): + #print(f"SwapCrossAttnContext for {attention_type} not active - passing request to superclass") + return super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask) + #else: + # print(f"SwapCrossAttnContext for {attention_type} active") + + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) + + query = attn.to_q(hidden_states) + dim = query.shape[-1] + query = attn.head_to_batch_dim(query) + + original_text_embeddings = encoder_hidden_states + modified_text_embeddings = swap_cross_attn_context.modified_text_embeddings + original_text_key = attn.to_k(original_text_embeddings) + modified_text_key = attn.to_k(modified_text_embeddings) + original_value = attn.to_v(original_text_embeddings) + modified_value = attn.to_v(modified_text_embeddings) + + original_text_key = attn.head_to_batch_dim(original_text_key) + modified_text_key = attn.head_to_batch_dim(modified_text_key) + original_value = attn.head_to_batch_dim(original_value) + modified_value = attn.head_to_batch_dim(modified_value) + + # compute slices and prepare output tensor + batch_size_attention = query.shape[0] + hidden_states = torch.zeros( + (batch_size_attention, sequence_length, dim // attn.heads), device=query.device, dtype=query.dtype + ) + + # do slices + for i in range(max(1,hidden_states.shape[0] // self.slice_size)): + start_idx = i * self.slice_size + end_idx = (i + 1) * self.slice_size + + query_slice = query[start_idx:end_idx] + original_key_slice = original_text_key[start_idx:end_idx] + modified_key_slice = modified_text_key[start_idx:end_idx] + attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None + + original_attn_slice = attn.get_attention_scores(query_slice, original_key_slice, attn_mask_slice) + modified_attn_slice = attn.get_attention_scores(query_slice, modified_key_slice, attn_mask_slice) + + # because the prompt modifications may result in token sequences shifted forwards or backwards, + # the original attention probabilities must be remapped to account for token index changes in the + # modified prompt + remapped_original_attn_slice = torch.index_select(original_attn_slice, -1, + swap_cross_attn_context.index_map) + + # only some tokens taken from the original attention probabilities. this is controlled by the mask. + mask = swap_cross_attn_context.mask + inverse_mask = 1 - mask + attn_slice = \ + remapped_original_attn_slice * mask + \ + modified_attn_slice * inverse_mask + + del remapped_original_attn_slice, modified_attn_slice + + attn_slice = torch.bmm(attn_slice, modified_value[start_idx:end_idx]) + hidden_states[start_idx:end_idx] = attn_slice + + + # done + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +class SwapCrossAttnProcessor(SlicedSwapCrossAttnProcesser): + + def __init__(self): + super(SwapCrossAttnProcessor, self).__init__(slice_size=int(1e9)) # massive slice size = don't slice + diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py index e5a502f977..304009c1d3 100644 --- a/ldm/models/diffusion/ddim.py +++ b/ldm/models/diffusion/ddim.py @@ -19,9 +19,9 @@ class DDIMSampler(Sampler): all_timesteps_count = kwargs.get('all_timesteps_count', t_enc) if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: - self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = all_timesteps_count) + self.invokeai_diffuser.override_cross_attention(extra_conditioning_info, step_count = all_timesteps_count) else: - self.invokeai_diffuser.remove_cross_attention_control() + self.invokeai_diffuser.restore_default_cross_attention() # This is the central routine diff --git a/ldm/models/diffusion/ksampler.py b/ldm/models/diffusion/ksampler.py index 0038c481e8..f98ca8de21 100644 --- a/ldm/models/diffusion/ksampler.py +++ b/ldm/models/diffusion/ksampler.py @@ -43,9 +43,9 @@ class CFGDenoiser(nn.Module): extra_conditioning_info = kwargs.get('extra_conditioning_info', None) if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: - self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = t_enc) + self.invokeai_diffuser.override_cross_attention(extra_conditioning_info, step_count = t_enc) else: - self.invokeai_diffuser.remove_cross_attention_control() + self.invokeai_diffuser.restore_default_cross_attention() def forward(self, x, sigma, uncond, cond, cond_scale): diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py index 5124badcd1..9edd333780 100644 --- a/ldm/models/diffusion/plms.py +++ b/ldm/models/diffusion/plms.py @@ -21,9 +21,9 @@ class PLMSSampler(Sampler): all_timesteps_count = kwargs.get('all_timesteps_count', t_enc) if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: - self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = all_timesteps_count) + self.invokeai_diffuser.override_cross_attention(extra_conditioning_info, step_count = all_timesteps_count) else: - self.invokeai_diffuser.remove_cross_attention_control() + self.invokeai_diffuser.restore_default_cross_attention() # this is the essential routine diff --git a/ldm/models/diffusion/shared_invokeai_diffusion.py b/ldm/models/diffusion/shared_invokeai_diffusion.py index 3be6b10170..f37bec789e 100644 --- a/ldm/models/diffusion/shared_invokeai_diffusion.py +++ b/ldm/models/diffusion/shared_invokeai_diffusion.py @@ -1,14 +1,16 @@ import math +from contextlib import contextmanager from dataclasses import dataclass from math import ceil -from typing import Callable, Optional, Union +from typing import Callable, Optional, Union, Any, Dict import numpy as np import torch +from diffusers.models.cross_attention import AttnProcessor from ldm.models.diffusion.cross_attention_control import Arguments, \ - remove_cross_attention_control, setup_cross_attention_control, Context, get_cross_attention_modules, \ - CrossAttentionType + restore_default_cross_attention, override_cross_attention, Context, get_cross_attention_modules, \ + CrossAttentionType, SwapCrossAttnContext from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver @@ -30,39 +32,68 @@ class InvokeAIDiffuserComponent: debug_thresholding = False + @dataclass class ExtraConditioningInfo: - def __init__(self, tokens_count_including_eos_bos:int, cross_attention_control_args: Optional[Arguments]): - self.tokens_count_including_eos_bos = tokens_count_including_eos_bos - self.cross_attention_control_args = cross_attention_control_args + + tokens_count_including_eos_bos: int + cross_attention_control_args: Optional[Arguments] = None @property def wants_cross_attention_control(self): return self.cross_attention_control_args is not None + def __init__(self, model, model_forward_callback: - Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor] - ): + Callable[[torch.Tensor, torch.Tensor, torch.Tensor, Optional[dict[str,Any]]], torch.Tensor], + is_running_diffusers: bool=False, + ): """ :param model: the unet model to pass through to cross attention control :param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning) """ self.conditioning = None self.model = model + self.is_running_diffusers = is_running_diffusers self.model_forward_callback = model_forward_callback self.cross_attention_control_context = None - def setup_cross_attention_control(self, conditioning: ExtraConditioningInfo, step_count: int): + @contextmanager + def custom_attention_context(self, + extra_conditioning_info: Optional[ExtraConditioningInfo], + step_count: int): + do_swap = extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control + old_attn_processor = None + if do_swap: + old_attn_processor = self.override_cross_attention(extra_conditioning_info, + step_count=step_count) + try: + yield None + finally: + if old_attn_processor is not None: + self.restore_default_cross_attention(old_attn_processor) + # TODO resuscitate attention map saving + #self.remove_attention_map_saving() + + def override_cross_attention(self, conditioning: ExtraConditioningInfo, step_count: int) -> Dict[str, AttnProcessor]: + """ + setup cross attention .swap control. for diffusers this replaces the attention processor, so + the previous attention processor is returned so that the caller can restore it later. + """ self.conditioning = conditioning self.cross_attention_control_context = Context( arguments=self.conditioning.cross_attention_control_args, step_count=step_count ) - setup_cross_attention_control(self.model, self.cross_attention_control_context) + return override_cross_attention(self.model, + self.cross_attention_control_context, + is_running_diffusers=self.is_running_diffusers) - def remove_cross_attention_control(self): + def restore_default_cross_attention(self, restore_attention_processor: Optional['AttnProcessor']=None): self.conditioning = None self.cross_attention_control_context = None - remove_cross_attention_control(self.model) + restore_default_cross_attention(self.model, + is_running_diffusers=self.is_running_diffusers, + restore_attention_processor=restore_attention_processor) def setup_attention_map_saving(self, saver: AttentionMapSaver): def callback(slice, dim, offset, slice_size, key): @@ -168,7 +199,41 @@ class InvokeAIDiffuserComponent: return unconditioned_next_x, conditioned_next_x - def apply_cross_attention_controlled_conditioning(self, x:torch.Tensor, sigma, unconditioning, conditioning, cross_attention_control_types_to_do): + def apply_cross_attention_controlled_conditioning(self, + x: torch.Tensor, + sigma, + unconditioning, + conditioning, + cross_attention_control_types_to_do): + if self.is_running_diffusers: + return self.apply_cross_attention_controlled_conditioning__diffusers(x, sigma, unconditioning, conditioning, cross_attention_control_types_to_do) + else: + return self.apply_cross_attention_controlled_conditioning__compvis(x, sigma, unconditioning, conditioning, cross_attention_control_types_to_do) + + def apply_cross_attention_controlled_conditioning__diffusers(self, + x: torch.Tensor, + sigma, + unconditioning, + conditioning, + cross_attention_control_types_to_do): + context: Context = self.cross_attention_control_context + + cross_attn_processor_context = SwapCrossAttnContext(modified_text_embeddings=context.arguments.edited_conditioning, + index_map=context.cross_attention_index_map, + mask=context.cross_attention_mask, + cross_attention_types_to_do=[]) + # no cross attention for unconditioning (negative prompt) + unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, + {"swap_cross_attn_context": cross_attn_processor_context}) + + # do requested cross attention types for conditioning (positive prompt) + cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do + conditioned_next_x = self.model_forward_callback(x, sigma, conditioning, + {"swap_cross_attn_context": cross_attn_processor_context}) + return unconditioned_next_x, conditioned_next_x + + + def apply_cross_attention_controlled_conditioning__compvis(self, x:torch.Tensor, sigma, unconditioning, conditioning, cross_attention_control_types_to_do): # print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do) # slower non-batched path (20% slower on mac MPS) # We are only interested in using attention maps for conditioned_next_x, but batching them with generation of