refactor(cross_attention_control): remove outer CrossAttentionControl class (#1459)

I was working on attention control in #1384, started making a few
changes to improve the typing and make it easier to work with. Then the
whitespace changes touched so many lines it seemed worth separating out
these refactoring operations to this PR so they don't get mixed up with
other functional changes.

It would be helpful to merge this to `development` before continuing
work on attention control in #1384

The github diff isn't good at showing these together since they changed
whitespace on so many lines. It may be easier to review by looking at
the individual commits, and/or toggling the "hide whitespace
differences" option in the view.
This commit is contained in:
Damian Stewart 2022-11-13 14:20:18 +01:00 committed by GitHub
commit 9bf6013fdd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 248 additions and 247 deletions

View File

@ -14,7 +14,7 @@ import torch
from .prompt_parser import PromptParser, Blend, FlattenedPrompt, \
CrossAttentionControlledFragment, CrossAttentionControlSubstitute, Fragment, log_tokenization
from ..models.diffusion.cross_attention_control import CrossAttentionControl
from ..models.diffusion import cross_attention_control
from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder
@ -50,7 +50,7 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
print(f">> Parsed prompt to {parsed_prompt}")
conditioning = None
cac_args:CrossAttentionControl.Arguments = None
cac_args:cross_attention_control.Arguments = None
if type(parsed_prompt) is Blend:
blend: Blend = parsed_prompt
@ -121,7 +121,7 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
conditioning = original_embeddings
edited_conditioning = edited_embeddings
#print('>> got edit_opcodes', edit_opcodes, 'options', edit_options)
cac_args = CrossAttentionControl.Arguments(
cac_args = cross_attention_control.Arguments(
edited_conditioning = edited_conditioning,
edit_opcodes = edit_opcodes,
edit_options = edit_options

View File

@ -7,9 +7,6 @@ import torch
# https://github.com/bloc97/CrossAttentionControl
class CrossAttentionControl:
class Arguments:
def __init__(self, edited_conditioning: torch.Tensor, edit_opcodes: list[tuple], edit_options: dict):
"""
@ -31,18 +28,30 @@ class CrossAttentionControl:
self.edit_options = non_none_edit_options[0]
class CrossAttentionType(enum.Enum):
SELF = 1
TOKENS = 2
class Context:
cross_attention_mask: Optional[torch.Tensor]
cross_attention_index_map: Optional[torch.Tensor]
class Action(enum.Enum):
NONE = 0
SAVE = 1,
APPLY = 2
def __init__(self, arguments: 'CrossAttentionControl.Arguments', step_count: int):
def __init__(self, arguments: Arguments, step_count: int):
"""
:param arguments: Arguments for the cross-attention control process
:param step_count: The absolute total number of steps of diffusion (for img2img this is likely larger than the number of steps that will actually run)
"""
self.cross_attention_mask = None
self.cross_attention_index_map = None
self.self_cross_attention_action = Context.Action.NONE
self.tokens_cross_attention_action = Context.Action.NONE
self.arguments = arguments
self.step_count = step_count
@ -54,58 +63,56 @@ class CrossAttentionControl:
self.clear_requests(cleanup=True)
def register_cross_attention_modules(self, model):
for name,module in CrossAttentionControl.get_attention_modules(model,
CrossAttentionControl.CrossAttentionType.SELF):
for name,module in get_attention_modules(model, CrossAttentionType.SELF):
self.self_cross_attention_module_identifiers.append(name)
for name,module in CrossAttentionControl.get_attention_modules(model,
CrossAttentionControl.CrossAttentionType.TOKENS):
for name,module in get_attention_modules(model, CrossAttentionType.TOKENS):
self.tokens_cross_attention_module_identifiers.append(name)
def request_save_attention_maps(self, cross_attention_type: 'CrossAttentionControl.CrossAttentionType'):
if cross_attention_type == CrossAttentionControl.CrossAttentionType.SELF:
self.self_cross_attention_action = CrossAttentionControl.Context.Action.SAVE
def request_save_attention_maps(self, cross_attention_type: CrossAttentionType):
if cross_attention_type == CrossAttentionType.SELF:
self.self_cross_attention_action = Context.Action.SAVE
else:
self.tokens_cross_attention_action = CrossAttentionControl.Context.Action.SAVE
self.tokens_cross_attention_action = Context.Action.SAVE
def request_apply_saved_attention_maps(self, cross_attention_type: 'CrossAttentionControl.CrossAttentionType'):
if cross_attention_type == CrossAttentionControl.CrossAttentionType.SELF:
self.self_cross_attention_action = CrossAttentionControl.Context.Action.APPLY
def request_apply_saved_attention_maps(self, cross_attention_type: CrossAttentionType):
if cross_attention_type == CrossAttentionType.SELF:
self.self_cross_attention_action = Context.Action.APPLY
else:
self.tokens_cross_attention_action = CrossAttentionControl.Context.Action.APPLY
self.tokens_cross_attention_action = Context.Action.APPLY
def is_tokens_cross_attention(self, module_identifier) -> bool:
return module_identifier in self.tokens_cross_attention_module_identifiers
def get_should_save_maps(self, module_identifier: str) -> bool:
if module_identifier in self.self_cross_attention_module_identifiers:
return self.self_cross_attention_action == CrossAttentionControl.Context.Action.SAVE
return self.self_cross_attention_action == Context.Action.SAVE
elif module_identifier in self.tokens_cross_attention_module_identifiers:
return self.tokens_cross_attention_action == CrossAttentionControl.Context.Action.SAVE
return self.tokens_cross_attention_action == Context.Action.SAVE
return False
def get_should_apply_saved_maps(self, module_identifier: str) -> bool:
if module_identifier in self.self_cross_attention_module_identifiers:
return self.self_cross_attention_action == CrossAttentionControl.Context.Action.APPLY
return self.self_cross_attention_action == Context.Action.APPLY
elif module_identifier in self.tokens_cross_attention_module_identifiers:
return self.tokens_cross_attention_action == CrossAttentionControl.Context.Action.APPLY
return self.tokens_cross_attention_action == Context.Action.APPLY
return False
def get_active_cross_attention_control_types_for_step(self, percent_through:float=None)\
-> list['CrossAttentionControl.CrossAttentionType']:
-> list[CrossAttentionType]:
"""
Should cross-attention control be applied on the given step?
:param percent_through: How far through the step sequence are we (0.0=pure noise, 1.0=completely denoised image). Expected range 0.0..<1.0.
:return: A list of attention types that cross-attention control should be performed for on the given step. May be [].
"""
if percent_through is None:
return [CrossAttentionControl.CrossAttentionType.SELF, CrossAttentionControl.CrossAttentionType.TOKENS]
return [CrossAttentionType.SELF, CrossAttentionType.TOKENS]
opts = self.arguments.edit_options
to_control = []
if opts['s_start'] <= percent_through and percent_through < opts['s_end']:
to_control.append(CrossAttentionControl.CrossAttentionType.SELF)
if opts['t_start'] <= percent_through and percent_through < opts['t_end']:
to_control.append(CrossAttentionControl.CrossAttentionType.TOKENS)
if opts['s_start'] <= percent_through < opts['s_end']:
to_control.append(CrossAttentionType.SELF)
if opts['t_start'] <= percent_through < opts['t_end']:
to_control.append(CrossAttentionType.TOKENS)
return to_control
def save_slice(self, identifier: str, slice: torch.Tensor, dim: Optional[int], offset: int,
@ -132,7 +139,7 @@ class CrossAttentionControl:
f"slice_size mismatch: expected slice_size={slice_size}, have {saved_attention_dict['slice_size']}")
return saved_attention_dict['slices'][requested_offset]
if saved_attention_dict['dim'] == None:
if saved_attention_dict['dim'] is None:
whole_saved_attention = saved_attention_dict['slices'][0]
if requested_dim == 0:
return whole_saved_attention[requested_offset:requested_offset + slice_size]
@ -141,15 +148,15 @@ class CrossAttentionControl:
raise RuntimeError(f"Cannot convert dim {saved_attention_dict['dim']} to requested dim {requested_dim}")
def get_slicing_strategy(self, identifier: str) -> Optional[tuple[int, int]]:
def get_slicing_strategy(self, identifier: str) -> tuple[Optional[int], Optional[int]]:
saved_attention = self.saved_cross_attention_maps.get(identifier, None)
if saved_attention is None:
return None, None
return saved_attention['dim'], saved_attention['slice_size']
def clear_requests(self, cleanup=True):
self.tokens_cross_attention_action = CrossAttentionControl.Context.Action.NONE
self.self_cross_attention_action = CrossAttentionControl.Context.Action.NONE
self.tokens_cross_attention_action = Context.Action.NONE
self.self_cross_attention_action = Context.Action.NONE
if cleanup:
self.saved_cross_attention_maps = {}
@ -158,12 +165,12 @@ class CrossAttentionControl:
for offset, slice in map_dict['slices'].items():
map_dict[offset] = slice.to('cpu')
@classmethod
def remove_cross_attention_control(cls, model):
cls.remove_attention_function(model)
@classmethod
def setup_cross_attention_control(cls, model, context: Context):
def remove_cross_attention_control(model):
remove_attention_function(model)
def setup_cross_attention_control(model, context: Context):
"""
Inject attention parameters and functions into the passed in model to enable cross attention editing.
@ -191,22 +198,16 @@ class CrossAttentionControl:
context.register_cross_attention_modules(model)
context.cross_attention_mask = mask.to(device)
context.cross_attention_index_map = indices.to(device)
cls.inject_attention_function(model, context)
inject_attention_function(model, context)
class CrossAttentionType(enum.Enum):
SELF = 1
TOKENS = 2
@classmethod
def get_attention_modules(cls, model, which: CrossAttentionType):
which_attn = "attn1" if which is cls.CrossAttentionType.SELF else "attn2"
def get_attention_modules(model, which: CrossAttentionType):
which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2"
return [(name,module) for name, module in model.named_modules() if
type(module).__name__ == "CrossAttention" and which_attn in name]
@classmethod
def inject_attention_function(cls, unet, context: 'CrossAttentionControl.Context'):
def inject_attention_function(unet, context: Context):
# ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
def attention_slice_wrangler(module, suggested_attention_slice:torch.Tensor, dim, offset, slice_size):
@ -251,12 +252,11 @@ class CrossAttentionControl:
module.set_slicing_strategy_getter(lambda module, module_identifier=name: \
context.get_slicing_strategy(module_identifier))
@classmethod
def remove_attention_function(cls, unet):
def remove_attention_function(unet):
# clear wrangler callback
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == "CrossAttention":
module.set_attention_slice_wrangler(None)
module.set_slicing_strategy_getter(None)

View File

@ -4,7 +4,8 @@ from typing import Callable, Optional, Union
import torch
from ldm.models.diffusion.cross_attention_control import CrossAttentionControl
from ldm.models.diffusion.cross_attention_control import Arguments, \
remove_cross_attention_control, setup_cross_attention_control, Context
from ldm.modules.attention import get_mem_free_total
@ -20,7 +21,7 @@ class InvokeAIDiffuserComponent:
class ExtraConditioningInfo:
def __init__(self, cross_attention_control_args: Optional[CrossAttentionControl.Arguments]):
def __init__(self, cross_attention_control_args: Optional[Arguments]):
self.cross_attention_control_args = cross_attention_control_args
@property
@ -40,16 +41,16 @@ class InvokeAIDiffuserComponent:
def setup_cross_attention_control(self, conditioning: ExtraConditioningInfo, step_count: int):
self.conditioning = conditioning
self.cross_attention_control_context = CrossAttentionControl.Context(
self.cross_attention_control_context = Context(
arguments=self.conditioning.cross_attention_control_args,
step_count=step_count
)
CrossAttentionControl.setup_cross_attention_control(self.model, self.cross_attention_control_context)
setup_cross_attention_control(self.model, self.cross_attention_control_context)
def remove_cross_attention_control(self):
self.conditioning = None
self.cross_attention_control_context = None
CrossAttentionControl.remove_cross_attention_control(self.model)
remove_cross_attention_control(self.model)
@ -71,7 +72,7 @@ class InvokeAIDiffuserComponent:
cross_attention_control_types_to_do = []
context: CrossAttentionControl.Context = self.cross_attention_control_context
context: Context = self.cross_attention_control_context
if self.cross_attention_control_context is not None:
percent_through = self.estimate_percent_through(step_index, sigma)
cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step(percent_through)
@ -133,7 +134,7 @@ class InvokeAIDiffuserComponent:
# representing batched uncond + cond, but then when it comes to applying the saved attention, the
# wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.)
# todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well.
context:CrossAttentionControl.Context = self.cross_attention_control_context
context:Context = self.cross_attention_control_context
try:
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning)