cross attention control options

This commit is contained in:
Damian at mba 2022-10-23 14:58:25 +02:00
parent 8273c04575
commit 7d677a63b8
12 changed files with 318 additions and 299 deletions

View File

@ -400,7 +400,7 @@ class Generate:
mask_image = None mask_image = None
try: try:
uc, c, ec, ec_index_map = get_uc_and_c_and_ec( uc, c, extra_conditioning_info = get_uc_and_c_and_ec(
prompt, model =self.model, prompt, model =self.model,
skip_normalize=skip_normalize, skip_normalize=skip_normalize,
log_tokens =self.log_tokenization log_tokens =self.log_tokenization
@ -438,7 +438,7 @@ class Generate:
sampler=self.sampler, sampler=self.sampler,
steps=steps, steps=steps,
cfg_scale=cfg_scale, cfg_scale=cfg_scale,
conditioning=(uc, c, ec, ec_index_map), conditioning=(uc, c, extra_conditioning_info),
ddim_eta=ddim_eta, ddim_eta=ddim_eta,
image_callback=image_callback, # called after the final image is generated image_callback=image_callback, # called after the final image is generated
step_callback=step_callback, # called after each intermediate image is generated step_callback=step_callback, # called after each intermediate image is generated
@ -541,8 +541,8 @@ class Generate:
image = Image.open(image_path) image = Image.open(image_path)
# used by multiple postfixers # used by multiple postfixers
# todo: cross-attention # todo: cross-attention control
uc, c, _, _ = get_uc_and_c_and_ec( uc, c, _ = get_uc_and_c_and_ec(
prompt, model =self.model, prompt, model =self.model,
skip_normalize=opt.skip_normalize, skip_normalize=opt.skip_normalize,
log_tokens =opt.log_tokenization log_tokens =opt.log_tokenization

View File

@ -17,6 +17,7 @@ import torch
from .prompt_parser import PromptParser, Blend, FlattenedPrompt, \ from .prompt_parser import PromptParser, Blend, FlattenedPrompt, \
CrossAttentionControlledFragment, CrossAttentionControlSubstitute, CrossAttentionControlAppend, Fragment CrossAttentionControlledFragment, CrossAttentionControlSubstitute, CrossAttentionControlAppend, Fragment
from ..models.diffusion.cross_attention_control import CrossAttentionControl
from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder
@ -46,8 +47,7 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
print("parsed prompt to", parsed_prompt) print("parsed prompt to", parsed_prompt)
conditioning = None conditioning = None
edited_conditioning = None cac_args:CrossAttentionControl.Arguments = None
edit_opcodes = None
if type(parsed_prompt) is Blend: if type(parsed_prompt) is Blend:
blend: Blend = parsed_prompt blend: Blend = parsed_prompt
@ -98,21 +98,31 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
original_token_count += count original_token_count += count
edited_token_count += count edited_token_count += count
original_embeddings, original_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, original_prompt) original_embeddings, original_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, original_prompt)
# naïvely building a single edited_embeddings like this disregards the effects of changing the absolute location of
# subsequent tokens when there is >1 edit and earlier edits change the total token count.
# eg "a cat.swap(smiling dog, s_start=0.5) eating a hotdog.swap(pizza)" - when the 'pizza' edit is active but the
# 'cat' edit is not, the 'pizza' feature vector will nevertheless be affected by the introduction of the extra
# token 'smiling' in the inactive 'cat' edit.
# todo: build multiple edited_embeddings, one for each edit, and pass just the edited fragments through to the CrossAttentionControl functions
edited_embeddings, edited_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, edited_prompt) edited_embeddings, edited_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, edited_prompt)
conditioning = original_embeddings conditioning = original_embeddings
edited_conditioning = edited_embeddings edited_conditioning = edited_embeddings
print('got edit_opcodes', edit_opcodes, 'options', edit_options) print('got edit_opcodes', edit_opcodes, 'options', edit_options)
cac_args = CrossAttentionControl.Arguments(
edited_conditioning = edited_conditioning,
edit_opcodes = edit_opcodes,
edit_options = edit_options
)
else: else:
conditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt) conditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt)
unconditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, parsed_negative_prompt) unconditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, parsed_negative_prompt)
return ( return (
unconditioning, conditioning, edited_conditioning, edit_opcodes unconditioning, conditioning, InvokeAIDiffuserComponent.ExtraConditioningInfo(
#InvokeAIDiffuserComponent.ExtraConditioningInfo(edited_conditioning=edited_conditioning, cross_attention_control_args=cac_args
# edit_opcodes=edit_opcodes, )
# edit_options=edit_options)
) )

View File

@ -33,8 +33,7 @@ class Img2Img(Generator):
) # move to latent space ) # move to latent space
t_enc = int(strength * steps) t_enc = int(strength * steps)
uc, c, ec, edit_opcodes = conditioning uc, c, extra_conditioning_info = conditioning
extra_conditioning_info = InvokeAIDiffuserComponent.ExtraConditioningInfo(edited_conditioning=ec, edit_opcodes=edit_opcodes)
def make_image(x_T): def make_image(x_T):
# encode (scaled latent) # encode (scaled latent)

View File

@ -46,7 +46,7 @@ class Inpaint(Img2Img):
t_enc = int(strength * steps) t_enc = int(strength * steps)
# todo: support cross-attention control # todo: support cross-attention control
uc, c, _, _ = conditioning uc, c, _ = conditioning
print(f">> target t_enc is {t_enc} steps") print(f">> target t_enc is {t_enc} steps")

View File

@ -21,8 +21,7 @@ class Txt2Img(Generator):
kwargs are 'width' and 'height' kwargs are 'width' and 'height'
""" """
self.perlin = perlin self.perlin = perlin
uc, c, ec, edit_opcodes = conditioning uc, c, extra_conditioning_info = conditioning
extra_conditioning_info = InvokeAIDiffuserComponent.ExtraConditioningInfo(edited_conditioning=ec, edit_opcodes=edit_opcodes)
@torch.no_grad() @torch.no_grad()
def make_image(x_T): def make_image(x_T):

View File

@ -23,8 +23,7 @@ class Txt2Img2Img(Generator):
Return value depends on the seed at the time you call it Return value depends on the seed at the time you call it
kwargs are 'width' and 'height' kwargs are 'width' and 'height'
""" """
uc, c, ec, edit_opcodes = conditioning uc, c, extra_conditioing_info = conditioning
extra_conditioning_info = InvokeAIDiffuserComponent.ExtraConditioningInfo(edited_conditioning=ec, edit_opcodes=edit_opcodes)
@torch.no_grad() @torch.no_grad()
def make_image(x_T): def make_image(x_T):

View File

@ -0,0 +1,236 @@
from enum import Enum
import torch
# adapted from bloc97's CrossAttentionControl colab
# https://github.com/bloc97/CrossAttentionControl
class CrossAttentionControl:
class Arguments:
def __init__(self, edited_conditioning: torch.Tensor, edit_opcodes: list[tuple], edit_options: dict):
"""
:param edited_conditioning: if doing cross-attention control, the edited conditioning [1 x 77 x 768]
:param edit_opcodes: if doing cross-attention control, a list of difflib.SequenceMatcher-like opcodes describing how to map original conditioning tokens to edited conditioning tokens (only the 'equal' opcode is required)
:param edit_options: if doing cross-attention control, per-edit options. there should be 1 item in edit_options for each item in edit_opcodes.
"""
# todo: rewrite this to take embedding fragments rather than a single edited_conditioning vector
self.edited_conditioning = edited_conditioning
self.edit_opcodes = edit_opcodes
if edited_conditioning is not None:
assert len(edit_opcodes) == len(edit_options), \
"there must be 1 edit_options dict for each edit_opcodes tuple"
non_none_edit_options = [x for x in edit_options if x is not None]
assert len(non_none_edit_options)>0, "missing edit_options"
if len(non_none_edit_options)>1:
print('warning: cross-attention control options are not working properly for >1 edit')
self.edit_options = non_none_edit_options[0]
class Context:
def __init__(self, arguments: 'CrossAttentionControl.Arguments', step_count: int):
self.arguments = arguments
self.step_count = step_count
@classmethod
def remove_cross_attention_control(cls, model):
cls.remove_attention_function(model)
@classmethod
def setup_cross_attention_control(cls, model,
cross_attention_control_args: Arguments
):
"""
Inject attention parameters and functions into the passed in model to enable cross attention editing.
:param model: The unet model to inject into.
:param cross_attention_control_args: Arugments passeed to the CrossAttentionControl implementations
:return: None
"""
# adapted from init_attention_edit
device = cross_attention_control_args.edited_conditioning.device
# urgh. should this be hardcoded?
max_length = 77
# mask=1 means use base prompt attention, mask=0 means use edited prompt attention
mask = torch.zeros(max_length)
indices_target = torch.arange(max_length, dtype=torch.long)
indices = torch.zeros(max_length, dtype=torch.long)
for name, a0, a1, b0, b1 in cross_attention_control_args.edit_opcodes:
if b0 < max_length:
if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0):
# these tokens have not been edited
indices[b0:b1] = indices_target[a0:a1]
mask[b0:b1] = 1
for m in cls.get_attention_modules(model, cls.CrossAttentionType.SELF):
m.last_attn_slice_mask = None
m.last_attn_slice_indices = None
for m in cls.get_attention_modules(model, cls.CrossAttentionType.TOKENS):
m.last_attn_slice_mask = mask.to(device)
m.last_attn_slice_indices = indices.to(device)
cls.inject_attention_function(model)
class CrossAttentionType(Enum):
SELF = 1
TOKENS = 2
@classmethod
def get_active_cross_attention_control_types_for_step(cls, context: 'CrossAttentionControl.Context', step_index:int=None)\
-> list['CrossAttentionControl.CrossAttentionType']:
"""
Should cross-attention control be applied on the given step?
:param step_index: The step index (counts upwards from 0), or None if unknown.
:return: A list of attention types that cross-attention control should be performed for on the given step. May be [].
"""
if step_index is None:
return [cls.CrossAttentionType.SELF, cls.CrossAttentionType.TOKENS]
opts = context.arguments.edit_options
# percent_through will never reach 1.0 (but this is intended)
percent_through = float(step_index)/float(context.step_count)
to_control = []
if opts['s_start'] <= percent_through and percent_through < opts['s_end']:
to_control.append(cls.CrossAttentionType.SELF)
if opts['t_start'] <= percent_through and percent_through < opts['t_end']:
to_control.append(cls.CrossAttentionType.TOKENS)
return to_control
@classmethod
def get_attention_modules(cls, model, which: CrossAttentionType):
which_attn = "attn1" if which is cls.CrossAttentionType.SELF else "attn2"
return [module for name, module in model.named_modules() if
type(module).__name__ == "CrossAttention" and which_attn in name]
@classmethod
def clear_requests(cls, model):
self_attention_modules = cls.get_attention_modules(model, cls.CrossAttentionType.SELF)
tokens_attention_modules = cls.get_attention_modules(model, cls.CrossAttentionType.TOKENS)
for m in self_attention_modules+tokens_attention_modules:
m.save_last_attn_slice = False
m.use_last_attn_slice = False
@classmethod
def request_save_attention_maps(cls, model, cross_attention_type: CrossAttentionType):
modules = cls.get_attention_modules(model, cross_attention_type)
for m in modules:
# clear out the saved slice in case the outermost dim changes
m.last_attn_slice = None
m.save_last_attn_slice = True
@classmethod
def request_apply_saved_attention_maps(cls, model, cross_attention_type: CrossAttentionType):
modules = cls.get_attention_modules(model, cross_attention_type)
for m in modules:
m.use_last_attn_slice = True
@classmethod
def inject_attention_function(cls, unet):
# ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
def attention_slice_wrangler(self, attention_scores, suggested_attention_slice, dim, offset, slice_size):
#print("in wrangler with suggested_attention_slice shape", suggested_attention_slice.shape, "dim", dim)
attn_slice = suggested_attention_slice
if dim is not None:
start = offset
end = start+slice_size
#print(f"in wrangler, sliced dim {dim} {start}-{end}, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}")
#else:
# print(f"in wrangler, whole, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}")
if self.use_last_attn_slice:
this_attn_slice = attn_slice
if self.last_attn_slice_mask is not None:
# indices and mask operate on dim=2, no need to slice
base_attn_slice_full = torch.index_select(self.last_attn_slice, -1, self.last_attn_slice_indices)
base_attn_slice_mask = self.last_attn_slice_mask
if dim is None:
base_attn_slice = base_attn_slice_full
#print("using whole base slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape)
elif dim == 0:
base_attn_slice = base_attn_slice_full[start:end]
#print("using base dim 0 slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape)
elif dim == 1:
base_attn_slice = base_attn_slice_full[:, start:end]
#print("using base dim 1 slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape)
attn_slice = this_attn_slice * (1 - base_attn_slice_mask) + \
base_attn_slice * base_attn_slice_mask
else:
if dim is None:
attn_slice = self.last_attn_slice
#print("took whole slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape)
elif dim == 0:
attn_slice = self.last_attn_slice[start:end]
#print("took dim 0 slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape)
elif dim == 1:
attn_slice = self.last_attn_slice[:, start:end]
#print("took dim 1 slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape)
if self.save_last_attn_slice:
if dim is None:
self.last_attn_slice = attn_slice
elif dim == 0:
# dynamically grow last_attn_slice if needed
if self.last_attn_slice is None:
self.last_attn_slice = attn_slice
#print("no last_attn_slice: shape now", self.last_attn_slice.shape)
elif self.last_attn_slice.shape[0] == start:
self.last_attn_slice = torch.cat([self.last_attn_slice, attn_slice], dim=0)
assert(self.last_attn_slice.shape[0] == end)
#print("last_attn_slice too small, appended dim 0 shape", attn_slice.shape, ", shape now", self.last_attn_slice.shape)
else:
# no need to grow
self.last_attn_slice[start:end] = attn_slice
#print("last_attn_slice shape is fine, setting dim 0 shape", attn_slice.shape, ", shape now", self.last_attn_slice.shape)
elif dim == 1:
# dynamically grow last_attn_slice if needed
if self.last_attn_slice is None:
self.last_attn_slice = attn_slice
elif self.last_attn_slice.shape[1] == start:
self.last_attn_slice = torch.cat([self.last_attn_slice, attn_slice], dim=1)
assert(self.last_attn_slice.shape[1] == end)
else:
# no need to grow
self.last_attn_slice[:, start:end] = attn_slice
if self.use_last_attn_weights and self.last_attn_slice_weights is not None:
if dim is None:
weights = self.last_attn_slice_weights
elif dim == 0:
weights = self.last_attn_slice_weights[start:end]
elif dim == 1:
weights = self.last_attn_slice_weights[:, start:end]
attn_slice = attn_slice * weights
return attn_slice
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == "CrossAttention":
module.last_attn_slice = None
module.last_attn_slice_indices = None
module.last_attn_slice_mask = None
module.use_last_attn_weights = False
module.use_last_attn_slice = False
module.save_last_attn_slice = False
module.set_attention_slice_wrangler(attention_slice_wrangler)
@classmethod
def remove_attention_function(cls, unet):
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == "CrossAttention":
module.set_attention_slice_wrangler(None)

View File

@ -18,7 +18,7 @@ class DDIMSampler(Sampler):
extra_conditioning_info = kwargs.get('extra_conditioning_info', None) extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info) self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = t_enc)
else: else:
self.invokeai_diffuser.remove_cross_attention_control() self.invokeai_diffuser.remove_cross_attention_control()
@ -40,6 +40,7 @@ class DDIMSampler(Sampler):
corrector_kwargs=None, corrector_kwargs=None,
unconditional_guidance_scale=1.0, unconditional_guidance_scale=1.0,
unconditional_conditioning=None, unconditional_conditioning=None,
step_count:int=1000, # total number of steps
**kwargs, **kwargs,
): ):
b, *_, device = *x.shape, x.device b, *_, device = *x.shape, x.device
@ -51,7 +52,11 @@ class DDIMSampler(Sampler):
# damian0815 would like to know when/if this code path is used # damian0815 would like to know when/if this code path is used
e_t = self.model.apply_model(x, t, c) e_t = self.model.apply_model(x, t, c)
else: else:
e_t = self.invokeai_diffuser.do_diffusion_step(x, t, unconditional_conditioning, c, unconditional_guidance_scale) step_index = step_count-(index+1)
e_t = self.invokeai_diffuser.do_diffusion_step(x, t,
unconditional_conditioning, c,
unconditional_guidance_scale,
step_index=step_index)
if score_corrector is not None: if score_corrector is not None:
assert self.model.parameterization == 'eps' assert self.model.parameterization == 'eps'

View File

@ -37,14 +37,14 @@ class CFGDenoiser(nn.Module):
extra_conditioning_info = kwargs.get('extra_conditioning_info', None) extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info) self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = t_enc)
else: else:
self.invokeai_diffuser.remove_cross_attention_control() self.invokeai_diffuser.remove_cross_attention_control()
def forward(self, x, sigma, uncond, cond, cond_scale): def forward(self, x, sigma, uncond, cond, cond_scale, step_index):
next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale) next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale, step_index)
# apply threshold # apply threshold
if self.warmup < self.warmup_max: if self.warmup < self.warmup_max:

View File

@ -23,7 +23,7 @@ class PLMSSampler(Sampler):
extra_conditioning_info = kwargs.get('extra_conditioning_info', None) extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info) self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = t_enc)
else: else:
self.invokeai_diffuser.remove_cross_attention_control() self.invokeai_diffuser.remove_cross_attention_control()
@ -47,6 +47,7 @@ class PLMSSampler(Sampler):
unconditional_conditioning=None, unconditional_conditioning=None,
old_eps=[], old_eps=[],
t_next=None, t_next=None,
step_count:int=1000, # total number of steps
**kwargs, **kwargs,
): ):
b, *_, device = *x.shape, x.device b, *_, device = *x.shape, x.device
@ -59,7 +60,13 @@ class PLMSSampler(Sampler):
# damian0815 would like to know when/if this code path is used # damian0815 would like to know when/if this code path is used
e_t = self.model.apply_model(x, t, c) e_t = self.model.apply_model(x, t, c)
else: else:
e_t = self.invokeai_diffuser.do_diffusion_step(x, t, unconditional_conditioning, c, unconditional_guidance_scale) # step_index is expected to count up while index counts down
step_index = step_count-(index+1)
# note that step_index == 0 is evaluated twice with different x
e_t = self.invokeai_diffuser.do_diffusion_step(x, t,
unconditional_conditioning, c,
unconditional_guidance_scale,
step_index=step_index)
if score_corrector is not None: if score_corrector is not None:
assert self.model.parameterization == 'eps' assert self.model.parameterization == 'eps'

View File

@ -278,6 +278,7 @@ class Sampler(object):
unconditional_conditioning=unconditional_conditioning, unconditional_conditioning=unconditional_conditioning,
old_eps=old_eps, old_eps=old_eps,
t_next=ts_next, t_next=ts_next,
step_count=steps
) )
img, pred_x0, e_t = outs img, pred_x0, e_t = outs

View File

@ -1,9 +1,11 @@
from enum import Enum from enum import Enum
from math import ceil from math import ceil
from typing import Callable from typing import Callable, Optional
import torch import torch
from ldm.models.diffusion.cross_attention_control import CrossAttentionControl
class InvokeAIDiffuserComponent: class InvokeAIDiffuserComponent:
''' '''
@ -16,19 +18,16 @@ class InvokeAIDiffuserComponent:
class ExtraConditioningInfo: class ExtraConditioningInfo:
def __init__(self, edited_conditioning: torch.Tensor = None, edit_opcodes: list[tuple] = None): def __init__(self, cross_attention_control_args: Optional[CrossAttentionControl.Arguments]):
""" self.cross_attention_control_args = cross_attention_control_args
:param edited_conditioning: if doing cross-attention control, the edited conditioning (1 x 77 x 768)
:param edit_opcodes: if doing cross-attention control, opcodes from a SequenceMatcher describing how to map original conditioning tokens to edited conditioning tokens
"""
self.edited_conditioning = edited_conditioning
self.edit_opcodes = edit_opcodes
@property @property
def wants_cross_attention_control(self): def wants_cross_attention_control(self):
return self.edited_conditioning is not None return self.cross_attention_control_args is not None
def __init__(self, model, model_forward_callback: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]): def __init__(self, model, model_forward_callback:
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
):
""" """
:param model: the unet model to pass through to cross attention control :param model: the unet model to pass through to cross attention control
:param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning) :param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning)
@ -37,44 +36,53 @@ class InvokeAIDiffuserComponent:
self.model_forward_callback = model_forward_callback self.model_forward_callback = model_forward_callback
def setup_cross_attention_control(self, conditioning: ExtraConditioningInfo): def setup_cross_attention_control(self, conditioning: ExtraConditioningInfo, step_count: int):
self.conditioning = conditioning self.conditioning = conditioning
CrossAttentionControl.setup_cross_attention_control(self.model, conditioning.edited_conditioning, conditioning.edit_opcodes) self.cross_attention_control_context = CrossAttentionControl.Context(
arguments=self.conditioning.cross_attention_control_args,
step_count=step_count
)
CrossAttentionControl.setup_cross_attention_control(self.model,
cross_attention_control_args=self.conditioning.cross_attention_control_args
)
#todo: refactor edited_conditioning, edit_opcodes, edit_options into a struct
#todo: apply edit_options using step_count
def remove_cross_attention_control(self): def remove_cross_attention_control(self):
self.conditioning = None self.conditioning = None
self.cross_attention_control_context = None
CrossAttentionControl.remove_cross_attention_control(self.model) CrossAttentionControl.remove_cross_attention_control(self.model)
@property
def edited_conditioning(self):
if self.conditioning is None:
return None
else:
return self.conditioning.edited_conditioning
def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor, def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor,
unconditioning: torch.Tensor, conditioning: torch.Tensor, unconditioning: torch.Tensor, conditioning: torch.Tensor,
unconditional_guidance_scale: float): unconditional_guidance_scale: float,
step_index: int=None):
""" """
:param x: Current latents :param x: Current latents
:param sigma: aka t, passed to the internal model to control how much denoising will occur :param sigma: aka t, passed to the internal model to control how much denoising will occur
:param unconditioning: [B x 77 x 768] embeddings for unconditioned output :param unconditioning: [B x 77 x 768] embeddings for unconditioned output
:param conditioning: [B x 77 x 768] embeddings for conditioned output :param conditioning: [B x 77 x 768] embeddings for conditioned output
:param unconditional_guidance_scale: aka CFG scale, controls how much effect the conditioning tensor has :param unconditional_guidance_scale: aka CFG scale, controls how much effect the conditioning tensor has
:param model: the unet model to pass through to cross attention control :param step_index: Counts upwards from 0 to (step_count-1) (as passed to setup_cross_attention_control, if using). May be called multiple times for a single step, therefore do not assume that its value will monotically increase.
:param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning) :return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning.
:return: the new latents after applying the model to x using unconditioning and CFG-scaled conditioning.
""" """
CrossAttentionControl.clear_requests(self.model) CrossAttentionControl.clear_requests(self.model)
cross_attention_control_types_to_do = []
if self.edited_conditioning is None: if self.cross_attention_control_context is not None:
cross_attention_control_types_to_do = CrossAttentionControl.get_active_cross_attention_control_types_for_step(self.cross_attention_control_context, step_index)
if len(cross_attention_control_types_to_do)==0:
print('step', step_index, ': not doing cross attention control')
# faster batched path # faster batched path
x_twice = torch.cat([x]*2) x_twice = torch.cat([x]*2)
sigma_twice = torch.cat([sigma]*2) sigma_twice = torch.cat([sigma]*2)
both_conditionings = torch.cat([unconditioning, conditioning]) both_conditionings = torch.cat([unconditioning, conditioning])
unconditioned_next_x, conditioned_next_x = self.model_forward_callback(x_twice, sigma_twice, both_conditionings).chunk(2) unconditioned_next_x, conditioned_next_x = self.model_forward_callback(x_twice, sigma_twice, both_conditionings).chunk(2)
else: else:
print('step', step_index, ': doing cross attention control on', cross_attention_control_types_to_do)
# slower non-batched path (20% slower on mac MPS) # slower non-batched path (20% slower on mac MPS)
# We are only interested in using attention maps for conditioned_next_x, but batching them with generation of # We are only interested in using attention maps for conditioned_next_x, but batching them with generation of
# unconditioned_next_x causes attention maps to *also* be saved for the unconditioned_next_x. # unconditioned_next_x causes attention maps to *also* be saved for the unconditioned_next_x.
@ -86,13 +94,16 @@ class InvokeAIDiffuserComponent:
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning) unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning)
# process x using the original prompt, saving the attention maps # process x using the original prompt, saving the attention maps
CrossAttentionControl.request_save_attention_maps(self.model) for type in cross_attention_control_types_to_do:
CrossAttentionControl.request_save_attention_maps(self.model, type)
_ = self.model_forward_callback(x, sigma, conditioning) _ = self.model_forward_callback(x, sigma, conditioning)
CrossAttentionControl.clear_requests(self.model) CrossAttentionControl.clear_requests(self.model)
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied # process x again, using the saved attention maps to control where self.edited_conditioning will be applied
CrossAttentionControl.request_apply_saved_attention_maps(self.model) for type in cross_attention_control_types_to_do:
conditioned_next_x = self.model_forward_callback(x, sigma, self.edited_conditioning) CrossAttentionControl.request_apply_saved_attention_maps(self.model, type)
edited_conditioning = self.conditioning.cross_attention_control_args.edited_conditioning
conditioned_next_x = self.model_forward_callback(x, sigma, edited_conditioning)
CrossAttentionControl.clear_requests(self.model) CrossAttentionControl.clear_requests(self.model)
@ -102,7 +113,6 @@ class InvokeAIDiffuserComponent:
return combined_next_x return combined_next_x
# todo: make this work # todo: make this work
@classmethod @classmethod
def apply_conjunction(cls, x, t, forward_func, uc, c_or_weighted_c_list, global_guidance_scale): def apply_conjunction(cls, x, t, forward_func, uc, c_or_weighted_c_list, global_guidance_scale):
@ -153,250 +163,3 @@ class InvokeAIDiffuserComponent:
return uncond_latents + deltas_merged * global_guidance_scale return uncond_latents + deltas_merged * global_guidance_scale
# adapted from bloc97's CrossAttentionControl colab
# https://github.com/bloc97/CrossAttentionControl
class CrossAttentionControl:
@classmethod
def remove_cross_attention_control(cls, model):
cls.remove_attention_function(model)
@classmethod
def setup_cross_attention_control(cls, model,
substitute_conditioning: torch.Tensor,
edit_opcodes: list):
"""
Inject attention parameters and functions into the passed in model to enable cross attention editing.
:param model: The unet model to inject into.
:param substitute_conditioning: The "edited" conditioning vector, [Bx77x768]
:param edit_opcodes: Opcodes from difflib.SequenceMatcher describing how the base
conditionings map to the "edited" conditionings.
:return:
"""
# adapted from init_attention_edit
device = substitute_conditioning.device
# urgh. should this be hardcoded?
max_length = 77
# mask=1 means use base prompt attention, mask=0 means use edited prompt attention
mask = torch.zeros(max_length)
indices_target = torch.arange(max_length, dtype=torch.long)
indices = torch.zeros(max_length, dtype=torch.long)
for name, a0, a1, b0, b1 in edit_opcodes:
if b0 < max_length:
if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0):
# these tokens have not been edited
indices[b0:b1] = indices_target[a0:a1]
mask[b0:b1] = 1
for m in cls.get_attention_modules(model, cls.AttentionType.SELF):
m.last_attn_slice_mask = None
m.last_attn_slice_indices = None
for m in cls.get_attention_modules(model, cls.AttentionType.TOKENS):
m.last_attn_slice_mask = mask.to(device)
m.last_attn_slice_indices = indices.to(device)
cls.inject_attention_function(model)
class AttentionType(Enum):
SELF = 1
TOKENS = 2
@classmethod
def get_attention_modules(cls, model, which: AttentionType):
which_attn = "attn1" if which is cls.AttentionType.SELF else "attn2"
return [module for name, module in model.named_modules() if
type(module).__name__ == "CrossAttention" and which_attn in name]
@classmethod
def clear_requests(cls, model):
self_attention_modules = cls.get_attention_modules(model, cls.AttentionType.SELF)
tokens_attention_modules = cls.get_attention_modules(model, cls.AttentionType.TOKENS)
for m in self_attention_modules+tokens_attention_modules:
m.save_last_attn_slice = False
m.use_last_attn_slice = False
@classmethod
def request_save_attention_maps(cls, model):
self_attention_modules = cls.get_attention_modules(model, cls.AttentionType.SELF)
tokens_attention_modules = cls.get_attention_modules(model, cls.AttentionType.TOKENS)
for m in self_attention_modules+tokens_attention_modules:
# clear out the saved slice in case the outermost dim changes
m.last_attn_slice = None
m.save_last_attn_slice = True
@classmethod
def request_apply_saved_attention_maps(cls, model):
self_attention_modules = cls.get_attention_modules(model, cls.AttentionType.SELF)
tokens_attention_modules = cls.get_attention_modules(model, cls.AttentionType.TOKENS)
for m in self_attention_modules+tokens_attention_modules:
m.use_last_attn_slice = True
@classmethod
def inject_attention_function(cls, unet):
# ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
def attention_slice_wrangler(self, attention_scores, suggested_attention_slice, dim, offset, slice_size):
#print("in wrangler with suggested_attention_slice shape", suggested_attention_slice.shape, "dim", dim)
attn_slice = suggested_attention_slice
if dim is not None:
start = offset
end = start+slice_size
#print(f"in wrangler, sliced dim {dim} {start}-{end}, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}")
#else:
# print(f"in wrangler, whole, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}")
if self.use_last_attn_slice:
this_attn_slice = attn_slice
if self.last_attn_slice_mask is not None:
# indices and mask operate on dim=2, no need to slice
base_attn_slice_full = torch.index_select(self.last_attn_slice, -1, self.last_attn_slice_indices)
base_attn_slice_mask = self.last_attn_slice_mask
if dim is None:
base_attn_slice = base_attn_slice_full
#print("using whole base slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape)
elif dim == 0:
base_attn_slice = base_attn_slice_full[start:end]
#print("using base dim 0 slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape)
elif dim == 1:
base_attn_slice = base_attn_slice_full[:, start:end]
#print("using base dim 1 slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape)
attn_slice = this_attn_slice * (1 - base_attn_slice_mask) + \
base_attn_slice * base_attn_slice_mask
else:
if dim is None:
attn_slice = self.last_attn_slice
#print("took whole slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape)
elif dim == 0:
attn_slice = self.last_attn_slice[start:end]
#print("took dim 0 slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape)
elif dim == 1:
attn_slice = self.last_attn_slice[:, start:end]
#print("took dim 1 slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape)
if self.save_last_attn_slice:
if dim is None:
self.last_attn_slice = attn_slice
elif dim == 0:
# dynamically grow last_attn_slice if needed
if self.last_attn_slice is None:
self.last_attn_slice = attn_slice
#print("no last_attn_slice: shape now", self.last_attn_slice.shape)
elif self.last_attn_slice.shape[0] == start:
self.last_attn_slice = torch.cat([self.last_attn_slice, attn_slice], dim=0)
assert(self.last_attn_slice.shape[0] == end)
#print("last_attn_slice too small, appended dim 0 shape", attn_slice.shape, ", shape now", self.last_attn_slice.shape)
else:
# no need to grow
self.last_attn_slice[start:end] = attn_slice
#print("last_attn_slice shape is fine, setting dim 0 shape", attn_slice.shape, ", shape now", self.last_attn_slice.shape)
elif dim == 1:
# dynamically grow last_attn_slice if needed
if self.last_attn_slice is None:
self.last_attn_slice = attn_slice
elif self.last_attn_slice.shape[1] == start:
self.last_attn_slice = torch.cat([self.last_attn_slice, attn_slice], dim=1)
assert(self.last_attn_slice.shape[1] == end)
else:
# no need to grow
self.last_attn_slice[:, start:end] = attn_slice
if self.use_last_attn_weights and self.last_attn_slice_weights is not None:
if dim is None:
weights = self.last_attn_slice_weights
elif dim == 0:
weights = self.last_attn_slice_weights[start:end]
elif dim == 1:
weights = self.last_attn_slice_weights[:, start:end]
attn_slice = attn_slice * weights
return attn_slice
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == "CrossAttention":
module.last_attn_slice = None
module.last_attn_slice_indices = None
module.last_attn_slice_mask = None
module.use_last_attn_weights = False
module.use_last_attn_slice = False
module.save_last_attn_slice = False
module.set_attention_slice_wrangler(attention_slice_wrangler)
@classmethod
def remove_attention_function(cls, unet):
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == "CrossAttention":
module.set_attention_slice_wrangler(None)
# original code below
# Functions supporting Cross-Attention Control
# Copied from https://github.com/bloc97/CrossAttentionControl
from difflib import SequenceMatcher
import torch
def prompt_token(prompt, index, clip_tokenizer):
tokens = clip_tokenizer(prompt,
padding='max_length',
max_length=clip_tokenizer.model_max_length,
truncation=True,
return_tensors='pt',
return_overflowing_tokens=True
).input_ids[0]
return clip_tokenizer.decode(tokens[index:index + 1])
def use_last_tokens_attention(unet, use=True):
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == 'CrossAttention' and 'attn2' in name:
module.use_last_attn_slice = use
def use_last_tokens_attention_weights(unet, use=True):
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == 'CrossAttention' and 'attn2' in name:
module.use_last_attn_weights = use
def use_last_self_attention(unet, use=True):
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == 'CrossAttention' and 'attn1' in name:
module.use_last_attn_slice = use
def save_last_tokens_attention(unet, save=True):
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == 'CrossAttention' and 'attn2' in name:
module.save_last_attn_slice = save
def save_last_self_attention(unet, save=True):
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == 'CrossAttention' and 'attn1' in name:
module.save_last_attn_slice = save