2023-07-04 22:05:01 +00:00
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
|
|
if torch.backends.mps.is_available():
|
|
|
|
torch.empty = torch.zeros
|
|
|
|
|
|
|
|
|
|
|
|
_torch_layer_norm = torch.nn.functional.layer_norm
|
|
|
|
def new_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
|
|
|
|
if input.device.type == "mps" and input.dtype == torch.float16:
|
|
|
|
input = input.float()
|
|
|
|
if weight is not None:
|
|
|
|
weight = weight.float()
|
|
|
|
if bias is not None:
|
|
|
|
bias = bias.float()
|
|
|
|
return _torch_layer_norm(input, normalized_shape, weight, bias, eps).half()
|
|
|
|
else:
|
|
|
|
return _torch_layer_norm(input, normalized_shape, weight, bias, eps)
|
|
|
|
|
|
|
|
torch.nn.functional.layer_norm = new_layer_norm
|
|
|
|
|
|
|
|
|
|
|
|
_torch_tensor_permute = torch.Tensor.permute
|
|
|
|
def new_torch_tensor_permute(input, *dims):
|
|
|
|
result = _torch_tensor_permute(input, *dims)
|
|
|
|
if input.device == "mps" and input.dtype == torch.float16:
|
|
|
|
result = result.contiguous()
|
|
|
|
return result
|
|
|
|
|
|
|
|
torch.Tensor.permute = new_torch_tensor_permute
|
|
|
|
|
|
|
|
|
|
|
|
_torch_lerp = torch.lerp
|
|
|
|
def new_torch_lerp(input, end, weight, *, out=None):
|
|
|
|
if input.device.type == "mps" and input.dtype == torch.float16:
|
|
|
|
input = input.float()
|
|
|
|
end = end.float()
|
|
|
|
if isinstance(weight, torch.Tensor):
|
|
|
|
weight = weight.float()
|
|
|
|
if out is not None:
|
|
|
|
out_fp32 = torch.zeros_like(out, dtype=torch.float32)
|
|
|
|
else:
|
|
|
|
out_fp32 = None
|
|
|
|
result = _torch_lerp(input, end, weight, out=out_fp32)
|
|
|
|
if out is not None:
|
|
|
|
out.copy_(out_fp32.half())
|
|
|
|
del out_fp32
|
|
|
|
return result.half()
|
|
|
|
|
|
|
|
else:
|
|
|
|
return _torch_lerp(input, end, weight, out=out)
|
|
|
|
|
2023-07-05 21:47:23 +00:00
|
|
|
torch.lerp = new_torch_lerp
|
|
|
|
|
|
|
|
|
|
|
|
_torch_interpolate = torch.nn.functional.interpolate
|
|
|
|
def new_torch_interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False):
|
|
|
|
if input.device.type == "mps" and input.dtype == torch.float16:
|
|
|
|
return _torch_interpolate(input.float(), size, scale_factor, mode, align_corners, recompute_scale_factor, antialias).half()
|
|
|
|
else:
|
|
|
|
return _torch_interpolate(input, size, scale_factor, mode, align_corners, recompute_scale_factor, antialias)
|
|
|
|
|
|
|
|
torch.nn.functional.interpolate = new_torch_interpolate
|