mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Update mps_fixes.py - additional torch op for nodes
This fixes scaling in the nodes UI.
This commit is contained in:
parent
021e1eca8e
commit
169ff6368b
@ -51,3 +51,13 @@ def new_torch_lerp(input, end, weight, *, out=None):
|
|||||||
return _torch_lerp(input, end, weight, out=out)
|
return _torch_lerp(input, end, weight, out=out)
|
||||||
|
|
||||||
torch.lerp = new_torch_lerp
|
torch.lerp = new_torch_lerp
|
||||||
|
|
||||||
|
|
||||||
|
_torch_interpolate = torch.nn.functional.interpolate
|
||||||
|
def new_torch_interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False):
|
||||||
|
if input.device.type == "mps" and input.dtype == torch.float16:
|
||||||
|
return _torch_interpolate(input.float(), size, scale_factor, mode, align_corners, recompute_scale_factor, antialias).half()
|
||||||
|
else:
|
||||||
|
return _torch_interpolate(input, size, scale_factor, mode, align_corners, recompute_scale_factor, antialias)
|
||||||
|
|
||||||
|
torch.nn.functional.interpolate = new_torch_interpolate
|
||||||
|
Loading…
Reference in New Issue
Block a user