fix crash (be a little less aggressive clearing out the attention slice)

This commit is contained in:
damian0815 2022-11-01 20:08:52 +01:00 committed by Lincoln Stein
parent 214d276379
commit 349cc25433
2 changed files with 4 additions and 3 deletions

View File

@ -110,12 +110,13 @@ class CrossAttentionControl:
type(module).__name__ == "CrossAttention" and which_attn in name] type(module).__name__ == "CrossAttention" and which_attn in name]
@classmethod @classmethod
def clear_requests(cls, model): def clear_requests(cls, model, clear_attn_slice=True):
self_attention_modules = cls.get_attention_modules(model, cls.CrossAttentionType.SELF) self_attention_modules = cls.get_attention_modules(model, cls.CrossAttentionType.SELF)
tokens_attention_modules = cls.get_attention_modules(model, cls.CrossAttentionType.TOKENS) tokens_attention_modules = cls.get_attention_modules(model, cls.CrossAttentionType.TOKENS)
for m in self_attention_modules+tokens_attention_modules: for m in self_attention_modules+tokens_attention_modules:
m.save_last_attn_slice = False m.save_last_attn_slice = False
m.use_last_attn_slice = False m.use_last_attn_slice = False
if clear_attn_slice:
m.last_attn_slice = None m.last_attn_slice = None
@classmethod @classmethod

View File

@ -142,7 +142,7 @@ class InvokeAIDiffuserComponent:
for type in cross_attention_control_types_to_do: for type in cross_attention_control_types_to_do:
CrossAttentionControl.request_save_attention_maps(self.model, type) CrossAttentionControl.request_save_attention_maps(self.model, type)
_ = self.model_forward_callback(x, sigma, conditioning) _ = self.model_forward_callback(x, sigma, conditioning)
CrossAttentionControl.clear_requests(self.model) CrossAttentionControl.clear_requests(self.model, clear_attn_slice=False)
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied # process x again, using the saved attention maps to control where self.edited_conditioning will be applied
for type in cross_attention_control_types_to_do: for type in cross_attention_control_types_to_do: