SwapCrossAttnProcessor working - tested on mac CPU (MPS doesn't work)

This commit is contained in:
Damian Stewart
2023-01-21 20:54:18 +01:00
parent 0c2a511671
commit bffe199ad7
4 changed files with 226 additions and 181 deletions

View File

@ -24,9 +24,6 @@ from ...models.diffusion import cross_attention_control
from ...models.diffusion.cross_attention_map_saving import AttentionMapSaver
from ...modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToEmbeddingsConverter
# monkeypatch diffusers CrossAttention 🙈
# this is to make prompt2prompt and (future) attention maps work
attention.CrossAttention = cross_attention_control.InvokeAIDiffusersCrossAttention
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
@ -295,7 +292,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward)
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward, is_running_diffusers=True)
use_full_precision = (precision == 'float32' or precision == 'autocast')
self.textual_inversion_manager = TextualInversionManager(tokenizer=self.tokenizer,
text_encoder=self.text_encoder,
@ -389,6 +386,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
attention_map_saver: Optional[AttentionMapSaver] = None
self.invokeai_diffuser.remove_attention_map_saving()
for i, t in enumerate(self.progress_bar(timesteps)):
batched_t.fill_(t)
step_output = self.step(batched_t, latents, conditioning_data,
@ -447,7 +445,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
return step_output
def _unet_forward(self, latents, t, text_embeddings):
def _unet_forward(self, latents, t, text_embeddings, cross_attention_kwargs: Optional[dict[str,Any]] = None):
"""predict the noise residual"""
if is_inpainting_model(self.unet) and latents.size(1) == 4:
# Pad out normal non-inpainting inputs for an inpainting model.
@ -460,7 +458,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
initial_image_latents=torch.zeros_like(latents[:1], device=latents.device, dtype=latents.dtype)
).add_mask_channels(latents)
return self.unet(latents, t, encoder_hidden_states=text_embeddings).sample
return self.unet(sample=latents,
timestep=t,
encoder_hidden_states=text_embeddings,
cross_attention_kwargs=cross_attention_kwargs).sample
def img2img_from_embeddings(self,
init_image: Union[torch.FloatTensor, PIL.Image.Image],

View File

@ -1,160 +0,0 @@
"""
# base implementation
class CrossAttnProcessor:
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
query = attn.to_q(hidden_states)
query = attn.head_to_batch_dim(query)
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
"""
import enum
from dataclasses import field, dataclass
import torch
from diffusers.models.cross_attention import CrossAttention, CrossAttnProcessor
class AttentionType(enum.Enum):
SELF = 1
TOKENS = 2
@dataclass
class SwapCrossAttnContext:
cross_attention_types_to_do: list[AttentionType]
modified_text_embeddings: torch.Tensor
index_map: torch.Tensor # maps from original prompt token indices to the equivalent tokens in the modified prompt
mask: torch.Tensor # in the target space of the index_map
def __int__(self,
cac_types_to_do: [AttentionType],
modified_text_embeddings: torch.Tensor,
index_map: torch.Tensor,
mask: torch.Tensor):
self.cross_attention_types_to_do = cac_types_to_do
self.modified_text_embeddings = modified_text_embeddings
self.index_map = index_map
self.mask = mask
def wants_cross_attention_control(self, attn_type: AttentionType) -> bool:
return attn_type in self.cross_attention_types_to_do
class SwapCrossAttnProcessor(CrossAttnProcessor):
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None,
# kwargs
cross_attention_swap_context_provider: SwapCrossAttnContext=None):
if cross_attention_swap_context_provider is None:
raise RuntimeError("a SwapCrossAttnContext instance must be passed via attention processor kwargs")
attention_type = AttentionType.SELF if encoder_hidden_states is None else AttentionType.TOKENS
# if cross-attention control is not in play, just call through to the base implementation.
if not cross_attention_swap_context_provider.wants_cross_attention_control(attention_type):
return super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask)
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
query = attn.to_q(hidden_states)
query = attn.head_to_batch_dim(query)
# helper function
def get_attention_probs(embeddings):
this_key = attn.to_k(embeddings)
this_key = attn.head_to_batch_dim(this_key)
return attn.get_attention_scores(query, this_key, attention_mask)
if attention_type == AttentionType.SELF:
# self attention has no remapping, it just bluntly copies the whole tensor
attention_probs = get_attention_probs(hidden_states)
value = attn.to_v(hidden_states)
else:
# tokens (cross) attention
# first, find attention probabilities for the "original" prompt
original_text_embeddings = encoder_hidden_states
original_attention_probs = get_attention_probs(original_text_embeddings)
# then, find attention probabilities for the "modified" prompt
modified_text_embeddings = cross_attention_swap_context_provider.modified_text_embeddings
modified_attention_probs = get_attention_probs(modified_text_embeddings)
# because the prompt modifications may result in token sequences shifted forwards or backwards,
# the original attention probabilities must be remapped to account for token index changes in the
# modified prompt
remapped_original_attention_probs = torch.index_select(original_attention_probs, -1,
cross_attention_swap_context_provider.index_map)
# only some tokens taken from the original attention probabilities. this is controlled by the mask.
mask = cross_attention_swap_context_provider.mask
inverse_mask = 1 - mask
attention_probs = \
remapped_original_attention_probs * mask + \
modified_attention_probs * inverse_mask
# for the "value" just use the modified text embeddings.
value = attn.to_v(modified_text_embeddings)
value = attn.head_to_batch_dim(value)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class P2PCrossAttentionProc:
def __init__(self, head_size, upcast_attention, attn_maps_reweight):
super().__init__(head_size=head_size, upcast_attention=upcast_attention)
self.attn_maps_reweight = attn_maps_reweight
def __call__(self, hidden_states, query_proj, key_proj, value_proj, encoder_hidden_states, modified_text_embeddings):
batch_size, sequence_length, _ = hidden_states.shape
query = query_proj(hidden_states)
context = context if context is not None else hidden_states
attention_probs = []
original_text_embeddings = encoder_hidden_states
for context in [original_text_embeddings, modified_text_embeddings]:
key = key_proj(original_text_embeddings)
value = self.value_proj(original_text_embeddings)
query = self.head_to_batch_dim(query, self.head_size)
key = self.head_to_batch_dim(key, self.head_size)
value = self.head_to_batch_dim(value, self.head_size)
attention_probs.append(self.get_attention_scores(query, key))
merged_probs = self.attn_maps_reweight * torch.cat(attention_probs)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = self.batch_to_head_dim(hidden_states)
return hidden_states
proc = P2PCrossAttentionProc(unet.config.head_size, unet.config.upcast_attention, 0.6)