refactor(cross_attention_control): type hints and other lint 🚮

This commit is contained in:
Kevin Turner 2022-11-12 11:25:39 -08:00
parent 810fad9e06
commit 47e6f94111

View File

@ -35,6 +35,9 @@ class CrossAttentionType(enum.Enum):
class Context: class Context:
cross_attention_mask: Optional[torch.Tensor]
cross_attention_index_map: Optional[torch.Tensor]
class Action(enum.Enum): class Action(enum.Enum):
NONE = 0 NONE = 0
SAVE = 1, SAVE = 1,
@ -45,6 +48,10 @@ class Context:
:param arguments: Arguments for the cross-attention control process :param arguments: Arguments for the cross-attention control process
:param step_count: The absolute total number of steps of diffusion (for img2img this is likely larger than the number of steps that will actually run) :param step_count: The absolute total number of steps of diffusion (for img2img this is likely larger than the number of steps that will actually run)
""" """
self.cross_attention_mask = None
self.cross_attention_index_map = None
self.self_cross_attention_action = Context.Action.NONE
self.tokens_cross_attention_action = Context.Action.NONE
self.arguments = arguments self.arguments = arguments
self.step_count = step_count self.step_count = step_count
@ -56,11 +63,9 @@ class Context:
self.clear_requests(cleanup=True) self.clear_requests(cleanup=True)
def register_cross_attention_modules(self, model): def register_cross_attention_modules(self, model):
for name,module in get_attention_modules(model, for name,module in get_attention_modules(model, CrossAttentionType.SELF):
CrossAttentionType.SELF):
self.self_cross_attention_module_identifiers.append(name) self.self_cross_attention_module_identifiers.append(name)
for name,module in get_attention_modules(model, for name,module in get_attention_modules(model, CrossAttentionType.TOKENS):
CrossAttentionType.TOKENS):
self.tokens_cross_attention_module_identifiers.append(name) self.tokens_cross_attention_module_identifiers.append(name)
def request_save_attention_maps(self, cross_attention_type: CrossAttentionType): def request_save_attention_maps(self, cross_attention_type: CrossAttentionType):
@ -104,9 +109,9 @@ class Context:
opts = self.arguments.edit_options opts = self.arguments.edit_options
to_control = [] to_control = []
if opts['s_start'] <= percent_through and percent_through < opts['s_end']: if opts['s_start'] <= percent_through < opts['s_end']:
to_control.append(CrossAttentionType.SELF) to_control.append(CrossAttentionType.SELF)
if opts['t_start'] <= percent_through and percent_through < opts['t_end']: if opts['t_start'] <= percent_through < opts['t_end']:
to_control.append(CrossAttentionType.TOKENS) to_control.append(CrossAttentionType.TOKENS)
return to_control return to_control
@ -134,7 +139,7 @@ class Context:
f"slice_size mismatch: expected slice_size={slice_size}, have {saved_attention_dict['slice_size']}") f"slice_size mismatch: expected slice_size={slice_size}, have {saved_attention_dict['slice_size']}")
return saved_attention_dict['slices'][requested_offset] return saved_attention_dict['slices'][requested_offset]
if saved_attention_dict['dim'] == None: if saved_attention_dict['dim'] is None:
whole_saved_attention = saved_attention_dict['slices'][0] whole_saved_attention = saved_attention_dict['slices'][0]
if requested_dim == 0: if requested_dim == 0:
return whole_saved_attention[requested_offset:requested_offset + slice_size] return whole_saved_attention[requested_offset:requested_offset + slice_size]
@ -143,7 +148,7 @@ class Context:
raise RuntimeError(f"Cannot convert dim {saved_attention_dict['dim']} to requested dim {requested_dim}") raise RuntimeError(f"Cannot convert dim {saved_attention_dict['dim']} to requested dim {requested_dim}")
def get_slicing_strategy(self, identifier: str) -> Optional[tuple[int, int]]: def get_slicing_strategy(self, identifier: str) -> tuple[Optional[int], Optional[int]]:
saved_attention = self.saved_cross_attention_maps.get(identifier, None) saved_attention = self.saved_cross_attention_maps.get(identifier, None)
if saved_attention is None: if saved_attention is None:
return None, None return None, None