From 349cc25433e1d358c6c61cca4a6c300ecc8ca3d7 Mon Sep 17 00:00:00 2001 From: damian0815 Date: Tue, 1 Nov 2022 20:08:52 +0100 Subject: [PATCH] fix crash (be a little less aggressive clearing out the attention slice) --- ldm/models/diffusion/cross_attention_control.py | 5 +++-- ldm/models/diffusion/shared_invokeai_diffusion.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/ldm/models/diffusion/cross_attention_control.py b/ldm/models/diffusion/cross_attention_control.py index 95113b4406..2f1470512f 100644 --- a/ldm/models/diffusion/cross_attention_control.py +++ b/ldm/models/diffusion/cross_attention_control.py @@ -110,13 +110,14 @@ class CrossAttentionControl: type(module).__name__ == "CrossAttention" and which_attn in name] @classmethod - def clear_requests(cls, model): + def clear_requests(cls, model, clear_attn_slice=True): self_attention_modules = cls.get_attention_modules(model, cls.CrossAttentionType.SELF) tokens_attention_modules = cls.get_attention_modules(model, cls.CrossAttentionType.TOKENS) for m in self_attention_modules+tokens_attention_modules: m.save_last_attn_slice = False m.use_last_attn_slice = False - m.last_attn_slice = None + if clear_attn_slice: + m.last_attn_slice = None @classmethod def request_save_attention_maps(cls, model, cross_attention_type: CrossAttentionType): diff --git a/ldm/models/diffusion/shared_invokeai_diffusion.py b/ldm/models/diffusion/shared_invokeai_diffusion.py index d273f3922e..5a9cc3eb74 100644 --- a/ldm/models/diffusion/shared_invokeai_diffusion.py +++ b/ldm/models/diffusion/shared_invokeai_diffusion.py @@ -142,7 +142,7 @@ class InvokeAIDiffuserComponent: for type in cross_attention_control_types_to_do: CrossAttentionControl.request_save_attention_maps(self.model, type) _ = self.model_forward_callback(x, sigma, conditioning) - CrossAttentionControl.clear_requests(self.model) + CrossAttentionControl.clear_requests(self.model, clear_attn_slice=False) # process x again, using the saved attention maps to control where self.edited_conditioning will be applied for type in cross_attention_control_types_to_do: