mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
SwapCrossAttnProcessor working - tested on mac CPU (MPS doesn't work)
This commit is contained in:
parent
0c2a511671
commit
bffe199ad7
@ -24,9 +24,6 @@ from ...models.diffusion import cross_attention_control
|
|||||||
from ...models.diffusion.cross_attention_map_saving import AttentionMapSaver
|
from ...models.diffusion.cross_attention_map_saving import AttentionMapSaver
|
||||||
from ...modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToEmbeddingsConverter
|
from ...modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToEmbeddingsConverter
|
||||||
|
|
||||||
# monkeypatch diffusers CrossAttention 🙈
|
|
||||||
# this is to make prompt2prompt and (future) attention maps work
|
|
||||||
attention.CrossAttention = cross_attention_control.InvokeAIDiffusersCrossAttention
|
|
||||||
|
|
||||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||||
@ -295,7 +292,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
safety_checker=safety_checker,
|
safety_checker=safety_checker,
|
||||||
feature_extractor=feature_extractor,
|
feature_extractor=feature_extractor,
|
||||||
)
|
)
|
||||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward)
|
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward, is_running_diffusers=True)
|
||||||
use_full_precision = (precision == 'float32' or precision == 'autocast')
|
use_full_precision = (precision == 'float32' or precision == 'autocast')
|
||||||
self.textual_inversion_manager = TextualInversionManager(tokenizer=self.tokenizer,
|
self.textual_inversion_manager = TextualInversionManager(tokenizer=self.tokenizer,
|
||||||
text_encoder=self.text_encoder,
|
text_encoder=self.text_encoder,
|
||||||
@ -389,6 +386,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
|
|
||||||
attention_map_saver: Optional[AttentionMapSaver] = None
|
attention_map_saver: Optional[AttentionMapSaver] = None
|
||||||
self.invokeai_diffuser.remove_attention_map_saving()
|
self.invokeai_diffuser.remove_attention_map_saving()
|
||||||
|
|
||||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||||
batched_t.fill_(t)
|
batched_t.fill_(t)
|
||||||
step_output = self.step(batched_t, latents, conditioning_data,
|
step_output = self.step(batched_t, latents, conditioning_data,
|
||||||
@ -447,7 +445,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
|
|
||||||
return step_output
|
return step_output
|
||||||
|
|
||||||
def _unet_forward(self, latents, t, text_embeddings):
|
def _unet_forward(self, latents, t, text_embeddings, cross_attention_kwargs: Optional[dict[str,Any]] = None):
|
||||||
"""predict the noise residual"""
|
"""predict the noise residual"""
|
||||||
if is_inpainting_model(self.unet) and latents.size(1) == 4:
|
if is_inpainting_model(self.unet) and latents.size(1) == 4:
|
||||||
# Pad out normal non-inpainting inputs for an inpainting model.
|
# Pad out normal non-inpainting inputs for an inpainting model.
|
||||||
@ -460,7 +458,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
initial_image_latents=torch.zeros_like(latents[:1], device=latents.device, dtype=latents.dtype)
|
initial_image_latents=torch.zeros_like(latents[:1], device=latents.device, dtype=latents.dtype)
|
||||||
).add_mask_channels(latents)
|
).add_mask_channels(latents)
|
||||||
|
|
||||||
return self.unet(latents, t, encoder_hidden_states=text_embeddings).sample
|
return self.unet(sample=latents,
|
||||||
|
timestep=t,
|
||||||
|
encoder_hidden_states=text_embeddings,
|
||||||
|
cross_attention_kwargs=cross_attention_kwargs).sample
|
||||||
|
|
||||||
def img2img_from_embeddings(self,
|
def img2img_from_embeddings(self,
|
||||||
init_image: Union[torch.FloatTensor, PIL.Image.Image],
|
init_image: Union[torch.FloatTensor, PIL.Image.Image],
|
||||||
|
@ -1,160 +0,0 @@
|
|||||||
|
|
||||||
"""
|
|
||||||
# base implementation
|
|
||||||
|
|
||||||
class CrossAttnProcessor:
|
|
||||||
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
|
||||||
batch_size, sequence_length, _ = hidden_states.shape
|
|
||||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
|
|
||||||
|
|
||||||
query = attn.to_q(hidden_states)
|
|
||||||
query = attn.head_to_batch_dim(query)
|
|
||||||
|
|
||||||
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
|
||||||
key = attn.to_k(encoder_hidden_states)
|
|
||||||
value = attn.to_v(encoder_hidden_states)
|
|
||||||
key = attn.head_to_batch_dim(key)
|
|
||||||
value = attn.head_to_batch_dim(value)
|
|
||||||
|
|
||||||
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
|
||||||
hidden_states = torch.bmm(attention_probs, value)
|
|
||||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
|
||||||
|
|
||||||
# linear proj
|
|
||||||
hidden_states = attn.to_out[0](hidden_states)
|
|
||||||
# dropout
|
|
||||||
hidden_states = attn.to_out[1](hidden_states)
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
"""
|
|
||||||
import enum
|
|
||||||
from dataclasses import field, dataclass
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from diffusers.models.cross_attention import CrossAttention, CrossAttnProcessor
|
|
||||||
|
|
||||||
class AttentionType(enum.Enum):
|
|
||||||
SELF = 1
|
|
||||||
TOKENS = 2
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SwapCrossAttnContext:
|
|
||||||
|
|
||||||
cross_attention_types_to_do: list[AttentionType]
|
|
||||||
modified_text_embeddings: torch.Tensor
|
|
||||||
index_map: torch.Tensor # maps from original prompt token indices to the equivalent tokens in the modified prompt
|
|
||||||
mask: torch.Tensor # in the target space of the index_map
|
|
||||||
|
|
||||||
def __int__(self,
|
|
||||||
cac_types_to_do: [AttentionType],
|
|
||||||
modified_text_embeddings: torch.Tensor,
|
|
||||||
index_map: torch.Tensor,
|
|
||||||
mask: torch.Tensor):
|
|
||||||
self.cross_attention_types_to_do = cac_types_to_do
|
|
||||||
self.modified_text_embeddings = modified_text_embeddings
|
|
||||||
self.index_map = index_map
|
|
||||||
self.mask = mask
|
|
||||||
|
|
||||||
def wants_cross_attention_control(self, attn_type: AttentionType) -> bool:
|
|
||||||
return attn_type in self.cross_attention_types_to_do
|
|
||||||
|
|
||||||
|
|
||||||
class SwapCrossAttnProcessor(CrossAttnProcessor):
|
|
||||||
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None,
|
|
||||||
# kwargs
|
|
||||||
cross_attention_swap_context_provider: SwapCrossAttnContext=None):
|
|
||||||
|
|
||||||
if cross_attention_swap_context_provider is None:
|
|
||||||
raise RuntimeError("a SwapCrossAttnContext instance must be passed via attention processor kwargs")
|
|
||||||
|
|
||||||
attention_type = AttentionType.SELF if encoder_hidden_states is None else AttentionType.TOKENS
|
|
||||||
# if cross-attention control is not in play, just call through to the base implementation.
|
|
||||||
if not cross_attention_swap_context_provider.wants_cross_attention_control(attention_type):
|
|
||||||
return super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask)
|
|
||||||
|
|
||||||
batch_size, sequence_length, _ = hidden_states.shape
|
|
||||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
|
|
||||||
|
|
||||||
query = attn.to_q(hidden_states)
|
|
||||||
query = attn.head_to_batch_dim(query)
|
|
||||||
|
|
||||||
# helper function
|
|
||||||
def get_attention_probs(embeddings):
|
|
||||||
this_key = attn.to_k(embeddings)
|
|
||||||
this_key = attn.head_to_batch_dim(this_key)
|
|
||||||
return attn.get_attention_scores(query, this_key, attention_mask)
|
|
||||||
|
|
||||||
if attention_type == AttentionType.SELF:
|
|
||||||
# self attention has no remapping, it just bluntly copies the whole tensor
|
|
||||||
attention_probs = get_attention_probs(hidden_states)
|
|
||||||
value = attn.to_v(hidden_states)
|
|
||||||
else:
|
|
||||||
# tokens (cross) attention
|
|
||||||
# first, find attention probabilities for the "original" prompt
|
|
||||||
original_text_embeddings = encoder_hidden_states
|
|
||||||
original_attention_probs = get_attention_probs(original_text_embeddings)
|
|
||||||
|
|
||||||
# then, find attention probabilities for the "modified" prompt
|
|
||||||
modified_text_embeddings = cross_attention_swap_context_provider.modified_text_embeddings
|
|
||||||
modified_attention_probs = get_attention_probs(modified_text_embeddings)
|
|
||||||
|
|
||||||
# because the prompt modifications may result in token sequences shifted forwards or backwards,
|
|
||||||
# the original attention probabilities must be remapped to account for token index changes in the
|
|
||||||
# modified prompt
|
|
||||||
remapped_original_attention_probs = torch.index_select(original_attention_probs, -1,
|
|
||||||
cross_attention_swap_context_provider.index_map)
|
|
||||||
|
|
||||||
# only some tokens taken from the original attention probabilities. this is controlled by the mask.
|
|
||||||
mask = cross_attention_swap_context_provider.mask
|
|
||||||
inverse_mask = 1 - mask
|
|
||||||
attention_probs = \
|
|
||||||
remapped_original_attention_probs * mask + \
|
|
||||||
modified_attention_probs * inverse_mask
|
|
||||||
|
|
||||||
# for the "value" just use the modified text embeddings.
|
|
||||||
value = attn.to_v(modified_text_embeddings)
|
|
||||||
|
|
||||||
value = attn.head_to_batch_dim(value)
|
|
||||||
|
|
||||||
hidden_states = torch.bmm(attention_probs, value)
|
|
||||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
|
||||||
|
|
||||||
# linear proj
|
|
||||||
hidden_states = attn.to_out[0](hidden_states)
|
|
||||||
# dropout
|
|
||||||
hidden_states = attn.to_out[1](hidden_states)
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class P2PCrossAttentionProc:
|
|
||||||
|
|
||||||
def __init__(self, head_size, upcast_attention, attn_maps_reweight):
|
|
||||||
super().__init__(head_size=head_size, upcast_attention=upcast_attention)
|
|
||||||
self.attn_maps_reweight = attn_maps_reweight
|
|
||||||
|
|
||||||
def __call__(self, hidden_states, query_proj, key_proj, value_proj, encoder_hidden_states, modified_text_embeddings):
|
|
||||||
batch_size, sequence_length, _ = hidden_states.shape
|
|
||||||
query = query_proj(hidden_states)
|
|
||||||
|
|
||||||
context = context if context is not None else hidden_states
|
|
||||||
attention_probs = []
|
|
||||||
original_text_embeddings = encoder_hidden_states
|
|
||||||
for context in [original_text_embeddings, modified_text_embeddings]:
|
|
||||||
key = key_proj(original_text_embeddings)
|
|
||||||
value = self.value_proj(original_text_embeddings)
|
|
||||||
|
|
||||||
query = self.head_to_batch_dim(query, self.head_size)
|
|
||||||
key = self.head_to_batch_dim(key, self.head_size)
|
|
||||||
value = self.head_to_batch_dim(value, self.head_size)
|
|
||||||
|
|
||||||
attention_probs.append(self.get_attention_scores(query, key))
|
|
||||||
|
|
||||||
merged_probs = self.attn_maps_reweight * torch.cat(attention_probs)
|
|
||||||
hidden_states = torch.bmm(attention_probs, value)
|
|
||||||
hidden_states = self.batch_to_head_dim(hidden_states)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
proc = P2PCrossAttentionProc(unet.config.head_size, unet.config.upcast_attention, 0.6)
|
|
@ -9,6 +9,7 @@ from torch import nn
|
|||||||
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
||||||
from ldm.invoke.devices import torch_dtype
|
from ldm.invoke.devices import torch_dtype
|
||||||
|
|
||||||
|
|
||||||
# adapted from bloc97's CrossAttentionControl colab
|
# adapted from bloc97's CrossAttentionControl colab
|
||||||
# https://github.com/bloc97/CrossAttentionControl
|
# https://github.com/bloc97/CrossAttentionControl
|
||||||
|
|
||||||
@ -304,11 +305,16 @@ class InvokeAICrossAttentionMixin:
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
def remove_cross_attention_control(model):
|
def remove_cross_attention_control(model, is_running_diffusers: bool):
|
||||||
remove_attention_function(model)
|
if is_running_diffusers:
|
||||||
|
unet = model
|
||||||
|
print("** need to know what cross attn processor to use by default, None in the following line is wrong")
|
||||||
|
unet.set_attn_processor(CrossAttnProcessor())
|
||||||
|
else:
|
||||||
|
remove_attention_function(model)
|
||||||
|
|
||||||
|
|
||||||
def setup_cross_attention_control(model, context: Context):
|
def setup_cross_attention_control(model, context: Context, is_running_diffusers = False):
|
||||||
"""
|
"""
|
||||||
Inject attention parameters and functions into the passed in model to enable cross attention editing.
|
Inject attention parameters and functions into the passed in model to enable cross attention editing.
|
||||||
|
|
||||||
@ -333,10 +339,16 @@ def setup_cross_attention_control(model, context: Context):
|
|||||||
indices[b0:b1] = indices_target[a0:a1]
|
indices[b0:b1] = indices_target[a0:a1]
|
||||||
mask[b0:b1] = 1
|
mask[b0:b1] = 1
|
||||||
|
|
||||||
#context.register_cross_attention_modules(model)
|
|
||||||
context.cross_attention_mask = mask.to(device)
|
context.cross_attention_mask = mask.to(device)
|
||||||
context.cross_attention_index_map = indices.to(device)
|
context.cross_attention_index_map = indices.to(device)
|
||||||
#inject_attention_function(model, context)
|
if is_running_diffusers:
|
||||||
|
unet = model
|
||||||
|
unet.set_attn_processor(SwapCrossAttnProcessor())
|
||||||
|
else:
|
||||||
|
context.register_cross_attention_modules(model)
|
||||||
|
inject_attention_function(model, context)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[str, InvokeAICrossAttentionMixin]]:
|
def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[str, InvokeAICrossAttentionMixin]]:
|
||||||
@ -461,3 +473,155 @@ class InvokeAIDiffusersCrossAttention(diffusers.models.attention.CrossAttention,
|
|||||||
hidden_states = self.reshape_batch_dim_to_heads(attention_result)
|
hidden_states = self.reshape_batch_dim_to_heads(attention_result)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## 🧨diffusers implementation follows
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
# base implementation
|
||||||
|
|
||||||
|
class CrossAttnProcessor:
|
||||||
|
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||||
|
batch_size, sequence_length, _ = hidden_states.shape
|
||||||
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
|
||||||
|
|
||||||
|
query = attn.to_q(hidden_states)
|
||||||
|
query = attn.head_to_batch_dim(query)
|
||||||
|
|
||||||
|
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
||||||
|
key = attn.to_k(encoder_hidden_states)
|
||||||
|
value = attn.to_v(encoder_hidden_states)
|
||||||
|
key = attn.head_to_batch_dim(key)
|
||||||
|
value = attn.head_to_batch_dim(value)
|
||||||
|
|
||||||
|
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
||||||
|
hidden_states = torch.bmm(attention_probs, value)
|
||||||
|
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||||
|
|
||||||
|
# linear proj
|
||||||
|
hidden_states = attn.to_out[0](hidden_states)
|
||||||
|
# dropout
|
||||||
|
hidden_states = attn.to_out[1](hidden_states)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
"""
|
||||||
|
import enum
|
||||||
|
from dataclasses import field, dataclass
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from diffusers.models.cross_attention import CrossAttention, CrossAttnProcessor
|
||||||
|
from ldm.models.diffusion.cross_attention_control import CrossAttentionType
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SwapCrossAttnContext:
|
||||||
|
modified_text_embeddings: torch.Tensor
|
||||||
|
index_map: torch.Tensor # maps from original prompt token indices to the equivalent tokens in the modified prompt
|
||||||
|
mask: torch.Tensor # in the target space of the index_map
|
||||||
|
cross_attention_types_to_do: list[CrossAttentionType] = field(default_factory=[])
|
||||||
|
|
||||||
|
def __int__(self,
|
||||||
|
cac_types_to_do: [CrossAttentionType],
|
||||||
|
modified_text_embeddings: torch.Tensor,
|
||||||
|
index_map: torch.Tensor,
|
||||||
|
mask: torch.Tensor):
|
||||||
|
self.cross_attention_types_to_do = cac_types_to_do
|
||||||
|
self.modified_text_embeddings = modified_text_embeddings
|
||||||
|
self.index_map = index_map
|
||||||
|
self.mask = mask
|
||||||
|
|
||||||
|
def wants_cross_attention_control(self, attn_type: CrossAttentionType) -> bool:
|
||||||
|
return attn_type in self.cross_attention_types_to_do
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def make_mask_and_index_map(cls, edit_opcodes: list[tuple[str, int, int, int, int]], max_length: int) \
|
||||||
|
-> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
|
||||||
|
# mask=1 means use original prompt attention, mask=0 means use modified prompt attention
|
||||||
|
mask = torch.zeros(max_length)
|
||||||
|
indices_target = torch.arange(max_length, dtype=torch.long)
|
||||||
|
indices = torch.arange(max_length, dtype=torch.long)
|
||||||
|
for name, a0, a1, b0, b1 in edit_opcodes:
|
||||||
|
if b0 < max_length:
|
||||||
|
if name == "equal":
|
||||||
|
# these tokens remain the same as in the original prompt
|
||||||
|
indices[b0:b1] = indices_target[a0:a1]
|
||||||
|
mask[b0:b1] = 1
|
||||||
|
|
||||||
|
return mask, indices
|
||||||
|
|
||||||
|
|
||||||
|
class SwapCrossAttnProcessor(CrossAttnProcessor):
|
||||||
|
|
||||||
|
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None,
|
||||||
|
# kwargs
|
||||||
|
swap_cross_attn_context: SwapCrossAttnContext=None):
|
||||||
|
|
||||||
|
attention_type = CrossAttentionType.SELF if encoder_hidden_states is None else CrossAttentionType.TOKENS
|
||||||
|
|
||||||
|
# if cross-attention control is not in play, just call through to the base implementation.
|
||||||
|
if swap_cross_attn_context is None or not swap_cross_attn_context.wants_cross_attention_control(attention_type):
|
||||||
|
#print(f"SwapCrossAttnContext for {attention_type} not active - passing request to superclass")
|
||||||
|
return super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask)
|
||||||
|
#else:
|
||||||
|
# print(f"SwapCrossAttnContext for {attention_type} active")
|
||||||
|
|
||||||
|
batch_size, sequence_length, _ = hidden_states.shape
|
||||||
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
|
||||||
|
|
||||||
|
query = attn.to_q(hidden_states)
|
||||||
|
query = attn.head_to_batch_dim(query)
|
||||||
|
|
||||||
|
# helper function
|
||||||
|
def get_attention_probs(embeddings):
|
||||||
|
this_key = attn.to_k(embeddings)
|
||||||
|
this_key = attn.head_to_batch_dim(this_key)
|
||||||
|
return attn.get_attention_scores(query, this_key, attention_mask)
|
||||||
|
|
||||||
|
if attention_type == CrossAttentionType.SELF:
|
||||||
|
# self attention has no remapping, it just bluntly copies the whole tensor
|
||||||
|
attention_probs = get_attention_probs(hidden_states)
|
||||||
|
value = attn.to_v(hidden_states)
|
||||||
|
else:
|
||||||
|
# tokens (cross) attention
|
||||||
|
# first, find attention probabilities for the "original" prompt
|
||||||
|
original_text_embeddings = encoder_hidden_states
|
||||||
|
original_attention_probs = get_attention_probs(original_text_embeddings)
|
||||||
|
|
||||||
|
# then, find attention probabilities for the "modified" prompt
|
||||||
|
modified_text_embeddings = swap_cross_attn_context.modified_text_embeddings
|
||||||
|
modified_attention_probs = get_attention_probs(modified_text_embeddings)
|
||||||
|
|
||||||
|
# because the prompt modifications may result in token sequences shifted forwards or backwards,
|
||||||
|
# the original attention probabilities must be remapped to account for token index changes in the
|
||||||
|
# modified prompt
|
||||||
|
remapped_original_attention_probs = torch.index_select(original_attention_probs, -1,
|
||||||
|
swap_cross_attn_context.index_map)
|
||||||
|
|
||||||
|
# only some tokens taken from the original attention probabilities. this is controlled by the mask.
|
||||||
|
mask = swap_cross_attn_context.mask
|
||||||
|
inverse_mask = 1 - mask
|
||||||
|
attention_probs = \
|
||||||
|
remapped_original_attention_probs * mask + \
|
||||||
|
modified_attention_probs * inverse_mask
|
||||||
|
|
||||||
|
# for the "value" just use the modified text embeddings.
|
||||||
|
value = attn.to_v(modified_text_embeddings)
|
||||||
|
|
||||||
|
value = attn.head_to_batch_dim(value)
|
||||||
|
|
||||||
|
hidden_states = torch.bmm(attention_probs, value)
|
||||||
|
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||||
|
|
||||||
|
# linear proj
|
||||||
|
hidden_states = attn.to_out[0](hidden_states)
|
||||||
|
# dropout
|
||||||
|
hidden_states = attn.to_out[1](hidden_states)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
@ -1,14 +1,14 @@
|
|||||||
import math
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from math import ceil
|
from math import ceil
|
||||||
from typing import Callable, Optional, Union
|
from typing import Callable, Optional, Union, Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ldm.models.diffusion.cross_attention_control import Arguments, \
|
from ldm.models.diffusion.cross_attention_control import Arguments, \
|
||||||
remove_cross_attention_control, setup_cross_attention_control, Context, get_cross_attention_modules, \
|
remove_cross_attention_control, setup_cross_attention_control, Context, get_cross_attention_modules, \
|
||||||
CrossAttentionType
|
CrossAttentionType, SwapCrossAttnContext
|
||||||
from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
|
from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
|
||||||
|
|
||||||
|
|
||||||
@ -30,24 +30,28 @@ class InvokeAIDiffuserComponent:
|
|||||||
debug_thresholding = False
|
debug_thresholding = False
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class ExtraConditioningInfo:
|
class ExtraConditioningInfo:
|
||||||
def __init__(self, tokens_count_including_eos_bos:int, cross_attention_control_args: Optional[Arguments]):
|
|
||||||
self.tokens_count_including_eos_bos = tokens_count_including_eos_bos
|
tokens_count_including_eos_bos: int
|
||||||
self.cross_attention_control_args = cross_attention_control_args
|
cross_attention_control_args: Optional[Arguments] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def wants_cross_attention_control(self):
|
def wants_cross_attention_control(self):
|
||||||
return self.cross_attention_control_args is not None
|
return self.cross_attention_control_args is not None
|
||||||
|
|
||||||
|
|
||||||
def __init__(self, model, model_forward_callback:
|
def __init__(self, model, model_forward_callback:
|
||||||
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
|
Callable[[torch.Tensor, torch.Tensor, torch.Tensor, Optional[dict[str,Any]]], torch.Tensor],
|
||||||
):
|
is_running_diffusers: bool=False,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
:param model: the unet model to pass through to cross attention control
|
:param model: the unet model to pass through to cross attention control
|
||||||
:param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning)
|
:param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning)
|
||||||
"""
|
"""
|
||||||
self.conditioning = None
|
self.conditioning = None
|
||||||
self.model = model
|
self.model = model
|
||||||
|
self.is_running_diffusers = is_running_diffusers
|
||||||
self.model_forward_callback = model_forward_callback
|
self.model_forward_callback = model_forward_callback
|
||||||
self.cross_attention_control_context = None
|
self.cross_attention_control_context = None
|
||||||
|
|
||||||
@ -57,12 +61,14 @@ class InvokeAIDiffuserComponent:
|
|||||||
arguments=self.conditioning.cross_attention_control_args,
|
arguments=self.conditioning.cross_attention_control_args,
|
||||||
step_count=step_count
|
step_count=step_count
|
||||||
)
|
)
|
||||||
setup_cross_attention_control(self.model, self.cross_attention_control_context)
|
setup_cross_attention_control(self.model,
|
||||||
|
self.cross_attention_control_context,
|
||||||
|
is_running_diffusers=self.is_running_diffusers)
|
||||||
|
|
||||||
def remove_cross_attention_control(self):
|
def remove_cross_attention_control(self):
|
||||||
self.conditioning = None
|
self.conditioning = None
|
||||||
self.cross_attention_control_context = None
|
self.cross_attention_control_context = None
|
||||||
remove_cross_attention_control(self.model)
|
remove_cross_attention_control(self.model, is_running_diffusers=self.is_running_diffusers)
|
||||||
|
|
||||||
def setup_attention_map_saving(self, saver: AttentionMapSaver):
|
def setup_attention_map_saving(self, saver: AttentionMapSaver):
|
||||||
def callback(slice, dim, offset, slice_size, key):
|
def callback(slice, dim, offset, slice_size, key):
|
||||||
@ -168,7 +174,41 @@ class InvokeAIDiffuserComponent:
|
|||||||
return unconditioned_next_x, conditioned_next_x
|
return unconditioned_next_x, conditioned_next_x
|
||||||
|
|
||||||
|
|
||||||
def apply_cross_attention_controlled_conditioning(self, x:torch.Tensor, sigma, unconditioning, conditioning, cross_attention_control_types_to_do):
|
def apply_cross_attention_controlled_conditioning(self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
sigma,
|
||||||
|
unconditioning,
|
||||||
|
conditioning,
|
||||||
|
cross_attention_control_types_to_do):
|
||||||
|
if self.is_running_diffusers:
|
||||||
|
return self.apply_cross_attention_controlled_conditioning__diffusers(x, sigma, unconditioning, conditioning, cross_attention_control_types_to_do)
|
||||||
|
else:
|
||||||
|
return self.apply_cross_attention_controlled_conditioning__compvis(x, sigma, unconditioning, conditioning, cross_attention_control_types_to_do)
|
||||||
|
|
||||||
|
def apply_cross_attention_controlled_conditioning__diffusers(self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
sigma,
|
||||||
|
unconditioning,
|
||||||
|
conditioning,
|
||||||
|
cross_attention_control_types_to_do):
|
||||||
|
context: Context = self.cross_attention_control_context
|
||||||
|
|
||||||
|
cross_attn_processor_context = SwapCrossAttnContext(modified_text_embeddings=context.arguments.edited_conditioning,
|
||||||
|
index_map=context.cross_attention_index_map,
|
||||||
|
mask=context.cross_attention_mask,
|
||||||
|
cross_attention_types_to_do=[])
|
||||||
|
# no cross attention for unconditioning (negative prompt)
|
||||||
|
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning,
|
||||||
|
{"swap_cross_attn_context": cross_attn_processor_context})
|
||||||
|
|
||||||
|
# do requested cross attention types for conditioning (positive prompt)
|
||||||
|
cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do
|
||||||
|
conditioned_next_x = self.model_forward_callback(x, sigma, conditioning,
|
||||||
|
{"swap_cross_attn_context": cross_attn_processor_context})
|
||||||
|
return unconditioned_next_x, conditioned_next_x
|
||||||
|
|
||||||
|
|
||||||
|
def apply_cross_attention_controlled_conditioning__compvis(self, x:torch.Tensor, sigma, unconditioning, conditioning, cross_attention_control_types_to_do):
|
||||||
# print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do)
|
# print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do)
|
||||||
# slower non-batched path (20% slower on mac MPS)
|
# slower non-batched path (20% slower on mac MPS)
|
||||||
# We are only interested in using attention maps for conditioned_next_x, but batching them with generation of
|
# We are only interested in using attention maps for conditioned_next_x, but batching them with generation of
|
||||||
|
Loading…
x
Reference in New Issue
Block a user