mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
simplify logic for retrieving execution devices
This commit is contained in:
parent
fb9b7fb63a
commit
371f5bc782
@ -85,7 +85,6 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
ram_cache = ModelCache(
|
||||
max_cache_size=app_config.ram,
|
||||
logger=logger,
|
||||
execution_devices=execution_devices,
|
||||
)
|
||||
convert_cache = ModelConvertCache(cache_path=app_config.convert_cache_path, max_size=app_config.convert_cache)
|
||||
loader = ModelLoadService(
|
||||
|
@ -65,7 +65,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
Initialize the model RAM cache.
|
||||
|
||||
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
|
||||
:param execution_devices: Set of torch device to load active model into [calculated]
|
||||
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
|
||||
:param precision: Precision for loaded models [torch.float16]
|
||||
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
|
||||
@ -88,7 +87,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
# device to thread id
|
||||
self._device_lock = threading.Lock()
|
||||
self._execution_devices: Dict[torch.device, int] = {
|
||||
x: 0 for x in execution_devices or self._get_execution_devices()
|
||||
x: 0 for x in TorchDevice.execution_devices()
|
||||
}
|
||||
self._free_execution_device = BoundedSemaphore(len(self._execution_devices))
|
||||
|
||||
@ -403,17 +402,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
self._cache_stack.remove(cache_entry.key)
|
||||
del self._cached_models[cache_entry.key]
|
||||
|
||||
@staticmethod
|
||||
def _get_execution_devices(devices: Optional[Set[torch.device]] = None) -> Set[torch.device]:
|
||||
if not devices:
|
||||
if torch.cuda.is_available():
|
||||
devices = {torch.device(f"cuda:{x}") for x in range(0, torch.cuda.device_count())}
|
||||
elif torch.backends.mps.is_available():
|
||||
devices = {torch.device("mps")}
|
||||
else:
|
||||
devices = {torch.device("cpu")}
|
||||
return devices
|
||||
|
||||
@staticmethod
|
||||
def _device_name(device: torch.device) -> str:
|
||||
return f"{device.type}:{device.index}"
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""Torch Device class provides torch device selection services."""
|
||||
|
||||
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union, Set
|
||||
|
||||
import torch
|
||||
from deprecated import deprecated
|
||||
@ -72,12 +72,12 @@ class TorchDevice:
|
||||
return cls.normalize(device)
|
||||
|
||||
@classmethod
|
||||
def execution_devices(cls) -> List[torch.device]:
|
||||
def execution_devices(cls) -> Set[torch.device]:
|
||||
"""Return a list of torch.devices that can be used for accelerated inference."""
|
||||
if cls._model_cache:
|
||||
return cls._model_cache.execution_devices
|
||||
else:
|
||||
return [cls.choose_torch_device]
|
||||
app_config = get_config()
|
||||
if app_config.devices is None:
|
||||
return cls._lookup_execution_devices()
|
||||
return {torch.device(x) for x in app_config.devices}
|
||||
|
||||
@classmethod
|
||||
def choose_torch_dtype(cls, device: Optional[torch.device] = None) -> torch.dtype:
|
||||
@ -131,3 +131,14 @@ class TorchDevice:
|
||||
@classmethod
|
||||
def _to_dtype(cls, precision_name: TorchPrecisionNames) -> torch.dtype:
|
||||
return NAME_TO_PRECISION[precision_name]
|
||||
|
||||
@classmethod
|
||||
def _lookup_execution_devices(cls) -> Set[torch.device]:
|
||||
if torch.cuda.is_available():
|
||||
devices = {torch.device(f"cuda:{x}") for x in range(0, torch.cuda.device_count())}
|
||||
elif torch.backends.mps.is_available():
|
||||
devices = {torch.device("mps")}
|
||||
else:
|
||||
devices = {torch.device("cpu")}
|
||||
return devices
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user