mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
cross attention control options
This commit is contained in:
parent
8273c04575
commit
7d677a63b8
@ -400,7 +400,7 @@ class Generate:
|
||||
mask_image = None
|
||||
|
||||
try:
|
||||
uc, c, ec, ec_index_map = get_uc_and_c_and_ec(
|
||||
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(
|
||||
prompt, model =self.model,
|
||||
skip_normalize=skip_normalize,
|
||||
log_tokens =self.log_tokenization
|
||||
@ -438,7 +438,7 @@ class Generate:
|
||||
sampler=self.sampler,
|
||||
steps=steps,
|
||||
cfg_scale=cfg_scale,
|
||||
conditioning=(uc, c, ec, ec_index_map),
|
||||
conditioning=(uc, c, extra_conditioning_info),
|
||||
ddim_eta=ddim_eta,
|
||||
image_callback=image_callback, # called after the final image is generated
|
||||
step_callback=step_callback, # called after each intermediate image is generated
|
||||
@ -541,8 +541,8 @@ class Generate:
|
||||
image = Image.open(image_path)
|
||||
|
||||
# used by multiple postfixers
|
||||
# todo: cross-attention
|
||||
uc, c, _, _ = get_uc_and_c_and_ec(
|
||||
# todo: cross-attention control
|
||||
uc, c, _ = get_uc_and_c_and_ec(
|
||||
prompt, model =self.model,
|
||||
skip_normalize=opt.skip_normalize,
|
||||
log_tokens =opt.log_tokenization
|
||||
|
@ -17,6 +17,7 @@ import torch
|
||||
|
||||
from .prompt_parser import PromptParser, Blend, FlattenedPrompt, \
|
||||
CrossAttentionControlledFragment, CrossAttentionControlSubstitute, CrossAttentionControlAppend, Fragment
|
||||
from ..models.diffusion.cross_attention_control import CrossAttentionControl
|
||||
from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder
|
||||
|
||||
@ -46,8 +47,7 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
|
||||
print("parsed prompt to", parsed_prompt)
|
||||
|
||||
conditioning = None
|
||||
edited_conditioning = None
|
||||
edit_opcodes = None
|
||||
cac_args:CrossAttentionControl.Arguments = None
|
||||
|
||||
if type(parsed_prompt) is Blend:
|
||||
blend: Blend = parsed_prompt
|
||||
@ -98,21 +98,31 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
|
||||
original_token_count += count
|
||||
edited_token_count += count
|
||||
original_embeddings, original_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, original_prompt)
|
||||
# naïvely building a single edited_embeddings like this disregards the effects of changing the absolute location of
|
||||
# subsequent tokens when there is >1 edit and earlier edits change the total token count.
|
||||
# eg "a cat.swap(smiling dog, s_start=0.5) eating a hotdog.swap(pizza)" - when the 'pizza' edit is active but the
|
||||
# 'cat' edit is not, the 'pizza' feature vector will nevertheless be affected by the introduction of the extra
|
||||
# token 'smiling' in the inactive 'cat' edit.
|
||||
# todo: build multiple edited_embeddings, one for each edit, and pass just the edited fragments through to the CrossAttentionControl functions
|
||||
edited_embeddings, edited_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, edited_prompt)
|
||||
|
||||
conditioning = original_embeddings
|
||||
edited_conditioning = edited_embeddings
|
||||
print('got edit_opcodes', edit_opcodes, 'options', edit_options)
|
||||
cac_args = CrossAttentionControl.Arguments(
|
||||
edited_conditioning = edited_conditioning,
|
||||
edit_opcodes = edit_opcodes,
|
||||
edit_options = edit_options
|
||||
)
|
||||
else:
|
||||
conditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt)
|
||||
|
||||
|
||||
unconditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, parsed_negative_prompt)
|
||||
return (
|
||||
unconditioning, conditioning, edited_conditioning, edit_opcodes
|
||||
#InvokeAIDiffuserComponent.ExtraConditioningInfo(edited_conditioning=edited_conditioning,
|
||||
# edit_opcodes=edit_opcodes,
|
||||
# edit_options=edit_options)
|
||||
unconditioning, conditioning, InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
||||
cross_attention_control_args=cac_args
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
|
@ -33,8 +33,7 @@ class Img2Img(Generator):
|
||||
) # move to latent space
|
||||
|
||||
t_enc = int(strength * steps)
|
||||
uc, c, ec, edit_opcodes = conditioning
|
||||
extra_conditioning_info = InvokeAIDiffuserComponent.ExtraConditioningInfo(edited_conditioning=ec, edit_opcodes=edit_opcodes)
|
||||
uc, c, extra_conditioning_info = conditioning
|
||||
|
||||
def make_image(x_T):
|
||||
# encode (scaled latent)
|
||||
|
@ -46,7 +46,7 @@ class Inpaint(Img2Img):
|
||||
|
||||
t_enc = int(strength * steps)
|
||||
# todo: support cross-attention control
|
||||
uc, c, _, _ = conditioning
|
||||
uc, c, _ = conditioning
|
||||
|
||||
print(f">> target t_enc is {t_enc} steps")
|
||||
|
||||
|
@ -21,8 +21,7 @@ class Txt2Img(Generator):
|
||||
kwargs are 'width' and 'height'
|
||||
"""
|
||||
self.perlin = perlin
|
||||
uc, c, ec, edit_opcodes = conditioning
|
||||
extra_conditioning_info = InvokeAIDiffuserComponent.ExtraConditioningInfo(edited_conditioning=ec, edit_opcodes=edit_opcodes)
|
||||
uc, c, extra_conditioning_info = conditioning
|
||||
|
||||
@torch.no_grad()
|
||||
def make_image(x_T):
|
||||
|
@ -23,8 +23,7 @@ class Txt2Img2Img(Generator):
|
||||
Return value depends on the seed at the time you call it
|
||||
kwargs are 'width' and 'height'
|
||||
"""
|
||||
uc, c, ec, edit_opcodes = conditioning
|
||||
extra_conditioning_info = InvokeAIDiffuserComponent.ExtraConditioningInfo(edited_conditioning=ec, edit_opcodes=edit_opcodes)
|
||||
uc, c, extra_conditioing_info = conditioning
|
||||
|
||||
@torch.no_grad()
|
||||
def make_image(x_T):
|
||||
|
236
ldm/models/diffusion/cross_attention_control.py
Normal file
236
ldm/models/diffusion/cross_attention_control.py
Normal file
@ -0,0 +1,236 @@
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
|
||||
# adapted from bloc97's CrossAttentionControl colab
|
||||
# https://github.com/bloc97/CrossAttentionControl
|
||||
|
||||
class CrossAttentionControl:
|
||||
|
||||
class Arguments:
|
||||
def __init__(self, edited_conditioning: torch.Tensor, edit_opcodes: list[tuple], edit_options: dict):
|
||||
"""
|
||||
:param edited_conditioning: if doing cross-attention control, the edited conditioning [1 x 77 x 768]
|
||||
:param edit_opcodes: if doing cross-attention control, a list of difflib.SequenceMatcher-like opcodes describing how to map original conditioning tokens to edited conditioning tokens (only the 'equal' opcode is required)
|
||||
:param edit_options: if doing cross-attention control, per-edit options. there should be 1 item in edit_options for each item in edit_opcodes.
|
||||
"""
|
||||
# todo: rewrite this to take embedding fragments rather than a single edited_conditioning vector
|
||||
self.edited_conditioning = edited_conditioning
|
||||
self.edit_opcodes = edit_opcodes
|
||||
|
||||
if edited_conditioning is not None:
|
||||
assert len(edit_opcodes) == len(edit_options), \
|
||||
"there must be 1 edit_options dict for each edit_opcodes tuple"
|
||||
non_none_edit_options = [x for x in edit_options if x is not None]
|
||||
assert len(non_none_edit_options)>0, "missing edit_options"
|
||||
if len(non_none_edit_options)>1:
|
||||
print('warning: cross-attention control options are not working properly for >1 edit')
|
||||
self.edit_options = non_none_edit_options[0]
|
||||
|
||||
class Context:
|
||||
def __init__(self, arguments: 'CrossAttentionControl.Arguments', step_count: int):
|
||||
self.arguments = arguments
|
||||
self.step_count = step_count
|
||||
|
||||
@classmethod
|
||||
def remove_cross_attention_control(cls, model):
|
||||
cls.remove_attention_function(model)
|
||||
|
||||
@classmethod
|
||||
def setup_cross_attention_control(cls, model,
|
||||
cross_attention_control_args: Arguments
|
||||
):
|
||||
"""
|
||||
Inject attention parameters and functions into the passed in model to enable cross attention editing.
|
||||
|
||||
:param model: The unet model to inject into.
|
||||
:param cross_attention_control_args: Arugments passeed to the CrossAttentionControl implementations
|
||||
:return: None
|
||||
"""
|
||||
|
||||
# adapted from init_attention_edit
|
||||
device = cross_attention_control_args.edited_conditioning.device
|
||||
|
||||
# urgh. should this be hardcoded?
|
||||
max_length = 77
|
||||
# mask=1 means use base prompt attention, mask=0 means use edited prompt attention
|
||||
mask = torch.zeros(max_length)
|
||||
indices_target = torch.arange(max_length, dtype=torch.long)
|
||||
indices = torch.zeros(max_length, dtype=torch.long)
|
||||
for name, a0, a1, b0, b1 in cross_attention_control_args.edit_opcodes:
|
||||
if b0 < max_length:
|
||||
if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0):
|
||||
# these tokens have not been edited
|
||||
indices[b0:b1] = indices_target[a0:a1]
|
||||
mask[b0:b1] = 1
|
||||
|
||||
for m in cls.get_attention_modules(model, cls.CrossAttentionType.SELF):
|
||||
m.last_attn_slice_mask = None
|
||||
m.last_attn_slice_indices = None
|
||||
|
||||
for m in cls.get_attention_modules(model, cls.CrossAttentionType.TOKENS):
|
||||
m.last_attn_slice_mask = mask.to(device)
|
||||
m.last_attn_slice_indices = indices.to(device)
|
||||
|
||||
cls.inject_attention_function(model)
|
||||
|
||||
|
||||
class CrossAttentionType(Enum):
|
||||
SELF = 1
|
||||
TOKENS = 2
|
||||
|
||||
@classmethod
|
||||
def get_active_cross_attention_control_types_for_step(cls, context: 'CrossAttentionControl.Context', step_index:int=None)\
|
||||
-> list['CrossAttentionControl.CrossAttentionType']:
|
||||
"""
|
||||
Should cross-attention control be applied on the given step?
|
||||
:param step_index: The step index (counts upwards from 0), or None if unknown.
|
||||
:return: A list of attention types that cross-attention control should be performed for on the given step. May be [].
|
||||
"""
|
||||
if step_index is None:
|
||||
return [cls.CrossAttentionType.SELF, cls.CrossAttentionType.TOKENS]
|
||||
|
||||
opts = context.arguments.edit_options
|
||||
# percent_through will never reach 1.0 (but this is intended)
|
||||
percent_through = float(step_index)/float(context.step_count)
|
||||
to_control = []
|
||||
if opts['s_start'] <= percent_through and percent_through < opts['s_end']:
|
||||
to_control.append(cls.CrossAttentionType.SELF)
|
||||
if opts['t_start'] <= percent_through and percent_through < opts['t_end']:
|
||||
to_control.append(cls.CrossAttentionType.TOKENS)
|
||||
return to_control
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_attention_modules(cls, model, which: CrossAttentionType):
|
||||
which_attn = "attn1" if which is cls.CrossAttentionType.SELF else "attn2"
|
||||
return [module for name, module in model.named_modules() if
|
||||
type(module).__name__ == "CrossAttention" and which_attn in name]
|
||||
|
||||
@classmethod
|
||||
def clear_requests(cls, model):
|
||||
self_attention_modules = cls.get_attention_modules(model, cls.CrossAttentionType.SELF)
|
||||
tokens_attention_modules = cls.get_attention_modules(model, cls.CrossAttentionType.TOKENS)
|
||||
for m in self_attention_modules+tokens_attention_modules:
|
||||
m.save_last_attn_slice = False
|
||||
m.use_last_attn_slice = False
|
||||
|
||||
@classmethod
|
||||
def request_save_attention_maps(cls, model, cross_attention_type: CrossAttentionType):
|
||||
modules = cls.get_attention_modules(model, cross_attention_type)
|
||||
for m in modules:
|
||||
# clear out the saved slice in case the outermost dim changes
|
||||
m.last_attn_slice = None
|
||||
m.save_last_attn_slice = True
|
||||
|
||||
@classmethod
|
||||
def request_apply_saved_attention_maps(cls, model, cross_attention_type: CrossAttentionType):
|
||||
modules = cls.get_attention_modules(model, cross_attention_type)
|
||||
for m in modules:
|
||||
m.use_last_attn_slice = True
|
||||
|
||||
|
||||
|
||||
@classmethod
|
||||
def inject_attention_function(cls, unet):
|
||||
# ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
|
||||
|
||||
def attention_slice_wrangler(self, attention_scores, suggested_attention_slice, dim, offset, slice_size):
|
||||
|
||||
#print("in wrangler with suggested_attention_slice shape", suggested_attention_slice.shape, "dim", dim)
|
||||
|
||||
attn_slice = suggested_attention_slice
|
||||
if dim is not None:
|
||||
start = offset
|
||||
end = start+slice_size
|
||||
#print(f"in wrangler, sliced dim {dim} {start}-{end}, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}")
|
||||
#else:
|
||||
# print(f"in wrangler, whole, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}")
|
||||
|
||||
|
||||
if self.use_last_attn_slice:
|
||||
this_attn_slice = attn_slice
|
||||
if self.last_attn_slice_mask is not None:
|
||||
# indices and mask operate on dim=2, no need to slice
|
||||
base_attn_slice_full = torch.index_select(self.last_attn_slice, -1, self.last_attn_slice_indices)
|
||||
base_attn_slice_mask = self.last_attn_slice_mask
|
||||
if dim is None:
|
||||
base_attn_slice = base_attn_slice_full
|
||||
#print("using whole base slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape)
|
||||
elif dim == 0:
|
||||
base_attn_slice = base_attn_slice_full[start:end]
|
||||
#print("using base dim 0 slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape)
|
||||
elif dim == 1:
|
||||
base_attn_slice = base_attn_slice_full[:, start:end]
|
||||
#print("using base dim 1 slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape)
|
||||
|
||||
attn_slice = this_attn_slice * (1 - base_attn_slice_mask) + \
|
||||
base_attn_slice * base_attn_slice_mask
|
||||
else:
|
||||
if dim is None:
|
||||
attn_slice = self.last_attn_slice
|
||||
#print("took whole slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape)
|
||||
elif dim == 0:
|
||||
attn_slice = self.last_attn_slice[start:end]
|
||||
#print("took dim 0 slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape)
|
||||
elif dim == 1:
|
||||
attn_slice = self.last_attn_slice[:, start:end]
|
||||
#print("took dim 1 slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape)
|
||||
|
||||
if self.save_last_attn_slice:
|
||||
if dim is None:
|
||||
self.last_attn_slice = attn_slice
|
||||
elif dim == 0:
|
||||
# dynamically grow last_attn_slice if needed
|
||||
if self.last_attn_slice is None:
|
||||
self.last_attn_slice = attn_slice
|
||||
#print("no last_attn_slice: shape now", self.last_attn_slice.shape)
|
||||
elif self.last_attn_slice.shape[0] == start:
|
||||
self.last_attn_slice = torch.cat([self.last_attn_slice, attn_slice], dim=0)
|
||||
assert(self.last_attn_slice.shape[0] == end)
|
||||
#print("last_attn_slice too small, appended dim 0 shape", attn_slice.shape, ", shape now", self.last_attn_slice.shape)
|
||||
else:
|
||||
# no need to grow
|
||||
self.last_attn_slice[start:end] = attn_slice
|
||||
#print("last_attn_slice shape is fine, setting dim 0 shape", attn_slice.shape, ", shape now", self.last_attn_slice.shape)
|
||||
|
||||
elif dim == 1:
|
||||
# dynamically grow last_attn_slice if needed
|
||||
if self.last_attn_slice is None:
|
||||
self.last_attn_slice = attn_slice
|
||||
elif self.last_attn_slice.shape[1] == start:
|
||||
self.last_attn_slice = torch.cat([self.last_attn_slice, attn_slice], dim=1)
|
||||
assert(self.last_attn_slice.shape[1] == end)
|
||||
else:
|
||||
# no need to grow
|
||||
self.last_attn_slice[:, start:end] = attn_slice
|
||||
|
||||
if self.use_last_attn_weights and self.last_attn_slice_weights is not None:
|
||||
if dim is None:
|
||||
weights = self.last_attn_slice_weights
|
||||
elif dim == 0:
|
||||
weights = self.last_attn_slice_weights[start:end]
|
||||
elif dim == 1:
|
||||
weights = self.last_attn_slice_weights[:, start:end]
|
||||
attn_slice = attn_slice * weights
|
||||
|
||||
return attn_slice
|
||||
|
||||
for name, module in unet.named_modules():
|
||||
module_name = type(module).__name__
|
||||
if module_name == "CrossAttention":
|
||||
module.last_attn_slice = None
|
||||
module.last_attn_slice_indices = None
|
||||
module.last_attn_slice_mask = None
|
||||
module.use_last_attn_weights = False
|
||||
module.use_last_attn_slice = False
|
||||
module.save_last_attn_slice = False
|
||||
module.set_attention_slice_wrangler(attention_slice_wrangler)
|
||||
|
||||
@classmethod
|
||||
def remove_attention_function(cls, unet):
|
||||
for name, module in unet.named_modules():
|
||||
module_name = type(module).__name__
|
||||
if module_name == "CrossAttention":
|
||||
module.set_attention_slice_wrangler(None)
|
||||
|
@ -18,7 +18,7 @@ class DDIMSampler(Sampler):
|
||||
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
|
||||
|
||||
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
||||
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info)
|
||||
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = t_enc)
|
||||
else:
|
||||
self.invokeai_diffuser.remove_cross_attention_control()
|
||||
|
||||
@ -40,6 +40,7 @@ class DDIMSampler(Sampler):
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
step_count:int=1000, # total number of steps
|
||||
**kwargs,
|
||||
):
|
||||
b, *_, device = *x.shape, x.device
|
||||
@ -51,7 +52,11 @@ class DDIMSampler(Sampler):
|
||||
# damian0815 would like to know when/if this code path is used
|
||||
e_t = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
e_t = self.invokeai_diffuser.do_diffusion_step(x, t, unconditional_conditioning, c, unconditional_guidance_scale)
|
||||
step_index = step_count-(index+1)
|
||||
e_t = self.invokeai_diffuser.do_diffusion_step(x, t,
|
||||
unconditional_conditioning, c,
|
||||
unconditional_guidance_scale,
|
||||
step_index=step_index)
|
||||
|
||||
if score_corrector is not None:
|
||||
assert self.model.parameterization == 'eps'
|
||||
|
@ -37,14 +37,14 @@ class CFGDenoiser(nn.Module):
|
||||
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
|
||||
|
||||
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
||||
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info)
|
||||
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = t_enc)
|
||||
else:
|
||||
self.invokeai_diffuser.remove_cross_attention_control()
|
||||
|
||||
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale, step_index):
|
||||
|
||||
next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale)
|
||||
next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale, step_index)
|
||||
|
||||
# apply threshold
|
||||
if self.warmup < self.warmup_max:
|
||||
|
@ -23,7 +23,7 @@ class PLMSSampler(Sampler):
|
||||
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
|
||||
|
||||
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
||||
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info)
|
||||
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = t_enc)
|
||||
else:
|
||||
self.invokeai_diffuser.remove_cross_attention_control()
|
||||
|
||||
@ -47,6 +47,7 @@ class PLMSSampler(Sampler):
|
||||
unconditional_conditioning=None,
|
||||
old_eps=[],
|
||||
t_next=None,
|
||||
step_count:int=1000, # total number of steps
|
||||
**kwargs,
|
||||
):
|
||||
b, *_, device = *x.shape, x.device
|
||||
@ -59,7 +60,13 @@ class PLMSSampler(Sampler):
|
||||
# damian0815 would like to know when/if this code path is used
|
||||
e_t = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
e_t = self.invokeai_diffuser.do_diffusion_step(x, t, unconditional_conditioning, c, unconditional_guidance_scale)
|
||||
# step_index is expected to count up while index counts down
|
||||
step_index = step_count-(index+1)
|
||||
# note that step_index == 0 is evaluated twice with different x
|
||||
e_t = self.invokeai_diffuser.do_diffusion_step(x, t,
|
||||
unconditional_conditioning, c,
|
||||
unconditional_guidance_scale,
|
||||
step_index=step_index)
|
||||
|
||||
if score_corrector is not None:
|
||||
assert self.model.parameterization == 'eps'
|
||||
|
@ -278,6 +278,7 @@ class Sampler(object):
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
old_eps=old_eps,
|
||||
t_next=ts_next,
|
||||
step_count=steps
|
||||
)
|
||||
img, pred_x0, e_t = outs
|
||||
|
||||
|
@ -1,9 +1,11 @@
|
||||
from enum import Enum
|
||||
from math import ceil
|
||||
from typing import Callable
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from ldm.models.diffusion.cross_attention_control import CrossAttentionControl
|
||||
|
||||
|
||||
class InvokeAIDiffuserComponent:
|
||||
'''
|
||||
@ -16,19 +18,16 @@ class InvokeAIDiffuserComponent:
|
||||
|
||||
|
||||
class ExtraConditioningInfo:
|
||||
def __init__(self, edited_conditioning: torch.Tensor = None, edit_opcodes: list[tuple] = None):
|
||||
"""
|
||||
:param edited_conditioning: if doing cross-attention control, the edited conditioning (1 x 77 x 768)
|
||||
:param edit_opcodes: if doing cross-attention control, opcodes from a SequenceMatcher describing how to map original conditioning tokens to edited conditioning tokens
|
||||
"""
|
||||
self.edited_conditioning = edited_conditioning
|
||||
self.edit_opcodes = edit_opcodes
|
||||
def __init__(self, cross_attention_control_args: Optional[CrossAttentionControl.Arguments]):
|
||||
self.cross_attention_control_args = cross_attention_control_args
|
||||
|
||||
@property
|
||||
def wants_cross_attention_control(self):
|
||||
return self.edited_conditioning is not None
|
||||
return self.cross_attention_control_args is not None
|
||||
|
||||
def __init__(self, model, model_forward_callback: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]):
|
||||
def __init__(self, model, model_forward_callback:
|
||||
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
|
||||
):
|
||||
"""
|
||||
:param model: the unet model to pass through to cross attention control
|
||||
:param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning)
|
||||
@ -37,44 +36,53 @@ class InvokeAIDiffuserComponent:
|
||||
self.model_forward_callback = model_forward_callback
|
||||
|
||||
|
||||
def setup_cross_attention_control(self, conditioning: ExtraConditioningInfo):
|
||||
def setup_cross_attention_control(self, conditioning: ExtraConditioningInfo, step_count: int):
|
||||
self.conditioning = conditioning
|
||||
CrossAttentionControl.setup_cross_attention_control(self.model, conditioning.edited_conditioning, conditioning.edit_opcodes)
|
||||
self.cross_attention_control_context = CrossAttentionControl.Context(
|
||||
arguments=self.conditioning.cross_attention_control_args,
|
||||
step_count=step_count
|
||||
)
|
||||
CrossAttentionControl.setup_cross_attention_control(self.model,
|
||||
cross_attention_control_args=self.conditioning.cross_attention_control_args
|
||||
)
|
||||
#todo: refactor edited_conditioning, edit_opcodes, edit_options into a struct
|
||||
#todo: apply edit_options using step_count
|
||||
|
||||
|
||||
def remove_cross_attention_control(self):
|
||||
self.conditioning = None
|
||||
self.cross_attention_control_context = None
|
||||
CrossAttentionControl.remove_cross_attention_control(self.model)
|
||||
|
||||
@property
|
||||
def edited_conditioning(self):
|
||||
if self.conditioning is None:
|
||||
return None
|
||||
else:
|
||||
return self.conditioning.edited_conditioning
|
||||
|
||||
def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor,
|
||||
unconditioning: torch.Tensor, conditioning: torch.Tensor,
|
||||
unconditional_guidance_scale: float):
|
||||
unconditional_guidance_scale: float,
|
||||
step_index: int=None):
|
||||
"""
|
||||
:param x: Current latents
|
||||
:param sigma: aka t, passed to the internal model to control how much denoising will occur
|
||||
:param unconditioning: [B x 77 x 768] embeddings for unconditioned output
|
||||
:param conditioning: [B x 77 x 768] embeddings for conditioned output
|
||||
:param unconditional_guidance_scale: aka CFG scale, controls how much effect the conditioning tensor has
|
||||
:param model: the unet model to pass through to cross attention control
|
||||
:param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning)
|
||||
:return: the new latents after applying the model to x using unconditioning and CFG-scaled conditioning.
|
||||
:param step_index: Counts upwards from 0 to (step_count-1) (as passed to setup_cross_attention_control, if using). May be called multiple times for a single step, therefore do not assume that its value will monotically increase.
|
||||
:return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning.
|
||||
"""
|
||||
|
||||
CrossAttentionControl.clear_requests(self.model)
|
||||
cross_attention_control_types_to_do = []
|
||||
|
||||
if self.edited_conditioning is None:
|
||||
if self.cross_attention_control_context is not None:
|
||||
cross_attention_control_types_to_do = CrossAttentionControl.get_active_cross_attention_control_types_for_step(self.cross_attention_control_context, step_index)
|
||||
|
||||
if len(cross_attention_control_types_to_do)==0:
|
||||
print('step', step_index, ': not doing cross attention control')
|
||||
# faster batched path
|
||||
x_twice = torch.cat([x]*2)
|
||||
sigma_twice = torch.cat([sigma]*2)
|
||||
both_conditionings = torch.cat([unconditioning, conditioning])
|
||||
unconditioned_next_x, conditioned_next_x = self.model_forward_callback(x_twice, sigma_twice, both_conditionings).chunk(2)
|
||||
else:
|
||||
print('step', step_index, ': doing cross attention control on', cross_attention_control_types_to_do)
|
||||
# slower non-batched path (20% slower on mac MPS)
|
||||
# We are only interested in using attention maps for conditioned_next_x, but batching them with generation of
|
||||
# unconditioned_next_x causes attention maps to *also* be saved for the unconditioned_next_x.
|
||||
@ -86,13 +94,16 @@ class InvokeAIDiffuserComponent:
|
||||
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning)
|
||||
|
||||
# process x using the original prompt, saving the attention maps
|
||||
CrossAttentionControl.request_save_attention_maps(self.model)
|
||||
for type in cross_attention_control_types_to_do:
|
||||
CrossAttentionControl.request_save_attention_maps(self.model, type)
|
||||
_ = self.model_forward_callback(x, sigma, conditioning)
|
||||
CrossAttentionControl.clear_requests(self.model)
|
||||
|
||||
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied
|
||||
CrossAttentionControl.request_apply_saved_attention_maps(self.model)
|
||||
conditioned_next_x = self.model_forward_callback(x, sigma, self.edited_conditioning)
|
||||
for type in cross_attention_control_types_to_do:
|
||||
CrossAttentionControl.request_apply_saved_attention_maps(self.model, type)
|
||||
edited_conditioning = self.conditioning.cross_attention_control_args.edited_conditioning
|
||||
conditioned_next_x = self.model_forward_callback(x, sigma, edited_conditioning)
|
||||
CrossAttentionControl.clear_requests(self.model)
|
||||
|
||||
|
||||
@ -102,7 +113,6 @@ class InvokeAIDiffuserComponent:
|
||||
|
||||
return combined_next_x
|
||||
|
||||
|
||||
# todo: make this work
|
||||
@classmethod
|
||||
def apply_conjunction(cls, x, t, forward_func, uc, c_or_weighted_c_list, global_guidance_scale):
|
||||
@ -153,250 +163,3 @@ class InvokeAIDiffuserComponent:
|
||||
|
||||
return uncond_latents + deltas_merged * global_guidance_scale
|
||||
|
||||
|
||||
# adapted from bloc97's CrossAttentionControl colab
|
||||
# https://github.com/bloc97/CrossAttentionControl
|
||||
|
||||
class CrossAttentionControl:
|
||||
|
||||
|
||||
@classmethod
|
||||
def remove_cross_attention_control(cls, model):
|
||||
cls.remove_attention_function(model)
|
||||
|
||||
@classmethod
|
||||
def setup_cross_attention_control(cls, model,
|
||||
substitute_conditioning: torch.Tensor,
|
||||
edit_opcodes: list):
|
||||
"""
|
||||
Inject attention parameters and functions into the passed in model to enable cross attention editing.
|
||||
|
||||
:param model: The unet model to inject into.
|
||||
:param substitute_conditioning: The "edited" conditioning vector, [Bx77x768]
|
||||
:param edit_opcodes: Opcodes from difflib.SequenceMatcher describing how the base
|
||||
conditionings map to the "edited" conditionings.
|
||||
:return:
|
||||
"""
|
||||
|
||||
# adapted from init_attention_edit
|
||||
device = substitute_conditioning.device
|
||||
|
||||
# urgh. should this be hardcoded?
|
||||
max_length = 77
|
||||
# mask=1 means use base prompt attention, mask=0 means use edited prompt attention
|
||||
mask = torch.zeros(max_length)
|
||||
indices_target = torch.arange(max_length, dtype=torch.long)
|
||||
indices = torch.zeros(max_length, dtype=torch.long)
|
||||
for name, a0, a1, b0, b1 in edit_opcodes:
|
||||
if b0 < max_length:
|
||||
if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0):
|
||||
# these tokens have not been edited
|
||||
indices[b0:b1] = indices_target[a0:a1]
|
||||
mask[b0:b1] = 1
|
||||
|
||||
for m in cls.get_attention_modules(model, cls.AttentionType.SELF):
|
||||
m.last_attn_slice_mask = None
|
||||
m.last_attn_slice_indices = None
|
||||
|
||||
for m in cls.get_attention_modules(model, cls.AttentionType.TOKENS):
|
||||
m.last_attn_slice_mask = mask.to(device)
|
||||
m.last_attn_slice_indices = indices.to(device)
|
||||
|
||||
cls.inject_attention_function(model)
|
||||
|
||||
|
||||
class AttentionType(Enum):
|
||||
SELF = 1
|
||||
TOKENS = 2
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_attention_modules(cls, model, which: AttentionType):
|
||||
which_attn = "attn1" if which is cls.AttentionType.SELF else "attn2"
|
||||
return [module for name, module in model.named_modules() if
|
||||
type(module).__name__ == "CrossAttention" and which_attn in name]
|
||||
|
||||
@classmethod
|
||||
def clear_requests(cls, model):
|
||||
self_attention_modules = cls.get_attention_modules(model, cls.AttentionType.SELF)
|
||||
tokens_attention_modules = cls.get_attention_modules(model, cls.AttentionType.TOKENS)
|
||||
for m in self_attention_modules+tokens_attention_modules:
|
||||
m.save_last_attn_slice = False
|
||||
m.use_last_attn_slice = False
|
||||
|
||||
@classmethod
|
||||
def request_save_attention_maps(cls, model):
|
||||
self_attention_modules = cls.get_attention_modules(model, cls.AttentionType.SELF)
|
||||
tokens_attention_modules = cls.get_attention_modules(model, cls.AttentionType.TOKENS)
|
||||
for m in self_attention_modules+tokens_attention_modules:
|
||||
# clear out the saved slice in case the outermost dim changes
|
||||
m.last_attn_slice = None
|
||||
m.save_last_attn_slice = True
|
||||
|
||||
@classmethod
|
||||
def request_apply_saved_attention_maps(cls, model):
|
||||
self_attention_modules = cls.get_attention_modules(model, cls.AttentionType.SELF)
|
||||
tokens_attention_modules = cls.get_attention_modules(model, cls.AttentionType.TOKENS)
|
||||
for m in self_attention_modules+tokens_attention_modules:
|
||||
m.use_last_attn_slice = True
|
||||
|
||||
|
||||
|
||||
@classmethod
|
||||
def inject_attention_function(cls, unet):
|
||||
# ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
|
||||
|
||||
def attention_slice_wrangler(self, attention_scores, suggested_attention_slice, dim, offset, slice_size):
|
||||
|
||||
#print("in wrangler with suggested_attention_slice shape", suggested_attention_slice.shape, "dim", dim)
|
||||
|
||||
attn_slice = suggested_attention_slice
|
||||
if dim is not None:
|
||||
start = offset
|
||||
end = start+slice_size
|
||||
#print(f"in wrangler, sliced dim {dim} {start}-{end}, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}")
|
||||
#else:
|
||||
# print(f"in wrangler, whole, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}")
|
||||
|
||||
|
||||
if self.use_last_attn_slice:
|
||||
this_attn_slice = attn_slice
|
||||
if self.last_attn_slice_mask is not None:
|
||||
# indices and mask operate on dim=2, no need to slice
|
||||
base_attn_slice_full = torch.index_select(self.last_attn_slice, -1, self.last_attn_slice_indices)
|
||||
base_attn_slice_mask = self.last_attn_slice_mask
|
||||
if dim is None:
|
||||
base_attn_slice = base_attn_slice_full
|
||||
#print("using whole base slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape)
|
||||
elif dim == 0:
|
||||
base_attn_slice = base_attn_slice_full[start:end]
|
||||
#print("using base dim 0 slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape)
|
||||
elif dim == 1:
|
||||
base_attn_slice = base_attn_slice_full[:, start:end]
|
||||
#print("using base dim 1 slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape)
|
||||
|
||||
attn_slice = this_attn_slice * (1 - base_attn_slice_mask) + \
|
||||
base_attn_slice * base_attn_slice_mask
|
||||
else:
|
||||
if dim is None:
|
||||
attn_slice = self.last_attn_slice
|
||||
#print("took whole slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape)
|
||||
elif dim == 0:
|
||||
attn_slice = self.last_attn_slice[start:end]
|
||||
#print("took dim 0 slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape)
|
||||
elif dim == 1:
|
||||
attn_slice = self.last_attn_slice[:, start:end]
|
||||
#print("took dim 1 slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape)
|
||||
|
||||
if self.save_last_attn_slice:
|
||||
if dim is None:
|
||||
self.last_attn_slice = attn_slice
|
||||
elif dim == 0:
|
||||
# dynamically grow last_attn_slice if needed
|
||||
if self.last_attn_slice is None:
|
||||
self.last_attn_slice = attn_slice
|
||||
#print("no last_attn_slice: shape now", self.last_attn_slice.shape)
|
||||
elif self.last_attn_slice.shape[0] == start:
|
||||
self.last_attn_slice = torch.cat([self.last_attn_slice, attn_slice], dim=0)
|
||||
assert(self.last_attn_slice.shape[0] == end)
|
||||
#print("last_attn_slice too small, appended dim 0 shape", attn_slice.shape, ", shape now", self.last_attn_slice.shape)
|
||||
else:
|
||||
# no need to grow
|
||||
self.last_attn_slice[start:end] = attn_slice
|
||||
#print("last_attn_slice shape is fine, setting dim 0 shape", attn_slice.shape, ", shape now", self.last_attn_slice.shape)
|
||||
|
||||
elif dim == 1:
|
||||
# dynamically grow last_attn_slice if needed
|
||||
if self.last_attn_slice is None:
|
||||
self.last_attn_slice = attn_slice
|
||||
elif self.last_attn_slice.shape[1] == start:
|
||||
self.last_attn_slice = torch.cat([self.last_attn_slice, attn_slice], dim=1)
|
||||
assert(self.last_attn_slice.shape[1] == end)
|
||||
else:
|
||||
# no need to grow
|
||||
self.last_attn_slice[:, start:end] = attn_slice
|
||||
|
||||
if self.use_last_attn_weights and self.last_attn_slice_weights is not None:
|
||||
if dim is None:
|
||||
weights = self.last_attn_slice_weights
|
||||
elif dim == 0:
|
||||
weights = self.last_attn_slice_weights[start:end]
|
||||
elif dim == 1:
|
||||
weights = self.last_attn_slice_weights[:, start:end]
|
||||
attn_slice = attn_slice * weights
|
||||
|
||||
return attn_slice
|
||||
|
||||
for name, module in unet.named_modules():
|
||||
module_name = type(module).__name__
|
||||
if module_name == "CrossAttention":
|
||||
module.last_attn_slice = None
|
||||
module.last_attn_slice_indices = None
|
||||
module.last_attn_slice_mask = None
|
||||
module.use_last_attn_weights = False
|
||||
module.use_last_attn_slice = False
|
||||
module.save_last_attn_slice = False
|
||||
module.set_attention_slice_wrangler(attention_slice_wrangler)
|
||||
|
||||
@classmethod
|
||||
def remove_attention_function(cls, unet):
|
||||
for name, module in unet.named_modules():
|
||||
module_name = type(module).__name__
|
||||
if module_name == "CrossAttention":
|
||||
module.set_attention_slice_wrangler(None)
|
||||
|
||||
|
||||
# original code below
|
||||
|
||||
# Functions supporting Cross-Attention Control
|
||||
# Copied from https://github.com/bloc97/CrossAttentionControl
|
||||
|
||||
from difflib import SequenceMatcher
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def prompt_token(prompt, index, clip_tokenizer):
|
||||
tokens = clip_tokenizer(prompt,
|
||||
padding='max_length',
|
||||
max_length=clip_tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors='pt',
|
||||
return_overflowing_tokens=True
|
||||
).input_ids[0]
|
||||
return clip_tokenizer.decode(tokens[index:index + 1])
|
||||
|
||||
|
||||
def use_last_tokens_attention(unet, use=True):
|
||||
for name, module in unet.named_modules():
|
||||
module_name = type(module).__name__
|
||||
if module_name == 'CrossAttention' and 'attn2' in name:
|
||||
module.use_last_attn_slice = use
|
||||
|
||||
|
||||
def use_last_tokens_attention_weights(unet, use=True):
|
||||
for name, module in unet.named_modules():
|
||||
module_name = type(module).__name__
|
||||
if module_name == 'CrossAttention' and 'attn2' in name:
|
||||
module.use_last_attn_weights = use
|
||||
|
||||
|
||||
def use_last_self_attention(unet, use=True):
|
||||
for name, module in unet.named_modules():
|
||||
module_name = type(module).__name__
|
||||
if module_name == 'CrossAttention' and 'attn1' in name:
|
||||
module.use_last_attn_slice = use
|
||||
|
||||
|
||||
def save_last_tokens_attention(unet, save=True):
|
||||
for name, module in unet.named_modules():
|
||||
module_name = type(module).__name__
|
||||
if module_name == 'CrossAttention' and 'attn2' in name:
|
||||
module.save_last_attn_slice = save
|
||||
|
||||
|
||||
def save_last_self_attention(unet, save=True):
|
||||
for name, module in unet.named_modules():
|
||||
module_name = type(module).__name__
|
||||
if module_name == 'CrossAttention' and 'attn1' in name:
|
||||
module.save_last_attn_slice = save
|
||||
|
Loading…
Reference in New Issue
Block a user