mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
squash float16/float32 mismatch on linux
This commit is contained in:
parent
bffe199ad7
commit
313b206ff8
@ -329,7 +329,7 @@ def setup_cross_attention_control(model, context: Context, is_running_diffusers
|
||||
# urgh. should this be hardcoded?
|
||||
max_length = 77
|
||||
# mask=1 means use base prompt attention, mask=0 means use edited prompt attention
|
||||
mask = torch.zeros(max_length)
|
||||
mask = torch.zeros(max_length, dtype=torch_dtype())
|
||||
indices_target = torch.arange(max_length, dtype=torch.long)
|
||||
indices = torch.arange(max_length, dtype=torch.long)
|
||||
for name, a0, a1, b0, b1 in context.arguments.edit_opcodes:
|
||||
@ -338,7 +338,7 @@ def setup_cross_attention_control(model, context: Context, is_running_diffusers
|
||||
# these tokens have not been edited
|
||||
indices[b0:b1] = indices_target[a0:a1]
|
||||
mask[b0:b1] = 1
|
||||
|
||||
b
|
||||
context.cross_attention_mask = mask.to(device)
|
||||
context.cross_attention_index_map = indices.to(device)
|
||||
if is_running_diffusers:
|
||||
|
Loading…
Reference in New Issue
Block a user