diff --git a/ldm/models/diffusion/cross_attention_control.py b/ldm/models/diffusion/cross_attention_control.py index 1e5b073a3d..95113b4406 100644 --- a/ldm/models/diffusion/cross_attention_control.py +++ b/ldm/models/diffusion/cross_attention_control.py @@ -116,6 +116,7 @@ class CrossAttentionControl: for m in self_attention_modules+tokens_attention_modules: m.save_last_attn_slice = False m.use_last_attn_slice = False + m.last_attn_slice = None @classmethod def request_save_attention_maps(cls, model, cross_attention_type: CrossAttentionType): diff --git a/ldm/models/diffusion/shared_invokeai_diffusion.py b/ldm/models/diffusion/shared_invokeai_diffusion.py index dd2643cd0a..d273f3922e 100644 --- a/ldm/models/diffusion/shared_invokeai_diffusion.py +++ b/ldm/models/diffusion/shared_invokeai_diffusion.py @@ -134,23 +134,31 @@ class InvokeAIDiffuserComponent: # representing batched uncond + cond, but then when it comes to applying the saved attention, the # wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.) # todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well. - unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning) - # process x using the original prompt, saving the attention maps - for type in cross_attention_control_types_to_do: - CrossAttentionControl.request_save_attention_maps(self.model, type) - _ = self.model_forward_callback(x, sigma, conditioning) - CrossAttentionControl.clear_requests(self.model) + try: + unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning) - # process x again, using the saved attention maps to control where self.edited_conditioning will be applied - for type in cross_attention_control_types_to_do: - CrossAttentionControl.request_apply_saved_attention_maps(self.model, type) - edited_conditioning = self.conditioning.cross_attention_control_args.edited_conditioning - conditioned_next_x = self.model_forward_callback(x, sigma, edited_conditioning) + # process x using the original prompt, saving the attention maps + for type in cross_attention_control_types_to_do: + CrossAttentionControl.request_save_attention_maps(self.model, type) + _ = self.model_forward_callback(x, sigma, conditioning) + CrossAttentionControl.clear_requests(self.model) - CrossAttentionControl.clear_requests(self.model) + # process x again, using the saved attention maps to control where self.edited_conditioning will be applied + for type in cross_attention_control_types_to_do: + CrossAttentionControl.request_apply_saved_attention_maps(self.model, type) + edited_conditioning = self.conditioning.cross_attention_control_args.edited_conditioning + conditioned_next_x = self.model_forward_callback(x, sigma, edited_conditioning) - return unconditioned_next_x, conditioned_next_x + CrossAttentionControl.clear_requests(self.model) + + return unconditioned_next_x, conditioned_next_x + + except RuntimeError: + # make sure we clean out the attention slices we're storing on the model + # TODO don't store things on the model + CrossAttentionControl.clear_requests(self.model) + raise def estimate_percent_through(self, step_index, sigma): if step_index is not None and self.cross_attention_control_context is not None: