mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactor(cross_attention_control): type hints and other lint 🚮
This commit is contained in:
parent
810fad9e06
commit
47e6f94111
@ -35,6 +35,9 @@ class CrossAttentionType(enum.Enum):
|
|||||||
|
|
||||||
class Context:
|
class Context:
|
||||||
|
|
||||||
|
cross_attention_mask: Optional[torch.Tensor]
|
||||||
|
cross_attention_index_map: Optional[torch.Tensor]
|
||||||
|
|
||||||
class Action(enum.Enum):
|
class Action(enum.Enum):
|
||||||
NONE = 0
|
NONE = 0
|
||||||
SAVE = 1,
|
SAVE = 1,
|
||||||
@ -45,6 +48,10 @@ class Context:
|
|||||||
:param arguments: Arguments for the cross-attention control process
|
:param arguments: Arguments for the cross-attention control process
|
||||||
:param step_count: The absolute total number of steps of diffusion (for img2img this is likely larger than the number of steps that will actually run)
|
:param step_count: The absolute total number of steps of diffusion (for img2img this is likely larger than the number of steps that will actually run)
|
||||||
"""
|
"""
|
||||||
|
self.cross_attention_mask = None
|
||||||
|
self.cross_attention_index_map = None
|
||||||
|
self.self_cross_attention_action = Context.Action.NONE
|
||||||
|
self.tokens_cross_attention_action = Context.Action.NONE
|
||||||
self.arguments = arguments
|
self.arguments = arguments
|
||||||
self.step_count = step_count
|
self.step_count = step_count
|
||||||
|
|
||||||
@ -56,11 +63,9 @@ class Context:
|
|||||||
self.clear_requests(cleanup=True)
|
self.clear_requests(cleanup=True)
|
||||||
|
|
||||||
def register_cross_attention_modules(self, model):
|
def register_cross_attention_modules(self, model):
|
||||||
for name,module in get_attention_modules(model,
|
for name,module in get_attention_modules(model, CrossAttentionType.SELF):
|
||||||
CrossAttentionType.SELF):
|
|
||||||
self.self_cross_attention_module_identifiers.append(name)
|
self.self_cross_attention_module_identifiers.append(name)
|
||||||
for name,module in get_attention_modules(model,
|
for name,module in get_attention_modules(model, CrossAttentionType.TOKENS):
|
||||||
CrossAttentionType.TOKENS):
|
|
||||||
self.tokens_cross_attention_module_identifiers.append(name)
|
self.tokens_cross_attention_module_identifiers.append(name)
|
||||||
|
|
||||||
def request_save_attention_maps(self, cross_attention_type: CrossAttentionType):
|
def request_save_attention_maps(self, cross_attention_type: CrossAttentionType):
|
||||||
@ -104,9 +109,9 @@ class Context:
|
|||||||
|
|
||||||
opts = self.arguments.edit_options
|
opts = self.arguments.edit_options
|
||||||
to_control = []
|
to_control = []
|
||||||
if opts['s_start'] <= percent_through and percent_through < opts['s_end']:
|
if opts['s_start'] <= percent_through < opts['s_end']:
|
||||||
to_control.append(CrossAttentionType.SELF)
|
to_control.append(CrossAttentionType.SELF)
|
||||||
if opts['t_start'] <= percent_through and percent_through < opts['t_end']:
|
if opts['t_start'] <= percent_through < opts['t_end']:
|
||||||
to_control.append(CrossAttentionType.TOKENS)
|
to_control.append(CrossAttentionType.TOKENS)
|
||||||
return to_control
|
return to_control
|
||||||
|
|
||||||
@ -134,7 +139,7 @@ class Context:
|
|||||||
f"slice_size mismatch: expected slice_size={slice_size}, have {saved_attention_dict['slice_size']}")
|
f"slice_size mismatch: expected slice_size={slice_size}, have {saved_attention_dict['slice_size']}")
|
||||||
return saved_attention_dict['slices'][requested_offset]
|
return saved_attention_dict['slices'][requested_offset]
|
||||||
|
|
||||||
if saved_attention_dict['dim'] == None:
|
if saved_attention_dict['dim'] is None:
|
||||||
whole_saved_attention = saved_attention_dict['slices'][0]
|
whole_saved_attention = saved_attention_dict['slices'][0]
|
||||||
if requested_dim == 0:
|
if requested_dim == 0:
|
||||||
return whole_saved_attention[requested_offset:requested_offset + slice_size]
|
return whole_saved_attention[requested_offset:requested_offset + slice_size]
|
||||||
@ -143,7 +148,7 @@ class Context:
|
|||||||
|
|
||||||
raise RuntimeError(f"Cannot convert dim {saved_attention_dict['dim']} to requested dim {requested_dim}")
|
raise RuntimeError(f"Cannot convert dim {saved_attention_dict['dim']} to requested dim {requested_dim}")
|
||||||
|
|
||||||
def get_slicing_strategy(self, identifier: str) -> Optional[tuple[int, int]]:
|
def get_slicing_strategy(self, identifier: str) -> tuple[Optional[int], Optional[int]]:
|
||||||
saved_attention = self.saved_cross_attention_maps.get(identifier, None)
|
saved_attention = self.saved_cross_attention_maps.get(identifier, None)
|
||||||
if saved_attention is None:
|
if saved_attention is None:
|
||||||
return None, None
|
return None, None
|
||||||
|
Loading…
Reference in New Issue
Block a user