check if torch.backends has mps before calling it

This commit is contained in:
James Reynolds 2022-08-31 03:29:37 -06:00
parent 2aa8393272
commit 84c10346fb

View File

@ -4,7 +4,7 @@ def choose_torch_device() -> str:
'''Convenience routine for guessing which GPU device to run model on'''
if torch.cuda.is_available():
return 'cuda'
if torch.backends.mps.is_available():
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
return 'mps'
return 'cpu'