diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py
index 258cd58e8d..8ab4105284 100644
--- a/invokeai/app/services/config/config_default.py
+++ b/invokeai/app/services/config/config_default.py
@@ -100,7 +100,8 @@ class InvokeAIAppConfig(BaseSettings):
ram: Maximum memory amount used by memory model cache for rapid switching (GB).
convert_cache: Maximum size of on-disk converted models cache (GB).
log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
- device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.
Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
+ device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.
Valid values: `auto`, `cpu`, `cuda:0`, `cuda:1`, `cuda:2`, `cuda:3`, `cuda:4`, `cuda:5`, `mps`
+ devices: List of execution devices; will override default device selected.
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.
Valid values: `auto`, `float16`, `bfloat16`, `float32`, `autocast`
sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.
attention_type: Attention type.
Valid values: `auto`, `normal`, `xformers`, `sliced`, `torch-sdp`
@@ -108,6 +109,7 @@ class InvokeAIAppConfig(BaseSettings):
force_tiled_decode: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).
pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.
max_queue_size: Maximum number of items in the session queue.
+ max_threads: Maximum number of session queue execution threads.
allow_nodes: List of nodes to allow. Omit to allow all.
deny_nodes: List of nodes to deny. Omit to deny none.
node_cache_size: How many cached nodes to keep in memory.
diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py
index 3088d99c5d..fd198d0ff9 100644
--- a/invokeai/app/services/session_processor/session_processor_default.py
+++ b/invokeai/app/services/session_processor/session_processor_default.py
@@ -181,47 +181,51 @@ class DefaultSessionProcessor(SessionProcessorBase):
if profiler is not None:
profiler.start(profile_id=session.session_id)
- # Prepare invocations and take the first
- with self._process_lock:
- invocation = session.session.next()
+ # reserve a GPU for this session - may block
+ with self._invoker.services.model_manager.load.ram_cache.reserve_execution_device() as gpu:
+ print(f"DEBUG: session {session.item_id} has reserved gpu {gpu}")
- # Loop over invocations until the session is complete or canceled
- while invocation is not None:
- if self._stop_event.is_set():
- break
- self._resume_event.wait()
+ # Prepare invocations and take the first
+ with self._process_lock:
+ invocation = session.session.next()
- self._process_next_invocation(session, invocation, stats_service)
+ # Loop over invocations until the session is complete or canceled
+ while invocation is not None:
+ if self._stop_event.is_set():
+ break
+ self._resume_event.wait()
- # The session is complete if all invocations are complete or there was an error
- if session.session.is_complete():
- # Send complete event
- self._invoker.services.events.emit_graph_execution_complete(
- queue_batch_id=session.batch_id,
- queue_item_id=session.item_id,
- queue_id=session.queue_id,
- graph_execution_state_id=session.session.id,
- )
- # Log stats
- # We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
- # we don't care about that - suppress the error.
- with suppress(GESStatsNotFoundError):
- stats_service.log_stats(session.session.id)
- stats_service.reset_stats()
+ self._process_next_invocation(session, invocation, stats_service)
- # If we are profiling, stop the profiler and dump the profile & stats
- if self._profiler:
- profile_path = self._profiler.stop()
- stats_path = profile_path.with_suffix(".json")
- stats_service.dump_stats(
- graph_execution_state_id=session.session.id, output_path=stats_path
+ # The session is complete if all invocations are complete or there was an error
+ if session.session.is_complete():
+ # Send complete event
+ self._invoker.services.events.emit_graph_execution_complete(
+ queue_batch_id=session.batch_id,
+ queue_item_id=session.item_id,
+ queue_id=session.queue_id,
+ graph_execution_state_id=session.session.id,
)
- self._queue_items.remove(session.item_id)
- invocation = None
- else:
- # Prepare the next invocation
- with self._process_lock:
- invocation = session.session.next()
+ # Log stats
+ # We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
+ # we don't care about that - suppress the error.
+ with suppress(GESStatsNotFoundError):
+ stats_service.log_stats(session.session.id)
+ stats_service.reset_stats()
+
+ # If we are profiling, stop the profiler and dump the profile & stats
+ if self._profiler:
+ profile_path = self._profiler.stop()
+ stats_path = profile_path.with_suffix(".json")
+ stats_service.dump_stats(
+ graph_execution_state_id=session.session.id, output_path=stats_path
+ )
+ self._queue_items.remove(session.item_id)
+ invocation = None
+ else:
+ # Prepare the next invocation
+ with self._process_lock:
+ invocation = session.session.next()
except Exception:
# Non-fatal error in processor
diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py
index 1d6a4f15db..45640aff42 100644
--- a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py
+++ b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py
@@ -8,9 +8,10 @@ model will be cleared and (re)loaded from disk when next needed.
"""
from abc import ABC, abstractmethod
+from contextlib import contextmanager
from dataclasses import dataclass, field
from logging import Logger
-from typing import Dict, Generic, Optional, Set, TypeVar
+from typing import Dict, Generator, Generic, Optional, Set, TypeVar
import torch
@@ -93,20 +94,23 @@ class ModelCacheBase(ABC, Generic[T]):
"""Return the set of available execution devices."""
pass
+ @contextmanager
@abstractmethod
- def acquire_execution_device(self, timeout: int = 0) -> torch.device:
- """
- Pick the next available execution device.
-
- If all devices are currently engaged (locked), then
- block until timeout seconds have passed and raise a
- TimeoutError if no devices are available.
- """
+ def reserve_execution_device(self, timeout: int = 0) -> Generator[torch.device, None, None]:
+ """Reserve an execution device (GPU) under the current thread id."""
pass
@abstractmethod
- def release_execution_device(self, device: torch.device) -> None:
- """Release a previously-acquired execution device."""
+ def get_execution_device(self) -> torch.device:
+ """
+ Return an execution device that has been reserved for current thread.
+
+ Note that reservations are done using the current thread's TID.
+ It would be better to do this using the session ID, but that involves
+ too many detailed changes to model manager calls.
+
+ May generate a ValueError if no GPU has been reserved.
+ """
pass
@property
diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py
index 4478360dfe..04cac01092 100644
--- a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py
+++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py
@@ -21,13 +21,12 @@ context. Use like this:
import gc
import sys
import threading
-from contextlib import suppress
+from contextlib import contextmanager, suppress
from logging import Logger
-from threading import BoundedSemaphore, Lock
-from typing import Dict, List, Optional, Set
+from threading import BoundedSemaphore
+from typing import Dict, Generator, List, Optional, Set
import torch
-from pydantic import BaseModel
from invokeai.backend.model_manager import AnyModel, SubModelType
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot
@@ -51,26 +50,6 @@ GIG = 1073741824
MB = 2**20
-# GPU device can only be used by one thread at a time.
-# The refcount indicates the number of models stored
-# in it.
-class GPUDeviceStatus(BaseModel):
- """Track of which threads are using the GPU(s) on this system."""
-
- device: torch.device
- thread_id: int = 0
- refcount: int = 0
-
- class Config:
- """Configure the base model."""
-
- arbitrary_types_allowed = True
-
- def __hash__(self) -> int:
- """Allow to be added to a set."""
- return hash(str(torch.device))
-
-
class ModelCache(ModelCacheBase[AnyModel]):
"""Implementation of ModelCacheBase."""
@@ -100,9 +79,8 @@ class ModelCache(ModelCacheBase[AnyModel]):
"""
self._precision: torch.dtype = precision
self._max_cache_size: float = max_cache_size
- self._execution_devices: Set[GPUDeviceStatus] = self._get_execution_devices(execution_devices)
self._storage_device: torch.device = storage_device
- self._lock = threading.Lock()
+ self._ram_lock = threading.Lock()
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
self._log_memory_usage = log_memory_usage
self._stats: Optional[CacheStats] = None
@@ -110,11 +88,15 @@ class ModelCache(ModelCacheBase[AnyModel]):
self._cached_models: Dict[str, CacheRecord[AnyModel]] = {}
self._cache_stack: List[str] = []
- self._lock = Lock()
+ # device to thread id
+ self._device_lock = threading.Lock()
+ self._execution_devices: Dict[torch.device, int] = {
+ x: 0 for x in execution_devices or self._get_execution_devices()
+ }
self._free_execution_device = BoundedSemaphore(len(self._execution_devices))
self.logger.info(
- f"Using rendering device(s): {', '.join(sorted([str(x.device) for x in self._execution_devices]))}"
+ f"Using rendering device(s): {', '.join(sorted([str(x) for x in self._execution_devices.keys()]))}"
)
@property
@@ -130,34 +112,61 @@ class ModelCache(ModelCacheBase[AnyModel]):
@property
def execution_devices(self) -> Set[torch.device]:
"""Return the set of available execution devices."""
- return {x.device for x in self._execution_devices}
+ devices = self._execution_devices.keys()
+ return set(devices)
- def acquire_execution_device(self, timeout: int = 0) -> torch.device:
- """Acquire and return an execution device (e.g. "cuda" for VRAM)."""
+ def get_execution_device(self) -> torch.device:
+ """
+ Return an execution device that has been reserved for current thread.
+
+ Note that reservations are done using the current thread's TID.
+ It would be better to do this using the session ID, but that involves
+ too many detailed changes to model manager calls.
+
+ May generate a ValueError if no GPU has been reserved.
+ """
current_thread = threading.current_thread().ident
assert current_thread is not None
+ assigned = [x for x, tid in self._execution_devices.items() if current_thread == tid]
+ if not assigned:
+ raise ValueError("No GPU has been reserved for the use of thread {current_thread}")
+ return assigned[0]
- # first try to assign a device that is already executing on this thread
- if claimed_devices := [x for x in self._execution_devices if x.thread_id == current_thread]:
- claimed_devices[0].refcount += 1
- return claimed_devices[0].device
- else:
- # this thread is not currently using any gpu. Wait for a free one
+ @contextmanager
+ def reserve_execution_device(self, timeout: Optional[int] = None) -> Generator[torch.device, None, None]:
+ """Reserve an execution device (e.g. GPU) for exclusive use by a generation thread.
+
+ Note that the reservation is done using the current thread's TID.
+ It would be better to do this using the session ID, but that involves
+ too many detailed changes to model manager calls.
+ """
+ device = None
+ with self._device_lock:
+ current_thread = threading.current_thread().ident
+ assert current_thread is not None
+
+ # look for a device that has already been assigned to this thread
+ assigned = [x for x, tid in self._execution_devices.items() if current_thread == tid]
+ if assigned:
+ device = assigned[0]
+
+ # no device already assigned. Get one.
+ if device is None:
self._free_execution_device.acquire(timeout=timeout)
- unclaimed_devices = [x for x in self._execution_devices if x.refcount == 0]
- unclaimed_devices[0].thread_id = current_thread
- unclaimed_devices[0].refcount += 1
- return unclaimed_devices[0].device
+ with self._device_lock:
+ free_device = [x for x, tid in self._execution_devices.items() if tid == 0]
+ print(f"DEBUG: execution devices = {self._execution_devices}")
+ self._execution_devices[free_device[0]] = current_thread
+ device = free_device[0]
- def release_execution_device(self, device: torch.device) -> None:
- """Mark this execution device as unused."""
- current_thread = threading.current_thread().ident
- for x in self._execution_devices:
- if x.thread_id == current_thread and x.device == device:
- x.refcount -= 1
- if x.refcount == 0:
- x.thread_id = 0
- self._free_execution_device.release()
+ # we are outside the lock region now
+ try:
+ yield device
+ finally:
+ with self._device_lock:
+ self._execution_devices[device] = 0
+ self._free_execution_device.release()
+ torch.cuda.empty_cache()
@property
def max_cache_size(self) -> float:
@@ -203,7 +212,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
submodel_type: Optional[SubModelType] = None,
) -> None:
"""Store model under key and optional submodel_type."""
- with self._lock:
+ with self._ram_lock:
key = self._make_cache_key(key, submodel_type)
if key in self._cached_models:
return
@@ -228,7 +237,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
This may raise an IndexError if the model is not in the cache.
"""
- with self._lock:
+ with self._ram_lock:
key = self._make_cache_key(key, submodel_type)
if key in self._cached_models:
if self.stats:
@@ -395,7 +404,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
raise torch.cuda.OutOfMemoryError
@staticmethod
- def _get_execution_devices(devices: Optional[Set[torch.device]] = None) -> Set[GPUDeviceStatus]:
+ def _get_execution_devices(devices: Optional[Set[torch.device]] = None) -> Set[torch.device]:
if not devices:
default_device = choose_torch_device()
if default_device != torch.device("cuda"):
@@ -403,7 +412,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
else:
# we get here if the default device is cuda, and return each of the cuda devices.
devices = {torch.device(f"cuda:{x}") for x in range(0, torch.cuda.device_count())}
- return {GPUDeviceStatus(device=x) for x in devices}
+ return devices
@staticmethod
def _device_name(device: torch.device) -> str:
diff --git a/invokeai/backend/model_manager/load/model_cache/model_locker.py b/invokeai/backend/model_manager/load/model_cache/model_locker.py
index 0ea87fbe06..fd2465b517 100644
--- a/invokeai/backend/model_manager/load/model_cache/model_locker.py
+++ b/invokeai/backend/model_manager/load/model_cache/model_locker.py
@@ -56,8 +56,8 @@ class ModelLocker(ModelLockerBase):
self._cache_entry.lock()
try:
- # We wait for a gpu to be free - may raise a TimeoutError
- self._execution_device = self._cache.acquire_execution_device(MAX_GPU_WAIT)
+ # We wait for a gpu to be free - may raise a ValueError
+ self._execution_device = self._cache.get_execution_device()
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._execution_device}")
model_in_gpu = copy.deepcopy(self._cache_entry.model)
if hasattr(model_in_gpu, "to"):
@@ -77,14 +77,5 @@ class ModelLocker(ModelLockerBase):
"""Call upon exit from context."""
if not hasattr(self.model, "to"):
return
-
self._cache_entry.unlock()
- if self._execution_device:
- self._cache.release_execution_device(self._execution_device)
-
- try:
- torch.cuda.empty_cache()
- torch.mps.empty_cache()
- except Exception:
- pass
self._cache.print_cuda_stats()
diff --git a/invokeai/backend/util/devices.py b/invokeai/backend/util/devices.py
index 0be53c842a..e0fd0b1c9e 100644
--- a/invokeai/backend/util/devices.py
+++ b/invokeai/backend/util/devices.py
@@ -15,7 +15,15 @@ MPS_DEVICE = torch.device("mps")
def choose_torch_device() -> torch.device:
- """Convenience routine for guessing which GPU device to run model on"""
+ """Convenience routine for guessing which GPU device to run model on."""
+ # """Temporarily modified to use the model manager's get_execution_device()"""
+ # try:
+ # from invokeai.app.api.dependencies import ApiDependencies
+ # model_manager = ApiDependencies.invoker.services.model_manager
+ # device = model_manager.load.ram_cache.acquire_execution_device()
+ # print(f'DEBUG choose_torch_device returning {device}')
+ # return device
+ # except Exception:
config = get_config()
if config.device == "auto":
if torch.cuda.is_available():