diff --git a/ldm/models/diffusion/cross_attention_control.py b/ldm/models/diffusion/cross_attention_control.py index 45294ac993..fcb9f52dde 100644 --- a/ldm/models/diffusion/cross_attention_control.py +++ b/ldm/models/diffusion/cross_attention_control.py @@ -329,7 +329,7 @@ def setup_cross_attention_control(model, context: Context, is_running_diffusers # urgh. should this be hardcoded? max_length = 77 # mask=1 means use base prompt attention, mask=0 means use edited prompt attention - mask = torch.zeros(max_length, dtype=torch_dtype()) + mask = torch.zeros(max_length, dtype=torch_dtype(device)) indices_target = torch.arange(max_length, dtype=torch.long) indices = torch.arange(max_length, dtype=torch.long) for name, a0, a1, b0, b1 in context.arguments.edit_opcodes: @@ -338,7 +338,7 @@ def setup_cross_attention_control(model, context: Context, is_running_diffusers # these tokens have not been edited indices[b0:b1] = indices_target[a0:a1] mask[b0:b1] = 1 -b + context.cross_attention_mask = mask.to(device) context.cross_attention_index_map = indices.to(device) if is_running_diffusers: