mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
c7562dd6c0
We can get black outputs when moving tensors from CPU to MPS. It appears MPS to CPU is fine. See: - https://github.com/pytorch/pytorch/issues/107455 - https://discuss.pytorch.org/t/should-we-set-non-blocking-to-true/38234/28 Changes: - Add properties for each device on `TorchDevice` as a convenience. - Add `get_non_blocking` static method on `TorchDevice`. This utility takes a torch device and returns the flag to be used for non_blocking when moving a tensor to the device provided. - Update model patching and caching APIs to use this new utility. Fixes: #6545
127 lines
4.9 KiB
Python
127 lines
4.9 KiB
Python
from typing import Dict, Literal, Optional, Union
|
|
|
|
import torch
|
|
from deprecated import deprecated
|
|
|
|
from invokeai.app.services.config.config_default import get_config
|
|
|
|
# legacy APIs
|
|
TorchPrecisionNames = Literal["float32", "float16", "bfloat16"]
|
|
CPU_DEVICE = torch.device("cpu")
|
|
CUDA_DEVICE = torch.device("cuda")
|
|
MPS_DEVICE = torch.device("mps")
|
|
|
|
|
|
@deprecated("Use TorchDevice.choose_torch_dtype() instead.") # type: ignore
|
|
def choose_precision(device: torch.device) -> TorchPrecisionNames:
|
|
"""Return the string representation of the recommended torch device."""
|
|
torch_dtype = TorchDevice.choose_torch_dtype(device)
|
|
return PRECISION_TO_NAME[torch_dtype]
|
|
|
|
|
|
@deprecated("Use TorchDevice.choose_torch_device() instead.") # type: ignore
|
|
def choose_torch_device() -> torch.device:
|
|
"""Return the torch.device to use for accelerated inference."""
|
|
return TorchDevice.choose_torch_device()
|
|
|
|
|
|
@deprecated("Use TorchDevice.choose_torch_dtype() instead.") # type: ignore
|
|
def torch_dtype(device: torch.device) -> torch.dtype:
|
|
"""Return the torch precision for the recommended torch device."""
|
|
return TorchDevice.choose_torch_dtype(device)
|
|
|
|
|
|
NAME_TO_PRECISION: Dict[TorchPrecisionNames, torch.dtype] = {
|
|
"float32": torch.float32,
|
|
"float16": torch.float16,
|
|
"bfloat16": torch.bfloat16,
|
|
}
|
|
PRECISION_TO_NAME: Dict[torch.dtype, TorchPrecisionNames] = {v: k for k, v in NAME_TO_PRECISION.items()}
|
|
|
|
|
|
class TorchDevice:
|
|
"""Abstraction layer for torch devices."""
|
|
|
|
CPU_DEVICE = torch.device("cpu")
|
|
CUDA_DEVICE = torch.device("cuda")
|
|
MPS_DEVICE = torch.device("mps")
|
|
|
|
@classmethod
|
|
def choose_torch_device(cls) -> torch.device:
|
|
"""Return the torch.device to use for accelerated inference."""
|
|
app_config = get_config()
|
|
if app_config.device != "auto":
|
|
device = torch.device(app_config.device)
|
|
elif torch.cuda.is_available():
|
|
device = CUDA_DEVICE
|
|
elif torch.backends.mps.is_available():
|
|
device = MPS_DEVICE
|
|
else:
|
|
device = CPU_DEVICE
|
|
return cls.normalize(device)
|
|
|
|
@classmethod
|
|
def choose_torch_dtype(cls, device: Optional[torch.device] = None) -> torch.dtype:
|
|
"""Return the precision to use for accelerated inference."""
|
|
device = device or cls.choose_torch_device()
|
|
config = get_config()
|
|
if device.type == "cuda" and torch.cuda.is_available():
|
|
device_name = torch.cuda.get_device_name(device)
|
|
if "GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name:
|
|
# These GPUs have limited support for float16
|
|
return cls._to_dtype("float32")
|
|
elif config.precision == "auto":
|
|
# Default to float16 for CUDA devices
|
|
return cls._to_dtype("float16")
|
|
else:
|
|
# Use the user-defined precision
|
|
return cls._to_dtype(config.precision)
|
|
|
|
elif device.type == "mps" and torch.backends.mps.is_available():
|
|
if config.precision == "auto":
|
|
# Default to float16 for MPS devices
|
|
return cls._to_dtype("float16")
|
|
else:
|
|
# Use the user-defined precision
|
|
return cls._to_dtype(config.precision)
|
|
# CPU / safe fallback
|
|
return cls._to_dtype("float32")
|
|
|
|
@classmethod
|
|
def get_torch_device_name(cls) -> str:
|
|
"""Return the device name for the current torch device."""
|
|
device = cls.choose_torch_device()
|
|
return torch.cuda.get_device_name(device) if device.type == "cuda" else device.type.upper()
|
|
|
|
@classmethod
|
|
def normalize(cls, device: Union[str, torch.device]) -> torch.device:
|
|
"""Add the device index to CUDA devices."""
|
|
device = torch.device(device)
|
|
if device.index is None and device.type == "cuda" and torch.cuda.is_available():
|
|
device = torch.device(device.type, torch.cuda.current_device())
|
|
return device
|
|
|
|
@classmethod
|
|
def empty_cache(cls) -> None:
|
|
"""Clear the GPU device cache."""
|
|
if torch.backends.mps.is_available():
|
|
torch.mps.empty_cache()
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
|
|
@classmethod
|
|
def _to_dtype(cls, precision_name: TorchPrecisionNames) -> torch.dtype:
|
|
return NAME_TO_PRECISION[precision_name]
|
|
|
|
@staticmethod
|
|
def get_non_blocking(to_device: torch.device) -> bool:
|
|
"""Return the non_blocking flag to be used when moving a tensor to a given device.
|
|
MPS may have unexpected errors with non-blocking operations - we should not use non-blocking when moving _to_ MPS.
|
|
When moving _from_ MPS, we can use non-blocking operations.
|
|
|
|
See:
|
|
- https://github.com/pytorch/pytorch/issues/107455
|
|
- https://discuss.pytorch.org/t/should-we-set-non-blocking-to-true/38234/28
|
|
"""
|
|
return False if to_device.type == "mps" else True
|