Mac MPS FP16 fixes

This PR is to allow FP16 precision to work on Macs with MPS. In addition, it centralizes the torch fixes/workarounds
required for MPS into a new backend utility file `mps_fixes.py`. This is conditionally imported in `api_app.py`/`cli_app.py`.

Many MANY thanks to StAlKeR7779 for patiently working to debug and fix these issues.
This commit is contained in:
gogurtenjoyer
2023-07-04 18:05:01 -04:00
parent 92b163e95c
commit 233869b56a
9 changed files with 103 additions and 73 deletions

View File

@ -28,6 +28,8 @@ def choose_precision(device: torch.device) -> str:
device_name = torch.cuda.get_device_name(device)
if not ("GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name):
return "float16"
elif device.type == "mps":
return "float16"
return "float32"