mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactor(cross_attention_control): type hints and other lint 🚮
This commit is contained in:
parent
810fad9e06
commit
47e6f94111
@ -35,6 +35,9 @@ class CrossAttentionType(enum.Enum):
|
||||
|
||||
class Context:
|
||||
|
||||
cross_attention_mask: Optional[torch.Tensor]
|
||||
cross_attention_index_map: Optional[torch.Tensor]
|
||||
|
||||
class Action(enum.Enum):
|
||||
NONE = 0
|
||||
SAVE = 1,
|
||||
@ -45,6 +48,10 @@ class Context:
|
||||
:param arguments: Arguments for the cross-attention control process
|
||||
:param step_count: The absolute total number of steps of diffusion (for img2img this is likely larger than the number of steps that will actually run)
|
||||
"""
|
||||
self.cross_attention_mask = None
|
||||
self.cross_attention_index_map = None
|
||||
self.self_cross_attention_action = Context.Action.NONE
|
||||
self.tokens_cross_attention_action = Context.Action.NONE
|
||||
self.arguments = arguments
|
||||
self.step_count = step_count
|
||||
|
||||
@ -56,11 +63,9 @@ class Context:
|
||||
self.clear_requests(cleanup=True)
|
||||
|
||||
def register_cross_attention_modules(self, model):
|
||||
for name,module in get_attention_modules(model,
|
||||
CrossAttentionType.SELF):
|
||||
for name,module in get_attention_modules(model, CrossAttentionType.SELF):
|
||||
self.self_cross_attention_module_identifiers.append(name)
|
||||
for name,module in get_attention_modules(model,
|
||||
CrossAttentionType.TOKENS):
|
||||
for name,module in get_attention_modules(model, CrossAttentionType.TOKENS):
|
||||
self.tokens_cross_attention_module_identifiers.append(name)
|
||||
|
||||
def request_save_attention_maps(self, cross_attention_type: CrossAttentionType):
|
||||
@ -104,9 +109,9 @@ class Context:
|
||||
|
||||
opts = self.arguments.edit_options
|
||||
to_control = []
|
||||
if opts['s_start'] <= percent_through and percent_through < opts['s_end']:
|
||||
if opts['s_start'] <= percent_through < opts['s_end']:
|
||||
to_control.append(CrossAttentionType.SELF)
|
||||
if opts['t_start'] <= percent_through and percent_through < opts['t_end']:
|
||||
if opts['t_start'] <= percent_through < opts['t_end']:
|
||||
to_control.append(CrossAttentionType.TOKENS)
|
||||
return to_control
|
||||
|
||||
@ -134,7 +139,7 @@ class Context:
|
||||
f"slice_size mismatch: expected slice_size={slice_size}, have {saved_attention_dict['slice_size']}")
|
||||
return saved_attention_dict['slices'][requested_offset]
|
||||
|
||||
if saved_attention_dict['dim'] == None:
|
||||
if saved_attention_dict['dim'] is None:
|
||||
whole_saved_attention = saved_attention_dict['slices'][0]
|
||||
if requested_dim == 0:
|
||||
return whole_saved_attention[requested_offset:requested_offset + slice_size]
|
||||
@ -143,7 +148,7 @@ class Context:
|
||||
|
||||
raise RuntimeError(f"Cannot convert dim {saved_attention_dict['dim']} to requested dim {requested_dim}")
|
||||
|
||||
def get_slicing_strategy(self, identifier: str) -> Optional[tuple[int, int]]:
|
||||
def get_slicing_strategy(self, identifier: str) -> tuple[Optional[int], Optional[int]]:
|
||||
saved_attention = self.saved_cross_attention_maps.get(identifier, None)
|
||||
if saved_attention is None:
|
||||
return None, None
|
||||
|
Loading…
Reference in New Issue
Block a user