check if torch.backends has mps before calling it (#245)

Co-authored-by: James Reynolds <magnsuviri@me.com>
This commit is contained in:
James Reynolds 2022-08-31 08:56:38 -06:00 committed by GitHub
parent 31b77dbaf8
commit a547c33327
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -4,7 +4,7 @@ def choose_torch_device() -> str:
'''Convenience routine for guessing which GPU device to run model on'''
if torch.cuda.is_available():
return 'cuda'
if torch.backends.mps.is_available():
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
return 'mps'
return 'cpu'