mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactor(cross_attention_control): re-order enum class for easier reference
This commit is contained in:
parent
853c6af623
commit
810fad9e06
@ -7,9 +7,6 @@ import torch
|
||||
# https://github.com/bloc97/CrossAttentionControl
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class Arguments:
|
||||
def __init__(self, edited_conditioning: torch.Tensor, edit_opcodes: list[tuple], edit_options: dict):
|
||||
"""
|
||||
@ -31,6 +28,11 @@ class Arguments:
|
||||
self.edit_options = non_none_edit_options[0]
|
||||
|
||||
|
||||
class CrossAttentionType(enum.Enum):
|
||||
SELF = 1
|
||||
TOKENS = 2
|
||||
|
||||
|
||||
class Context:
|
||||
|
||||
class Action(enum.Enum):
|
||||
@ -61,13 +63,13 @@ class Context:
|
||||
CrossAttentionType.TOKENS):
|
||||
self.tokens_cross_attention_module_identifiers.append(name)
|
||||
|
||||
def request_save_attention_maps(self, cross_attention_type: 'CrossAttentionType'):
|
||||
def request_save_attention_maps(self, cross_attention_type: CrossAttentionType):
|
||||
if cross_attention_type == CrossAttentionType.SELF:
|
||||
self.self_cross_attention_action = Context.Action.SAVE
|
||||
else:
|
||||
self.tokens_cross_attention_action = Context.Action.SAVE
|
||||
|
||||
def request_apply_saved_attention_maps(self, cross_attention_type: 'CrossAttentionType'):
|
||||
def request_apply_saved_attention_maps(self, cross_attention_type: CrossAttentionType):
|
||||
if cross_attention_type == CrossAttentionType.SELF:
|
||||
self.self_cross_attention_action = Context.Action.APPLY
|
||||
else:
|
||||
@ -91,7 +93,7 @@ class Context:
|
||||
return False
|
||||
|
||||
def get_active_cross_attention_control_types_for_step(self, percent_through:float=None)\
|
||||
-> list['CrossAttentionType']:
|
||||
-> list[CrossAttentionType]:
|
||||
"""
|
||||
Should cross-attention control be applied on the given step?
|
||||
:param percent_through: How far through the step sequence are we (0.0=pure noise, 1.0=completely denoised image). Expected range 0.0..<1.0.
|
||||
@ -194,11 +196,6 @@ def setup_cross_attention_control(model, context: Context):
|
||||
inject_attention_function(model, context)
|
||||
|
||||
|
||||
class CrossAttentionType(enum.Enum):
|
||||
SELF = 1
|
||||
TOKENS = 2
|
||||
|
||||
|
||||
def get_attention_modules(model, which: CrossAttentionType):
|
||||
which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2"
|
||||
return [(name,module) for name, module in model.named_modules() if
|
||||
@ -258,4 +255,3 @@ def remove_attention_function(unet):
|
||||
if module_name == "CrossAttention":
|
||||
module.set_attention_slice_wrangler(None)
|
||||
module.set_slicing_strategy_getter(None)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user