Merge branch 'main' into build/no-actions-on-draft

This commit is contained in:
psychedelicious 2023-01-31 12:00:38 +11:00 committed by GitHub
commit 6a0e1c8673
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 354 additions and 84 deletions

View File

@ -239,28 +239,24 @@ Generate an image with a given prompt, record the seed of the image, and then
use the `prompt2prompt` syntax to substitute words in the original prompt for use the `prompt2prompt` syntax to substitute words in the original prompt for
words in a new prompt. This works for `img2img` as well. words in a new prompt. This works for `img2img` as well.
- `a ("fluffy cat").swap("smiling dog") eating a hotdog`. For example, consider the prompt `a cat.swap(dog) playing with a ball in the forest`. Normally, because of the word words interact with each other when doing a stable diffusion image generation, these two prompts would generate different compositions:
- quotes optional: `a (fluffy cat).swap(smiling dog) eating a hotdog`. - `a cat playing with a ball in the forest`
- for single word substitutions parentheses are also optional: - `a dog playing with a ball in the forest`
`a cat.swap(dog) eating a hotdog`.
- Supports options `s_start`, `s_end`, `t_start`, `t_end` (each 0-1) loosely | `a cat playing with a ball in the forest` | `a dog playing with a ball in the forest` |
corresponding to bloc97's `prompt_edit_spatial_start/_end` and | --- | --- |
`prompt_edit_tokens_start/_end` but with the math swapped to make it easier to | img | img |
intuitively understand.
- Example usage:`a (cat).swap(dog, s_end=0.3) eating a hotdog` - the `s_end`
argument means that the "spatial" (self-attention) edit will stop having any - For multiple word swaps, use parentheses: `a (fluffy cat).swap(barking dog) playing with a ball in the forest`.
effect after 30% (=0.3) of the steps have been done, leaving Stable - To swap a comma, use quotes: `a ("fluffy, grey cat").swap("big, barking dog") playing with a ball in the forest`.
Diffusion with 70% of the steps where it is free to decide for itself how to - Supports options `t_start` and `t_end` (each 0-1) loosely corresponding to bloc97's `prompt_edit_tokens_start/_end` but with the math swapped to make it easier to
reshape the cat-form into a dog form. intuitively understand. `t_start` and `t_end` are used to control on which steps cross-attention control should run. With the default values `t_start=0` and `t_end=1`, cross-attention control is active on every step of image generation. Other values can be used to turn cross-attention control off for part of the image generation process.
- The numbers represent a percentage through the step sequence where the edits - For example, if doing a diffusion with 10 steps for the prompt is `a cat.swap(dog, t_start=0.3, t_end=1.0) playing with a ball in the forest`, the first 3 steps will be run as `a cat playing with a ball in the forest`, while the last 7 steps will run as `a dog playing with a ball in the forest`, but the pixels that represent `dog` will be locked to the pixels that would have represented `cat` if the `cat` prompt had been used instead.
should happen. 0 means the start (noisy starting image), 1 is the end (final - Conversely, for `a cat.swap(dog, t_start=0, t_end=0.7) playing with a ball in the forest`, the first 7 steps will run as `a dog playing with a ball in the forest` with the pixels that represent `dog` locked to the same pixels that would have represented `cat` if the `cat` prompt was being used instead. The final 3 steps will just run `a cat playing with a ball in the forest`.
image). > For img2img, the step sequence does not start at 0 but instead at `(1.0-strength)` - so if the img2img `strength` is `0.7`, `t_start` and `t_end` must both be greater than `0.3` (`1.0-0.7`) to have any effect.
- For img2img, the step sequence does not start at 0 but instead at
(1-strength) - so if strength is 0.7, s_start and s_end must both be Prompt2prompt `.swap()` is not compatible with xformers, which will be temporarily disabled when doing a `.swap()` - so you should expect to use more VRAM and run slower that with xformers enabled.
greater than 0.3 (1-0.7) to have any effect.
- Convenience option `shape_freedom` (0-1) to specify how much "freedom" Stable
Diffusion should have to change the shape of the subject being swapped.
- `a (cat).swap(dog, shape_freedom=0.5) eating a hotdog`.
The `prompt2prompt` code is based off The `prompt2prompt` code is based off
[bloc97's colab](https://github.com/bloc97/CrossAttentionControl). [bloc97's colab](https://github.com/bloc97/CrossAttentionControl).

View File

@ -24,9 +24,6 @@ from ...models.diffusion import cross_attention_control
from ...models.diffusion.cross_attention_map_saving import AttentionMapSaver from ...models.diffusion.cross_attention_map_saving import AttentionMapSaver
from ...modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToEmbeddingsConverter from ...modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToEmbeddingsConverter
# monkeypatch diffusers CrossAttention 🙈
# this is to make prompt2prompt and (future) attention maps work
attention.CrossAttention = cross_attention_control.InvokeAIDiffusersCrossAttention
from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
@ -295,7 +292,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
safety_checker=safety_checker, safety_checker=safety_checker,
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward) self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward, is_running_diffusers=True)
use_full_precision = (precision == 'float32' or precision == 'autocast') use_full_precision = (precision == 'float32' or precision == 'autocast')
self.textual_inversion_manager = TextualInversionManager(tokenizer=self.tokenizer, self.textual_inversion_manager = TextualInversionManager(tokenizer=self.tokenizer,
text_encoder=self.text_encoder, text_encoder=self.text_encoder,
@ -307,8 +304,23 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
textual_inversion_manager=self.textual_inversion_manager textual_inversion_manager=self.textual_inversion_manager
) )
self._enable_memory_efficient_attention()
def _enable_memory_efficient_attention(self):
"""
if xformers is available, use it, otherwise use sliced attention.
"""
if is_xformers_available() and not Globals.disable_xformers: if is_xformers_available() and not Globals.disable_xformers:
self.enable_xformers_memory_efficient_attention() self.enable_xformers_memory_efficient_attention()
else:
if torch.backends.mps.is_available():
# until pytorch #91617 is fixed, slicing is borked on MPS
# https://github.com/pytorch/pytorch/issues/91617
# fix is in https://github.com/kulinseth/pytorch/pull/222 but no idea when it will get merged to pytorch mainline.
pass
else:
self.enable_attention_slicing(slice_size='auto')
def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int, def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,
conditioning_data: ConditioningData, conditioning_data: ConditioningData,
@ -373,42 +385,40 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if additional_guidance is None: if additional_guidance is None:
additional_guidance = [] additional_guidance = []
extra_conditioning_info = conditioning_data.extra extra_conditioning_info = conditioning_data.extra
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: with self.invokeai_diffuser.custom_attention_context(extra_conditioning_info=extra_conditioning_info,
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count=len(self.scheduler.timesteps)
step_count=len(self.scheduler.timesteps)) ):
else:
self.invokeai_diffuser.remove_cross_attention_control()
yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps, yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps,
latents=latents) latents=latents)
batch_size = latents.shape[0] batch_size = latents.shape[0]
batched_t = torch.full((batch_size,), timesteps[0], batched_t = torch.full((batch_size,), timesteps[0],
dtype=timesteps.dtype, device=self.unet.device) dtype=timesteps.dtype, device=self.unet.device)
latents = self.scheduler.add_noise(latents, noise, batched_t) latents = self.scheduler.add_noise(latents, noise, batched_t)
attention_map_saver: Optional[AttentionMapSaver] = None attention_map_saver: Optional[AttentionMapSaver] = None
self.invokeai_diffuser.remove_attention_map_saving()
for i, t in enumerate(self.progress_bar(timesteps)):
batched_t.fill_(t)
step_output = self.step(batched_t, latents, conditioning_data,
step_index=i,
total_step_count=len(timesteps),
additional_guidance=additional_guidance)
latents = step_output.prev_sample
predicted_original = getattr(step_output, 'pred_original_sample', None)
if i == len(timesteps)-1 and extra_conditioning_info is not None: for i, t in enumerate(self.progress_bar(timesteps)):
eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1 batched_t.fill_(t)
attention_map_token_ids = range(1, eos_token_index) step_output = self.step(batched_t, latents, conditioning_data,
attention_map_saver = AttentionMapSaver(token_ids=attention_map_token_ids, latents_shape=latents.shape[-2:]) step_index=i,
self.invokeai_diffuser.setup_attention_map_saving(attention_map_saver) total_step_count=len(timesteps),
additional_guidance=additional_guidance)
latents = step_output.prev_sample
predicted_original = getattr(step_output, 'pred_original_sample', None)
yield PipelineIntermediateState(run_id=run_id, step=i, timestep=int(t), latents=latents, # TODO resuscitate attention map saving
predicted_original=predicted_original, attention_map_saver=attention_map_saver) #if i == len(timesteps)-1 and extra_conditioning_info is not None:
# eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1
# attention_map_token_ids = range(1, eos_token_index)
# attention_map_saver = AttentionMapSaver(token_ids=attention_map_token_ids, latents_shape=latents.shape[-2:])
# self.invokeai_diffuser.setup_attention_map_saving(attention_map_saver)
self.invokeai_diffuser.remove_attention_map_saving() yield PipelineIntermediateState(run_id=run_id, step=i, timestep=int(t), latents=latents,
return latents, attention_map_saver predicted_original=predicted_original, attention_map_saver=attention_map_saver)
return latents, attention_map_saver
@torch.inference_mode() @torch.inference_mode()
def step(self, t: torch.Tensor, latents: torch.Tensor, def step(self, t: torch.Tensor, latents: torch.Tensor,
@ -447,7 +457,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
return step_output return step_output
def _unet_forward(self, latents, t, text_embeddings): def _unet_forward(self, latents, t, text_embeddings, cross_attention_kwargs: Optional[dict[str,Any]] = None):
"""predict the noise residual""" """predict the noise residual"""
if is_inpainting_model(self.unet) and latents.size(1) == 4: if is_inpainting_model(self.unet) and latents.size(1) == 4:
# Pad out normal non-inpainting inputs for an inpainting model. # Pad out normal non-inpainting inputs for an inpainting model.
@ -460,7 +470,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
initial_image_latents=torch.zeros_like(latents[:1], device=latents.device, dtype=latents.dtype) initial_image_latents=torch.zeros_like(latents[:1], device=latents.device, dtype=latents.dtype)
).add_mask_channels(latents) ).add_mask_channels(latents)
return self.unet(latents, t, encoder_hidden_states=text_embeddings).sample return self.unet(sample=latents,
timestep=t,
encoder_hidden_states=text_embeddings,
cross_attention_kwargs=cross_attention_kwargs).sample
def img2img_from_embeddings(self, def img2img_from_embeddings(self,
init_image: Union[torch.FloatTensor, PIL.Image.Image], init_image: Union[torch.FloatTensor, PIL.Image.Image],

View File

@ -155,7 +155,7 @@ class CrossAttentionControlSubstitute(CrossAttentionControlledFragment):
default_options = { default_options = {
's_start': 0.0, 's_start': 0.0,
's_end': 0.2062994740159002, # ~= shape_freedom=0.5 's_end': 0.2062994740159002, # ~= shape_freedom=0.5
't_start': 0.0, 't_start': 0.1,
't_end': 1.0 't_end': 1.0
} }
merged_options = default_options merged_options = default_options

View File

@ -7,8 +7,10 @@ import torch
import diffusers import diffusers
from torch import nn from torch import nn
from diffusers.models.unet_2d_condition import UNet2DConditionModel from diffusers.models.unet_2d_condition import UNet2DConditionModel
from diffusers.models.cross_attention import AttnProcessor
from ldm.invoke.devices import torch_dtype from ldm.invoke.devices import torch_dtype
# adapted from bloc97's CrossAttentionControl colab # adapted from bloc97's CrossAttentionControl colab
# https://github.com/bloc97/CrossAttentionControl # https://github.com/bloc97/CrossAttentionControl
@ -304,11 +306,15 @@ class InvokeAICrossAttentionMixin:
def remove_cross_attention_control(model): def restore_default_cross_attention(model, is_running_diffusers: bool, restore_attention_processor: Optional[AttnProcessor]=None):
remove_attention_function(model) if is_running_diffusers:
unet = model
unet.set_attn_processor(restore_attention_processor or CrossAttnProcessor())
else:
remove_attention_function(model)
def setup_cross_attention_control(model, context: Context): def override_cross_attention(model, context: Context, is_running_diffusers = False):
""" """
Inject attention parameters and functions into the passed in model to enable cross attention editing. Inject attention parameters and functions into the passed in model to enable cross attention editing.
@ -323,7 +329,7 @@ def setup_cross_attention_control(model, context: Context):
# urgh. should this be hardcoded? # urgh. should this be hardcoded?
max_length = 77 max_length = 77
# mask=1 means use base prompt attention, mask=0 means use edited prompt attention # mask=1 means use base prompt attention, mask=0 means use edited prompt attention
mask = torch.zeros(max_length) mask = torch.zeros(max_length, dtype=torch_dtype(device))
indices_target = torch.arange(max_length, dtype=torch.long) indices_target = torch.arange(max_length, dtype=torch.long)
indices = torch.arange(max_length, dtype=torch.long) indices = torch.arange(max_length, dtype=torch.long)
for name, a0, a1, b0, b1 in context.arguments.edit_opcodes: for name, a0, a1, b0, b1 in context.arguments.edit_opcodes:
@ -333,10 +339,26 @@ def setup_cross_attention_control(model, context: Context):
indices[b0:b1] = indices_target[a0:a1] indices[b0:b1] = indices_target[a0:a1]
mask[b0:b1] = 1 mask[b0:b1] = 1
context.register_cross_attention_modules(model)
context.cross_attention_mask = mask.to(device) context.cross_attention_mask = mask.to(device)
context.cross_attention_index_map = indices.to(device) context.cross_attention_index_map = indices.to(device)
inject_attention_function(model, context) if is_running_diffusers:
unet = model
old_attn_processors = unet.attn_processors
if torch.backends.mps.is_available():
# see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS
unet.set_attn_processor(SwapCrossAttnProcessor())
else:
# try to re-use an existing slice size
default_slice_size = 4
slice_size = next((p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size)
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
return old_attn_processors
else:
context.register_cross_attention_modules(model)
inject_attention_function(model, context)
return None
def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[str, InvokeAICrossAttentionMixin]]: def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[str, InvokeAICrossAttentionMixin]]:
@ -445,6 +467,7 @@ def get_mem_free_total(device):
return mem_free_total return mem_free_total
class InvokeAIDiffusersCrossAttention(diffusers.models.attention.CrossAttention, InvokeAICrossAttentionMixin): class InvokeAIDiffusersCrossAttention(diffusers.models.attention.CrossAttention, InvokeAICrossAttentionMixin):
def __init__(self, **kwargs): def __init__(self, **kwargs):
@ -460,3 +483,176 @@ class InvokeAIDiffusersCrossAttention(diffusers.models.attention.CrossAttention,
hidden_states = self.reshape_batch_dim_to_heads(attention_result) hidden_states = self.reshape_batch_dim_to_heads(attention_result)
return hidden_states return hidden_states
## 🧨diffusers implementation follows
"""
# base implementation
class CrossAttnProcessor:
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
query = attn.to_q(hidden_states)
query = attn.head_to_batch_dim(query)
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
"""
from dataclasses import field, dataclass
import torch
from diffusers.models.cross_attention import CrossAttention, CrossAttnProcessor, SlicedAttnProcessor, AttnProcessor
@dataclass
class SwapCrossAttnContext:
modified_text_embeddings: torch.Tensor
index_map: torch.Tensor # maps from original prompt token indices to the equivalent tokens in the modified prompt
mask: torch.Tensor # in the target space of the index_map
cross_attention_types_to_do: list[CrossAttentionType] = field(default_factory=list)
def __int__(self,
cac_types_to_do: [CrossAttentionType],
modified_text_embeddings: torch.Tensor,
index_map: torch.Tensor,
mask: torch.Tensor):
self.cross_attention_types_to_do = cac_types_to_do
self.modified_text_embeddings = modified_text_embeddings
self.index_map = index_map
self.mask = mask
def wants_cross_attention_control(self, attn_type: CrossAttentionType) -> bool:
return attn_type in self.cross_attention_types_to_do
@classmethod
def make_mask_and_index_map(cls, edit_opcodes: list[tuple[str, int, int, int, int]], max_length: int) \
-> tuple[torch.Tensor, torch.Tensor]:
# mask=1 means use original prompt attention, mask=0 means use modified prompt attention
mask = torch.zeros(max_length)
indices_target = torch.arange(max_length, dtype=torch.long)
indices = torch.arange(max_length, dtype=torch.long)
for name, a0, a1, b0, b1 in edit_opcodes:
if b0 < max_length:
if name == "equal":
# these tokens remain the same as in the original prompt
indices[b0:b1] = indices_target[a0:a1]
mask[b0:b1] = 1
return mask, indices
class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
# TODO: dynamically pick slice size based on memory conditions
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None,
# kwargs
swap_cross_attn_context: SwapCrossAttnContext=None):
attention_type = CrossAttentionType.SELF if encoder_hidden_states is None else CrossAttentionType.TOKENS
# if cross-attention control is not in play, just call through to the base implementation.
if attention_type is CrossAttentionType.SELF or \
swap_cross_attn_context is None or \
not swap_cross_attn_context.wants_cross_attention_control(attention_type):
#print(f"SwapCrossAttnContext for {attention_type} not active - passing request to superclass")
return super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask)
#else:
# print(f"SwapCrossAttnContext for {attention_type} active")
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
query = attn.to_q(hidden_states)
dim = query.shape[-1]
query = attn.head_to_batch_dim(query)
original_text_embeddings = encoder_hidden_states
modified_text_embeddings = swap_cross_attn_context.modified_text_embeddings
original_text_key = attn.to_k(original_text_embeddings)
modified_text_key = attn.to_k(modified_text_embeddings)
original_value = attn.to_v(original_text_embeddings)
modified_value = attn.to_v(modified_text_embeddings)
original_text_key = attn.head_to_batch_dim(original_text_key)
modified_text_key = attn.head_to_batch_dim(modified_text_key)
original_value = attn.head_to_batch_dim(original_value)
modified_value = attn.head_to_batch_dim(modified_value)
# compute slices and prepare output tensor
batch_size_attention = query.shape[0]
hidden_states = torch.zeros(
(batch_size_attention, sequence_length, dim // attn.heads), device=query.device, dtype=query.dtype
)
# do slices
for i in range(max(1,hidden_states.shape[0] // self.slice_size)):
start_idx = i * self.slice_size
end_idx = (i + 1) * self.slice_size
query_slice = query[start_idx:end_idx]
original_key_slice = original_text_key[start_idx:end_idx]
modified_key_slice = modified_text_key[start_idx:end_idx]
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
original_attn_slice = attn.get_attention_scores(query_slice, original_key_slice, attn_mask_slice)
modified_attn_slice = attn.get_attention_scores(query_slice, modified_key_slice, attn_mask_slice)
# because the prompt modifications may result in token sequences shifted forwards or backwards,
# the original attention probabilities must be remapped to account for token index changes in the
# modified prompt
remapped_original_attn_slice = torch.index_select(original_attn_slice, -1,
swap_cross_attn_context.index_map)
# only some tokens taken from the original attention probabilities. this is controlled by the mask.
mask = swap_cross_attn_context.mask
inverse_mask = 1 - mask
attn_slice = \
remapped_original_attn_slice * mask + \
modified_attn_slice * inverse_mask
del remapped_original_attn_slice, modified_attn_slice
attn_slice = torch.bmm(attn_slice, modified_value[start_idx:end_idx])
hidden_states[start_idx:end_idx] = attn_slice
# done
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class SwapCrossAttnProcessor(SlicedSwapCrossAttnProcesser):
def __init__(self):
super(SwapCrossAttnProcessor, self).__init__(slice_size=int(1e9)) # massive slice size = don't slice

View File

@ -19,9 +19,9 @@ class DDIMSampler(Sampler):
all_timesteps_count = kwargs.get('all_timesteps_count', t_enc) all_timesteps_count = kwargs.get('all_timesteps_count', t_enc)
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = all_timesteps_count) self.invokeai_diffuser.override_cross_attention(extra_conditioning_info, step_count = all_timesteps_count)
else: else:
self.invokeai_diffuser.remove_cross_attention_control() self.invokeai_diffuser.restore_default_cross_attention()
# This is the central routine # This is the central routine

View File

@ -43,9 +43,9 @@ class CFGDenoiser(nn.Module):
extra_conditioning_info = kwargs.get('extra_conditioning_info', None) extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = t_enc) self.invokeai_diffuser.override_cross_attention(extra_conditioning_info, step_count = t_enc)
else: else:
self.invokeai_diffuser.remove_cross_attention_control() self.invokeai_diffuser.restore_default_cross_attention()
def forward(self, x, sigma, uncond, cond, cond_scale): def forward(self, x, sigma, uncond, cond, cond_scale):

View File

@ -21,9 +21,9 @@ class PLMSSampler(Sampler):
all_timesteps_count = kwargs.get('all_timesteps_count', t_enc) all_timesteps_count = kwargs.get('all_timesteps_count', t_enc)
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = all_timesteps_count) self.invokeai_diffuser.override_cross_attention(extra_conditioning_info, step_count = all_timesteps_count)
else: else:
self.invokeai_diffuser.remove_cross_attention_control() self.invokeai_diffuser.restore_default_cross_attention()
# this is the essential routine # this is the essential routine

View File

@ -1,14 +1,16 @@
import math import math
from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from math import ceil from math import ceil
from typing import Callable, Optional, Union from typing import Callable, Optional, Union, Any, Dict
import numpy as np import numpy as np
import torch import torch
from diffusers.models.cross_attention import AttnProcessor
from ldm.models.diffusion.cross_attention_control import Arguments, \ from ldm.models.diffusion.cross_attention_control import Arguments, \
remove_cross_attention_control, setup_cross_attention_control, Context, get_cross_attention_modules, \ restore_default_cross_attention, override_cross_attention, Context, get_cross_attention_modules, \
CrossAttentionType CrossAttentionType, SwapCrossAttnContext
from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
@ -30,39 +32,68 @@ class InvokeAIDiffuserComponent:
debug_thresholding = False debug_thresholding = False
@dataclass
class ExtraConditioningInfo: class ExtraConditioningInfo:
def __init__(self, tokens_count_including_eos_bos:int, cross_attention_control_args: Optional[Arguments]):
self.tokens_count_including_eos_bos = tokens_count_including_eos_bos tokens_count_including_eos_bos: int
self.cross_attention_control_args = cross_attention_control_args cross_attention_control_args: Optional[Arguments] = None
@property @property
def wants_cross_attention_control(self): def wants_cross_attention_control(self):
return self.cross_attention_control_args is not None return self.cross_attention_control_args is not None
def __init__(self, model, model_forward_callback: def __init__(self, model, model_forward_callback:
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor] Callable[[torch.Tensor, torch.Tensor, torch.Tensor, Optional[dict[str,Any]]], torch.Tensor],
): is_running_diffusers: bool=False,
):
""" """
:param model: the unet model to pass through to cross attention control :param model: the unet model to pass through to cross attention control
:param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning) :param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning)
""" """
self.conditioning = None self.conditioning = None
self.model = model self.model = model
self.is_running_diffusers = is_running_diffusers
self.model_forward_callback = model_forward_callback self.model_forward_callback = model_forward_callback
self.cross_attention_control_context = None self.cross_attention_control_context = None
def setup_cross_attention_control(self, conditioning: ExtraConditioningInfo, step_count: int): @contextmanager
def custom_attention_context(self,
extra_conditioning_info: Optional[ExtraConditioningInfo],
step_count: int):
do_swap = extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control
old_attn_processor = None
if do_swap:
old_attn_processor = self.override_cross_attention(extra_conditioning_info,
step_count=step_count)
try:
yield None
finally:
if old_attn_processor is not None:
self.restore_default_cross_attention(old_attn_processor)
# TODO resuscitate attention map saving
#self.remove_attention_map_saving()
def override_cross_attention(self, conditioning: ExtraConditioningInfo, step_count: int) -> Dict[str, AttnProcessor]:
"""
setup cross attention .swap control. for diffusers this replaces the attention processor, so
the previous attention processor is returned so that the caller can restore it later.
"""
self.conditioning = conditioning self.conditioning = conditioning
self.cross_attention_control_context = Context( self.cross_attention_control_context = Context(
arguments=self.conditioning.cross_attention_control_args, arguments=self.conditioning.cross_attention_control_args,
step_count=step_count step_count=step_count
) )
setup_cross_attention_control(self.model, self.cross_attention_control_context) return override_cross_attention(self.model,
self.cross_attention_control_context,
is_running_diffusers=self.is_running_diffusers)
def remove_cross_attention_control(self): def restore_default_cross_attention(self, restore_attention_processor: Optional['AttnProcessor']=None):
self.conditioning = None self.conditioning = None
self.cross_attention_control_context = None self.cross_attention_control_context = None
remove_cross_attention_control(self.model) restore_default_cross_attention(self.model,
is_running_diffusers=self.is_running_diffusers,
restore_attention_processor=restore_attention_processor)
def setup_attention_map_saving(self, saver: AttentionMapSaver): def setup_attention_map_saving(self, saver: AttentionMapSaver):
def callback(slice, dim, offset, slice_size, key): def callback(slice, dim, offset, slice_size, key):
@ -168,7 +199,41 @@ class InvokeAIDiffuserComponent:
return unconditioned_next_x, conditioned_next_x return unconditioned_next_x, conditioned_next_x
def apply_cross_attention_controlled_conditioning(self, x:torch.Tensor, sigma, unconditioning, conditioning, cross_attention_control_types_to_do): def apply_cross_attention_controlled_conditioning(self,
x: torch.Tensor,
sigma,
unconditioning,
conditioning,
cross_attention_control_types_to_do):
if self.is_running_diffusers:
return self.apply_cross_attention_controlled_conditioning__diffusers(x, sigma, unconditioning, conditioning, cross_attention_control_types_to_do)
else:
return self.apply_cross_attention_controlled_conditioning__compvis(x, sigma, unconditioning, conditioning, cross_attention_control_types_to_do)
def apply_cross_attention_controlled_conditioning__diffusers(self,
x: torch.Tensor,
sigma,
unconditioning,
conditioning,
cross_attention_control_types_to_do):
context: Context = self.cross_attention_control_context
cross_attn_processor_context = SwapCrossAttnContext(modified_text_embeddings=context.arguments.edited_conditioning,
index_map=context.cross_attention_index_map,
mask=context.cross_attention_mask,
cross_attention_types_to_do=[])
# no cross attention for unconditioning (negative prompt)
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning,
{"swap_cross_attn_context": cross_attn_processor_context})
# do requested cross attention types for conditioning (positive prompt)
cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do
conditioned_next_x = self.model_forward_callback(x, sigma, conditioning,
{"swap_cross_attn_context": cross_attn_processor_context})
return unconditioned_next_x, conditioned_next_x
def apply_cross_attention_controlled_conditioning__compvis(self, x:torch.Tensor, sigma, unconditioning, conditioning, cross_attention_control_types_to_do):
# print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do) # print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do)
# slower non-batched path (20% slower on mac MPS) # slower non-batched path (20% slower on mac MPS)
# We are only interested in using attention maps for conditioned_next_x, but batching them with generation of # We are only interested in using attention maps for conditioned_next_x, but batching them with generation of