diff --git a/ldm/generate.py b/ldm/generate.py index f83a732816..39f0b06759 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -400,7 +400,7 @@ class Generate: mask_image = None try: - uc, c, ec, ec_index_map = get_uc_and_c_and_ec( + uc, c, extra_conditioning_info = get_uc_and_c_and_ec( prompt, model =self.model, skip_normalize=skip_normalize, log_tokens =self.log_tokenization @@ -438,7 +438,7 @@ class Generate: sampler=self.sampler, steps=steps, cfg_scale=cfg_scale, - conditioning=(uc, c, ec, ec_index_map), + conditioning=(uc, c, extra_conditioning_info), ddim_eta=ddim_eta, image_callback=image_callback, # called after the final image is generated step_callback=step_callback, # called after each intermediate image is generated @@ -541,8 +541,8 @@ class Generate: image = Image.open(image_path) # used by multiple postfixers - # todo: cross-attention - uc, c, _, _ = get_uc_and_c_and_ec( + # todo: cross-attention control + uc, c, _ = get_uc_and_c_and_ec( prompt, model =self.model, skip_normalize=opt.skip_normalize, log_tokens =opt.log_tokenization diff --git a/ldm/invoke/conditioning.py b/ldm/invoke/conditioning.py index 924ea39c77..52d40312ac 100644 --- a/ldm/invoke/conditioning.py +++ b/ldm/invoke/conditioning.py @@ -17,6 +17,7 @@ import torch from .prompt_parser import PromptParser, Blend, FlattenedPrompt, \ CrossAttentionControlledFragment, CrossAttentionControlSubstitute, CrossAttentionControlAppend, Fragment +from ..models.diffusion.cross_attention_control import CrossAttentionControl from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder @@ -46,8 +47,7 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n print("parsed prompt to", parsed_prompt) conditioning = None - edited_conditioning = None - edit_opcodes = None + cac_args:CrossAttentionControl.Arguments = None if type(parsed_prompt) is Blend: blend: Blend = parsed_prompt @@ -98,21 +98,31 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n original_token_count += count edited_token_count += count original_embeddings, original_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, original_prompt) + # naïvely building a single edited_embeddings like this disregards the effects of changing the absolute location of + # subsequent tokens when there is >1 edit and earlier edits change the total token count. + # eg "a cat.swap(smiling dog, s_start=0.5) eating a hotdog.swap(pizza)" - when the 'pizza' edit is active but the + # 'cat' edit is not, the 'pizza' feature vector will nevertheless be affected by the introduction of the extra + # token 'smiling' in the inactive 'cat' edit. + # todo: build multiple edited_embeddings, one for each edit, and pass just the edited fragments through to the CrossAttentionControl functions edited_embeddings, edited_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, edited_prompt) conditioning = original_embeddings edited_conditioning = edited_embeddings print('got edit_opcodes', edit_opcodes, 'options', edit_options) + cac_args = CrossAttentionControl.Arguments( + edited_conditioning = edited_conditioning, + edit_opcodes = edit_opcodes, + edit_options = edit_options + ) else: conditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt) unconditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, parsed_negative_prompt) return ( - unconditioning, conditioning, edited_conditioning, edit_opcodes - #InvokeAIDiffuserComponent.ExtraConditioningInfo(edited_conditioning=edited_conditioning, - # edit_opcodes=edit_opcodes, - # edit_options=edit_options) + unconditioning, conditioning, InvokeAIDiffuserComponent.ExtraConditioningInfo( + cross_attention_control_args=cac_args + ) ) diff --git a/ldm/invoke/generator/img2img.py b/ldm/invoke/generator/img2img.py index 2f5e6e61d0..4942bcc0c3 100644 --- a/ldm/invoke/generator/img2img.py +++ b/ldm/invoke/generator/img2img.py @@ -33,8 +33,7 @@ class Img2Img(Generator): ) # move to latent space t_enc = int(strength * steps) - uc, c, ec, edit_opcodes = conditioning - extra_conditioning_info = InvokeAIDiffuserComponent.ExtraConditioningInfo(edited_conditioning=ec, edit_opcodes=edit_opcodes) + uc, c, extra_conditioning_info = conditioning def make_image(x_T): # encode (scaled latent) diff --git a/ldm/invoke/generator/inpaint.py b/ldm/invoke/generator/inpaint.py index 8f01b4ad2d..25bbc7e017 100644 --- a/ldm/invoke/generator/inpaint.py +++ b/ldm/invoke/generator/inpaint.py @@ -46,7 +46,7 @@ class Inpaint(Img2Img): t_enc = int(strength * steps) # todo: support cross-attention control - uc, c, _, _ = conditioning + uc, c, _ = conditioning print(f">> target t_enc is {t_enc} steps") diff --git a/ldm/invoke/generator/txt2img.py b/ldm/invoke/generator/txt2img.py index 7e739860c3..696cc06f78 100644 --- a/ldm/invoke/generator/txt2img.py +++ b/ldm/invoke/generator/txt2img.py @@ -21,8 +21,7 @@ class Txt2Img(Generator): kwargs are 'width' and 'height' """ self.perlin = perlin - uc, c, ec, edit_opcodes = conditioning - extra_conditioning_info = InvokeAIDiffuserComponent.ExtraConditioningInfo(edited_conditioning=ec, edit_opcodes=edit_opcodes) + uc, c, extra_conditioning_info = conditioning @torch.no_grad() def make_image(x_T): diff --git a/ldm/invoke/generator/txt2img2img.py b/ldm/invoke/generator/txt2img2img.py index 2d67a44346..5808f7bdb2 100644 --- a/ldm/invoke/generator/txt2img2img.py +++ b/ldm/invoke/generator/txt2img2img.py @@ -23,8 +23,7 @@ class Txt2Img2Img(Generator): Return value depends on the seed at the time you call it kwargs are 'width' and 'height' """ - uc, c, ec, edit_opcodes = conditioning - extra_conditioning_info = InvokeAIDiffuserComponent.ExtraConditioningInfo(edited_conditioning=ec, edit_opcodes=edit_opcodes) + uc, c, extra_conditioing_info = conditioning @torch.no_grad() def make_image(x_T): diff --git a/ldm/models/diffusion/cross_attention_control.py b/ldm/models/diffusion/cross_attention_control.py new file mode 100644 index 0000000000..905803ccfa --- /dev/null +++ b/ldm/models/diffusion/cross_attention_control.py @@ -0,0 +1,236 @@ +from enum import Enum + +import torch + +# adapted from bloc97's CrossAttentionControl colab +# https://github.com/bloc97/CrossAttentionControl + +class CrossAttentionControl: + + class Arguments: + def __init__(self, edited_conditioning: torch.Tensor, edit_opcodes: list[tuple], edit_options: dict): + """ + :param edited_conditioning: if doing cross-attention control, the edited conditioning [1 x 77 x 768] + :param edit_opcodes: if doing cross-attention control, a list of difflib.SequenceMatcher-like opcodes describing how to map original conditioning tokens to edited conditioning tokens (only the 'equal' opcode is required) + :param edit_options: if doing cross-attention control, per-edit options. there should be 1 item in edit_options for each item in edit_opcodes. + """ + # todo: rewrite this to take embedding fragments rather than a single edited_conditioning vector + self.edited_conditioning = edited_conditioning + self.edit_opcodes = edit_opcodes + + if edited_conditioning is not None: + assert len(edit_opcodes) == len(edit_options), \ + "there must be 1 edit_options dict for each edit_opcodes tuple" + non_none_edit_options = [x for x in edit_options if x is not None] + assert len(non_none_edit_options)>0, "missing edit_options" + if len(non_none_edit_options)>1: + print('warning: cross-attention control options are not working properly for >1 edit') + self.edit_options = non_none_edit_options[0] + + class Context: + def __init__(self, arguments: 'CrossAttentionControl.Arguments', step_count: int): + self.arguments = arguments + self.step_count = step_count + + @classmethod + def remove_cross_attention_control(cls, model): + cls.remove_attention_function(model) + + @classmethod + def setup_cross_attention_control(cls, model, + cross_attention_control_args: Arguments + ): + """ + Inject attention parameters and functions into the passed in model to enable cross attention editing. + + :param model: The unet model to inject into. + :param cross_attention_control_args: Arugments passeed to the CrossAttentionControl implementations + :return: None + """ + + # adapted from init_attention_edit + device = cross_attention_control_args.edited_conditioning.device + + # urgh. should this be hardcoded? + max_length = 77 + # mask=1 means use base prompt attention, mask=0 means use edited prompt attention + mask = torch.zeros(max_length) + indices_target = torch.arange(max_length, dtype=torch.long) + indices = torch.zeros(max_length, dtype=torch.long) + for name, a0, a1, b0, b1 in cross_attention_control_args.edit_opcodes: + if b0 < max_length: + if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0): + # these tokens have not been edited + indices[b0:b1] = indices_target[a0:a1] + mask[b0:b1] = 1 + + for m in cls.get_attention_modules(model, cls.CrossAttentionType.SELF): + m.last_attn_slice_mask = None + m.last_attn_slice_indices = None + + for m in cls.get_attention_modules(model, cls.CrossAttentionType.TOKENS): + m.last_attn_slice_mask = mask.to(device) + m.last_attn_slice_indices = indices.to(device) + + cls.inject_attention_function(model) + + + class CrossAttentionType(Enum): + SELF = 1 + TOKENS = 2 + + @classmethod + def get_active_cross_attention_control_types_for_step(cls, context: 'CrossAttentionControl.Context', step_index:int=None)\ + -> list['CrossAttentionControl.CrossAttentionType']: + """ + Should cross-attention control be applied on the given step? + :param step_index: The step index (counts upwards from 0), or None if unknown. + :return: A list of attention types that cross-attention control should be performed for on the given step. May be []. + """ + if step_index is None: + return [cls.CrossAttentionType.SELF, cls.CrossAttentionType.TOKENS] + + opts = context.arguments.edit_options + # percent_through will never reach 1.0 (but this is intended) + percent_through = float(step_index)/float(context.step_count) + to_control = [] + if opts['s_start'] <= percent_through and percent_through < opts['s_end']: + to_control.append(cls.CrossAttentionType.SELF) + if opts['t_start'] <= percent_through and percent_through < opts['t_end']: + to_control.append(cls.CrossAttentionType.TOKENS) + return to_control + + + @classmethod + def get_attention_modules(cls, model, which: CrossAttentionType): + which_attn = "attn1" if which is cls.CrossAttentionType.SELF else "attn2" + return [module for name, module in model.named_modules() if + type(module).__name__ == "CrossAttention" and which_attn in name] + + @classmethod + def clear_requests(cls, model): + self_attention_modules = cls.get_attention_modules(model, cls.CrossAttentionType.SELF) + tokens_attention_modules = cls.get_attention_modules(model, cls.CrossAttentionType.TOKENS) + for m in self_attention_modules+tokens_attention_modules: + m.save_last_attn_slice = False + m.use_last_attn_slice = False + + @classmethod + def request_save_attention_maps(cls, model, cross_attention_type: CrossAttentionType): + modules = cls.get_attention_modules(model, cross_attention_type) + for m in modules: + # clear out the saved slice in case the outermost dim changes + m.last_attn_slice = None + m.save_last_attn_slice = True + + @classmethod + def request_apply_saved_attention_maps(cls, model, cross_attention_type: CrossAttentionType): + modules = cls.get_attention_modules(model, cross_attention_type) + for m in modules: + m.use_last_attn_slice = True + + + + @classmethod + def inject_attention_function(cls, unet): + # ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276 + + def attention_slice_wrangler(self, attention_scores, suggested_attention_slice, dim, offset, slice_size): + + #print("in wrangler with suggested_attention_slice shape", suggested_attention_slice.shape, "dim", dim) + + attn_slice = suggested_attention_slice + if dim is not None: + start = offset + end = start+slice_size + #print(f"in wrangler, sliced dim {dim} {start}-{end}, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}") + #else: + # print(f"in wrangler, whole, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}") + + + if self.use_last_attn_slice: + this_attn_slice = attn_slice + if self.last_attn_slice_mask is not None: + # indices and mask operate on dim=2, no need to slice + base_attn_slice_full = torch.index_select(self.last_attn_slice, -1, self.last_attn_slice_indices) + base_attn_slice_mask = self.last_attn_slice_mask + if dim is None: + base_attn_slice = base_attn_slice_full + #print("using whole base slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape) + elif dim == 0: + base_attn_slice = base_attn_slice_full[start:end] + #print("using base dim 0 slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape) + elif dim == 1: + base_attn_slice = base_attn_slice_full[:, start:end] + #print("using base dim 1 slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape) + + attn_slice = this_attn_slice * (1 - base_attn_slice_mask) + \ + base_attn_slice * base_attn_slice_mask + else: + if dim is None: + attn_slice = self.last_attn_slice + #print("took whole slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape) + elif dim == 0: + attn_slice = self.last_attn_slice[start:end] + #print("took dim 0 slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape) + elif dim == 1: + attn_slice = self.last_attn_slice[:, start:end] + #print("took dim 1 slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape) + + if self.save_last_attn_slice: + if dim is None: + self.last_attn_slice = attn_slice + elif dim == 0: + # dynamically grow last_attn_slice if needed + if self.last_attn_slice is None: + self.last_attn_slice = attn_slice + #print("no last_attn_slice: shape now", self.last_attn_slice.shape) + elif self.last_attn_slice.shape[0] == start: + self.last_attn_slice = torch.cat([self.last_attn_slice, attn_slice], dim=0) + assert(self.last_attn_slice.shape[0] == end) + #print("last_attn_slice too small, appended dim 0 shape", attn_slice.shape, ", shape now", self.last_attn_slice.shape) + else: + # no need to grow + self.last_attn_slice[start:end] = attn_slice + #print("last_attn_slice shape is fine, setting dim 0 shape", attn_slice.shape, ", shape now", self.last_attn_slice.shape) + + elif dim == 1: + # dynamically grow last_attn_slice if needed + if self.last_attn_slice is None: + self.last_attn_slice = attn_slice + elif self.last_attn_slice.shape[1] == start: + self.last_attn_slice = torch.cat([self.last_attn_slice, attn_slice], dim=1) + assert(self.last_attn_slice.shape[1] == end) + else: + # no need to grow + self.last_attn_slice[:, start:end] = attn_slice + + if self.use_last_attn_weights and self.last_attn_slice_weights is not None: + if dim is None: + weights = self.last_attn_slice_weights + elif dim == 0: + weights = self.last_attn_slice_weights[start:end] + elif dim == 1: + weights = self.last_attn_slice_weights[:, start:end] + attn_slice = attn_slice * weights + + return attn_slice + + for name, module in unet.named_modules(): + module_name = type(module).__name__ + if module_name == "CrossAttention": + module.last_attn_slice = None + module.last_attn_slice_indices = None + module.last_attn_slice_mask = None + module.use_last_attn_weights = False + module.use_last_attn_slice = False + module.save_last_attn_slice = False + module.set_attention_slice_wrangler(attention_slice_wrangler) + + @classmethod + def remove_attention_function(cls, unet): + for name, module in unet.named_modules(): + module_name = type(module).__name__ + if module_name == "CrossAttention": + module.set_attention_slice_wrangler(None) + diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py index 71944a9b7e..5b5dfaf4af 100644 --- a/ldm/models/diffusion/ddim.py +++ b/ldm/models/diffusion/ddim.py @@ -18,7 +18,7 @@ class DDIMSampler(Sampler): extra_conditioning_info = kwargs.get('extra_conditioning_info', None) if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: - self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info) + self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = t_enc) else: self.invokeai_diffuser.remove_cross_attention_control() @@ -40,6 +40,7 @@ class DDIMSampler(Sampler): corrector_kwargs=None, unconditional_guidance_scale=1.0, unconditional_conditioning=None, + step_count:int=1000, # total number of steps **kwargs, ): b, *_, device = *x.shape, x.device @@ -51,7 +52,11 @@ class DDIMSampler(Sampler): # damian0815 would like to know when/if this code path is used e_t = self.model.apply_model(x, t, c) else: - e_t = self.invokeai_diffuser.do_diffusion_step(x, t, unconditional_conditioning, c, unconditional_guidance_scale) + step_index = step_count-(index+1) + e_t = self.invokeai_diffuser.do_diffusion_step(x, t, + unconditional_conditioning, c, + unconditional_guidance_scale, + step_index=step_index) if score_corrector is not None: assert self.model.parameterization == 'eps' diff --git a/ldm/models/diffusion/ksampler.py b/ldm/models/diffusion/ksampler.py index 44e418acb1..7bf48c62e8 100644 --- a/ldm/models/diffusion/ksampler.py +++ b/ldm/models/diffusion/ksampler.py @@ -37,14 +37,14 @@ class CFGDenoiser(nn.Module): extra_conditioning_info = kwargs.get('extra_conditioning_info', None) if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: - self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info) + self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = t_enc) else: self.invokeai_diffuser.remove_cross_attention_control() - def forward(self, x, sigma, uncond, cond, cond_scale): + def forward(self, x, sigma, uncond, cond, cond_scale, step_index): - next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale) + next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale, step_index) # apply threshold if self.warmup < self.warmup_max: diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py index f58e2c3220..5b4674f28d 100644 --- a/ldm/models/diffusion/plms.py +++ b/ldm/models/diffusion/plms.py @@ -23,7 +23,7 @@ class PLMSSampler(Sampler): extra_conditioning_info = kwargs.get('extra_conditioning_info', None) if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: - self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info) + self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = t_enc) else: self.invokeai_diffuser.remove_cross_attention_control() @@ -47,6 +47,7 @@ class PLMSSampler(Sampler): unconditional_conditioning=None, old_eps=[], t_next=None, + step_count:int=1000, # total number of steps **kwargs, ): b, *_, device = *x.shape, x.device @@ -59,7 +60,13 @@ class PLMSSampler(Sampler): # damian0815 would like to know when/if this code path is used e_t = self.model.apply_model(x, t, c) else: - e_t = self.invokeai_diffuser.do_diffusion_step(x, t, unconditional_conditioning, c, unconditional_guidance_scale) + # step_index is expected to count up while index counts down + step_index = step_count-(index+1) + # note that step_index == 0 is evaluated twice with different x + e_t = self.invokeai_diffuser.do_diffusion_step(x, t, + unconditional_conditioning, c, + unconditional_guidance_scale, + step_index=step_index) if score_corrector is not None: assert self.model.parameterization == 'eps' diff --git a/ldm/models/diffusion/sampler.py b/ldm/models/diffusion/sampler.py index e33d57fe31..8099997bb3 100644 --- a/ldm/models/diffusion/sampler.py +++ b/ldm/models/diffusion/sampler.py @@ -278,6 +278,7 @@ class Sampler(object): unconditional_conditioning=unconditional_conditioning, old_eps=old_eps, t_next=ts_next, + step_count=steps ) img, pred_x0, e_t = outs diff --git a/ldm/models/diffusion/shared_invokeai_diffusion.py b/ldm/models/diffusion/shared_invokeai_diffusion.py index 290925fc8c..507feacaa9 100644 --- a/ldm/models/diffusion/shared_invokeai_diffusion.py +++ b/ldm/models/diffusion/shared_invokeai_diffusion.py @@ -1,9 +1,11 @@ from enum import Enum from math import ceil -from typing import Callable +from typing import Callable, Optional import torch +from ldm.models.diffusion.cross_attention_control import CrossAttentionControl + class InvokeAIDiffuserComponent: ''' @@ -16,19 +18,16 @@ class InvokeAIDiffuserComponent: class ExtraConditioningInfo: - def __init__(self, edited_conditioning: torch.Tensor = None, edit_opcodes: list[tuple] = None): - """ - :param edited_conditioning: if doing cross-attention control, the edited conditioning (1 x 77 x 768) - :param edit_opcodes: if doing cross-attention control, opcodes from a SequenceMatcher describing how to map original conditioning tokens to edited conditioning tokens - """ - self.edited_conditioning = edited_conditioning - self.edit_opcodes = edit_opcodes + def __init__(self, cross_attention_control_args: Optional[CrossAttentionControl.Arguments]): + self.cross_attention_control_args = cross_attention_control_args @property def wants_cross_attention_control(self): - return self.edited_conditioning is not None + return self.cross_attention_control_args is not None - def __init__(self, model, model_forward_callback: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]): + def __init__(self, model, model_forward_callback: + Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor] + ): """ :param model: the unet model to pass through to cross attention control :param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning) @@ -37,44 +36,53 @@ class InvokeAIDiffuserComponent: self.model_forward_callback = model_forward_callback - def setup_cross_attention_control(self, conditioning: ExtraConditioningInfo): + def setup_cross_attention_control(self, conditioning: ExtraConditioningInfo, step_count: int): self.conditioning = conditioning - CrossAttentionControl.setup_cross_attention_control(self.model, conditioning.edited_conditioning, conditioning.edit_opcodes) + self.cross_attention_control_context = CrossAttentionControl.Context( + arguments=self.conditioning.cross_attention_control_args, + step_count=step_count + ) + CrossAttentionControl.setup_cross_attention_control(self.model, + cross_attention_control_args=self.conditioning.cross_attention_control_args + ) + #todo: refactor edited_conditioning, edit_opcodes, edit_options into a struct + #todo: apply edit_options using step_count + def remove_cross_attention_control(self): self.conditioning = None + self.cross_attention_control_context = None CrossAttentionControl.remove_cross_attention_control(self.model) - @property - def edited_conditioning(self): - if self.conditioning is None: - return None - else: - return self.conditioning.edited_conditioning - def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor, unconditioning: torch.Tensor, conditioning: torch.Tensor, - unconditional_guidance_scale: float): + unconditional_guidance_scale: float, + step_index: int=None): """ :param x: Current latents :param sigma: aka t, passed to the internal model to control how much denoising will occur :param unconditioning: [B x 77 x 768] embeddings for unconditioned output :param conditioning: [B x 77 x 768] embeddings for conditioned output :param unconditional_guidance_scale: aka CFG scale, controls how much effect the conditioning tensor has - :param model: the unet model to pass through to cross attention control - :param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning) - :return: the new latents after applying the model to x using unconditioning and CFG-scaled conditioning. + :param step_index: Counts upwards from 0 to (step_count-1) (as passed to setup_cross_attention_control, if using). May be called multiple times for a single step, therefore do not assume that its value will monotically increase. + :return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning. """ CrossAttentionControl.clear_requests(self.model) + cross_attention_control_types_to_do = [] - if self.edited_conditioning is None: + if self.cross_attention_control_context is not None: + cross_attention_control_types_to_do = CrossAttentionControl.get_active_cross_attention_control_types_for_step(self.cross_attention_control_context, step_index) + + if len(cross_attention_control_types_to_do)==0: + print('step', step_index, ': not doing cross attention control') # faster batched path x_twice = torch.cat([x]*2) sigma_twice = torch.cat([sigma]*2) both_conditionings = torch.cat([unconditioning, conditioning]) unconditioned_next_x, conditioned_next_x = self.model_forward_callback(x_twice, sigma_twice, both_conditionings).chunk(2) else: + print('step', step_index, ': doing cross attention control on', cross_attention_control_types_to_do) # slower non-batched path (20% slower on mac MPS) # We are only interested in using attention maps for conditioned_next_x, but batching them with generation of # unconditioned_next_x causes attention maps to *also* be saved for the unconditioned_next_x. @@ -86,13 +94,16 @@ class InvokeAIDiffuserComponent: unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning) # process x using the original prompt, saving the attention maps - CrossAttentionControl.request_save_attention_maps(self.model) + for type in cross_attention_control_types_to_do: + CrossAttentionControl.request_save_attention_maps(self.model, type) _ = self.model_forward_callback(x, sigma, conditioning) CrossAttentionControl.clear_requests(self.model) # process x again, using the saved attention maps to control where self.edited_conditioning will be applied - CrossAttentionControl.request_apply_saved_attention_maps(self.model) - conditioned_next_x = self.model_forward_callback(x, sigma, self.edited_conditioning) + for type in cross_attention_control_types_to_do: + CrossAttentionControl.request_apply_saved_attention_maps(self.model, type) + edited_conditioning = self.conditioning.cross_attention_control_args.edited_conditioning + conditioned_next_x = self.model_forward_callback(x, sigma, edited_conditioning) CrossAttentionControl.clear_requests(self.model) @@ -102,7 +113,6 @@ class InvokeAIDiffuserComponent: return combined_next_x - # todo: make this work @classmethod def apply_conjunction(cls, x, t, forward_func, uc, c_or_weighted_c_list, global_guidance_scale): @@ -153,250 +163,3 @@ class InvokeAIDiffuserComponent: return uncond_latents + deltas_merged * global_guidance_scale - -# adapted from bloc97's CrossAttentionControl colab -# https://github.com/bloc97/CrossAttentionControl - -class CrossAttentionControl: - - - @classmethod - def remove_cross_attention_control(cls, model): - cls.remove_attention_function(model) - - @classmethod - def setup_cross_attention_control(cls, model, - substitute_conditioning: torch.Tensor, - edit_opcodes: list): - """ - Inject attention parameters and functions into the passed in model to enable cross attention editing. - - :param model: The unet model to inject into. - :param substitute_conditioning: The "edited" conditioning vector, [Bx77x768] - :param edit_opcodes: Opcodes from difflib.SequenceMatcher describing how the base - conditionings map to the "edited" conditionings. - :return: - """ - - # adapted from init_attention_edit - device = substitute_conditioning.device - - # urgh. should this be hardcoded? - max_length = 77 - # mask=1 means use base prompt attention, mask=0 means use edited prompt attention - mask = torch.zeros(max_length) - indices_target = torch.arange(max_length, dtype=torch.long) - indices = torch.zeros(max_length, dtype=torch.long) - for name, a0, a1, b0, b1 in edit_opcodes: - if b0 < max_length: - if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0): - # these tokens have not been edited - indices[b0:b1] = indices_target[a0:a1] - mask[b0:b1] = 1 - - for m in cls.get_attention_modules(model, cls.AttentionType.SELF): - m.last_attn_slice_mask = None - m.last_attn_slice_indices = None - - for m in cls.get_attention_modules(model, cls.AttentionType.TOKENS): - m.last_attn_slice_mask = mask.to(device) - m.last_attn_slice_indices = indices.to(device) - - cls.inject_attention_function(model) - - - class AttentionType(Enum): - SELF = 1 - TOKENS = 2 - - - @classmethod - def get_attention_modules(cls, model, which: AttentionType): - which_attn = "attn1" if which is cls.AttentionType.SELF else "attn2" - return [module for name, module in model.named_modules() if - type(module).__name__ == "CrossAttention" and which_attn in name] - - @classmethod - def clear_requests(cls, model): - self_attention_modules = cls.get_attention_modules(model, cls.AttentionType.SELF) - tokens_attention_modules = cls.get_attention_modules(model, cls.AttentionType.TOKENS) - for m in self_attention_modules+tokens_attention_modules: - m.save_last_attn_slice = False - m.use_last_attn_slice = False - - @classmethod - def request_save_attention_maps(cls, model): - self_attention_modules = cls.get_attention_modules(model, cls.AttentionType.SELF) - tokens_attention_modules = cls.get_attention_modules(model, cls.AttentionType.TOKENS) - for m in self_attention_modules+tokens_attention_modules: - # clear out the saved slice in case the outermost dim changes - m.last_attn_slice = None - m.save_last_attn_slice = True - - @classmethod - def request_apply_saved_attention_maps(cls, model): - self_attention_modules = cls.get_attention_modules(model, cls.AttentionType.SELF) - tokens_attention_modules = cls.get_attention_modules(model, cls.AttentionType.TOKENS) - for m in self_attention_modules+tokens_attention_modules: - m.use_last_attn_slice = True - - - - @classmethod - def inject_attention_function(cls, unet): - # ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276 - - def attention_slice_wrangler(self, attention_scores, suggested_attention_slice, dim, offset, slice_size): - - #print("in wrangler with suggested_attention_slice shape", suggested_attention_slice.shape, "dim", dim) - - attn_slice = suggested_attention_slice - if dim is not None: - start = offset - end = start+slice_size - #print(f"in wrangler, sliced dim {dim} {start}-{end}, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}") - #else: - # print(f"in wrangler, whole, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}") - - - if self.use_last_attn_slice: - this_attn_slice = attn_slice - if self.last_attn_slice_mask is not None: - # indices and mask operate on dim=2, no need to slice - base_attn_slice_full = torch.index_select(self.last_attn_slice, -1, self.last_attn_slice_indices) - base_attn_slice_mask = self.last_attn_slice_mask - if dim is None: - base_attn_slice = base_attn_slice_full - #print("using whole base slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape) - elif dim == 0: - base_attn_slice = base_attn_slice_full[start:end] - #print("using base dim 0 slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape) - elif dim == 1: - base_attn_slice = base_attn_slice_full[:, start:end] - #print("using base dim 1 slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape) - - attn_slice = this_attn_slice * (1 - base_attn_slice_mask) + \ - base_attn_slice * base_attn_slice_mask - else: - if dim is None: - attn_slice = self.last_attn_slice - #print("took whole slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape) - elif dim == 0: - attn_slice = self.last_attn_slice[start:end] - #print("took dim 0 slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape) - elif dim == 1: - attn_slice = self.last_attn_slice[:, start:end] - #print("took dim 1 slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape) - - if self.save_last_attn_slice: - if dim is None: - self.last_attn_slice = attn_slice - elif dim == 0: - # dynamically grow last_attn_slice if needed - if self.last_attn_slice is None: - self.last_attn_slice = attn_slice - #print("no last_attn_slice: shape now", self.last_attn_slice.shape) - elif self.last_attn_slice.shape[0] == start: - self.last_attn_slice = torch.cat([self.last_attn_slice, attn_slice], dim=0) - assert(self.last_attn_slice.shape[0] == end) - #print("last_attn_slice too small, appended dim 0 shape", attn_slice.shape, ", shape now", self.last_attn_slice.shape) - else: - # no need to grow - self.last_attn_slice[start:end] = attn_slice - #print("last_attn_slice shape is fine, setting dim 0 shape", attn_slice.shape, ", shape now", self.last_attn_slice.shape) - - elif dim == 1: - # dynamically grow last_attn_slice if needed - if self.last_attn_slice is None: - self.last_attn_slice = attn_slice - elif self.last_attn_slice.shape[1] == start: - self.last_attn_slice = torch.cat([self.last_attn_slice, attn_slice], dim=1) - assert(self.last_attn_slice.shape[1] == end) - else: - # no need to grow - self.last_attn_slice[:, start:end] = attn_slice - - if self.use_last_attn_weights and self.last_attn_slice_weights is not None: - if dim is None: - weights = self.last_attn_slice_weights - elif dim == 0: - weights = self.last_attn_slice_weights[start:end] - elif dim == 1: - weights = self.last_attn_slice_weights[:, start:end] - attn_slice = attn_slice * weights - - return attn_slice - - for name, module in unet.named_modules(): - module_name = type(module).__name__ - if module_name == "CrossAttention": - module.last_attn_slice = None - module.last_attn_slice_indices = None - module.last_attn_slice_mask = None - module.use_last_attn_weights = False - module.use_last_attn_slice = False - module.save_last_attn_slice = False - module.set_attention_slice_wrangler(attention_slice_wrangler) - - @classmethod - def remove_attention_function(cls, unet): - for name, module in unet.named_modules(): - module_name = type(module).__name__ - if module_name == "CrossAttention": - module.set_attention_slice_wrangler(None) - - -# original code below - -# Functions supporting Cross-Attention Control -# Copied from https://github.com/bloc97/CrossAttentionControl - -from difflib import SequenceMatcher - -import torch - - -def prompt_token(prompt, index, clip_tokenizer): - tokens = clip_tokenizer(prompt, - padding='max_length', - max_length=clip_tokenizer.model_max_length, - truncation=True, - return_tensors='pt', - return_overflowing_tokens=True - ).input_ids[0] - return clip_tokenizer.decode(tokens[index:index + 1]) - - -def use_last_tokens_attention(unet, use=True): - for name, module in unet.named_modules(): - module_name = type(module).__name__ - if module_name == 'CrossAttention' and 'attn2' in name: - module.use_last_attn_slice = use - - -def use_last_tokens_attention_weights(unet, use=True): - for name, module in unet.named_modules(): - module_name = type(module).__name__ - if module_name == 'CrossAttention' and 'attn2' in name: - module.use_last_attn_weights = use - - -def use_last_self_attention(unet, use=True): - for name, module in unet.named_modules(): - module_name = type(module).__name__ - if module_name == 'CrossAttention' and 'attn1' in name: - module.use_last_attn_slice = use - - -def save_last_tokens_attention(unet, save=True): - for name, module in unet.named_modules(): - module_name = type(module).__name__ - if module_name == 'CrossAttention' and 'attn2' in name: - module.save_last_attn_slice = save - - -def save_last_self_attention(unet, save=True): - for name, module in unet.named_modules(): - module_name = type(module).__name__ - if module_name == 'CrossAttention' and 'attn1' in name: - module.save_last_attn_slice = save