refactor(cross_attention_control): re-order enum class for easier reference

This commit is contained in:
Kevin Turner 2022-11-12 11:05:33 -08:00
parent 853c6af623
commit 810fad9e06

View File

@ -7,9 +7,6 @@ import torch
# https://github.com/bloc97/CrossAttentionControl
class Arguments:
def __init__(self, edited_conditioning: torch.Tensor, edit_opcodes: list[tuple], edit_options: dict):
"""
@ -31,6 +28,11 @@ class Arguments:
self.edit_options = non_none_edit_options[0]
class CrossAttentionType(enum.Enum):
SELF = 1
TOKENS = 2
class Context:
class Action(enum.Enum):
@ -61,13 +63,13 @@ class Context:
CrossAttentionType.TOKENS):
self.tokens_cross_attention_module_identifiers.append(name)
def request_save_attention_maps(self, cross_attention_type: 'CrossAttentionType'):
def request_save_attention_maps(self, cross_attention_type: CrossAttentionType):
if cross_attention_type == CrossAttentionType.SELF:
self.self_cross_attention_action = Context.Action.SAVE
else:
self.tokens_cross_attention_action = Context.Action.SAVE
def request_apply_saved_attention_maps(self, cross_attention_type: 'CrossAttentionType'):
def request_apply_saved_attention_maps(self, cross_attention_type: CrossAttentionType):
if cross_attention_type == CrossAttentionType.SELF:
self.self_cross_attention_action = Context.Action.APPLY
else:
@ -91,7 +93,7 @@ class Context:
return False
def get_active_cross_attention_control_types_for_step(self, percent_through:float=None)\
-> list['CrossAttentionType']:
-> list[CrossAttentionType]:
"""
Should cross-attention control be applied on the given step?
:param percent_through: How far through the step sequence are we (0.0=pure noise, 1.0=completely denoised image). Expected range 0.0..<1.0.
@ -194,11 +196,6 @@ def setup_cross_attention_control(model, context: Context):
inject_attention_function(model, context)
class CrossAttentionType(enum.Enum):
SELF = 1
TOKENS = 2
def get_attention_modules(model, which: CrossAttentionType):
which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2"
return [(name,module) for name, module in model.named_modules() if
@ -258,4 +255,3 @@ def remove_attention_function(unet):
if module_name == "CrossAttention":
module.set_attention_slice_wrangler(None)
module.set_slicing_strategy_getter(None)