From 853c6af6237895c0d958b4c5c2c280d6ca7bbee2 Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Sat, 12 Nov 2022 11:01:10 -0800 Subject: [PATCH] refactor(cross_attention_control): remove outer CrossAttentionControl class Python has modules. We don't need to use a class to provide a namespace. --- ldm/invoke/conditioning.py | 6 +- .../diffusion/cross_attention_control.py | 471 +++++++++--------- .../diffusion/shared_invokeai_diffusion.py | 15 +- 3 files changed, 246 insertions(+), 246 deletions(-) diff --git a/ldm/invoke/conditioning.py b/ldm/invoke/conditioning.py index 04fbd7c10a..61553d6e2b 100644 --- a/ldm/invoke/conditioning.py +++ b/ldm/invoke/conditioning.py @@ -14,7 +14,7 @@ import torch from .prompt_parser import PromptParser, Blend, FlattenedPrompt, \ CrossAttentionControlledFragment, CrossAttentionControlSubstitute, Fragment, log_tokenization -from ..models.diffusion.cross_attention_control import CrossAttentionControl +from ..models.diffusion import cross_attention_control from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder @@ -50,7 +50,7 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n print(f">> Parsed prompt to {parsed_prompt}") conditioning = None - cac_args:CrossAttentionControl.Arguments = None + cac_args:cross_attention_control.Arguments = None if type(parsed_prompt) is Blend: blend: Blend = parsed_prompt @@ -121,7 +121,7 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n conditioning = original_embeddings edited_conditioning = edited_embeddings #print('>> got edit_opcodes', edit_opcodes, 'options', edit_options) - cac_args = CrossAttentionControl.Arguments( + cac_args = cross_attention_control.Arguments( edited_conditioning = edited_conditioning, edit_opcodes = edit_opcodes, edit_options = edit_options diff --git a/ldm/models/diffusion/cross_attention_control.py b/ldm/models/diffusion/cross_attention_control.py index ff90a24856..6fb8548338 100644 --- a/ldm/models/diffusion/cross_attention_control.py +++ b/ldm/models/diffusion/cross_attention_control.py @@ -8,255 +8,254 @@ import torch -class CrossAttentionControl: - - class Arguments: - def __init__(self, edited_conditioning: torch.Tensor, edit_opcodes: list[tuple], edit_options: dict): - """ - :param edited_conditioning: if doing cross-attention control, the edited conditioning [1 x 77 x 768] - :param edit_opcodes: if doing cross-attention control, a list of difflib.SequenceMatcher-like opcodes describing how to map original conditioning tokens to edited conditioning tokens (only the 'equal' opcode is required) - :param edit_options: if doing cross-attention control, per-edit options. there should be 1 item in edit_options for each item in edit_opcodes. - """ - # todo: rewrite this to take embedding fragments rather than a single edited_conditioning vector - self.edited_conditioning = edited_conditioning - self.edit_opcodes = edit_opcodes - - if edited_conditioning is not None: - assert len(edit_opcodes) == len(edit_options), \ - "there must be 1 edit_options dict for each edit_opcodes tuple" - non_none_edit_options = [x for x in edit_options if x is not None] - assert len(non_none_edit_options)>0, "missing edit_options" - if len(non_none_edit_options)>1: - print('warning: cross-attention control options are not working properly for >1 edit') - self.edit_options = non_none_edit_options[0] - class Context: +class Arguments: + def __init__(self, edited_conditioning: torch.Tensor, edit_opcodes: list[tuple], edit_options: dict): + """ + :param edited_conditioning: if doing cross-attention control, the edited conditioning [1 x 77 x 768] + :param edit_opcodes: if doing cross-attention control, a list of difflib.SequenceMatcher-like opcodes describing how to map original conditioning tokens to edited conditioning tokens (only the 'equal' opcode is required) + :param edit_options: if doing cross-attention control, per-edit options. there should be 1 item in edit_options for each item in edit_opcodes. + """ + # todo: rewrite this to take embedding fragments rather than a single edited_conditioning vector + self.edited_conditioning = edited_conditioning + self.edit_opcodes = edit_opcodes - class Action(enum.Enum): - NONE = 0 - SAVE = 1, - APPLY = 2 + if edited_conditioning is not None: + assert len(edit_opcodes) == len(edit_options), \ + "there must be 1 edit_options dict for each edit_opcodes tuple" + non_none_edit_options = [x for x in edit_options if x is not None] + assert len(non_none_edit_options)>0, "missing edit_options" + if len(non_none_edit_options)>1: + print('warning: cross-attention control options are not working properly for >1 edit') + self.edit_options = non_none_edit_options[0] - def __init__(self, arguments: 'CrossAttentionControl.Arguments', step_count: int): - """ - :param arguments: Arguments for the cross-attention control process - :param step_count: The absolute total number of steps of diffusion (for img2img this is likely larger than the number of steps that will actually run) - """ - self.arguments = arguments - self.step_count = step_count - self.self_cross_attention_module_identifiers = [] - self.tokens_cross_attention_module_identifiers = [] +class Context: + class Action(enum.Enum): + NONE = 0 + SAVE = 1, + APPLY = 2 + + def __init__(self, arguments: Arguments, step_count: int): + """ + :param arguments: Arguments for the cross-attention control process + :param step_count: The absolute total number of steps of diffusion (for img2img this is likely larger than the number of steps that will actually run) + """ + self.arguments = arguments + self.step_count = step_count + + self.self_cross_attention_module_identifiers = [] + self.tokens_cross_attention_module_identifiers = [] + + self.saved_cross_attention_maps = {} + + self.clear_requests(cleanup=True) + + def register_cross_attention_modules(self, model): + for name,module in get_attention_modules(model, + CrossAttentionType.SELF): + self.self_cross_attention_module_identifiers.append(name) + for name,module in get_attention_modules(model, + CrossAttentionType.TOKENS): + self.tokens_cross_attention_module_identifiers.append(name) + + def request_save_attention_maps(self, cross_attention_type: 'CrossAttentionType'): + if cross_attention_type == CrossAttentionType.SELF: + self.self_cross_attention_action = Context.Action.SAVE + else: + self.tokens_cross_attention_action = Context.Action.SAVE + + def request_apply_saved_attention_maps(self, cross_attention_type: 'CrossAttentionType'): + if cross_attention_type == CrossAttentionType.SELF: + self.self_cross_attention_action = Context.Action.APPLY + else: + self.tokens_cross_attention_action = Context.Action.APPLY + + def is_tokens_cross_attention(self, module_identifier) -> bool: + return module_identifier in self.tokens_cross_attention_module_identifiers + + def get_should_save_maps(self, module_identifier: str) -> bool: + if module_identifier in self.self_cross_attention_module_identifiers: + return self.self_cross_attention_action == Context.Action.SAVE + elif module_identifier in self.tokens_cross_attention_module_identifiers: + return self.tokens_cross_attention_action == Context.Action.SAVE + return False + + def get_should_apply_saved_maps(self, module_identifier: str) -> bool: + if module_identifier in self.self_cross_attention_module_identifiers: + return self.self_cross_attention_action == Context.Action.APPLY + elif module_identifier in self.tokens_cross_attention_module_identifiers: + return self.tokens_cross_attention_action == Context.Action.APPLY + return False + + def get_active_cross_attention_control_types_for_step(self, percent_through:float=None)\ + -> list['CrossAttentionType']: + """ + Should cross-attention control be applied on the given step? + :param percent_through: How far through the step sequence are we (0.0=pure noise, 1.0=completely denoised image). Expected range 0.0..<1.0. + :return: A list of attention types that cross-attention control should be performed for on the given step. May be []. + """ + if percent_through is None: + return [CrossAttentionType.SELF, CrossAttentionType.TOKENS] + + opts = self.arguments.edit_options + to_control = [] + if opts['s_start'] <= percent_through and percent_through < opts['s_end']: + to_control.append(CrossAttentionType.SELF) + if opts['t_start'] <= percent_through and percent_through < opts['t_end']: + to_control.append(CrossAttentionType.TOKENS) + return to_control + + def save_slice(self, identifier: str, slice: torch.Tensor, dim: Optional[int], offset: int, + slice_size: Optional[int]): + if identifier not in self.saved_cross_attention_maps: + self.saved_cross_attention_maps[identifier] = { + 'dim': dim, + 'slice_size': slice_size, + 'slices': {offset or 0: slice} + } + else: + self.saved_cross_attention_maps[identifier]['slices'][offset or 0] = slice + + def get_slice(self, identifier: str, requested_dim: Optional[int], requested_offset: int, slice_size: int): + saved_attention_dict = self.saved_cross_attention_maps[identifier] + if requested_dim is None: + if saved_attention_dict['dim'] is not None: + raise RuntimeError(f"dim mismatch: expected dim=None, have {saved_attention_dict['dim']}") + return saved_attention_dict['slices'][0] + + if saved_attention_dict['dim'] == requested_dim: + if slice_size != saved_attention_dict['slice_size']: + raise RuntimeError( + f"slice_size mismatch: expected slice_size={slice_size}, have {saved_attention_dict['slice_size']}") + return saved_attention_dict['slices'][requested_offset] + + if saved_attention_dict['dim'] == None: + whole_saved_attention = saved_attention_dict['slices'][0] + if requested_dim == 0: + return whole_saved_attention[requested_offset:requested_offset + slice_size] + elif requested_dim == 1: + return whole_saved_attention[:, requested_offset:requested_offset + slice_size] + + raise RuntimeError(f"Cannot convert dim {saved_attention_dict['dim']} to requested dim {requested_dim}") + + def get_slicing_strategy(self, identifier: str) -> Optional[tuple[int, int]]: + saved_attention = self.saved_cross_attention_maps.get(identifier, None) + if saved_attention is None: + return None, None + return saved_attention['dim'], saved_attention['slice_size'] + + def clear_requests(self, cleanup=True): + self.tokens_cross_attention_action = Context.Action.NONE + self.self_cross_attention_action = Context.Action.NONE + if cleanup: self.saved_cross_attention_maps = {} - self.clear_requests(cleanup=True) + def offload_saved_attention_slices_to_cpu(self): + for key, map_dict in self.saved_cross_attention_maps.items(): + for offset, slice in map_dict['slices'].items(): + map_dict[offset] = slice.to('cpu') - def register_cross_attention_modules(self, model): - for name,module in CrossAttentionControl.get_attention_modules(model, - CrossAttentionControl.CrossAttentionType.SELF): - self.self_cross_attention_module_identifiers.append(name) - for name,module in CrossAttentionControl.get_attention_modules(model, - CrossAttentionControl.CrossAttentionType.TOKENS): - self.tokens_cross_attention_module_identifiers.append(name) - def request_save_attention_maps(self, cross_attention_type: 'CrossAttentionControl.CrossAttentionType'): - if cross_attention_type == CrossAttentionControl.CrossAttentionType.SELF: - self.self_cross_attention_action = CrossAttentionControl.Context.Action.SAVE +def remove_cross_attention_control(model): + remove_attention_function(model) + + +def setup_cross_attention_control(model, context: Context): + """ + Inject attention parameters and functions into the passed in model to enable cross attention editing. + + :param model: The unet model to inject into. + :param cross_attention_control_args: Arugments passeed to the CrossAttentionControl implementations + :return: None + """ + + # adapted from init_attention_edit + device = context.arguments.edited_conditioning.device + + # urgh. should this be hardcoded? + max_length = 77 + # mask=1 means use base prompt attention, mask=0 means use edited prompt attention + mask = torch.zeros(max_length) + indices_target = torch.arange(max_length, dtype=torch.long) + indices = torch.zeros(max_length, dtype=torch.long) + for name, a0, a1, b0, b1 in context.arguments.edit_opcodes: + if b0 < max_length: + if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0): + # these tokens have not been edited + indices[b0:b1] = indices_target[a0:a1] + mask[b0:b1] = 1 + + context.register_cross_attention_modules(model) + context.cross_attention_mask = mask.to(device) + context.cross_attention_index_map = indices.to(device) + inject_attention_function(model, context) + + +class CrossAttentionType(enum.Enum): + SELF = 1 + TOKENS = 2 + + +def get_attention_modules(model, which: CrossAttentionType): + which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2" + return [(name,module) for name, module in model.named_modules() if + type(module).__name__ == "CrossAttention" and which_attn in name] + + +def inject_attention_function(unet, context: Context): + # ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276 + + def attention_slice_wrangler(module, suggested_attention_slice:torch.Tensor, dim, offset, slice_size): + + #memory_usage = suggested_attention_slice.element_size() * suggested_attention_slice.nelement() + + attention_slice = suggested_attention_slice + + if context.get_should_save_maps(module.identifier): + #print(module.identifier, "saving suggested_attention_slice of shape", + # suggested_attention_slice.shape, "dim", dim, "offset", offset) + slice_to_save = attention_slice.to('cpu') if dim is not None else attention_slice + context.save_slice(module.identifier, slice_to_save, dim=dim, offset=offset, slice_size=slice_size) + elif context.get_should_apply_saved_maps(module.identifier): + #print(module.identifier, "applying saved attention slice for dim", dim, "offset", offset) + saved_attention_slice = context.get_slice(module.identifier, dim, offset, slice_size) + + # slice may have been offloaded to CPU + saved_attention_slice = saved_attention_slice.to(suggested_attention_slice.device) + + if context.is_tokens_cross_attention(module.identifier): + index_map = context.cross_attention_index_map + remapped_saved_attention_slice = torch.index_select(saved_attention_slice, -1, index_map) + this_attention_slice = suggested_attention_slice + + mask = context.cross_attention_mask + saved_mask = mask + this_mask = 1 - mask + attention_slice = remapped_saved_attention_slice * saved_mask + \ + this_attention_slice * this_mask else: - self.tokens_cross_attention_action = CrossAttentionControl.Context.Action.SAVE + # just use everything + attention_slice = saved_attention_slice - def request_apply_saved_attention_maps(self, cross_attention_type: 'CrossAttentionControl.CrossAttentionType'): - if cross_attention_type == CrossAttentionControl.CrossAttentionType.SELF: - self.self_cross_attention_action = CrossAttentionControl.Context.Action.APPLY - else: - self.tokens_cross_attention_action = CrossAttentionControl.Context.Action.APPLY + return attention_slice - def is_tokens_cross_attention(self, module_identifier) -> bool: - return module_identifier in self.tokens_cross_attention_module_identifiers - - def get_should_save_maps(self, module_identifier: str) -> bool: - if module_identifier in self.self_cross_attention_module_identifiers: - return self.self_cross_attention_action == CrossAttentionControl.Context.Action.SAVE - elif module_identifier in self.tokens_cross_attention_module_identifiers: - return self.tokens_cross_attention_action == CrossAttentionControl.Context.Action.SAVE - return False - - def get_should_apply_saved_maps(self, module_identifier: str) -> bool: - if module_identifier in self.self_cross_attention_module_identifiers: - return self.self_cross_attention_action == CrossAttentionControl.Context.Action.APPLY - elif module_identifier in self.tokens_cross_attention_module_identifiers: - return self.tokens_cross_attention_action == CrossAttentionControl.Context.Action.APPLY - return False - - def get_active_cross_attention_control_types_for_step(self, percent_through:float=None)\ - -> list['CrossAttentionControl.CrossAttentionType']: - """ - Should cross-attention control be applied on the given step? - :param percent_through: How far through the step sequence are we (0.0=pure noise, 1.0=completely denoised image). Expected range 0.0..<1.0. - :return: A list of attention types that cross-attention control should be performed for on the given step. May be []. - """ - if percent_through is None: - return [CrossAttentionControl.CrossAttentionType.SELF, CrossAttentionControl.CrossAttentionType.TOKENS] - - opts = self.arguments.edit_options - to_control = [] - if opts['s_start'] <= percent_through and percent_through < opts['s_end']: - to_control.append(CrossAttentionControl.CrossAttentionType.SELF) - if opts['t_start'] <= percent_through and percent_through < opts['t_end']: - to_control.append(CrossAttentionControl.CrossAttentionType.TOKENS) - return to_control - - def save_slice(self, identifier: str, slice: torch.Tensor, dim: Optional[int], offset: int, - slice_size: Optional[int]): - if identifier not in self.saved_cross_attention_maps: - self.saved_cross_attention_maps[identifier] = { - 'dim': dim, - 'slice_size': slice_size, - 'slices': {offset or 0: slice} - } - else: - self.saved_cross_attention_maps[identifier]['slices'][offset or 0] = slice - - def get_slice(self, identifier: str, requested_dim: Optional[int], requested_offset: int, slice_size: int): - saved_attention_dict = self.saved_cross_attention_maps[identifier] - if requested_dim is None: - if saved_attention_dict['dim'] is not None: - raise RuntimeError(f"dim mismatch: expected dim=None, have {saved_attention_dict['dim']}") - return saved_attention_dict['slices'][0] - - if saved_attention_dict['dim'] == requested_dim: - if slice_size != saved_attention_dict['slice_size']: - raise RuntimeError( - f"slice_size mismatch: expected slice_size={slice_size}, have {saved_attention_dict['slice_size']}") - return saved_attention_dict['slices'][requested_offset] - - if saved_attention_dict['dim'] == None: - whole_saved_attention = saved_attention_dict['slices'][0] - if requested_dim == 0: - return whole_saved_attention[requested_offset:requested_offset + slice_size] - elif requested_dim == 1: - return whole_saved_attention[:, requested_offset:requested_offset + slice_size] - - raise RuntimeError(f"Cannot convert dim {saved_attention_dict['dim']} to requested dim {requested_dim}") - - def get_slicing_strategy(self, identifier: str) -> Optional[tuple[int, int]]: - saved_attention = self.saved_cross_attention_maps.get(identifier, None) - if saved_attention is None: - return None, None - return saved_attention['dim'], saved_attention['slice_size'] - - def clear_requests(self, cleanup=True): - self.tokens_cross_attention_action = CrossAttentionControl.Context.Action.NONE - self.self_cross_attention_action = CrossAttentionControl.Context.Action.NONE - if cleanup: - self.saved_cross_attention_maps = {} - - def offload_saved_attention_slices_to_cpu(self): - for key, map_dict in self.saved_cross_attention_maps.items(): - for offset, slice in map_dict['slices'].items(): - map_dict[offset] = slice.to('cpu') - - @classmethod - def remove_cross_attention_control(cls, model): - cls.remove_attention_function(model) - - @classmethod - def setup_cross_attention_control(cls, model, context: Context): - """ - Inject attention parameters and functions into the passed in model to enable cross attention editing. - - :param model: The unet model to inject into. - :param cross_attention_control_args: Arugments passeed to the CrossAttentionControl implementations - :return: None - """ - - # adapted from init_attention_edit - device = context.arguments.edited_conditioning.device - - # urgh. should this be hardcoded? - max_length = 77 - # mask=1 means use base prompt attention, mask=0 means use edited prompt attention - mask = torch.zeros(max_length) - indices_target = torch.arange(max_length, dtype=torch.long) - indices = torch.zeros(max_length, dtype=torch.long) - for name, a0, a1, b0, b1 in context.arguments.edit_opcodes: - if b0 < max_length: - if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0): - # these tokens have not been edited - indices[b0:b1] = indices_target[a0:a1] - mask[b0:b1] = 1 - - context.register_cross_attention_modules(model) - context.cross_attention_mask = mask.to(device) - context.cross_attention_index_map = indices.to(device) - cls.inject_attention_function(model, context) + for name, module in unet.named_modules(): + module_name = type(module).__name__ + if module_name == "CrossAttention": + module.identifier = name + module.set_attention_slice_wrangler(attention_slice_wrangler) + module.set_slicing_strategy_getter(lambda module, module_identifier=name: \ + context.get_slicing_strategy(module_identifier)) - class CrossAttentionType(enum.Enum): - SELF = 1 - TOKENS = 2 - - @classmethod - def get_attention_modules(cls, model, which: CrossAttentionType): - which_attn = "attn1" if which is cls.CrossAttentionType.SELF else "attn2" - return [(name,module) for name, module in model.named_modules() if - type(module).__name__ == "CrossAttention" and which_attn in name] - - - @classmethod - def inject_attention_function(cls, unet, context: 'CrossAttentionControl.Context'): - # ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276 - - def attention_slice_wrangler(module, suggested_attention_slice:torch.Tensor, dim, offset, slice_size): - - #memory_usage = suggested_attention_slice.element_size() * suggested_attention_slice.nelement() - - attention_slice = suggested_attention_slice - - if context.get_should_save_maps(module.identifier): - #print(module.identifier, "saving suggested_attention_slice of shape", - # suggested_attention_slice.shape, "dim", dim, "offset", offset) - slice_to_save = attention_slice.to('cpu') if dim is not None else attention_slice - context.save_slice(module.identifier, slice_to_save, dim=dim, offset=offset, slice_size=slice_size) - elif context.get_should_apply_saved_maps(module.identifier): - #print(module.identifier, "applying saved attention slice for dim", dim, "offset", offset) - saved_attention_slice = context.get_slice(module.identifier, dim, offset, slice_size) - - # slice may have been offloaded to CPU - saved_attention_slice = saved_attention_slice.to(suggested_attention_slice.device) - - if context.is_tokens_cross_attention(module.identifier): - index_map = context.cross_attention_index_map - remapped_saved_attention_slice = torch.index_select(saved_attention_slice, -1, index_map) - this_attention_slice = suggested_attention_slice - - mask = context.cross_attention_mask - saved_mask = mask - this_mask = 1 - mask - attention_slice = remapped_saved_attention_slice * saved_mask + \ - this_attention_slice * this_mask - else: - # just use everything - attention_slice = saved_attention_slice - - return attention_slice - - for name, module in unet.named_modules(): - module_name = type(module).__name__ - if module_name == "CrossAttention": - module.identifier = name - module.set_attention_slice_wrangler(attention_slice_wrangler) - module.set_slicing_strategy_getter(lambda module, module_identifier=name: \ - context.get_slicing_strategy(module_identifier)) - - @classmethod - def remove_attention_function(cls, unet): - # clear wrangler callback - for name, module in unet.named_modules(): - module_name = type(module).__name__ - if module_name == "CrossAttention": - module.set_attention_slice_wrangler(None) - module.set_slicing_strategy_getter(None) +def remove_attention_function(unet): + # clear wrangler callback + for name, module in unet.named_modules(): + module_name = type(module).__name__ + if module_name == "CrossAttention": + module.set_attention_slice_wrangler(None) + module.set_slicing_strategy_getter(None) diff --git a/ldm/models/diffusion/shared_invokeai_diffusion.py b/ldm/models/diffusion/shared_invokeai_diffusion.py index 0a18eb25c8..d748c9a673 100644 --- a/ldm/models/diffusion/shared_invokeai_diffusion.py +++ b/ldm/models/diffusion/shared_invokeai_diffusion.py @@ -4,7 +4,8 @@ from typing import Callable, Optional, Union import torch -from ldm.models.diffusion.cross_attention_control import CrossAttentionControl +from ldm.models.diffusion.cross_attention_control import Arguments, \ + remove_cross_attention_control, setup_cross_attention_control, Context from ldm.modules.attention import get_mem_free_total @@ -20,7 +21,7 @@ class InvokeAIDiffuserComponent: class ExtraConditioningInfo: - def __init__(self, cross_attention_control_args: Optional[CrossAttentionControl.Arguments]): + def __init__(self, cross_attention_control_args: Optional[Arguments]): self.cross_attention_control_args = cross_attention_control_args @property @@ -40,16 +41,16 @@ class InvokeAIDiffuserComponent: def setup_cross_attention_control(self, conditioning: ExtraConditioningInfo, step_count: int): self.conditioning = conditioning - self.cross_attention_control_context = CrossAttentionControl.Context( + self.cross_attention_control_context = Context( arguments=self.conditioning.cross_attention_control_args, step_count=step_count ) - CrossAttentionControl.setup_cross_attention_control(self.model, self.cross_attention_control_context) + setup_cross_attention_control(self.model, self.cross_attention_control_context) def remove_cross_attention_control(self): self.conditioning = None self.cross_attention_control_context = None - CrossAttentionControl.remove_cross_attention_control(self.model) + remove_cross_attention_control(self.model) @@ -71,7 +72,7 @@ class InvokeAIDiffuserComponent: cross_attention_control_types_to_do = [] - context: CrossAttentionControl.Context = self.cross_attention_control_context + context: Context = self.cross_attention_control_context if self.cross_attention_control_context is not None: percent_through = self.estimate_percent_through(step_index, sigma) cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step(percent_through) @@ -133,7 +134,7 @@ class InvokeAIDiffuserComponent: # representing batched uncond + cond, but then when it comes to applying the saved attention, the # wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.) # todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well. - context:CrossAttentionControl.Context = self.cross_attention_control_context + context:Context = self.cross_attention_control_context try: unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning)