devices.py - Update MPS FP16 check to account for upcoming MacOS Sonoma

float16 doesn't seem to work on MacOS Sonoma due to further changes with Metal. This'll default back to float32 for Sonoma users.
This commit is contained in:
gogurtenjoyer 2023-07-21 19:59:22 -04:00 committed by psychedelicious
parent d162b78767
commit ecabfc252b

View File

@ -1,6 +1,8 @@
from __future__ import annotations
from contextlib import nullcontext
from packaging import version
import platform
import torch
from torch import autocast
@ -30,7 +32,7 @@ def choose_precision(device: torch.device) -> str:
device_name = torch.cuda.get_device_name(device)
if not ("GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name):
return "float16"
elif device.type == "mps":
elif device.type == "mps" and version.parse(platform.mac_ver()[0]) < version.parse('14.0.0'):
return "float16"
return "float32"