SwapCrossAttnProcessor working - tested on mac CPU (MPS doesn't work)

This commit is contained in:
Damian Stewart 2023-01-21 20:54:18 +01:00
parent 0c2a511671
commit bffe199ad7
4 changed files with 226 additions and 181 deletions

View File

@ -24,9 +24,6 @@ from ...models.diffusion import cross_attention_control
from ...models.diffusion.cross_attention_map_saving import AttentionMapSaver
from ...modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToEmbeddingsConverter
# monkeypatch diffusers CrossAttention 🙈
# this is to make prompt2prompt and (future) attention maps work
attention.CrossAttention = cross_attention_control.InvokeAIDiffusersCrossAttention
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
@ -295,7 +292,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward)
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward, is_running_diffusers=True)
use_full_precision = (precision == 'float32' or precision == 'autocast')
self.textual_inversion_manager = TextualInversionManager(tokenizer=self.tokenizer,
text_encoder=self.text_encoder,
@ -389,6 +386,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
attention_map_saver: Optional[AttentionMapSaver] = None
self.invokeai_diffuser.remove_attention_map_saving()
for i, t in enumerate(self.progress_bar(timesteps)):
batched_t.fill_(t)
step_output = self.step(batched_t, latents, conditioning_data,
@ -447,7 +445,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
return step_output
def _unet_forward(self, latents, t, text_embeddings):
def _unet_forward(self, latents, t, text_embeddings, cross_attention_kwargs: Optional[dict[str,Any]] = None):
"""predict the noise residual"""
if is_inpainting_model(self.unet) and latents.size(1) == 4:
# Pad out normal non-inpainting inputs for an inpainting model.
@ -460,7 +458,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
initial_image_latents=torch.zeros_like(latents[:1], device=latents.device, dtype=latents.dtype)
).add_mask_channels(latents)
return self.unet(latents, t, encoder_hidden_states=text_embeddings).sample
return self.unet(sample=latents,
timestep=t,
encoder_hidden_states=text_embeddings,
cross_attention_kwargs=cross_attention_kwargs).sample
def img2img_from_embeddings(self,
init_image: Union[torch.FloatTensor, PIL.Image.Image],

View File

@ -1,160 +0,0 @@
"""
# base implementation
class CrossAttnProcessor:
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
query = attn.to_q(hidden_states)
query = attn.head_to_batch_dim(query)
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
"""
import enum
from dataclasses import field, dataclass
import torch
from diffusers.models.cross_attention import CrossAttention, CrossAttnProcessor
class AttentionType(enum.Enum):
SELF = 1
TOKENS = 2
@dataclass
class SwapCrossAttnContext:
cross_attention_types_to_do: list[AttentionType]
modified_text_embeddings: torch.Tensor
index_map: torch.Tensor # maps from original prompt token indices to the equivalent tokens in the modified prompt
mask: torch.Tensor # in the target space of the index_map
def __int__(self,
cac_types_to_do: [AttentionType],
modified_text_embeddings: torch.Tensor,
index_map: torch.Tensor,
mask: torch.Tensor):
self.cross_attention_types_to_do = cac_types_to_do
self.modified_text_embeddings = modified_text_embeddings
self.index_map = index_map
self.mask = mask
def wants_cross_attention_control(self, attn_type: AttentionType) -> bool:
return attn_type in self.cross_attention_types_to_do
class SwapCrossAttnProcessor(CrossAttnProcessor):
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None,
# kwargs
cross_attention_swap_context_provider: SwapCrossAttnContext=None):
if cross_attention_swap_context_provider is None:
raise RuntimeError("a SwapCrossAttnContext instance must be passed via attention processor kwargs")
attention_type = AttentionType.SELF if encoder_hidden_states is None else AttentionType.TOKENS
# if cross-attention control is not in play, just call through to the base implementation.
if not cross_attention_swap_context_provider.wants_cross_attention_control(attention_type):
return super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask)
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
query = attn.to_q(hidden_states)
query = attn.head_to_batch_dim(query)
# helper function
def get_attention_probs(embeddings):
this_key = attn.to_k(embeddings)
this_key = attn.head_to_batch_dim(this_key)
return attn.get_attention_scores(query, this_key, attention_mask)
if attention_type == AttentionType.SELF:
# self attention has no remapping, it just bluntly copies the whole tensor
attention_probs = get_attention_probs(hidden_states)
value = attn.to_v(hidden_states)
else:
# tokens (cross) attention
# first, find attention probabilities for the "original" prompt
original_text_embeddings = encoder_hidden_states
original_attention_probs = get_attention_probs(original_text_embeddings)
# then, find attention probabilities for the "modified" prompt
modified_text_embeddings = cross_attention_swap_context_provider.modified_text_embeddings
modified_attention_probs = get_attention_probs(modified_text_embeddings)
# because the prompt modifications may result in token sequences shifted forwards or backwards,
# the original attention probabilities must be remapped to account for token index changes in the
# modified prompt
remapped_original_attention_probs = torch.index_select(original_attention_probs, -1,
cross_attention_swap_context_provider.index_map)
# only some tokens taken from the original attention probabilities. this is controlled by the mask.
mask = cross_attention_swap_context_provider.mask
inverse_mask = 1 - mask
attention_probs = \
remapped_original_attention_probs * mask + \
modified_attention_probs * inverse_mask
# for the "value" just use the modified text embeddings.
value = attn.to_v(modified_text_embeddings)
value = attn.head_to_batch_dim(value)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class P2PCrossAttentionProc:
def __init__(self, head_size, upcast_attention, attn_maps_reweight):
super().__init__(head_size=head_size, upcast_attention=upcast_attention)
self.attn_maps_reweight = attn_maps_reweight
def __call__(self, hidden_states, query_proj, key_proj, value_proj, encoder_hidden_states, modified_text_embeddings):
batch_size, sequence_length, _ = hidden_states.shape
query = query_proj(hidden_states)
context = context if context is not None else hidden_states
attention_probs = []
original_text_embeddings = encoder_hidden_states
for context in [original_text_embeddings, modified_text_embeddings]:
key = key_proj(original_text_embeddings)
value = self.value_proj(original_text_embeddings)
query = self.head_to_batch_dim(query, self.head_size)
key = self.head_to_batch_dim(key, self.head_size)
value = self.head_to_batch_dim(value, self.head_size)
attention_probs.append(self.get_attention_scores(query, key))
merged_probs = self.attn_maps_reweight * torch.cat(attention_probs)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = self.batch_to_head_dim(hidden_states)
return hidden_states
proc = P2PCrossAttentionProc(unet.config.head_size, unet.config.upcast_attention, 0.6)

View File

@ -9,6 +9,7 @@ from torch import nn
from diffusers.models.unet_2d_condition import UNet2DConditionModel
from ldm.invoke.devices import torch_dtype
# adapted from bloc97's CrossAttentionControl colab
# https://github.com/bloc97/CrossAttentionControl
@ -304,11 +305,16 @@ class InvokeAICrossAttentionMixin:
def remove_cross_attention_control(model):
remove_attention_function(model)
def remove_cross_attention_control(model, is_running_diffusers: bool):
if is_running_diffusers:
unet = model
print("** need to know what cross attn processor to use by default, None in the following line is wrong")
unet.set_attn_processor(CrossAttnProcessor())
else:
remove_attention_function(model)
def setup_cross_attention_control(model, context: Context):
def setup_cross_attention_control(model, context: Context, is_running_diffusers = False):
"""
Inject attention parameters and functions into the passed in model to enable cross attention editing.
@ -333,10 +339,16 @@ def setup_cross_attention_control(model, context: Context):
indices[b0:b1] = indices_target[a0:a1]
mask[b0:b1] = 1
#context.register_cross_attention_modules(model)
context.cross_attention_mask = mask.to(device)
context.cross_attention_index_map = indices.to(device)
#inject_attention_function(model, context)
if is_running_diffusers:
unet = model
unet.set_attn_processor(SwapCrossAttnProcessor())
else:
context.register_cross_attention_modules(model)
inject_attention_function(model, context)
def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[str, InvokeAICrossAttentionMixin]]:
@ -461,3 +473,155 @@ class InvokeAIDiffusersCrossAttention(diffusers.models.attention.CrossAttention,
hidden_states = self.reshape_batch_dim_to_heads(attention_result)
return hidden_states
## 🧨diffusers implementation follows
"""
# base implementation
class CrossAttnProcessor:
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
query = attn.to_q(hidden_states)
query = attn.head_to_batch_dim(query)
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
"""
import enum
from dataclasses import field, dataclass
import torch
from diffusers.models.cross_attention import CrossAttention, CrossAttnProcessor
from ldm.models.diffusion.cross_attention_control import CrossAttentionType
@dataclass
class SwapCrossAttnContext:
modified_text_embeddings: torch.Tensor
index_map: torch.Tensor # maps from original prompt token indices to the equivalent tokens in the modified prompt
mask: torch.Tensor # in the target space of the index_map
cross_attention_types_to_do: list[CrossAttentionType] = field(default_factory=[])
def __int__(self,
cac_types_to_do: [CrossAttentionType],
modified_text_embeddings: torch.Tensor,
index_map: torch.Tensor,
mask: torch.Tensor):
self.cross_attention_types_to_do = cac_types_to_do
self.modified_text_embeddings = modified_text_embeddings
self.index_map = index_map
self.mask = mask
def wants_cross_attention_control(self, attn_type: CrossAttentionType) -> bool:
return attn_type in self.cross_attention_types_to_do
@classmethod
def make_mask_and_index_map(cls, edit_opcodes: list[tuple[str, int, int, int, int]], max_length: int) \
-> tuple[torch.Tensor, torch.Tensor]:
# mask=1 means use original prompt attention, mask=0 means use modified prompt attention
mask = torch.zeros(max_length)
indices_target = torch.arange(max_length, dtype=torch.long)
indices = torch.arange(max_length, dtype=torch.long)
for name, a0, a1, b0, b1 in edit_opcodes:
if b0 < max_length:
if name == "equal":
# these tokens remain the same as in the original prompt
indices[b0:b1] = indices_target[a0:a1]
mask[b0:b1] = 1
return mask, indices
class SwapCrossAttnProcessor(CrossAttnProcessor):
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None,
# kwargs
swap_cross_attn_context: SwapCrossAttnContext=None):
attention_type = CrossAttentionType.SELF if encoder_hidden_states is None else CrossAttentionType.TOKENS
# if cross-attention control is not in play, just call through to the base implementation.
if swap_cross_attn_context is None or not swap_cross_attn_context.wants_cross_attention_control(attention_type):
#print(f"SwapCrossAttnContext for {attention_type} not active - passing request to superclass")
return super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask)
#else:
# print(f"SwapCrossAttnContext for {attention_type} active")
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
query = attn.to_q(hidden_states)
query = attn.head_to_batch_dim(query)
# helper function
def get_attention_probs(embeddings):
this_key = attn.to_k(embeddings)
this_key = attn.head_to_batch_dim(this_key)
return attn.get_attention_scores(query, this_key, attention_mask)
if attention_type == CrossAttentionType.SELF:
# self attention has no remapping, it just bluntly copies the whole tensor
attention_probs = get_attention_probs(hidden_states)
value = attn.to_v(hidden_states)
else:
# tokens (cross) attention
# first, find attention probabilities for the "original" prompt
original_text_embeddings = encoder_hidden_states
original_attention_probs = get_attention_probs(original_text_embeddings)
# then, find attention probabilities for the "modified" prompt
modified_text_embeddings = swap_cross_attn_context.modified_text_embeddings
modified_attention_probs = get_attention_probs(modified_text_embeddings)
# because the prompt modifications may result in token sequences shifted forwards or backwards,
# the original attention probabilities must be remapped to account for token index changes in the
# modified prompt
remapped_original_attention_probs = torch.index_select(original_attention_probs, -1,
swap_cross_attn_context.index_map)
# only some tokens taken from the original attention probabilities. this is controlled by the mask.
mask = swap_cross_attn_context.mask
inverse_mask = 1 - mask
attention_probs = \
remapped_original_attention_probs * mask + \
modified_attention_probs * inverse_mask
# for the "value" just use the modified text embeddings.
value = attn.to_v(modified_text_embeddings)
value = attn.head_to_batch_dim(value)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states

View File

@ -1,14 +1,14 @@
import math
from dataclasses import dataclass
from math import ceil
from typing import Callable, Optional, Union
from typing import Callable, Optional, Union, Any
import numpy as np
import torch
from ldm.models.diffusion.cross_attention_control import Arguments, \
remove_cross_attention_control, setup_cross_attention_control, Context, get_cross_attention_modules, \
CrossAttentionType
CrossAttentionType, SwapCrossAttnContext
from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
@ -30,24 +30,28 @@ class InvokeAIDiffuserComponent:
debug_thresholding = False
@dataclass
class ExtraConditioningInfo:
def __init__(self, tokens_count_including_eos_bos:int, cross_attention_control_args: Optional[Arguments]):
self.tokens_count_including_eos_bos = tokens_count_including_eos_bos
self.cross_attention_control_args = cross_attention_control_args
tokens_count_including_eos_bos: int
cross_attention_control_args: Optional[Arguments] = None
@property
def wants_cross_attention_control(self):
return self.cross_attention_control_args is not None
def __init__(self, model, model_forward_callback:
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
):
Callable[[torch.Tensor, torch.Tensor, torch.Tensor, Optional[dict[str,Any]]], torch.Tensor],
is_running_diffusers: bool=False,
):
"""
:param model: the unet model to pass through to cross attention control
:param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning)
"""
self.conditioning = None
self.model = model
self.is_running_diffusers = is_running_diffusers
self.model_forward_callback = model_forward_callback
self.cross_attention_control_context = None
@ -57,12 +61,14 @@ class InvokeAIDiffuserComponent:
arguments=self.conditioning.cross_attention_control_args,
step_count=step_count
)
setup_cross_attention_control(self.model, self.cross_attention_control_context)
setup_cross_attention_control(self.model,
self.cross_attention_control_context,
is_running_diffusers=self.is_running_diffusers)
def remove_cross_attention_control(self):
self.conditioning = None
self.cross_attention_control_context = None
remove_cross_attention_control(self.model)
remove_cross_attention_control(self.model, is_running_diffusers=self.is_running_diffusers)
def setup_attention_map_saving(self, saver: AttentionMapSaver):
def callback(slice, dim, offset, slice_size, key):
@ -168,7 +174,41 @@ class InvokeAIDiffuserComponent:
return unconditioned_next_x, conditioned_next_x
def apply_cross_attention_controlled_conditioning(self, x:torch.Tensor, sigma, unconditioning, conditioning, cross_attention_control_types_to_do):
def apply_cross_attention_controlled_conditioning(self,
x: torch.Tensor,
sigma,
unconditioning,
conditioning,
cross_attention_control_types_to_do):
if self.is_running_diffusers:
return self.apply_cross_attention_controlled_conditioning__diffusers(x, sigma, unconditioning, conditioning, cross_attention_control_types_to_do)
else:
return self.apply_cross_attention_controlled_conditioning__compvis(x, sigma, unconditioning, conditioning, cross_attention_control_types_to_do)
def apply_cross_attention_controlled_conditioning__diffusers(self,
x: torch.Tensor,
sigma,
unconditioning,
conditioning,
cross_attention_control_types_to_do):
context: Context = self.cross_attention_control_context
cross_attn_processor_context = SwapCrossAttnContext(modified_text_embeddings=context.arguments.edited_conditioning,
index_map=context.cross_attention_index_map,
mask=context.cross_attention_mask,
cross_attention_types_to_do=[])
# no cross attention for unconditioning (negative prompt)
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning,
{"swap_cross_attn_context": cross_attn_processor_context})
# do requested cross attention types for conditioning (positive prompt)
cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do
conditioned_next_x = self.model_forward_callback(x, sigma, conditioning,
{"swap_cross_attn_context": cross_attn_processor_context})
return unconditioned_next_x, conditioned_next_x
def apply_cross_attention_controlled_conditioning__compvis(self, x:torch.Tensor, sigma, unconditioning, conditioning, cross_attention_control_types_to_do):
# print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do)
# slower non-batched path (20% slower on mac MPS)
# We are only interested in using attention maps for conditioned_next_x, but batching them with generation of