mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
145 lines
5.5 KiB
Python
145 lines
5.5 KiB
Python
"""Torch Device class provides torch device selection services."""
|
|
|
|
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union, Set
|
|
|
|
import torch
|
|
from deprecated import deprecated
|
|
|
|
from invokeai.app.services.config.config_default import get_config
|
|
|
|
if TYPE_CHECKING:
|
|
from invokeai.backend.model_manager.config import AnyModel
|
|
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
|
|
|
# legacy APIs
|
|
TorchPrecisionNames = Literal["float32", "float16", "bfloat16"]
|
|
CPU_DEVICE = torch.device("cpu")
|
|
CUDA_DEVICE = torch.device("cuda")
|
|
MPS_DEVICE = torch.device("mps")
|
|
|
|
|
|
@deprecated("Use TorchDevice.choose_torch_dtype() instead.") # type: ignore
|
|
def choose_precision(device: torch.device) -> TorchPrecisionNames:
|
|
"""Return the string representation of the recommended torch device."""
|
|
torch_dtype = TorchDevice.choose_torch_dtype(device)
|
|
return PRECISION_TO_NAME[torch_dtype]
|
|
|
|
|
|
@deprecated("Use TorchDevice.choose_torch_device() instead.") # type: ignore
|
|
def choose_torch_device() -> torch.device:
|
|
"""Return the torch.device to use for accelerated inference."""
|
|
return TorchDevice.choose_torch_device()
|
|
|
|
|
|
@deprecated("Use TorchDevice.choose_torch_dtype() instead.") # type: ignore
|
|
def torch_dtype(device: torch.device) -> torch.dtype:
|
|
"""Return the torch precision for the recommended torch device."""
|
|
return TorchDevice.choose_torch_dtype(device)
|
|
|
|
|
|
NAME_TO_PRECISION: Dict[TorchPrecisionNames, torch.dtype] = {
|
|
"float32": torch.float32,
|
|
"float16": torch.float16,
|
|
"bfloat16": torch.bfloat16,
|
|
}
|
|
PRECISION_TO_NAME: Dict[torch.dtype, TorchPrecisionNames] = {v: k for k, v in NAME_TO_PRECISION.items()}
|
|
|
|
|
|
class TorchDevice:
|
|
"""Abstraction layer for torch devices."""
|
|
|
|
_model_cache: Optional["ModelCacheBase[AnyModel]"] = None
|
|
|
|
@classmethod
|
|
def set_model_cache(cls, cache: "ModelCacheBase[AnyModel]"):
|
|
"""Set the current model cache."""
|
|
cls._model_cache = cache
|
|
|
|
@classmethod
|
|
def choose_torch_device(cls) -> torch.device:
|
|
"""Return the torch.device to use for accelerated inference."""
|
|
if cls._model_cache:
|
|
return cls._model_cache.get_execution_device()
|
|
app_config = get_config()
|
|
if app_config.device != "auto":
|
|
device = torch.device(app_config.device)
|
|
elif torch.cuda.is_available():
|
|
device = CUDA_DEVICE
|
|
elif torch.backends.mps.is_available():
|
|
device = MPS_DEVICE
|
|
else:
|
|
device = CPU_DEVICE
|
|
return cls.normalize(device)
|
|
|
|
@classmethod
|
|
def execution_devices(cls) -> Set[torch.device]:
|
|
"""Return a list of torch.devices that can be used for accelerated inference."""
|
|
app_config = get_config()
|
|
if app_config.devices is None:
|
|
return cls._lookup_execution_devices()
|
|
return {torch.device(x) for x in app_config.devices}
|
|
|
|
@classmethod
|
|
def choose_torch_dtype(cls, device: Optional[torch.device] = None) -> torch.dtype:
|
|
"""Return the precision to use for accelerated inference."""
|
|
device = device or cls.choose_torch_device()
|
|
config = get_config()
|
|
if device.type == "cuda" and torch.cuda.is_available():
|
|
device_name = torch.cuda.get_device_name(device)
|
|
if "GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name:
|
|
# These GPUs have limited support for float16
|
|
return cls._to_dtype("float32")
|
|
elif config.precision == "auto":
|
|
# Default to float16 for CUDA devices
|
|
return cls._to_dtype("float16")
|
|
else:
|
|
# Use the user-defined precision
|
|
return cls._to_dtype(config.precision)
|
|
|
|
elif device.type == "mps" and torch.backends.mps.is_available():
|
|
if config.precision == "auto":
|
|
# Default to float16 for MPS devices
|
|
return cls._to_dtype("float16")
|
|
else:
|
|
# Use the user-defined precision
|
|
return cls._to_dtype(config.precision)
|
|
# CPU / safe fallback
|
|
return cls._to_dtype("float32")
|
|
|
|
@classmethod
|
|
def get_torch_device_name(cls) -> str:
|
|
"""Return the device name for the current torch device."""
|
|
device = cls.choose_torch_device()
|
|
return torch.cuda.get_device_name(device) if device.type == "cuda" else device.type.upper()
|
|
|
|
@classmethod
|
|
def normalize(cls, device: Union[str, torch.device]) -> torch.device:
|
|
"""Add the device index to CUDA devices."""
|
|
device = torch.device(device)
|
|
if device.index is None and device.type == "cuda" and torch.cuda.is_available():
|
|
device = torch.device(device.type, torch.cuda.current_device())
|
|
return device
|
|
|
|
@classmethod
|
|
def empty_cache(cls) -> None:
|
|
"""Clear the GPU device cache."""
|
|
if torch.backends.mps.is_available():
|
|
torch.mps.empty_cache()
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
|
|
@classmethod
|
|
def _to_dtype(cls, precision_name: TorchPrecisionNames) -> torch.dtype:
|
|
return NAME_TO_PRECISION[precision_name]
|
|
|
|
@classmethod
|
|
def _lookup_execution_devices(cls) -> Set[torch.device]:
|
|
if torch.cuda.is_available():
|
|
devices = {torch.device(f"cuda:{x}") for x in range(0, torch.cuda.device_count())}
|
|
elif torch.backends.mps.is_available():
|
|
devices = {torch.device("mps")}
|
|
else:
|
|
devices = {torch.device("cpu")}
|
|
return devices
|
|
|