mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Mac MPS FP16 fixes
This PR is to allow FP16 precision to work on Macs with MPS. In addition, it centralizes the torch fixes/workarounds required for MPS into a new backend utility file `mps_fixes.py`. This is conditionally imported in `api_app.py`/`cli_app.py`. Many MANY thanks to StAlKeR7779 for patiently working to debug and fix these issues.
This commit is contained in:
@ -28,6 +28,8 @@ def choose_precision(device: torch.device) -> str:
|
||||
device_name = torch.cuda.get_device_name(device)
|
||||
if not ("GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name):
|
||||
return "float16"
|
||||
elif device.type == "mps":
|
||||
return "float16"
|
||||
return "float32"
|
||||
|
||||
|
||||
|
53
invokeai/backend/util/mps_fixes.py
Normal file
53
invokeai/backend/util/mps_fixes.py
Normal file
@ -0,0 +1,53 @@
|
||||
import torch
|
||||
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
torch.empty = torch.zeros
|
||||
|
||||
|
||||
_torch_layer_norm = torch.nn.functional.layer_norm
|
||||
def new_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
|
||||
if input.device.type == "mps" and input.dtype == torch.float16:
|
||||
input = input.float()
|
||||
if weight is not None:
|
||||
weight = weight.float()
|
||||
if bias is not None:
|
||||
bias = bias.float()
|
||||
return _torch_layer_norm(input, normalized_shape, weight, bias, eps).half()
|
||||
else:
|
||||
return _torch_layer_norm(input, normalized_shape, weight, bias, eps)
|
||||
|
||||
torch.nn.functional.layer_norm = new_layer_norm
|
||||
|
||||
|
||||
_torch_tensor_permute = torch.Tensor.permute
|
||||
def new_torch_tensor_permute(input, *dims):
|
||||
result = _torch_tensor_permute(input, *dims)
|
||||
if input.device == "mps" and input.dtype == torch.float16:
|
||||
result = result.contiguous()
|
||||
return result
|
||||
|
||||
torch.Tensor.permute = new_torch_tensor_permute
|
||||
|
||||
|
||||
_torch_lerp = torch.lerp
|
||||
def new_torch_lerp(input, end, weight, *, out=None):
|
||||
if input.device.type == "mps" and input.dtype == torch.float16:
|
||||
input = input.float()
|
||||
end = end.float()
|
||||
if isinstance(weight, torch.Tensor):
|
||||
weight = weight.float()
|
||||
if out is not None:
|
||||
out_fp32 = torch.zeros_like(out, dtype=torch.float32)
|
||||
else:
|
||||
out_fp32 = None
|
||||
result = _torch_lerp(input, end, weight, out=out_fp32)
|
||||
if out is not None:
|
||||
out.copy_(out_fp32.half())
|
||||
del out_fp32
|
||||
return result.half()
|
||||
|
||||
else:
|
||||
return _torch_lerp(input, end, weight, out=out)
|
||||
|
||||
torch.lerp = new_torch_lerp
|
Reference in New Issue
Block a user