refactor(cross_attention_control): remove outer CrossAttentionControl class

Python has modules. We don't need to use a class to provide a namespace.
This commit is contained in:
Kevin Turner 2022-11-12 11:01:10 -08:00
parent 1b6bbfb4db
commit 853c6af623
3 changed files with 246 additions and 246 deletions

View File

@ -14,7 +14,7 @@ import torch
from .prompt_parser import PromptParser, Blend, FlattenedPrompt, \ from .prompt_parser import PromptParser, Blend, FlattenedPrompt, \
CrossAttentionControlledFragment, CrossAttentionControlSubstitute, Fragment, log_tokenization CrossAttentionControlledFragment, CrossAttentionControlSubstitute, Fragment, log_tokenization
from ..models.diffusion.cross_attention_control import CrossAttentionControl from ..models.diffusion import cross_attention_control
from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder
@ -50,7 +50,7 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
print(f">> Parsed prompt to {parsed_prompt}") print(f">> Parsed prompt to {parsed_prompt}")
conditioning = None conditioning = None
cac_args:CrossAttentionControl.Arguments = None cac_args:cross_attention_control.Arguments = None
if type(parsed_prompt) is Blend: if type(parsed_prompt) is Blend:
blend: Blend = parsed_prompt blend: Blend = parsed_prompt
@ -121,7 +121,7 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
conditioning = original_embeddings conditioning = original_embeddings
edited_conditioning = edited_embeddings edited_conditioning = edited_embeddings
#print('>> got edit_opcodes', edit_opcodes, 'options', edit_options) #print('>> got edit_opcodes', edit_opcodes, 'options', edit_options)
cac_args = CrossAttentionControl.Arguments( cac_args = cross_attention_control.Arguments(
edited_conditioning = edited_conditioning, edited_conditioning = edited_conditioning,
edit_opcodes = edit_opcodes, edit_opcodes = edit_opcodes,
edit_options = edit_options edit_options = edit_options

View File

@ -8,7 +8,7 @@ import torch
class CrossAttentionControl:
class Arguments: class Arguments:
def __init__(self, edited_conditioning: torch.Tensor, edit_opcodes: list[tuple], edit_options: dict): def __init__(self, edited_conditioning: torch.Tensor, edit_opcodes: list[tuple], edit_options: dict):
@ -38,7 +38,7 @@ class CrossAttentionControl:
SAVE = 1, SAVE = 1,
APPLY = 2 APPLY = 2
def __init__(self, arguments: 'CrossAttentionControl.Arguments', step_count: int): def __init__(self, arguments: Arguments, step_count: int):
""" """
:param arguments: Arguments for the cross-attention control process :param arguments: Arguments for the cross-attention control process
:param step_count: The absolute total number of steps of diffusion (for img2img this is likely larger than the number of steps that will actually run) :param step_count: The absolute total number of steps of diffusion (for img2img this is likely larger than the number of steps that will actually run)
@ -54,58 +54,58 @@ class CrossAttentionControl:
self.clear_requests(cleanup=True) self.clear_requests(cleanup=True)
def register_cross_attention_modules(self, model): def register_cross_attention_modules(self, model):
for name,module in CrossAttentionControl.get_attention_modules(model, for name,module in get_attention_modules(model,
CrossAttentionControl.CrossAttentionType.SELF): CrossAttentionType.SELF):
self.self_cross_attention_module_identifiers.append(name) self.self_cross_attention_module_identifiers.append(name)
for name,module in CrossAttentionControl.get_attention_modules(model, for name,module in get_attention_modules(model,
CrossAttentionControl.CrossAttentionType.TOKENS): CrossAttentionType.TOKENS):
self.tokens_cross_attention_module_identifiers.append(name) self.tokens_cross_attention_module_identifiers.append(name)
def request_save_attention_maps(self, cross_attention_type: 'CrossAttentionControl.CrossAttentionType'): def request_save_attention_maps(self, cross_attention_type: 'CrossAttentionType'):
if cross_attention_type == CrossAttentionControl.CrossAttentionType.SELF: if cross_attention_type == CrossAttentionType.SELF:
self.self_cross_attention_action = CrossAttentionControl.Context.Action.SAVE self.self_cross_attention_action = Context.Action.SAVE
else: else:
self.tokens_cross_attention_action = CrossAttentionControl.Context.Action.SAVE self.tokens_cross_attention_action = Context.Action.SAVE
def request_apply_saved_attention_maps(self, cross_attention_type: 'CrossAttentionControl.CrossAttentionType'): def request_apply_saved_attention_maps(self, cross_attention_type: 'CrossAttentionType'):
if cross_attention_type == CrossAttentionControl.CrossAttentionType.SELF: if cross_attention_type == CrossAttentionType.SELF:
self.self_cross_attention_action = CrossAttentionControl.Context.Action.APPLY self.self_cross_attention_action = Context.Action.APPLY
else: else:
self.tokens_cross_attention_action = CrossAttentionControl.Context.Action.APPLY self.tokens_cross_attention_action = Context.Action.APPLY
def is_tokens_cross_attention(self, module_identifier) -> bool: def is_tokens_cross_attention(self, module_identifier) -> bool:
return module_identifier in self.tokens_cross_attention_module_identifiers return module_identifier in self.tokens_cross_attention_module_identifiers
def get_should_save_maps(self, module_identifier: str) -> bool: def get_should_save_maps(self, module_identifier: str) -> bool:
if module_identifier in self.self_cross_attention_module_identifiers: if module_identifier in self.self_cross_attention_module_identifiers:
return self.self_cross_attention_action == CrossAttentionControl.Context.Action.SAVE return self.self_cross_attention_action == Context.Action.SAVE
elif module_identifier in self.tokens_cross_attention_module_identifiers: elif module_identifier in self.tokens_cross_attention_module_identifiers:
return self.tokens_cross_attention_action == CrossAttentionControl.Context.Action.SAVE return self.tokens_cross_attention_action == Context.Action.SAVE
return False return False
def get_should_apply_saved_maps(self, module_identifier: str) -> bool: def get_should_apply_saved_maps(self, module_identifier: str) -> bool:
if module_identifier in self.self_cross_attention_module_identifiers: if module_identifier in self.self_cross_attention_module_identifiers:
return self.self_cross_attention_action == CrossAttentionControl.Context.Action.APPLY return self.self_cross_attention_action == Context.Action.APPLY
elif module_identifier in self.tokens_cross_attention_module_identifiers: elif module_identifier in self.tokens_cross_attention_module_identifiers:
return self.tokens_cross_attention_action == CrossAttentionControl.Context.Action.APPLY return self.tokens_cross_attention_action == Context.Action.APPLY
return False return False
def get_active_cross_attention_control_types_for_step(self, percent_through:float=None)\ def get_active_cross_attention_control_types_for_step(self, percent_through:float=None)\
-> list['CrossAttentionControl.CrossAttentionType']: -> list['CrossAttentionType']:
""" """
Should cross-attention control be applied on the given step? Should cross-attention control be applied on the given step?
:param percent_through: How far through the step sequence are we (0.0=pure noise, 1.0=completely denoised image). Expected range 0.0..<1.0. :param percent_through: How far through the step sequence are we (0.0=pure noise, 1.0=completely denoised image). Expected range 0.0..<1.0.
:return: A list of attention types that cross-attention control should be performed for on the given step. May be []. :return: A list of attention types that cross-attention control should be performed for on the given step. May be [].
""" """
if percent_through is None: if percent_through is None:
return [CrossAttentionControl.CrossAttentionType.SELF, CrossAttentionControl.CrossAttentionType.TOKENS] return [CrossAttentionType.SELF, CrossAttentionType.TOKENS]
opts = self.arguments.edit_options opts = self.arguments.edit_options
to_control = [] to_control = []
if opts['s_start'] <= percent_through and percent_through < opts['s_end']: if opts['s_start'] <= percent_through and percent_through < opts['s_end']:
to_control.append(CrossAttentionControl.CrossAttentionType.SELF) to_control.append(CrossAttentionType.SELF)
if opts['t_start'] <= percent_through and percent_through < opts['t_end']: if opts['t_start'] <= percent_through and percent_through < opts['t_end']:
to_control.append(CrossAttentionControl.CrossAttentionType.TOKENS) to_control.append(CrossAttentionType.TOKENS)
return to_control return to_control
def save_slice(self, identifier: str, slice: torch.Tensor, dim: Optional[int], offset: int, def save_slice(self, identifier: str, slice: torch.Tensor, dim: Optional[int], offset: int,
@ -148,8 +148,8 @@ class CrossAttentionControl:
return saved_attention['dim'], saved_attention['slice_size'] return saved_attention['dim'], saved_attention['slice_size']
def clear_requests(self, cleanup=True): def clear_requests(self, cleanup=True):
self.tokens_cross_attention_action = CrossAttentionControl.Context.Action.NONE self.tokens_cross_attention_action = Context.Action.NONE
self.self_cross_attention_action = CrossAttentionControl.Context.Action.NONE self.self_cross_attention_action = Context.Action.NONE
if cleanup: if cleanup:
self.saved_cross_attention_maps = {} self.saved_cross_attention_maps = {}
@ -158,12 +158,12 @@ class CrossAttentionControl:
for offset, slice in map_dict['slices'].items(): for offset, slice in map_dict['slices'].items():
map_dict[offset] = slice.to('cpu') map_dict[offset] = slice.to('cpu')
@classmethod
def remove_cross_attention_control(cls, model):
cls.remove_attention_function(model)
@classmethod def remove_cross_attention_control(model):
def setup_cross_attention_control(cls, model, context: Context): remove_attention_function(model)
def setup_cross_attention_control(model, context: Context):
""" """
Inject attention parameters and functions into the passed in model to enable cross attention editing. Inject attention parameters and functions into the passed in model to enable cross attention editing.
@ -191,22 +191,21 @@ class CrossAttentionControl:
context.register_cross_attention_modules(model) context.register_cross_attention_modules(model)
context.cross_attention_mask = mask.to(device) context.cross_attention_mask = mask.to(device)
context.cross_attention_index_map = indices.to(device) context.cross_attention_index_map = indices.to(device)
cls.inject_attention_function(model, context) inject_attention_function(model, context)
class CrossAttentionType(enum.Enum): class CrossAttentionType(enum.Enum):
SELF = 1 SELF = 1
TOKENS = 2 TOKENS = 2
@classmethod
def get_attention_modules(cls, model, which: CrossAttentionType): def get_attention_modules(model, which: CrossAttentionType):
which_attn = "attn1" if which is cls.CrossAttentionType.SELF else "attn2" which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2"
return [(name,module) for name, module in model.named_modules() if return [(name,module) for name, module in model.named_modules() if
type(module).__name__ == "CrossAttention" and which_attn in name] type(module).__name__ == "CrossAttention" and which_attn in name]
@classmethod def inject_attention_function(unet, context: Context):
def inject_attention_function(cls, unet, context: 'CrossAttentionControl.Context'):
# ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276 # ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
def attention_slice_wrangler(module, suggested_attention_slice:torch.Tensor, dim, offset, slice_size): def attention_slice_wrangler(module, suggested_attention_slice:torch.Tensor, dim, offset, slice_size):
@ -251,8 +250,8 @@ class CrossAttentionControl:
module.set_slicing_strategy_getter(lambda module, module_identifier=name: \ module.set_slicing_strategy_getter(lambda module, module_identifier=name: \
context.get_slicing_strategy(module_identifier)) context.get_slicing_strategy(module_identifier))
@classmethod
def remove_attention_function(cls, unet): def remove_attention_function(unet):
# clear wrangler callback # clear wrangler callback
for name, module in unet.named_modules(): for name, module in unet.named_modules():
module_name = type(module).__name__ module_name = type(module).__name__

View File

@ -4,7 +4,8 @@ from typing import Callable, Optional, Union
import torch import torch
from ldm.models.diffusion.cross_attention_control import CrossAttentionControl from ldm.models.diffusion.cross_attention_control import Arguments, \
remove_cross_attention_control, setup_cross_attention_control, Context
from ldm.modules.attention import get_mem_free_total from ldm.modules.attention import get_mem_free_total
@ -20,7 +21,7 @@ class InvokeAIDiffuserComponent:
class ExtraConditioningInfo: class ExtraConditioningInfo:
def __init__(self, cross_attention_control_args: Optional[CrossAttentionControl.Arguments]): def __init__(self, cross_attention_control_args: Optional[Arguments]):
self.cross_attention_control_args = cross_attention_control_args self.cross_attention_control_args = cross_attention_control_args
@property @property
@ -40,16 +41,16 @@ class InvokeAIDiffuserComponent:
def setup_cross_attention_control(self, conditioning: ExtraConditioningInfo, step_count: int): def setup_cross_attention_control(self, conditioning: ExtraConditioningInfo, step_count: int):
self.conditioning = conditioning self.conditioning = conditioning
self.cross_attention_control_context = CrossAttentionControl.Context( self.cross_attention_control_context = Context(
arguments=self.conditioning.cross_attention_control_args, arguments=self.conditioning.cross_attention_control_args,
step_count=step_count step_count=step_count
) )
CrossAttentionControl.setup_cross_attention_control(self.model, self.cross_attention_control_context) setup_cross_attention_control(self.model, self.cross_attention_control_context)
def remove_cross_attention_control(self): def remove_cross_attention_control(self):
self.conditioning = None self.conditioning = None
self.cross_attention_control_context = None self.cross_attention_control_context = None
CrossAttentionControl.remove_cross_attention_control(self.model) remove_cross_attention_control(self.model)
@ -71,7 +72,7 @@ class InvokeAIDiffuserComponent:
cross_attention_control_types_to_do = [] cross_attention_control_types_to_do = []
context: CrossAttentionControl.Context = self.cross_attention_control_context context: Context = self.cross_attention_control_context
if self.cross_attention_control_context is not None: if self.cross_attention_control_context is not None:
percent_through = self.estimate_percent_through(step_index, sigma) percent_through = self.estimate_percent_through(step_index, sigma)
cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step(percent_through) cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step(percent_through)
@ -133,7 +134,7 @@ class InvokeAIDiffuserComponent:
# representing batched uncond + cond, but then when it comes to applying the saved attention, the # representing batched uncond + cond, but then when it comes to applying the saved attention, the
# wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.) # wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.)
# todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well. # todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well.
context:CrossAttentionControl.Context = self.cross_attention_control_context context:Context = self.cross_attention_control_context
try: try:
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning) unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning)