From 371f5bc782a364de4eea3389a7bed2d387934186 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Tue, 16 Apr 2024 15:52:03 -0400 Subject: [PATCH] simplify logic for retrieving execution devices --- .../model_manager/model_manager_default.py | 1 - .../load/model_cache/model_cache_default.py | 14 +---------- invokeai/backend/util/devices.py | 23 ++++++++++++++----- 3 files changed, 18 insertions(+), 20 deletions(-) diff --git a/invokeai/app/services/model_manager/model_manager_default.py b/invokeai/app/services/model_manager/model_manager_default.py index 241259c803..902501c1f9 100644 --- a/invokeai/app/services/model_manager/model_manager_default.py +++ b/invokeai/app/services/model_manager/model_manager_default.py @@ -85,7 +85,6 @@ class ModelManagerService(ModelManagerServiceBase): ram_cache = ModelCache( max_cache_size=app_config.ram, logger=logger, - execution_devices=execution_devices, ) convert_cache = ModelConvertCache(cache_path=app_config.convert_cache_path, max_size=app_config.convert_cache) loader = ModelLoadService( diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py index 026bb8aec5..9eab509e64 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py @@ -65,7 +65,6 @@ class ModelCache(ModelCacheBase[AnyModel]): Initialize the model RAM cache. :param max_cache_size: Maximum size of the RAM cache [6.0 GB] - :param execution_devices: Set of torch device to load active model into [calculated] :param storage_device: Torch device to save inactive model in [torch.device('cpu')] :param precision: Precision for loaded models [torch.float16] :param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially @@ -88,7 +87,7 @@ class ModelCache(ModelCacheBase[AnyModel]): # device to thread id self._device_lock = threading.Lock() self._execution_devices: Dict[torch.device, int] = { - x: 0 for x in execution_devices or self._get_execution_devices() + x: 0 for x in TorchDevice.execution_devices() } self._free_execution_device = BoundedSemaphore(len(self._execution_devices)) @@ -403,17 +402,6 @@ class ModelCache(ModelCacheBase[AnyModel]): self._cache_stack.remove(cache_entry.key) del self._cached_models[cache_entry.key] - @staticmethod - def _get_execution_devices(devices: Optional[Set[torch.device]] = None) -> Set[torch.device]: - if not devices: - if torch.cuda.is_available(): - devices = {torch.device(f"cuda:{x}") for x in range(0, torch.cuda.device_count())} - elif torch.backends.mps.is_available(): - devices = {torch.device("mps")} - else: - devices = {torch.device("cpu")} - return devices - @staticmethod def _device_name(device: torch.device) -> str: return f"{device.type}:{device.index}" diff --git a/invokeai/backend/util/devices.py b/invokeai/backend/util/devices.py index b88206c5f7..d1c432a53f 100644 --- a/invokeai/backend/util/devices.py +++ b/invokeai/backend/util/devices.py @@ -1,6 +1,6 @@ """Torch Device class provides torch device selection services.""" -from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union +from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union, Set import torch from deprecated import deprecated @@ -72,12 +72,12 @@ class TorchDevice: return cls.normalize(device) @classmethod - def execution_devices(cls) -> List[torch.device]: + def execution_devices(cls) -> Set[torch.device]: """Return a list of torch.devices that can be used for accelerated inference.""" - if cls._model_cache: - return cls._model_cache.execution_devices - else: - return [cls.choose_torch_device] + app_config = get_config() + if app_config.devices is None: + return cls._lookup_execution_devices() + return {torch.device(x) for x in app_config.devices} @classmethod def choose_torch_dtype(cls, device: Optional[torch.device] = None) -> torch.dtype: @@ -131,3 +131,14 @@ class TorchDevice: @classmethod def _to_dtype(cls, precision_name: TorchPrecisionNames) -> torch.dtype: return NAME_TO_PRECISION[precision_name] + + @classmethod + def _lookup_execution_devices(cls) -> Set[torch.device]: + if torch.cuda.is_available(): + devices = {torch.device(f"cuda:{x}") for x in range(0, torch.cuda.device_count())} + elif torch.backends.mps.is_available(): + devices = {torch.device("mps")} + else: + devices = {torch.device("cpu")} + return devices +