mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Update mps_fixes.py - additional torch op for nodes
This fixes scaling in the nodes UI.
This commit is contained in:
parent
021e1eca8e
commit
169ff6368b
@ -50,4 +50,14 @@ def new_torch_lerp(input, end, weight, *, out=None):
|
||||
else:
|
||||
return _torch_lerp(input, end, weight, out=out)
|
||||
|
||||
torch.lerp = new_torch_lerp
|
||||
torch.lerp = new_torch_lerp
|
||||
|
||||
|
||||
_torch_interpolate = torch.nn.functional.interpolate
|
||||
def new_torch_interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False):
|
||||
if input.device.type == "mps" and input.dtype == torch.float16:
|
||||
return _torch_interpolate(input.float(), size, scale_factor, mode, align_corners, recompute_scale_factor, antialias).half()
|
||||
else:
|
||||
return _torch_interpolate(input, size, scale_factor, mode, align_corners, recompute_scale_factor, antialias)
|
||||
|
||||
torch.nn.functional.interpolate = new_torch_interpolate
|
||||
|
Loading…
Reference in New Issue
Block a user