From ce1c5e70b8f260955592b837db78332db0cf70f5 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 16 Jan 2023 23:18:43 -0500 Subject: [PATCH] fix autocast dependency in cross_attention_control --- ldm/models/diffusion/cross_attention_control.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ldm/models/diffusion/cross_attention_control.py b/ldm/models/diffusion/cross_attention_control.py index 7415f1435b..03d5a5bcec 100644 --- a/ldm/models/diffusion/cross_attention_control.py +++ b/ldm/models/diffusion/cross_attention_control.py @@ -7,6 +7,7 @@ import torch import diffusers from torch import nn from diffusers.models.unet_2d_condition import UNet2DConditionModel +from ldm.invoke.devices import torch_dtype # adapted from bloc97's CrossAttentionControl colab # https://github.com/bloc97/CrossAttentionControl @@ -383,7 +384,7 @@ def inject_attention_function(unet, context: Context): remapped_saved_attention_slice = torch.index_select(saved_attention_slice, -1, index_map) this_attention_slice = suggested_attention_slice - mask = context.cross_attention_mask + mask = context.cross_attention_mask.to(torch_dtype(suggested_attention_slice.device)) saved_mask = mask this_mask = 1 - mask attention_slice = remapped_saved_attention_slice * saved_mask + \