pass missing value

This commit is contained in:
Damian Stewart 2023-01-22 18:19:01 +01:00
parent 313b206ff8
commit c0610f7cb9

View File

@ -329,7 +329,7 @@ def setup_cross_attention_control(model, context: Context, is_running_diffusers
# urgh. should this be hardcoded?
max_length = 77
# mask=1 means use base prompt attention, mask=0 means use edited prompt attention
mask = torch.zeros(max_length, dtype=torch_dtype())
mask = torch.zeros(max_length, dtype=torch_dtype(device))
indices_target = torch.arange(max_length, dtype=torch.long)
indices = torch.arange(max_length, dtype=torch.long)
for name, a0, a1, b0, b1 in context.arguments.edit_opcodes:
@ -338,7 +338,7 @@ def setup_cross_attention_control(model, context: Context, is_running_diffusers
# these tokens have not been edited
indices[b0:b1] = indices_target[a0:a1]
mask[b0:b1] = 1
b
context.cross_attention_mask = mask.to(device)
context.cross_attention_index_map = indices.to(device)
if is_running_diffusers: