mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
403 lines
19 KiB
Python
403 lines
19 KiB
Python
from enum import Enum
|
|
from math import ceil
|
|
from typing import Callable
|
|
|
|
import torch
|
|
|
|
|
|
class InvokeAIDiffuserComponent:
|
|
'''
|
|
The aim of this component is to provide a single place for code that can be applied identically to
|
|
all InvokeAI diffusion procedures.
|
|
|
|
At the moment it includes the following features:
|
|
* Cross Attention Control ("prompt2prompt")
|
|
'''
|
|
|
|
|
|
class ExtraConditioningInfo:
|
|
def __init__(self, edited_conditioning: torch.Tensor = None, edit_opcodes: list[tuple] = None):
|
|
"""
|
|
:param edited_conditioning: if doing cross-attention control, the edited conditioning (1 x 77 x 768)
|
|
:param edit_opcodes: if doing cross-attention control, opcodes from a SequenceMatcher describing how to map original conditioning tokens to edited conditioning tokens
|
|
"""
|
|
self.edited_conditioning = edited_conditioning
|
|
self.edit_opcodes = edit_opcodes
|
|
|
|
@property
|
|
def wants_cross_attention_control(self):
|
|
return self.edited_conditioning is not None
|
|
|
|
def __init__(self, model, model_forward_callback: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]):
|
|
"""
|
|
:param model: the unet model to pass through to cross attention control
|
|
:param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning)
|
|
"""
|
|
self.model = model
|
|
self.model_forward_callback = model_forward_callback
|
|
|
|
|
|
def setup_cross_attention_control(self, conditioning: ExtraConditioningInfo):
|
|
self.conditioning = conditioning
|
|
CrossAttentionControl.setup_cross_attention_control(self.model, conditioning.edited_conditioning, conditioning.edit_opcodes)
|
|
|
|
def remove_cross_attention_control(self):
|
|
self.conditioning = None
|
|
CrossAttentionControl.remove_cross_attention_control(self.model)
|
|
|
|
@property
|
|
def edited_conditioning(self):
|
|
if self.conditioning is None:
|
|
return None
|
|
else:
|
|
return self.conditioning.edited_conditioning
|
|
|
|
def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor,
|
|
unconditioning: torch.Tensor, conditioning: torch.Tensor,
|
|
unconditional_guidance_scale: float):
|
|
"""
|
|
:param x: Current latents
|
|
:param sigma: aka t, passed to the internal model to control how much denoising will occur
|
|
:param unconditioning: [B x 77 x 768] embeddings for unconditioned output
|
|
:param conditioning: [B x 77 x 768] embeddings for conditioned output
|
|
:param unconditional_guidance_scale: aka CFG scale, controls how much effect the conditioning tensor has
|
|
:param model: the unet model to pass through to cross attention control
|
|
:param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning)
|
|
:return: the new latents after applying the model to x using unconditioning and CFG-scaled conditioning.
|
|
"""
|
|
|
|
CrossAttentionControl.clear_requests(self.model)
|
|
|
|
if self.edited_conditioning is None:
|
|
# faster batched path
|
|
x_twice = torch.cat([x]*2)
|
|
sigma_twice = torch.cat([sigma]*2)
|
|
both_conditionings = torch.cat([unconditioning, conditioning])
|
|
unconditioned_next_x, conditioned_next_x = self.model_forward_callback(x_twice, sigma_twice, both_conditionings).chunk(2)
|
|
else:
|
|
# slower non-batched path (20% slower on mac MPS)
|
|
# We are only interested in using attention maps for conditioned_next_x, but batching them with generation of
|
|
# unconditioned_next_x causes attention maps to *also* be saved for the unconditioned_next_x.
|
|
# This messes app their application later, due to mismatched shape of dim 0 (seems to be 16 for batched vs. 8)
|
|
# (For the batched invocation the `wrangler` function gets attention tensor with shape[0]=16,
|
|
# representing batched uncond + cond, but then when it comes to applying the saved attention, the
|
|
# wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.)
|
|
# todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well.
|
|
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning)
|
|
|
|
# process x using the original prompt, saving the attention maps
|
|
CrossAttentionControl.request_save_attention_maps(self.model)
|
|
_ = self.model_forward_callback(x, sigma, conditioning)
|
|
CrossAttentionControl.clear_requests(self.model)
|
|
|
|
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied
|
|
CrossAttentionControl.request_apply_saved_attention_maps(self.model)
|
|
conditioned_next_x = self.model_forward_callback(x, sigma, self.edited_conditioning)
|
|
CrossAttentionControl.clear_requests(self.model)
|
|
|
|
|
|
# to scale how much effect conditioning has, calculate the changes it does and then scale that
|
|
scaled_delta = (conditioned_next_x - unconditioned_next_x) * unconditional_guidance_scale
|
|
combined_next_x = unconditioned_next_x + scaled_delta
|
|
|
|
return combined_next_x
|
|
|
|
|
|
# todo: make this work
|
|
@classmethod
|
|
def apply_conjunction(cls, x, t, forward_func, uc, c_or_weighted_c_list, global_guidance_scale):
|
|
x_in = torch.cat([x] * 2)
|
|
t_in = torch.cat([t] * 2) # aka sigmas
|
|
|
|
deltas = None
|
|
uncond_latents = None
|
|
weighted_cond_list = c_or_weighted_c_list if type(c_or_weighted_c_list) is list else [(c_or_weighted_c_list, 1)]
|
|
|
|
# below is fugly omg
|
|
num_actual_conditionings = len(c_or_weighted_c_list)
|
|
conditionings = [uc] + [c for c,weight in weighted_cond_list]
|
|
weights = [1] + [weight for c,weight in weighted_cond_list]
|
|
chunk_count = ceil(len(conditionings)/2)
|
|
deltas = None
|
|
for chunk_index in range(chunk_count):
|
|
offset = chunk_index*2
|
|
chunk_size = min(2, len(conditionings)-offset)
|
|
|
|
if chunk_size == 1:
|
|
c_in = conditionings[offset]
|
|
latents_a = forward_func(x_in[:-1], t_in[:-1], c_in)
|
|
latents_b = None
|
|
else:
|
|
c_in = torch.cat(conditionings[offset:offset+2])
|
|
latents_a, latents_b = forward_func(x_in, t_in, c_in).chunk(2)
|
|
|
|
# first chunk is guaranteed to be 2 entries: uncond_latents + first conditioining
|
|
if chunk_index == 0:
|
|
uncond_latents = latents_a
|
|
deltas = latents_b - uncond_latents
|
|
else:
|
|
deltas = torch.cat((deltas, latents_a - uncond_latents))
|
|
if latents_b is not None:
|
|
deltas = torch.cat((deltas, latents_b - uncond_latents))
|
|
|
|
# merge the weighted deltas together into a single merged delta
|
|
per_delta_weights = torch.tensor(weights[1:], dtype=deltas.dtype, device=deltas.device)
|
|
normalize = False
|
|
if normalize:
|
|
per_delta_weights /= torch.sum(per_delta_weights)
|
|
reshaped_weights = per_delta_weights.reshape(per_delta_weights.shape + (1, 1, 1))
|
|
deltas_merged = torch.sum(deltas * reshaped_weights, dim=0, keepdim=True)
|
|
|
|
# old_return_value = super().forward(x, sigma, uncond, cond, cond_scale)
|
|
# assert(0 == len(torch.nonzero(old_return_value - (uncond_latents + deltas_merged * cond_scale))))
|
|
|
|
return uncond_latents + deltas_merged * global_guidance_scale
|
|
|
|
|
|
# adapted from bloc97's CrossAttentionControl colab
|
|
# https://github.com/bloc97/CrossAttentionControl
|
|
|
|
class CrossAttentionControl:
|
|
|
|
|
|
@classmethod
|
|
def remove_cross_attention_control(cls, model):
|
|
cls.remove_attention_function(model)
|
|
|
|
@classmethod
|
|
def setup_cross_attention_control(cls, model,
|
|
substitute_conditioning: torch.Tensor,
|
|
edit_opcodes: list):
|
|
"""
|
|
Inject attention parameters and functions into the passed in model to enable cross attention editing.
|
|
|
|
:param model: The unet model to inject into.
|
|
:param substitute_conditioning: The "edited" conditioning vector, [Bx77x768]
|
|
:param edit_opcodes: Opcodes from difflib.SequenceMatcher describing how the base
|
|
conditionings map to the "edited" conditionings.
|
|
:return:
|
|
"""
|
|
|
|
# adapted from init_attention_edit
|
|
device = substitute_conditioning.device
|
|
|
|
# urgh. should this be hardcoded?
|
|
max_length = 77
|
|
# mask=1 means use base prompt attention, mask=0 means use edited prompt attention
|
|
mask = torch.zeros(max_length)
|
|
indices_target = torch.arange(max_length, dtype=torch.long)
|
|
indices = torch.zeros(max_length, dtype=torch.long)
|
|
for name, a0, a1, b0, b1 in edit_opcodes:
|
|
if b0 < max_length:
|
|
if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0):
|
|
# these tokens have not been edited
|
|
indices[b0:b1] = indices_target[a0:a1]
|
|
mask[b0:b1] = 1
|
|
|
|
for m in cls.get_attention_modules(model, cls.AttentionType.SELF):
|
|
m.last_attn_slice_mask = None
|
|
m.last_attn_slice_indices = None
|
|
|
|
for m in cls.get_attention_modules(model, cls.AttentionType.TOKENS):
|
|
m.last_attn_slice_mask = mask.to(device)
|
|
m.last_attn_slice_indices = indices.to(device)
|
|
|
|
cls.inject_attention_function(model)
|
|
|
|
|
|
class AttentionType(Enum):
|
|
SELF = 1
|
|
TOKENS = 2
|
|
|
|
|
|
@classmethod
|
|
def get_attention_modules(cls, model, which: AttentionType):
|
|
which_attn = "attn1" if which is cls.AttentionType.SELF else "attn2"
|
|
return [module for name, module in model.named_modules() if
|
|
type(module).__name__ == "CrossAttention" and which_attn in name]
|
|
|
|
@classmethod
|
|
def clear_requests(cls, model):
|
|
self_attention_modules = cls.get_attention_modules(model, cls.AttentionType.SELF)
|
|
tokens_attention_modules = cls.get_attention_modules(model, cls.AttentionType.TOKENS)
|
|
for m in self_attention_modules+tokens_attention_modules:
|
|
m.save_last_attn_slice = False
|
|
m.use_last_attn_slice = False
|
|
|
|
@classmethod
|
|
def request_save_attention_maps(cls, model):
|
|
self_attention_modules = cls.get_attention_modules(model, cls.AttentionType.SELF)
|
|
tokens_attention_modules = cls.get_attention_modules(model, cls.AttentionType.TOKENS)
|
|
for m in self_attention_modules+tokens_attention_modules:
|
|
# clear out the saved slice in case the outermost dim changes
|
|
m.last_attn_slice = None
|
|
m.save_last_attn_slice = True
|
|
|
|
@classmethod
|
|
def request_apply_saved_attention_maps(cls, model):
|
|
self_attention_modules = cls.get_attention_modules(model, cls.AttentionType.SELF)
|
|
tokens_attention_modules = cls.get_attention_modules(model, cls.AttentionType.TOKENS)
|
|
for m in self_attention_modules+tokens_attention_modules:
|
|
m.use_last_attn_slice = True
|
|
|
|
|
|
|
|
@classmethod
|
|
def inject_attention_function(cls, unet):
|
|
# ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
|
|
|
|
def attention_slice_wrangler(self, attention_scores, suggested_attention_slice, dim, offset, slice_size):
|
|
|
|
#print("in wrangler with suggested_attention_slice shape", suggested_attention_slice.shape, "dim", dim)
|
|
|
|
attn_slice = suggested_attention_slice
|
|
if dim is not None:
|
|
start = offset
|
|
end = start+slice_size
|
|
#print(f"in wrangler, sliced dim {dim} {start}-{end}, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}")
|
|
#else:
|
|
# print(f"in wrangler, whole, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}")
|
|
|
|
|
|
if self.use_last_attn_slice:
|
|
this_attn_slice = attn_slice
|
|
if self.last_attn_slice_mask is not None:
|
|
# indices and mask operate on dim=2, no need to slice
|
|
base_attn_slice_full = torch.index_select(self.last_attn_slice, -1, self.last_attn_slice_indices)
|
|
base_attn_slice_mask = self.last_attn_slice_mask
|
|
if dim is None:
|
|
base_attn_slice = base_attn_slice_full
|
|
#print("using whole base slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape)
|
|
elif dim == 0:
|
|
base_attn_slice = base_attn_slice_full[start:end]
|
|
#print("using base dim 0 slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape)
|
|
elif dim == 1:
|
|
base_attn_slice = base_attn_slice_full[:, start:end]
|
|
#print("using base dim 1 slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape)
|
|
|
|
attn_slice = this_attn_slice * (1 - base_attn_slice_mask) + \
|
|
base_attn_slice * base_attn_slice_mask
|
|
else:
|
|
if dim is None:
|
|
attn_slice = self.last_attn_slice
|
|
#print("took whole slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape)
|
|
elif dim == 0:
|
|
attn_slice = self.last_attn_slice[start:end]
|
|
#print("took dim 0 slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape)
|
|
elif dim == 1:
|
|
attn_slice = self.last_attn_slice[:, start:end]
|
|
#print("took dim 1 slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape)
|
|
|
|
if self.save_last_attn_slice:
|
|
if dim is None:
|
|
self.last_attn_slice = attn_slice
|
|
elif dim == 0:
|
|
# dynamically grow last_attn_slice if needed
|
|
if self.last_attn_slice is None:
|
|
self.last_attn_slice = attn_slice
|
|
#print("no last_attn_slice: shape now", self.last_attn_slice.shape)
|
|
elif self.last_attn_slice.shape[0] == start:
|
|
self.last_attn_slice = torch.cat([self.last_attn_slice, attn_slice], dim=0)
|
|
assert(self.last_attn_slice.shape[0] == end)
|
|
#print("last_attn_slice too small, appended dim 0 shape", attn_slice.shape, ", shape now", self.last_attn_slice.shape)
|
|
else:
|
|
# no need to grow
|
|
self.last_attn_slice[start:end] = attn_slice
|
|
#print("last_attn_slice shape is fine, setting dim 0 shape", attn_slice.shape, ", shape now", self.last_attn_slice.shape)
|
|
|
|
elif dim == 1:
|
|
# dynamically grow last_attn_slice if needed
|
|
if self.last_attn_slice is None:
|
|
self.last_attn_slice = attn_slice
|
|
elif self.last_attn_slice.shape[1] == start:
|
|
self.last_attn_slice = torch.cat([self.last_attn_slice, attn_slice], dim=1)
|
|
assert(self.last_attn_slice.shape[1] == end)
|
|
else:
|
|
# no need to grow
|
|
self.last_attn_slice[:, start:end] = attn_slice
|
|
|
|
if self.use_last_attn_weights and self.last_attn_slice_weights is not None:
|
|
if dim is None:
|
|
weights = self.last_attn_slice_weights
|
|
elif dim == 0:
|
|
weights = self.last_attn_slice_weights[start:end]
|
|
elif dim == 1:
|
|
weights = self.last_attn_slice_weights[:, start:end]
|
|
attn_slice = attn_slice * weights
|
|
|
|
return attn_slice
|
|
|
|
for name, module in unet.named_modules():
|
|
module_name = type(module).__name__
|
|
if module_name == "CrossAttention":
|
|
module.last_attn_slice = None
|
|
module.last_attn_slice_indices = None
|
|
module.last_attn_slice_mask = None
|
|
module.use_last_attn_weights = False
|
|
module.use_last_attn_slice = False
|
|
module.save_last_attn_slice = False
|
|
module.set_attention_slice_wrangler(attention_slice_wrangler)
|
|
|
|
@classmethod
|
|
def remove_attention_function(cls, unet):
|
|
for name, module in unet.named_modules():
|
|
module_name = type(module).__name__
|
|
if module_name == "CrossAttention":
|
|
module.set_attention_slice_wrangler(None)
|
|
|
|
|
|
# original code below
|
|
|
|
# Functions supporting Cross-Attention Control
|
|
# Copied from https://github.com/bloc97/CrossAttentionControl
|
|
|
|
from difflib import SequenceMatcher
|
|
|
|
import torch
|
|
|
|
|
|
def prompt_token(prompt, index, clip_tokenizer):
|
|
tokens = clip_tokenizer(prompt,
|
|
padding='max_length',
|
|
max_length=clip_tokenizer.model_max_length,
|
|
truncation=True,
|
|
return_tensors='pt',
|
|
return_overflowing_tokens=True
|
|
).input_ids[0]
|
|
return clip_tokenizer.decode(tokens[index:index + 1])
|
|
|
|
|
|
def use_last_tokens_attention(unet, use=True):
|
|
for name, module in unet.named_modules():
|
|
module_name = type(module).__name__
|
|
if module_name == 'CrossAttention' and 'attn2' in name:
|
|
module.use_last_attn_slice = use
|
|
|
|
|
|
def use_last_tokens_attention_weights(unet, use=True):
|
|
for name, module in unet.named_modules():
|
|
module_name = type(module).__name__
|
|
if module_name == 'CrossAttention' and 'attn2' in name:
|
|
module.use_last_attn_weights = use
|
|
|
|
|
|
def use_last_self_attention(unet, use=True):
|
|
for name, module in unet.named_modules():
|
|
module_name = type(module).__name__
|
|
if module_name == 'CrossAttention' and 'attn1' in name:
|
|
module.use_last_attn_slice = use
|
|
|
|
|
|
def save_last_tokens_attention(unet, save=True):
|
|
for name, module in unet.named_modules():
|
|
module_name = type(module).__name__
|
|
if module_name == 'CrossAttention' and 'attn2' in name:
|
|
module.save_last_attn_slice = save
|
|
|
|
|
|
def save_last_self_attention(unet, save=True):
|
|
for name, module in unet.named_modules():
|
|
module_name = type(module).__name__
|
|
if module_name == 'CrossAttention' and 'attn1' in name:
|
|
module.save_last_attn_slice = save
|