import torch
from torch import autocast
from contextlib import contextmanager, nullcontext

def choose_torch_device() -> str:
    '''Convenience routine for guessing which GPU device to run model on'''
    if torch.cuda.is_available():
        return 'cuda'
    if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
        return 'mps'
    return 'cpu'

def choose_autocast_device(device):
    '''Returns an autocast compatible device from a torch device'''
    device_type = device.type # this returns 'mps' on M1
    # autocast only for cuda, but GTX 16xx have issues with it
    if device_type == 'cuda':
        device_name = torch.cuda.get_device_name()
        if 'GeForce GTX 1660' in device_name or 'GeForce GTX 1650' in device_name:
            return device_type,nullcontext
        else:
            return device_type,autocast
    else:
        return 'cpu',nullcontext