simplify logic for retrieving execution devices

This commit is contained in:
Lincoln Stein 2024-04-16 15:52:03 -04:00
parent fb9b7fb63a
commit 371f5bc782
3 changed files with 18 additions and 20 deletions

View File

@ -85,7 +85,6 @@ class ModelManagerService(ModelManagerServiceBase):
ram_cache = ModelCache( ram_cache = ModelCache(
max_cache_size=app_config.ram, max_cache_size=app_config.ram,
logger=logger, logger=logger,
execution_devices=execution_devices,
) )
convert_cache = ModelConvertCache(cache_path=app_config.convert_cache_path, max_size=app_config.convert_cache) convert_cache = ModelConvertCache(cache_path=app_config.convert_cache_path, max_size=app_config.convert_cache)
loader = ModelLoadService( loader = ModelLoadService(

View File

@ -65,7 +65,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
Initialize the model RAM cache. Initialize the model RAM cache.
:param max_cache_size: Maximum size of the RAM cache [6.0 GB] :param max_cache_size: Maximum size of the RAM cache [6.0 GB]
:param execution_devices: Set of torch device to load active model into [calculated]
:param storage_device: Torch device to save inactive model in [torch.device('cpu')] :param storage_device: Torch device to save inactive model in [torch.device('cpu')]
:param precision: Precision for loaded models [torch.float16] :param precision: Precision for loaded models [torch.float16]
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially :param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
@ -88,7 +87,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
# device to thread id # device to thread id
self._device_lock = threading.Lock() self._device_lock = threading.Lock()
self._execution_devices: Dict[torch.device, int] = { self._execution_devices: Dict[torch.device, int] = {
x: 0 for x in execution_devices or self._get_execution_devices() x: 0 for x in TorchDevice.execution_devices()
} }
self._free_execution_device = BoundedSemaphore(len(self._execution_devices)) self._free_execution_device = BoundedSemaphore(len(self._execution_devices))
@ -403,17 +402,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
self._cache_stack.remove(cache_entry.key) self._cache_stack.remove(cache_entry.key)
del self._cached_models[cache_entry.key] del self._cached_models[cache_entry.key]
@staticmethod
def _get_execution_devices(devices: Optional[Set[torch.device]] = None) -> Set[torch.device]:
if not devices:
if torch.cuda.is_available():
devices = {torch.device(f"cuda:{x}") for x in range(0, torch.cuda.device_count())}
elif torch.backends.mps.is_available():
devices = {torch.device("mps")}
else:
devices = {torch.device("cpu")}
return devices
@staticmethod @staticmethod
def _device_name(device: torch.device) -> str: def _device_name(device: torch.device) -> str:
return f"{device.type}:{device.index}" return f"{device.type}:{device.index}"

View File

@ -1,6 +1,6 @@
"""Torch Device class provides torch device selection services.""" """Torch Device class provides torch device selection services."""
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union, Set
import torch import torch
from deprecated import deprecated from deprecated import deprecated
@ -72,12 +72,12 @@ class TorchDevice:
return cls.normalize(device) return cls.normalize(device)
@classmethod @classmethod
def execution_devices(cls) -> List[torch.device]: def execution_devices(cls) -> Set[torch.device]:
"""Return a list of torch.devices that can be used for accelerated inference.""" """Return a list of torch.devices that can be used for accelerated inference."""
if cls._model_cache: app_config = get_config()
return cls._model_cache.execution_devices if app_config.devices is None:
else: return cls._lookup_execution_devices()
return [cls.choose_torch_device] return {torch.device(x) for x in app_config.devices}
@classmethod @classmethod
def choose_torch_dtype(cls, device: Optional[torch.device] = None) -> torch.dtype: def choose_torch_dtype(cls, device: Optional[torch.device] = None) -> torch.dtype:
@ -131,3 +131,14 @@ class TorchDevice:
@classmethod @classmethod
def _to_dtype(cls, precision_name: TorchPrecisionNames) -> torch.dtype: def _to_dtype(cls, precision_name: TorchPrecisionNames) -> torch.dtype:
return NAME_TO_PRECISION[precision_name] return NAME_TO_PRECISION[precision_name]
@classmethod
def _lookup_execution_devices(cls) -> Set[torch.device]:
if torch.cuda.is_available():
devices = {torch.device(f"cuda:{x}") for x in range(0, torch.cuda.device_count())}
elif torch.backends.mps.is_available():
devices = {torch.device("mps")}
else:
devices = {torch.device("cpu")}
return devices