SwapCrossAttnProcessor working - tested on mac CPU (MPS doesn't work)

This commit is contained in:
Damian Stewart
2023-01-21 20:54:18 +01:00
parent 0c2a511671
commit bffe199ad7
4 changed files with 226 additions and 181 deletions

View File

@ -9,6 +9,7 @@ from torch import nn
from diffusers.models.unet_2d_condition import UNet2DConditionModel
from ldm.invoke.devices import torch_dtype
# adapted from bloc97's CrossAttentionControl colab
# https://github.com/bloc97/CrossAttentionControl
@ -304,11 +305,16 @@ class InvokeAICrossAttentionMixin:
def remove_cross_attention_control(model):
remove_attention_function(model)
def remove_cross_attention_control(model, is_running_diffusers: bool):
if is_running_diffusers:
unet = model
print("** need to know what cross attn processor to use by default, None in the following line is wrong")
unet.set_attn_processor(CrossAttnProcessor())
else:
remove_attention_function(model)
def setup_cross_attention_control(model, context: Context):
def setup_cross_attention_control(model, context: Context, is_running_diffusers = False):
"""
Inject attention parameters and functions into the passed in model to enable cross attention editing.
@ -333,10 +339,16 @@ def setup_cross_attention_control(model, context: Context):
indices[b0:b1] = indices_target[a0:a1]
mask[b0:b1] = 1
#context.register_cross_attention_modules(model)
context.cross_attention_mask = mask.to(device)
context.cross_attention_index_map = indices.to(device)
#inject_attention_function(model, context)
if is_running_diffusers:
unet = model
unet.set_attn_processor(SwapCrossAttnProcessor())
else:
context.register_cross_attention_modules(model)
inject_attention_function(model, context)
def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[str, InvokeAICrossAttentionMixin]]:
@ -461,3 +473,155 @@ class InvokeAIDiffusersCrossAttention(diffusers.models.attention.CrossAttention,
hidden_states = self.reshape_batch_dim_to_heads(attention_result)
return hidden_states
## 🧨diffusers implementation follows
"""
# base implementation
class CrossAttnProcessor:
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
query = attn.to_q(hidden_states)
query = attn.head_to_batch_dim(query)
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
"""
import enum
from dataclasses import field, dataclass
import torch
from diffusers.models.cross_attention import CrossAttention, CrossAttnProcessor
from ldm.models.diffusion.cross_attention_control import CrossAttentionType
@dataclass
class SwapCrossAttnContext:
modified_text_embeddings: torch.Tensor
index_map: torch.Tensor # maps from original prompt token indices to the equivalent tokens in the modified prompt
mask: torch.Tensor # in the target space of the index_map
cross_attention_types_to_do: list[CrossAttentionType] = field(default_factory=[])
def __int__(self,
cac_types_to_do: [CrossAttentionType],
modified_text_embeddings: torch.Tensor,
index_map: torch.Tensor,
mask: torch.Tensor):
self.cross_attention_types_to_do = cac_types_to_do
self.modified_text_embeddings = modified_text_embeddings
self.index_map = index_map
self.mask = mask
def wants_cross_attention_control(self, attn_type: CrossAttentionType) -> bool:
return attn_type in self.cross_attention_types_to_do
@classmethod
def make_mask_and_index_map(cls, edit_opcodes: list[tuple[str, int, int, int, int]], max_length: int) \
-> tuple[torch.Tensor, torch.Tensor]:
# mask=1 means use original prompt attention, mask=0 means use modified prompt attention
mask = torch.zeros(max_length)
indices_target = torch.arange(max_length, dtype=torch.long)
indices = torch.arange(max_length, dtype=torch.long)
for name, a0, a1, b0, b1 in edit_opcodes:
if b0 < max_length:
if name == "equal":
# these tokens remain the same as in the original prompt
indices[b0:b1] = indices_target[a0:a1]
mask[b0:b1] = 1
return mask, indices
class SwapCrossAttnProcessor(CrossAttnProcessor):
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None,
# kwargs
swap_cross_attn_context: SwapCrossAttnContext=None):
attention_type = CrossAttentionType.SELF if encoder_hidden_states is None else CrossAttentionType.TOKENS
# if cross-attention control is not in play, just call through to the base implementation.
if swap_cross_attn_context is None or not swap_cross_attn_context.wants_cross_attention_control(attention_type):
#print(f"SwapCrossAttnContext for {attention_type} not active - passing request to superclass")
return super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask)
#else:
# print(f"SwapCrossAttnContext for {attention_type} active")
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
query = attn.to_q(hidden_states)
query = attn.head_to_batch_dim(query)
# helper function
def get_attention_probs(embeddings):
this_key = attn.to_k(embeddings)
this_key = attn.head_to_batch_dim(this_key)
return attn.get_attention_scores(query, this_key, attention_mask)
if attention_type == CrossAttentionType.SELF:
# self attention has no remapping, it just bluntly copies the whole tensor
attention_probs = get_attention_probs(hidden_states)
value = attn.to_v(hidden_states)
else:
# tokens (cross) attention
# first, find attention probabilities for the "original" prompt
original_text_embeddings = encoder_hidden_states
original_attention_probs = get_attention_probs(original_text_embeddings)
# then, find attention probabilities for the "modified" prompt
modified_text_embeddings = swap_cross_attn_context.modified_text_embeddings
modified_attention_probs = get_attention_probs(modified_text_embeddings)
# because the prompt modifications may result in token sequences shifted forwards or backwards,
# the original attention probabilities must be remapped to account for token index changes in the
# modified prompt
remapped_original_attention_probs = torch.index_select(original_attention_probs, -1,
swap_cross_attn_context.index_map)
# only some tokens taken from the original attention probabilities. this is controlled by the mask.
mask = swap_cross_attn_context.mask
inverse_mask = 1 - mask
attention_probs = \
remapped_original_attention_probs * mask + \
modified_attention_probs * inverse_mask
# for the "value" just use the modified text embeddings.
value = attn.to_v(modified_text_embeddings)
value = attn.head_to_batch_dim(value)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states

View File

@ -1,14 +1,14 @@
import math
from dataclasses import dataclass
from math import ceil
from typing import Callable, Optional, Union
from typing import Callable, Optional, Union, Any
import numpy as np
import torch
from ldm.models.diffusion.cross_attention_control import Arguments, \
remove_cross_attention_control, setup_cross_attention_control, Context, get_cross_attention_modules, \
CrossAttentionType
CrossAttentionType, SwapCrossAttnContext
from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
@ -30,24 +30,28 @@ class InvokeAIDiffuserComponent:
debug_thresholding = False
@dataclass
class ExtraConditioningInfo:
def __init__(self, tokens_count_including_eos_bos:int, cross_attention_control_args: Optional[Arguments]):
self.tokens_count_including_eos_bos = tokens_count_including_eos_bos
self.cross_attention_control_args = cross_attention_control_args
tokens_count_including_eos_bos: int
cross_attention_control_args: Optional[Arguments] = None
@property
def wants_cross_attention_control(self):
return self.cross_attention_control_args is not None
def __init__(self, model, model_forward_callback:
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
):
Callable[[torch.Tensor, torch.Tensor, torch.Tensor, Optional[dict[str,Any]]], torch.Tensor],
is_running_diffusers: bool=False,
):
"""
:param model: the unet model to pass through to cross attention control
:param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning)
"""
self.conditioning = None
self.model = model
self.is_running_diffusers = is_running_diffusers
self.model_forward_callback = model_forward_callback
self.cross_attention_control_context = None
@ -57,12 +61,14 @@ class InvokeAIDiffuserComponent:
arguments=self.conditioning.cross_attention_control_args,
step_count=step_count
)
setup_cross_attention_control(self.model, self.cross_attention_control_context)
setup_cross_attention_control(self.model,
self.cross_attention_control_context,
is_running_diffusers=self.is_running_diffusers)
def remove_cross_attention_control(self):
self.conditioning = None
self.cross_attention_control_context = None
remove_cross_attention_control(self.model)
remove_cross_attention_control(self.model, is_running_diffusers=self.is_running_diffusers)
def setup_attention_map_saving(self, saver: AttentionMapSaver):
def callback(slice, dim, offset, slice_size, key):
@ -168,7 +174,41 @@ class InvokeAIDiffuserComponent:
return unconditioned_next_x, conditioned_next_x
def apply_cross_attention_controlled_conditioning(self, x:torch.Tensor, sigma, unconditioning, conditioning, cross_attention_control_types_to_do):
def apply_cross_attention_controlled_conditioning(self,
x: torch.Tensor,
sigma,
unconditioning,
conditioning,
cross_attention_control_types_to_do):
if self.is_running_diffusers:
return self.apply_cross_attention_controlled_conditioning__diffusers(x, sigma, unconditioning, conditioning, cross_attention_control_types_to_do)
else:
return self.apply_cross_attention_controlled_conditioning__compvis(x, sigma, unconditioning, conditioning, cross_attention_control_types_to_do)
def apply_cross_attention_controlled_conditioning__diffusers(self,
x: torch.Tensor,
sigma,
unconditioning,
conditioning,
cross_attention_control_types_to_do):
context: Context = self.cross_attention_control_context
cross_attn_processor_context = SwapCrossAttnContext(modified_text_embeddings=context.arguments.edited_conditioning,
index_map=context.cross_attention_index_map,
mask=context.cross_attention_mask,
cross_attention_types_to_do=[])
# no cross attention for unconditioning (negative prompt)
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning,
{"swap_cross_attn_context": cross_attn_processor_context})
# do requested cross attention types for conditioning (positive prompt)
cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do
conditioned_next_x = self.model_forward_callback(x, sigma, conditioning,
{"swap_cross_attn_context": cross_attn_processor_context})
return unconditioned_next_x, conditioned_next_x
def apply_cross_attention_controlled_conditioning__compvis(self, x:torch.Tensor, sigma, unconditioning, conditioning, cross_attention_control_types_to_do):
# print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do)
# slower non-batched path (20% slower on mac MPS)
# We are only interested in using attention maps for conditioned_next_x, but batching them with generation of