mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
work around an apparent MPS torch bug that causes conditioning to have no effect
This commit is contained in:
parent
e9a0f07033
commit
adaa1c7c3e
@ -102,8 +102,11 @@ class InvokeAIDiffuserComponent:
|
|||||||
x_twice = torch.cat([x] * 2)
|
x_twice = torch.cat([x] * 2)
|
||||||
sigma_twice = torch.cat([sigma] * 2)
|
sigma_twice = torch.cat([sigma] * 2)
|
||||||
both_conditionings = torch.cat([unconditioning, conditioning])
|
both_conditionings = torch.cat([unconditioning, conditioning])
|
||||||
unconditioned_next_x, conditioned_next_x = self.model_forward_callback(x_twice, sigma_twice,
|
both_results = self.model_forward_callback(x_twice, sigma_twice, both_conditionings)
|
||||||
both_conditionings).chunk(2)
|
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
|
||||||
|
if conditioned_next_x.device.type == 'mps':
|
||||||
|
# prevent a result filled with zeros. seems to be a torch bug.
|
||||||
|
conditioned_next_x = conditioned_next_x.clone()
|
||||||
return unconditioned_next_x, conditioned_next_x
|
return unconditioned_next_x, conditioned_next_x
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user