diff --git a/invokeai/app/services/invocation_processor/invocation_processor_default.py b/invokeai/app/services/invocation_processor/invocation_processor_default.py index d2ebe235e6..26030cad98 100644 --- a/invokeai/app/services/invocation_processor/invocation_processor_default.py +++ b/invokeai/app/services/invocation_processor/invocation_processor_default.py @@ -24,8 +24,10 @@ class DefaultInvocationProcessor(InvocationProcessorABC): __threadLimit: BoundedSemaphore def start(self, invoker: Invoker) -> None: - # if we do want multithreading at some point, we could make this configurable - self.__threadLimit = BoundedSemaphore(1) + # LS - this will probably break + # but the idea is to enable multithreading up to the number of available + # GPUs. Nodes will block on model loading if no GPU is free. + self.__threadLimit = BoundedSemaphore(invoker.services.model_manager.gpu_count) self.__invoker = invoker self.__stop_event = Event() self.__invoker_thread = Thread( diff --git a/invokeai/app/services/model_load/model_load_base.py b/invokeai/app/services/model_load/model_load_base.py index cc80333e93..cdd59f4e74 100644 --- a/invokeai/app/services/model_load/model_load_base.py +++ b/invokeai/app/services/model_load/model_load_base.py @@ -38,3 +38,8 @@ class ModelLoadServiceBase(ABC): @abstractmethod def convert_cache(self) -> ModelConvertCacheBase: """Return the checkpoint convert cache used by this loader.""" + + @property + @abstractmethod + def gpu_count(self) -> int: + """Return the number of GPUs we are configured to use.""" diff --git a/invokeai/app/services/model_load/model_load_default.py b/invokeai/app/services/model_load/model_load_default.py index 15c6283d8a..c6d829db5f 100644 --- a/invokeai/app/services/model_load/model_load_default.py +++ b/invokeai/app/services/model_load/model_load_default.py @@ -40,6 +40,7 @@ class ModelLoadService(ModelLoadServiceBase): self._registry = registry def start(self, invoker: Invoker) -> None: + """Start the service.""" self._invoker = invoker @property @@ -47,6 +48,11 @@ class ModelLoadService(ModelLoadServiceBase): """Return the RAM cache used by this loader.""" return self._ram_cache + @property + def gpu_count(self) -> int: + """Return the number of GPUs available for our uses.""" + return len(self._ram_cache.execution_devices) + @property def convert_cache(self) -> ModelConvertCacheBase: """Return the checkpoint convert cache used by this loader.""" diff --git a/invokeai/app/services/model_manager/model_manager_base.py b/invokeai/app/services/model_manager/model_manager_base.py index c25aa6fb47..ee17739882 100644 --- a/invokeai/app/services/model_manager/model_manager_base.py +++ b/invokeai/app/services/model_manager/model_manager_base.py @@ -98,3 +98,8 @@ class ModelManagerServiceBase(ABC): context_data: Optional[InvocationContextData] = None, ) -> LoadedModel: pass + + @property + @abstractmethod + def gpu_count(self) -> int: + """Return the number of GPUs we are configured to use.""" diff --git a/invokeai/app/services/model_manager/model_manager_default.py b/invokeai/app/services/model_manager/model_manager_default.py index d029f9e033..20e8896365 100644 --- a/invokeai/app/services/model_manager/model_manager_default.py +++ b/invokeai/app/services/model_manager/model_manager_default.py @@ -112,6 +112,11 @@ class ModelManagerService(ModelManagerServiceBase): else: return self.load.load_model(configs[0], submodel, context_data) + @property + def gpu_count(self) -> int: + """Return the number of GPUs we are using.""" + return self.load.gpu_count + @classmethod def build_model_manager( cls, diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py index 4a4a3c7d29..cdb29c181c 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py @@ -10,7 +10,7 @@ model will be cleared and (re)loaded from disk when next needed. from abc import ABC, abstractmethod from dataclasses import dataclass, field from logging import Logger -from typing import Dict, Generic, Optional, TypeVar +from typing import Dict, Generic, Optional, Set, TypeVar import torch @@ -89,8 +89,24 @@ class ModelCacheBase(ABC, Generic[T]): @property @abstractmethod - def execution_device(self) -> torch.device: - """Return the exection device (e.g. "cuda" for VRAM).""" + def execution_devices(self) -> Set[torch.device]: + """Return the set of available execution devices.""" + pass + + @abstractmethod + def acquire_execution_device(self, timeout: int = 0) -> torch.device: + """ + Pick the next available execution device. + + If all devices are currently engaged (locked), then + block until timeout seconds have passed and raise a + TimeoutError if no devices are available. + """ + pass + + @abstractmethod + def release_execution_device(self, device: torch.device) -> None: + """Release a previously-acquired execution device.""" pass @property diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py index 02ce1266c7..b179a190f7 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py @@ -25,7 +25,8 @@ import sys import time from contextlib import suppress from logging import Logger -from typing import Dict, List, Optional +from threading import BoundedSemaphore, Lock +from typing import Dict, List, Optional, Set import torch @@ -61,8 +62,8 @@ class ModelCache(ModelCacheBase[AnyModel]): self, max_cache_size: float = DEFAULT_MAX_CACHE_SIZE, max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE, - execution_device: torch.device = torch.device("cuda"), storage_device: torch.device = torch.device("cpu"), + execution_devices: Optional[Set[torch.device]] = None, precision: torch.dtype = torch.float16, sequential_offload: bool = False, lazy_offloading: bool = True, @@ -74,7 +75,7 @@ class ModelCache(ModelCacheBase[AnyModel]): Initialize the model RAM cache. :param max_cache_size: Maximum size of the RAM cache [6.0 GB] - :param execution_device: Torch device to load active model into [torch.device('cuda')] + :param execution_devices: Set of torch device to load active model into [calculated] :param storage_device: Torch device to save inactive model in [torch.device('cpu')] :param precision: Precision for loaded models [torch.float16] :param lazy_offloading: Keep model in VRAM until another model needs to be loaded @@ -89,7 +90,7 @@ class ModelCache(ModelCacheBase[AnyModel]): self._precision: torch.dtype = precision self._max_cache_size: float = max_cache_size self._max_vram_cache_size: float = max_vram_cache_size - self._execution_device: torch.device = execution_device + self._execution_devices: Set[torch.device] = execution_devices or self._get_execution_devices() self._storage_device: torch.device = storage_device self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__) self._log_memory_usage = log_memory_usage or self._logger.level == logging.DEBUG @@ -99,6 +100,10 @@ class ModelCache(ModelCacheBase[AnyModel]): self._cached_models: Dict[str, CacheRecord[AnyModel]] = {} self._cache_stack: List[str] = [] + self._lock = Lock() + self._free_execution_device = BoundedSemaphore(len(self._execution_devices)) + self._busy_execution_devices: Set[torch.device] = set() + @property def logger(self) -> Logger: """Return the logger used by the cache.""" @@ -115,9 +120,24 @@ class ModelCache(ModelCacheBase[AnyModel]): return self._storage_device @property - def execution_device(self) -> torch.device: - """Return the exection device (e.g. "cuda" for VRAM).""" - return self._execution_device + def execution_devices(self) -> Set[torch.device]: + """Return the set of available execution devices.""" + return self._execution_devices + + def acquire_execution_device(self, timeout: int = 0) -> torch.device: + """Acquire and return an execution device (e.g. "cuda" for VRAM).""" + with self._lock: + self._free_execution_device.acquire(timeout=timeout) + free_devices = self.execution_devices - self._busy_execution_devices + chosen_device = list(free_devices)[0] + self._busy_execution_devices.add(chosen_device) + return chosen_device + + def release_execution_device(self, device: torch.device) -> None: + """Mark this execution device as unused.""" + with self._lock: + self._free_execution_device.release() + self._busy_execution_devices.remove(device) @property def max_cache_size(self) -> float: @@ -405,3 +425,13 @@ class ModelCache(ModelCacheBase[AnyModel]): mps.empty_cache() self.logger.debug(f"After making room: cached_models={len(self._cached_models)}") + + @staticmethod + def _get_execution_devices() -> Set[torch.device]: + default_device = choose_torch_device() + if default_device != torch.device("cuda"): + return {default_device} + + # we get here if the default device is cuda, and return each of the + # cuda devices. + return {torch.device(f"cuda:{x}") for x in range(0, torch.cuda.device_count())} diff --git a/invokeai/backend/model_manager/load/model_cache/model_locker.py b/invokeai/backend/model_manager/load/model_cache/model_locker.py index 7a5fdd4284..b7314b5741 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_locker.py +++ b/invokeai/backend/model_manager/load/model_cache/model_locker.py @@ -2,10 +2,16 @@ Base class and implementation of a class that moves models in and out of VRAM. """ +from typing import Optional + +import torch + from invokeai.backend.model_manager import AnyModel from .model_cache_base import CacheRecord, ModelCacheBase, ModelLockerBase +MAX_GPU_WAIT = 600 # wait up to 10 minutes for a GPU to become free + class ModelLocker(ModelLockerBase): """Internal class that mediates movement in and out of GPU.""" @@ -19,6 +25,7 @@ class ModelLocker(ModelLockerBase): """ self._cache = cache self._cache_entry = cache_entry + self._execution_device: Optional[torch.device] = None @property def model(self) -> AnyModel: @@ -37,10 +44,12 @@ class ModelLocker(ModelLockerBase): if self._cache.lazy_offloading: self._cache.offload_unlocked_models(self._cache_entry.size) - self._cache.move_model_to_device(self._cache_entry, self._cache.execution_device) + # We wait for a gpu to be free - may raise a TimeoutError + self._execution_device = self._cache.acquire_execution_device(MAX_GPU_WAIT) + self._cache.move_model_to_device(self._cache_entry, self._execution_device) self._cache_entry.loaded = True - self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}") + self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._execution_device}") self._cache.print_cuda_stats() except Exception: @@ -54,6 +63,8 @@ class ModelLocker(ModelLockerBase): return self._cache_entry.unlock() + if self._execution_device: + self._cache.release_execution_device(self._execution_device) if not self._cache.lazy_offloading: self._cache.offload_unlocked_models(self._cache_entry.size) self._cache.print_cuda_stats() diff --git a/tests/aa_nodes/test_invoker.py b/tests/aa_nodes/test_invoker.py index f67b5a2ac5..34340c5f78 100644 --- a/tests/aa_nodes/test_invoker.py +++ b/tests/aa_nodes/test_invoker.py @@ -65,7 +65,7 @@ def mock_services() -> InvocationServices: images=None, # type: ignore invocation_cache=MemoryInvocationCache(max_cache_size=0), logger=logging, # type: ignore - model_manager=Mock(), # type: ignore + model_manager=Mock(gpu_count=1), # type: ignore download_queue=None, # type: ignore names=None, # type: ignore performance_statistics=InvocationStatsService(), diff --git a/tests/backend/model_manager/model_manager_fixtures.py b/tests/backend/model_manager/model_manager_fixtures.py index df54e2f926..12252fb315 100644 --- a/tests/backend/model_manager/model_manager_fixtures.py +++ b/tests/backend/model_manager/model_manager_fixtures.py @@ -112,7 +112,7 @@ def mm2_metadata_store(mm2_record_store: ModelRecordServiceSQL) -> ModelMetadata @pytest.fixture -def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordServiceBase) -> ModelLoadServiceBase: +def mm2_loader(mm2_app_config: InvokeAIAppConfig) -> ModelLoadServiceBase: ram_cache = ModelCache( logger=InvokeAILogger.get_logger(), max_cache_size=mm2_app_config.ram_cache_size,