squash float16/float32 mismatch on linux

This commit is contained in:
Damian Stewart 2023-01-22 18:12:11 +01:00
parent bffe199ad7
commit 313b206ff8

View File

@ -329,7 +329,7 @@ def setup_cross_attention_control(model, context: Context, is_running_diffusers
# urgh. should this be hardcoded?
max_length = 77
# mask=1 means use base prompt attention, mask=0 means use edited prompt attention
mask = torch.zeros(max_length)
mask = torch.zeros(max_length, dtype=torch_dtype())
indices_target = torch.arange(max_length, dtype=torch.long)
indices = torch.arange(max_length, dtype=torch.long)
for name, a0, a1, b0, b1 in context.arguments.edit_opcodes:
@ -338,7 +338,7 @@ def setup_cross_attention_control(model, context: Context, is_running_diffusers
# these tokens have not been edited
indices[b0:b1] = indices_target[a0:a1]
mask[b0:b1] = 1
b
context.cross_attention_mask = mask.to(device)
context.cross_attention_index_map = indices.to(device)
if is_running_diffusers: