Update mps_fixes.py - additional torch op for nodes

This fixes scaling in the nodes UI.
This commit is contained in:
gogurtenjoyer 2023-07-05 17:47:23 -04:00 committed by GitHub
parent 021e1eca8e
commit 169ff6368b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -50,4 +50,14 @@ def new_torch_lerp(input, end, weight, *, out=None):
else:
return _torch_lerp(input, end, weight, out=out)
torch.lerp = new_torch_lerp
torch.lerp = new_torch_lerp
_torch_interpolate = torch.nn.functional.interpolate
def new_torch_interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False):
if input.device.type == "mps" and input.dtype == torch.float16:
return _torch_interpolate(input.float(), size, scale_factor, mode, align_corners, recompute_scale_factor, antialias).half()
else:
return _torch_interpolate(input, size, scale_factor, mode, align_corners, recompute_scale_factor, antialias)
torch.nn.functional.interpolate = new_torch_interpolate