mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
squash float16/float32 mismatch on linux
This commit is contained in:
parent
bffe199ad7
commit
313b206ff8
@ -329,7 +329,7 @@ def setup_cross_attention_control(model, context: Context, is_running_diffusers
|
|||||||
# urgh. should this be hardcoded?
|
# urgh. should this be hardcoded?
|
||||||
max_length = 77
|
max_length = 77
|
||||||
# mask=1 means use base prompt attention, mask=0 means use edited prompt attention
|
# mask=1 means use base prompt attention, mask=0 means use edited prompt attention
|
||||||
mask = torch.zeros(max_length)
|
mask = torch.zeros(max_length, dtype=torch_dtype())
|
||||||
indices_target = torch.arange(max_length, dtype=torch.long)
|
indices_target = torch.arange(max_length, dtype=torch.long)
|
||||||
indices = torch.arange(max_length, dtype=torch.long)
|
indices = torch.arange(max_length, dtype=torch.long)
|
||||||
for name, a0, a1, b0, b1 in context.arguments.edit_opcodes:
|
for name, a0, a1, b0, b1 in context.arguments.edit_opcodes:
|
||||||
@ -338,7 +338,7 @@ def setup_cross_attention_control(model, context: Context, is_running_diffusers
|
|||||||
# these tokens have not been edited
|
# these tokens have not been edited
|
||||||
indices[b0:b1] = indices_target[a0:a1]
|
indices[b0:b1] = indices_target[a0:a1]
|
||||||
mask[b0:b1] = 1
|
mask[b0:b1] = 1
|
||||||
|
b
|
||||||
context.cross_attention_mask = mask.to(device)
|
context.cross_attention_mask = mask.to(device)
|
||||||
context.cross_attention_index_map = indices.to(device)
|
context.cross_attention_index_map = indices.to(device)
|
||||||
if is_running_diffusers:
|
if is_running_diffusers:
|
||||||
|
Loading…
Reference in New Issue
Block a user