mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
38343917f8
In #6490 we enabled non-blocking torch device transfers throughout the model manager's memory management code. When using this torch feature, torch attempts to wait until the tensor transfer has completed before allowing any access to the tensor. Theoretically, that should make this a safe feature to use. This provides a small performance improvement but causes race conditions in some situations. Specific platforms/systems are affected, and complicated data dependencies can make this unsafe. - Intermittent black images on MPS devices - reported on discord and #6545, fixed with special handling in #6549. - Intermittent OOMs and black images on a P4000 GPU on Windows - reported in #6613, fixed in this commit. On my system, I haven't experience any issues with generation, but targeted testing of non-blocking ops did expose a race condition when moving tensors from CUDA to CPU. One workaround is to use torch streams with manual sync points. Our application logic is complicated enough that this would be a lot of work and feels ripe for edge cases and missed spots. Much safer is to fully revert non-locking - which is what this change does.
115 lines
4.3 KiB
Python
115 lines
4.3 KiB
Python
from typing import Dict, Literal, Optional, Union
|
|
|
|
import torch
|
|
from deprecated import deprecated
|
|
|
|
from invokeai.app.services.config.config_default import get_config
|
|
|
|
# legacy APIs
|
|
TorchPrecisionNames = Literal["float32", "float16", "bfloat16"]
|
|
CPU_DEVICE = torch.device("cpu")
|
|
CUDA_DEVICE = torch.device("cuda")
|
|
MPS_DEVICE = torch.device("mps")
|
|
|
|
|
|
@deprecated("Use TorchDevice.choose_torch_dtype() instead.") # type: ignore
|
|
def choose_precision(device: torch.device) -> TorchPrecisionNames:
|
|
"""Return the string representation of the recommended torch device."""
|
|
torch_dtype = TorchDevice.choose_torch_dtype(device)
|
|
return PRECISION_TO_NAME[torch_dtype]
|
|
|
|
|
|
@deprecated("Use TorchDevice.choose_torch_device() instead.") # type: ignore
|
|
def choose_torch_device() -> torch.device:
|
|
"""Return the torch.device to use for accelerated inference."""
|
|
return TorchDevice.choose_torch_device()
|
|
|
|
|
|
@deprecated("Use TorchDevice.choose_torch_dtype() instead.") # type: ignore
|
|
def torch_dtype(device: torch.device) -> torch.dtype:
|
|
"""Return the torch precision for the recommended torch device."""
|
|
return TorchDevice.choose_torch_dtype(device)
|
|
|
|
|
|
NAME_TO_PRECISION: Dict[TorchPrecisionNames, torch.dtype] = {
|
|
"float32": torch.float32,
|
|
"float16": torch.float16,
|
|
"bfloat16": torch.bfloat16,
|
|
}
|
|
PRECISION_TO_NAME: Dict[torch.dtype, TorchPrecisionNames] = {v: k for k, v in NAME_TO_PRECISION.items()}
|
|
|
|
|
|
class TorchDevice:
|
|
"""Abstraction layer for torch devices."""
|
|
|
|
CPU_DEVICE = torch.device("cpu")
|
|
CUDA_DEVICE = torch.device("cuda")
|
|
MPS_DEVICE = torch.device("mps")
|
|
|
|
@classmethod
|
|
def choose_torch_device(cls) -> torch.device:
|
|
"""Return the torch.device to use for accelerated inference."""
|
|
app_config = get_config()
|
|
if app_config.device != "auto":
|
|
device = torch.device(app_config.device)
|
|
elif torch.cuda.is_available():
|
|
device = CUDA_DEVICE
|
|
elif torch.backends.mps.is_available():
|
|
device = MPS_DEVICE
|
|
else:
|
|
device = CPU_DEVICE
|
|
return cls.normalize(device)
|
|
|
|
@classmethod
|
|
def choose_torch_dtype(cls, device: Optional[torch.device] = None) -> torch.dtype:
|
|
"""Return the precision to use for accelerated inference."""
|
|
device = device or cls.choose_torch_device()
|
|
config = get_config()
|
|
if device.type == "cuda" and torch.cuda.is_available():
|
|
device_name = torch.cuda.get_device_name(device)
|
|
if "GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name:
|
|
# These GPUs have limited support for float16
|
|
return cls._to_dtype("float32")
|
|
elif config.precision == "auto":
|
|
# Default to float16 for CUDA devices
|
|
return cls._to_dtype("float16")
|
|
else:
|
|
# Use the user-defined precision
|
|
return cls._to_dtype(config.precision)
|
|
|
|
elif device.type == "mps" and torch.backends.mps.is_available():
|
|
if config.precision == "auto":
|
|
# Default to float16 for MPS devices
|
|
return cls._to_dtype("float16")
|
|
else:
|
|
# Use the user-defined precision
|
|
return cls._to_dtype(config.precision)
|
|
# CPU / safe fallback
|
|
return cls._to_dtype("float32")
|
|
|
|
@classmethod
|
|
def get_torch_device_name(cls) -> str:
|
|
"""Return the device name for the current torch device."""
|
|
device = cls.choose_torch_device()
|
|
return torch.cuda.get_device_name(device) if device.type == "cuda" else device.type.upper()
|
|
|
|
@classmethod
|
|
def normalize(cls, device: Union[str, torch.device]) -> torch.device:
|
|
"""Add the device index to CUDA devices."""
|
|
device = torch.device(device)
|
|
if device.index is None and device.type == "cuda" and torch.cuda.is_available():
|
|
device = torch.device(device.type, torch.cuda.current_device())
|
|
return device
|
|
|
|
@classmethod
|
|
def empty_cache(cls) -> None:
|
|
"""Clear the GPU device cache."""
|
|
if torch.backends.mps.is_available():
|
|
torch.mps.empty_cache()
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
|
|
@classmethod
|
|
def _to_dtype(cls, precision_name: TorchPrecisionNames) -> torch.dtype:
|
|
return NAME_TO_PRECISION[precision_name]
|