fix crash (be a little less aggressive clearing out the attention slice)

This commit is contained in:
damian0815 2022-11-01 20:08:52 +01:00 committed by Lincoln Stein
parent 0b72a4a35e
commit de2686d323
2 changed files with 4 additions and 3 deletions

View File

@ -110,13 +110,14 @@ class CrossAttentionControl:
type(module).__name__ == "CrossAttention" and which_attn in name]
@classmethod
def clear_requests(cls, model):
def clear_requests(cls, model, clear_attn_slice=True):
self_attention_modules = cls.get_attention_modules(model, cls.CrossAttentionType.SELF)
tokens_attention_modules = cls.get_attention_modules(model, cls.CrossAttentionType.TOKENS)
for m in self_attention_modules+tokens_attention_modules:
m.save_last_attn_slice = False
m.use_last_attn_slice = False
m.last_attn_slice = None
if clear_attn_slice:
m.last_attn_slice = None
@classmethod
def request_save_attention_maps(cls, model, cross_attention_type: CrossAttentionType):

View File

@ -142,7 +142,7 @@ class InvokeAIDiffuserComponent:
for type in cross_attention_control_types_to_do:
CrossAttentionControl.request_save_attention_maps(self.model, type)
_ = self.model_forward_callback(x, sigma, conditioning)
CrossAttentionControl.clear_requests(self.model)
CrossAttentionControl.clear_requests(self.model, clear_attn_slice=False)
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied
for type in cross_attention_control_types_to_do: