diff --git a/ldm/models/diffusion/cross_attention_control.py b/ldm/models/diffusion/cross_attention_control.py index 5a60a9b39f..a4362e0770 100644 --- a/ldm/models/diffusion/cross_attention_control.py +++ b/ldm/models/diffusion/cross_attention_control.py @@ -35,6 +35,9 @@ class CrossAttentionType(enum.Enum): class Context: + cross_attention_mask: Optional[torch.Tensor] + cross_attention_index_map: Optional[torch.Tensor] + class Action(enum.Enum): NONE = 0 SAVE = 1, @@ -45,6 +48,10 @@ class Context: :param arguments: Arguments for the cross-attention control process :param step_count: The absolute total number of steps of diffusion (for img2img this is likely larger than the number of steps that will actually run) """ + self.cross_attention_mask = None + self.cross_attention_index_map = None + self.self_cross_attention_action = Context.Action.NONE + self.tokens_cross_attention_action = Context.Action.NONE self.arguments = arguments self.step_count = step_count @@ -56,11 +63,9 @@ class Context: self.clear_requests(cleanup=True) def register_cross_attention_modules(self, model): - for name,module in get_attention_modules(model, - CrossAttentionType.SELF): + for name,module in get_attention_modules(model, CrossAttentionType.SELF): self.self_cross_attention_module_identifiers.append(name) - for name,module in get_attention_modules(model, - CrossAttentionType.TOKENS): + for name,module in get_attention_modules(model, CrossAttentionType.TOKENS): self.tokens_cross_attention_module_identifiers.append(name) def request_save_attention_maps(self, cross_attention_type: CrossAttentionType): @@ -104,9 +109,9 @@ class Context: opts = self.arguments.edit_options to_control = [] - if opts['s_start'] <= percent_through and percent_through < opts['s_end']: + if opts['s_start'] <= percent_through < opts['s_end']: to_control.append(CrossAttentionType.SELF) - if opts['t_start'] <= percent_through and percent_through < opts['t_end']: + if opts['t_start'] <= percent_through < opts['t_end']: to_control.append(CrossAttentionType.TOKENS) return to_control @@ -134,7 +139,7 @@ class Context: f"slice_size mismatch: expected slice_size={slice_size}, have {saved_attention_dict['slice_size']}") return saved_attention_dict['slices'][requested_offset] - if saved_attention_dict['dim'] == None: + if saved_attention_dict['dim'] is None: whole_saved_attention = saved_attention_dict['slices'][0] if requested_dim == 0: return whole_saved_attention[requested_offset:requested_offset + slice_size] @@ -143,7 +148,7 @@ class Context: raise RuntimeError(f"Cannot convert dim {saved_attention_dict['dim']} to requested dim {requested_dim}") - def get_slicing_strategy(self, identifier: str) -> Optional[tuple[int, int]]: + def get_slicing_strategy(self, identifier: str) -> tuple[Optional[int], Optional[int]]: saved_attention = self.saved_cross_attention_maps.get(identifier, None) if saved_attention is None: return None, None