From bffe199ad72735b780eef4d184927e37463d6a09 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Sat, 21 Jan 2023 20:54:18 +0100 Subject: [PATCH] SwapCrossAttnProcessor working - tested on mac CPU (MPS doesn't work) --- ldm/invoke/generator/diffusers_pipeline.py | 13 +- .../swap_cross_attention_processor.py | 160 ---------------- .../diffusion/cross_attention_control.py | 174 +++++++++++++++++- .../diffusion/shared_invokeai_diffusion.py | 60 +++++- 4 files changed, 226 insertions(+), 181 deletions(-) diff --git a/ldm/invoke/generator/diffusers_pipeline.py b/ldm/invoke/generator/diffusers_pipeline.py index 54e9d555af..6f3cd14550 100644 --- a/ldm/invoke/generator/diffusers_pipeline.py +++ b/ldm/invoke/generator/diffusers_pipeline.py @@ -24,9 +24,6 @@ from ...models.diffusion import cross_attention_control from ...models.diffusion.cross_attention_map_saving import AttentionMapSaver from ...modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToEmbeddingsConverter -# monkeypatch diffusers CrossAttention 🙈 -# this is to make prompt2prompt and (future) attention maps work -attention.CrossAttention = cross_attention_control.InvokeAIDiffusersCrossAttention from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput @@ -295,7 +292,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): safety_checker=safety_checker, feature_extractor=feature_extractor, ) - self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward) + self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward, is_running_diffusers=True) use_full_precision = (precision == 'float32' or precision == 'autocast') self.textual_inversion_manager = TextualInversionManager(tokenizer=self.tokenizer, text_encoder=self.text_encoder, @@ -389,6 +386,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): attention_map_saver: Optional[AttentionMapSaver] = None self.invokeai_diffuser.remove_attention_map_saving() + for i, t in enumerate(self.progress_bar(timesteps)): batched_t.fill_(t) step_output = self.step(batched_t, latents, conditioning_data, @@ -447,7 +445,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): return step_output - def _unet_forward(self, latents, t, text_embeddings): + def _unet_forward(self, latents, t, text_embeddings, cross_attention_kwargs: Optional[dict[str,Any]] = None): """predict the noise residual""" if is_inpainting_model(self.unet) and latents.size(1) == 4: # Pad out normal non-inpainting inputs for an inpainting model. @@ -460,7 +458,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): initial_image_latents=torch.zeros_like(latents[:1], device=latents.device, dtype=latents.dtype) ).add_mask_channels(latents) - return self.unet(latents, t, encoder_hidden_states=text_embeddings).sample + return self.unet(sample=latents, + timestep=t, + encoder_hidden_states=text_embeddings, + cross_attention_kwargs=cross_attention_kwargs).sample def img2img_from_embeddings(self, init_image: Union[torch.FloatTensor, PIL.Image.Image], diff --git a/ldm/invoke/generator/swap_cross_attention_processor.py b/ldm/invoke/generator/swap_cross_attention_processor.py index e3aa5bc484..e69de29bb2 100644 --- a/ldm/invoke/generator/swap_cross_attention_processor.py +++ b/ldm/invoke/generator/swap_cross_attention_processor.py @@ -1,160 +0,0 @@ - -""" -# base implementation - -class CrossAttnProcessor: - def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): - batch_size, sequence_length, _ = hidden_states.shape - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) - - query = attn.to_q(hidden_states) - query = attn.head_to_batch_dim(query) - - encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) - - attention_probs = attn.get_attention_scores(query, key, attention_mask) - hidden_states = torch.bmm(attention_probs, value) - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - return hidden_states - -""" -import enum -from dataclasses import field, dataclass - -import torch - -from diffusers.models.cross_attention import CrossAttention, CrossAttnProcessor - -class AttentionType(enum.Enum): - SELF = 1 - TOKENS = 2 - -@dataclass -class SwapCrossAttnContext: - - cross_attention_types_to_do: list[AttentionType] - modified_text_embeddings: torch.Tensor - index_map: torch.Tensor # maps from original prompt token indices to the equivalent tokens in the modified prompt - mask: torch.Tensor # in the target space of the index_map - - def __int__(self, - cac_types_to_do: [AttentionType], - modified_text_embeddings: torch.Tensor, - index_map: torch.Tensor, - mask: torch.Tensor): - self.cross_attention_types_to_do = cac_types_to_do - self.modified_text_embeddings = modified_text_embeddings - self.index_map = index_map - self.mask = mask - - def wants_cross_attention_control(self, attn_type: AttentionType) -> bool: - return attn_type in self.cross_attention_types_to_do - - -class SwapCrossAttnProcessor(CrossAttnProcessor): - def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, - # kwargs - cross_attention_swap_context_provider: SwapCrossAttnContext=None): - - if cross_attention_swap_context_provider is None: - raise RuntimeError("a SwapCrossAttnContext instance must be passed via attention processor kwargs") - - attention_type = AttentionType.SELF if encoder_hidden_states is None else AttentionType.TOKENS - # if cross-attention control is not in play, just call through to the base implementation. - if not cross_attention_swap_context_provider.wants_cross_attention_control(attention_type): - return super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask) - - batch_size, sequence_length, _ = hidden_states.shape - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) - - query = attn.to_q(hidden_states) - query = attn.head_to_batch_dim(query) - - # helper function - def get_attention_probs(embeddings): - this_key = attn.to_k(embeddings) - this_key = attn.head_to_batch_dim(this_key) - return attn.get_attention_scores(query, this_key, attention_mask) - - if attention_type == AttentionType.SELF: - # self attention has no remapping, it just bluntly copies the whole tensor - attention_probs = get_attention_probs(hidden_states) - value = attn.to_v(hidden_states) - else: - # tokens (cross) attention - # first, find attention probabilities for the "original" prompt - original_text_embeddings = encoder_hidden_states - original_attention_probs = get_attention_probs(original_text_embeddings) - - # then, find attention probabilities for the "modified" prompt - modified_text_embeddings = cross_attention_swap_context_provider.modified_text_embeddings - modified_attention_probs = get_attention_probs(modified_text_embeddings) - - # because the prompt modifications may result in token sequences shifted forwards or backwards, - # the original attention probabilities must be remapped to account for token index changes in the - # modified prompt - remapped_original_attention_probs = torch.index_select(original_attention_probs, -1, - cross_attention_swap_context_provider.index_map) - - # only some tokens taken from the original attention probabilities. this is controlled by the mask. - mask = cross_attention_swap_context_provider.mask - inverse_mask = 1 - mask - attention_probs = \ - remapped_original_attention_probs * mask + \ - modified_attention_probs * inverse_mask - - # for the "value" just use the modified text embeddings. - value = attn.to_v(modified_text_embeddings) - - value = attn.head_to_batch_dim(value) - - hidden_states = torch.bmm(attention_probs, value) - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - return hidden_states - - -class P2PCrossAttentionProc: - - def __init__(self, head_size, upcast_attention, attn_maps_reweight): - super().__init__(head_size=head_size, upcast_attention=upcast_attention) - self.attn_maps_reweight = attn_maps_reweight - - def __call__(self, hidden_states, query_proj, key_proj, value_proj, encoder_hidden_states, modified_text_embeddings): - batch_size, sequence_length, _ = hidden_states.shape - query = query_proj(hidden_states) - - context = context if context is not None else hidden_states - attention_probs = [] - original_text_embeddings = encoder_hidden_states - for context in [original_text_embeddings, modified_text_embeddings]: - key = key_proj(original_text_embeddings) - value = self.value_proj(original_text_embeddings) - - query = self.head_to_batch_dim(query, self.head_size) - key = self.head_to_batch_dim(key, self.head_size) - value = self.head_to_batch_dim(value, self.head_size) - - attention_probs.append(self.get_attention_scores(query, key)) - - merged_probs = self.attn_maps_reweight * torch.cat(attention_probs) - hidden_states = torch.bmm(attention_probs, value) - hidden_states = self.batch_to_head_dim(hidden_states) - return hidden_states - -proc = P2PCrossAttentionProc(unet.config.head_size, unet.config.upcast_attention, 0.6) diff --git a/ldm/models/diffusion/cross_attention_control.py b/ldm/models/diffusion/cross_attention_control.py index 6eae824301..4b89b5bd56 100644 --- a/ldm/models/diffusion/cross_attention_control.py +++ b/ldm/models/diffusion/cross_attention_control.py @@ -9,6 +9,7 @@ from torch import nn from diffusers.models.unet_2d_condition import UNet2DConditionModel from ldm.invoke.devices import torch_dtype + # adapted from bloc97's CrossAttentionControl colab # https://github.com/bloc97/CrossAttentionControl @@ -304,11 +305,16 @@ class InvokeAICrossAttentionMixin: -def remove_cross_attention_control(model): - remove_attention_function(model) +def remove_cross_attention_control(model, is_running_diffusers: bool): + if is_running_diffusers: + unet = model + print("** need to know what cross attn processor to use by default, None in the following line is wrong") + unet.set_attn_processor(CrossAttnProcessor()) + else: + remove_attention_function(model) -def setup_cross_attention_control(model, context: Context): +def setup_cross_attention_control(model, context: Context, is_running_diffusers = False): """ Inject attention parameters and functions into the passed in model to enable cross attention editing. @@ -333,10 +339,16 @@ def setup_cross_attention_control(model, context: Context): indices[b0:b1] = indices_target[a0:a1] mask[b0:b1] = 1 - #context.register_cross_attention_modules(model) context.cross_attention_mask = mask.to(device) context.cross_attention_index_map = indices.to(device) - #inject_attention_function(model, context) + if is_running_diffusers: + unet = model + unet.set_attn_processor(SwapCrossAttnProcessor()) + else: + context.register_cross_attention_modules(model) + inject_attention_function(model, context) + + def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[str, InvokeAICrossAttentionMixin]]: @@ -461,3 +473,155 @@ class InvokeAIDiffusersCrossAttention(diffusers.models.attention.CrossAttention, hidden_states = self.reshape_batch_dim_to_heads(attention_result) return hidden_states + + + + +## 🧨diffusers implementation follows + + +""" +# base implementation + +class CrossAttnProcessor: + def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) + + query = attn.to_q(hidden_states) + query = attn.head_to_batch_dim(query) + + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + +""" +import enum +from dataclasses import field, dataclass + +import torch + +from diffusers.models.cross_attention import CrossAttention, CrossAttnProcessor +from ldm.models.diffusion.cross_attention_control import CrossAttentionType + + +@dataclass +class SwapCrossAttnContext: + modified_text_embeddings: torch.Tensor + index_map: torch.Tensor # maps from original prompt token indices to the equivalent tokens in the modified prompt + mask: torch.Tensor # in the target space of the index_map + cross_attention_types_to_do: list[CrossAttentionType] = field(default_factory=[]) + + def __int__(self, + cac_types_to_do: [CrossAttentionType], + modified_text_embeddings: torch.Tensor, + index_map: torch.Tensor, + mask: torch.Tensor): + self.cross_attention_types_to_do = cac_types_to_do + self.modified_text_embeddings = modified_text_embeddings + self.index_map = index_map + self.mask = mask + + def wants_cross_attention_control(self, attn_type: CrossAttentionType) -> bool: + return attn_type in self.cross_attention_types_to_do + + @classmethod + def make_mask_and_index_map(cls, edit_opcodes: list[tuple[str, int, int, int, int]], max_length: int) \ + -> tuple[torch.Tensor, torch.Tensor]: + + # mask=1 means use original prompt attention, mask=0 means use modified prompt attention + mask = torch.zeros(max_length) + indices_target = torch.arange(max_length, dtype=torch.long) + indices = torch.arange(max_length, dtype=torch.long) + for name, a0, a1, b0, b1 in edit_opcodes: + if b0 < max_length: + if name == "equal": + # these tokens remain the same as in the original prompt + indices[b0:b1] = indices_target[a0:a1] + mask[b0:b1] = 1 + + return mask, indices + + +class SwapCrossAttnProcessor(CrossAttnProcessor): + + def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, + # kwargs + swap_cross_attn_context: SwapCrossAttnContext=None): + + attention_type = CrossAttentionType.SELF if encoder_hidden_states is None else CrossAttentionType.TOKENS + + # if cross-attention control is not in play, just call through to the base implementation. + if swap_cross_attn_context is None or not swap_cross_attn_context.wants_cross_attention_control(attention_type): + #print(f"SwapCrossAttnContext for {attention_type} not active - passing request to superclass") + return super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask) + #else: + # print(f"SwapCrossAttnContext for {attention_type} active") + + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) + + query = attn.to_q(hidden_states) + query = attn.head_to_batch_dim(query) + + # helper function + def get_attention_probs(embeddings): + this_key = attn.to_k(embeddings) + this_key = attn.head_to_batch_dim(this_key) + return attn.get_attention_scores(query, this_key, attention_mask) + + if attention_type == CrossAttentionType.SELF: + # self attention has no remapping, it just bluntly copies the whole tensor + attention_probs = get_attention_probs(hidden_states) + value = attn.to_v(hidden_states) + else: + # tokens (cross) attention + # first, find attention probabilities for the "original" prompt + original_text_embeddings = encoder_hidden_states + original_attention_probs = get_attention_probs(original_text_embeddings) + + # then, find attention probabilities for the "modified" prompt + modified_text_embeddings = swap_cross_attn_context.modified_text_embeddings + modified_attention_probs = get_attention_probs(modified_text_embeddings) + + # because the prompt modifications may result in token sequences shifted forwards or backwards, + # the original attention probabilities must be remapped to account for token index changes in the + # modified prompt + remapped_original_attention_probs = torch.index_select(original_attention_probs, -1, + swap_cross_attn_context.index_map) + + # only some tokens taken from the original attention probabilities. this is controlled by the mask. + mask = swap_cross_attn_context.mask + inverse_mask = 1 - mask + attention_probs = \ + remapped_original_attention_probs * mask + \ + modified_attention_probs * inverse_mask + + # for the "value" just use the modified text embeddings. + value = attn.to_v(modified_text_embeddings) + + value = attn.head_to_batch_dim(value) + + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + diff --git a/ldm/models/diffusion/shared_invokeai_diffusion.py b/ldm/models/diffusion/shared_invokeai_diffusion.py index 3be6b10170..e4932f6ad8 100644 --- a/ldm/models/diffusion/shared_invokeai_diffusion.py +++ b/ldm/models/diffusion/shared_invokeai_diffusion.py @@ -1,14 +1,14 @@ import math from dataclasses import dataclass from math import ceil -from typing import Callable, Optional, Union +from typing import Callable, Optional, Union, Any import numpy as np import torch from ldm.models.diffusion.cross_attention_control import Arguments, \ remove_cross_attention_control, setup_cross_attention_control, Context, get_cross_attention_modules, \ - CrossAttentionType + CrossAttentionType, SwapCrossAttnContext from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver @@ -30,24 +30,28 @@ class InvokeAIDiffuserComponent: debug_thresholding = False + @dataclass class ExtraConditioningInfo: - def __init__(self, tokens_count_including_eos_bos:int, cross_attention_control_args: Optional[Arguments]): - self.tokens_count_including_eos_bos = tokens_count_including_eos_bos - self.cross_attention_control_args = cross_attention_control_args + + tokens_count_including_eos_bos: int + cross_attention_control_args: Optional[Arguments] = None @property def wants_cross_attention_control(self): return self.cross_attention_control_args is not None + def __init__(self, model, model_forward_callback: - Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor] - ): + Callable[[torch.Tensor, torch.Tensor, torch.Tensor, Optional[dict[str,Any]]], torch.Tensor], + is_running_diffusers: bool=False, + ): """ :param model: the unet model to pass through to cross attention control :param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning) """ self.conditioning = None self.model = model + self.is_running_diffusers = is_running_diffusers self.model_forward_callback = model_forward_callback self.cross_attention_control_context = None @@ -57,12 +61,14 @@ class InvokeAIDiffuserComponent: arguments=self.conditioning.cross_attention_control_args, step_count=step_count ) - setup_cross_attention_control(self.model, self.cross_attention_control_context) + setup_cross_attention_control(self.model, + self.cross_attention_control_context, + is_running_diffusers=self.is_running_diffusers) def remove_cross_attention_control(self): self.conditioning = None self.cross_attention_control_context = None - remove_cross_attention_control(self.model) + remove_cross_attention_control(self.model, is_running_diffusers=self.is_running_diffusers) def setup_attention_map_saving(self, saver: AttentionMapSaver): def callback(slice, dim, offset, slice_size, key): @@ -168,7 +174,41 @@ class InvokeAIDiffuserComponent: return unconditioned_next_x, conditioned_next_x - def apply_cross_attention_controlled_conditioning(self, x:torch.Tensor, sigma, unconditioning, conditioning, cross_attention_control_types_to_do): + def apply_cross_attention_controlled_conditioning(self, + x: torch.Tensor, + sigma, + unconditioning, + conditioning, + cross_attention_control_types_to_do): + if self.is_running_diffusers: + return self.apply_cross_attention_controlled_conditioning__diffusers(x, sigma, unconditioning, conditioning, cross_attention_control_types_to_do) + else: + return self.apply_cross_attention_controlled_conditioning__compvis(x, sigma, unconditioning, conditioning, cross_attention_control_types_to_do) + + def apply_cross_attention_controlled_conditioning__diffusers(self, + x: torch.Tensor, + sigma, + unconditioning, + conditioning, + cross_attention_control_types_to_do): + context: Context = self.cross_attention_control_context + + cross_attn_processor_context = SwapCrossAttnContext(modified_text_embeddings=context.arguments.edited_conditioning, + index_map=context.cross_attention_index_map, + mask=context.cross_attention_mask, + cross_attention_types_to_do=[]) + # no cross attention for unconditioning (negative prompt) + unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, + {"swap_cross_attn_context": cross_attn_processor_context}) + + # do requested cross attention types for conditioning (positive prompt) + cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do + conditioned_next_x = self.model_forward_callback(x, sigma, conditioning, + {"swap_cross_attn_context": cross_attn_processor_context}) + return unconditioned_next_x, conditioned_next_x + + + def apply_cross_attention_controlled_conditioning__compvis(self, x:torch.Tensor, sigma, unconditioning, conditioning, cross_attention_control_types_to_do): # print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do) # slower non-batched path (20% slower on mac MPS) # We are only interested in using attention maps for conditioned_next_x, but batching them with generation of