mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
9ad4c03277
1) Downgrade numpy to avoid dependency conflict with numba 2) Move all non ldm/invoke files into `invokeai`. This includes assets, backend, frontend, and configs. 3) Fix up way that the backend finds the frontend and the generator finds the NSFW caution.png icon.
39 lines
1.3 KiB
Python
39 lines
1.3 KiB
Python
import torch
|
|
from torch import autocast
|
|
from contextlib import nullcontext
|
|
from ldm.invoke.globals import Globals
|
|
|
|
def choose_torch_device() -> str:
|
|
'''Convenience routine for guessing which GPU device to run model on'''
|
|
if Globals.always_use_cpu:
|
|
return "cpu"
|
|
if torch.cuda.is_available():
|
|
return 'cuda'
|
|
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
|
return 'mps'
|
|
return 'cpu'
|
|
|
|
def choose_precision(device) -> str:
|
|
'''Returns an appropriate precision for the given torch device'''
|
|
if device.type == 'cuda':
|
|
device_name = torch.cuda.get_device_name(device)
|
|
if not ('GeForce GTX 1660' in device_name or 'GeForce GTX 1650' in device_name):
|
|
return 'float16'
|
|
return 'float32'
|
|
|
|
def torch_dtype(device) -> torch.dtype:
|
|
if Globals.full_precision:
|
|
return torch.float32
|
|
if choose_precision(device) == 'float16':
|
|
return torch.float16
|
|
else:
|
|
return torch.float32
|
|
|
|
def choose_autocast(precision):
|
|
'''Returns an autocast context or nullcontext for the given precision string'''
|
|
# float16 currently requires autocast to avoid errors like:
|
|
# 'expected scalar type Half but found Float'
|
|
if precision == 'autocast' or precision == 'float16':
|
|
return autocast
|
|
return nullcontext
|