simplify logic for retrieving execution devices

This commit is contained in:
Lincoln Stein 2024-04-16 15:52:03 -04:00
parent fb9b7fb63a
commit 371f5bc782
3 changed files with 18 additions and 20 deletions

View File

@ -85,7 +85,6 @@ class ModelManagerService(ModelManagerServiceBase):
ram_cache = ModelCache(
max_cache_size=app_config.ram,
logger=logger,
execution_devices=execution_devices,
)
convert_cache = ModelConvertCache(cache_path=app_config.convert_cache_path, max_size=app_config.convert_cache)
loader = ModelLoadService(

View File

@ -65,7 +65,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
Initialize the model RAM cache.
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
:param execution_devices: Set of torch device to load active model into [calculated]
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
:param precision: Precision for loaded models [torch.float16]
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
@ -88,7 +87,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
# device to thread id
self._device_lock = threading.Lock()
self._execution_devices: Dict[torch.device, int] = {
x: 0 for x in execution_devices or self._get_execution_devices()
x: 0 for x in TorchDevice.execution_devices()
}
self._free_execution_device = BoundedSemaphore(len(self._execution_devices))
@ -403,17 +402,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
self._cache_stack.remove(cache_entry.key)
del self._cached_models[cache_entry.key]
@staticmethod
def _get_execution_devices(devices: Optional[Set[torch.device]] = None) -> Set[torch.device]:
if not devices:
if torch.cuda.is_available():
devices = {torch.device(f"cuda:{x}") for x in range(0, torch.cuda.device_count())}
elif torch.backends.mps.is_available():
devices = {torch.device("mps")}
else:
devices = {torch.device("cpu")}
return devices
@staticmethod
def _device_name(device: torch.device) -> str:
return f"{device.type}:{device.index}"

View File

@ -1,6 +1,6 @@
"""Torch Device class provides torch device selection services."""
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union, Set
import torch
from deprecated import deprecated
@ -72,12 +72,12 @@ class TorchDevice:
return cls.normalize(device)
@classmethod
def execution_devices(cls) -> List[torch.device]:
def execution_devices(cls) -> Set[torch.device]:
"""Return a list of torch.devices that can be used for accelerated inference."""
if cls._model_cache:
return cls._model_cache.execution_devices
else:
return [cls.choose_torch_device]
app_config = get_config()
if app_config.devices is None:
return cls._lookup_execution_devices()
return {torch.device(x) for x in app_config.devices}
@classmethod
def choose_torch_dtype(cls, device: Optional[torch.device] = None) -> torch.dtype:
@ -131,3 +131,14 @@ class TorchDevice:
@classmethod
def _to_dtype(cls, precision_name: TorchPrecisionNames) -> torch.dtype:
return NAME_TO_PRECISION[precision_name]
@classmethod
def _lookup_execution_devices(cls) -> Set[torch.device]:
if torch.cuda.is_available():
devices = {torch.device(f"cuda:{x}") for x in range(0, torch.cuda.device_count())}
elif torch.backends.mps.is_available():
devices = {torch.device("mps")}
else:
devices = {torch.device("cpu")}
return devices