implement session-level reservation of gpus

This commit is contained in:
Lincoln Stein 2024-04-01 16:01:43 -04:00
parent eca29c41d0
commit 3d69372785
6 changed files with 132 additions and 114 deletions

View File

@ -100,7 +100,8 @@ class InvokeAIAppConfig(BaseSettings):
ram: Maximum memory amount used by memory model cache for rapid switching (GB).
convert_cache: Maximum size of on-disk converted models cache (GB).
log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda:0`, `cuda:1`, `cuda:2`, `cuda:3`, `cuda:4`, `cuda:5`, `mps`
devices: List of execution devices; will override default device selected.
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`, `autocast`
sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.
attention_type: Attention type.<br>Valid values: `auto`, `normal`, `xformers`, `sliced`, `torch-sdp`
@ -108,6 +109,7 @@ class InvokeAIAppConfig(BaseSettings):
force_tiled_decode: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).
pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.
max_queue_size: Maximum number of items in the session queue.
max_threads: Maximum number of session queue execution threads.
allow_nodes: List of nodes to allow. Omit to allow all.
deny_nodes: List of nodes to deny. Omit to deny none.
node_cache_size: How many cached nodes to keep in memory.

View File

@ -181,47 +181,51 @@ class DefaultSessionProcessor(SessionProcessorBase):
if profiler is not None:
profiler.start(profile_id=session.session_id)
# Prepare invocations and take the first
with self._process_lock:
invocation = session.session.next()
# reserve a GPU for this session - may block
with self._invoker.services.model_manager.load.ram_cache.reserve_execution_device() as gpu:
print(f"DEBUG: session {session.item_id} has reserved gpu {gpu}")
# Loop over invocations until the session is complete or canceled
while invocation is not None:
if self._stop_event.is_set():
break
self._resume_event.wait()
# Prepare invocations and take the first
with self._process_lock:
invocation = session.session.next()
self._process_next_invocation(session, invocation, stats_service)
# Loop over invocations until the session is complete or canceled
while invocation is not None:
if self._stop_event.is_set():
break
self._resume_event.wait()
# The session is complete if all invocations are complete or there was an error
if session.session.is_complete():
# Send complete event
self._invoker.services.events.emit_graph_execution_complete(
queue_batch_id=session.batch_id,
queue_item_id=session.item_id,
queue_id=session.queue_id,
graph_execution_state_id=session.session.id,
)
# Log stats
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
# we don't care about that - suppress the error.
with suppress(GESStatsNotFoundError):
stats_service.log_stats(session.session.id)
stats_service.reset_stats()
self._process_next_invocation(session, invocation, stats_service)
# If we are profiling, stop the profiler and dump the profile & stats
if self._profiler:
profile_path = self._profiler.stop()
stats_path = profile_path.with_suffix(".json")
stats_service.dump_stats(
graph_execution_state_id=session.session.id, output_path=stats_path
# The session is complete if all invocations are complete or there was an error
if session.session.is_complete():
# Send complete event
self._invoker.services.events.emit_graph_execution_complete(
queue_batch_id=session.batch_id,
queue_item_id=session.item_id,
queue_id=session.queue_id,
graph_execution_state_id=session.session.id,
)
self._queue_items.remove(session.item_id)
invocation = None
else:
# Prepare the next invocation
with self._process_lock:
invocation = session.session.next()
# Log stats
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
# we don't care about that - suppress the error.
with suppress(GESStatsNotFoundError):
stats_service.log_stats(session.session.id)
stats_service.reset_stats()
# If we are profiling, stop the profiler and dump the profile & stats
if self._profiler:
profile_path = self._profiler.stop()
stats_path = profile_path.with_suffix(".json")
stats_service.dump_stats(
graph_execution_state_id=session.session.id, output_path=stats_path
)
self._queue_items.remove(session.item_id)
invocation = None
else:
# Prepare the next invocation
with self._process_lock:
invocation = session.session.next()
except Exception:
# Non-fatal error in processor

View File

@ -8,9 +8,10 @@ model will be cleared and (re)loaded from disk when next needed.
"""
from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass, field
from logging import Logger
from typing import Dict, Generic, Optional, Set, TypeVar
from typing import Dict, Generator, Generic, Optional, Set, TypeVar
import torch
@ -93,20 +94,23 @@ class ModelCacheBase(ABC, Generic[T]):
"""Return the set of available execution devices."""
pass
@contextmanager
@abstractmethod
def acquire_execution_device(self, timeout: int = 0) -> torch.device:
"""
Pick the next available execution device.
If all devices are currently engaged (locked), then
block until timeout seconds have passed and raise a
TimeoutError if no devices are available.
"""
def reserve_execution_device(self, timeout: int = 0) -> Generator[torch.device, None, None]:
"""Reserve an execution device (GPU) under the current thread id."""
pass
@abstractmethod
def release_execution_device(self, device: torch.device) -> None:
"""Release a previously-acquired execution device."""
def get_execution_device(self) -> torch.device:
"""
Return an execution device that has been reserved for current thread.
Note that reservations are done using the current thread's TID.
It would be better to do this using the session ID, but that involves
too many detailed changes to model manager calls.
May generate a ValueError if no GPU has been reserved.
"""
pass
@property

View File

@ -21,13 +21,12 @@ context. Use like this:
import gc
import sys
import threading
from contextlib import suppress
from contextlib import contextmanager, suppress
from logging import Logger
from threading import BoundedSemaphore, Lock
from typing import Dict, List, Optional, Set
from threading import BoundedSemaphore
from typing import Dict, Generator, List, Optional, Set
import torch
from pydantic import BaseModel
from invokeai.backend.model_manager import AnyModel, SubModelType
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot
@ -51,26 +50,6 @@ GIG = 1073741824
MB = 2**20
# GPU device can only be used by one thread at a time.
# The refcount indicates the number of models stored
# in it.
class GPUDeviceStatus(BaseModel):
"""Track of which threads are using the GPU(s) on this system."""
device: torch.device
thread_id: int = 0
refcount: int = 0
class Config:
"""Configure the base model."""
arbitrary_types_allowed = True
def __hash__(self) -> int:
"""Allow to be added to a set."""
return hash(str(torch.device))
class ModelCache(ModelCacheBase[AnyModel]):
"""Implementation of ModelCacheBase."""
@ -100,9 +79,8 @@ class ModelCache(ModelCacheBase[AnyModel]):
"""
self._precision: torch.dtype = precision
self._max_cache_size: float = max_cache_size
self._execution_devices: Set[GPUDeviceStatus] = self._get_execution_devices(execution_devices)
self._storage_device: torch.device = storage_device
self._lock = threading.Lock()
self._ram_lock = threading.Lock()
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
self._log_memory_usage = log_memory_usage
self._stats: Optional[CacheStats] = None
@ -110,11 +88,15 @@ class ModelCache(ModelCacheBase[AnyModel]):
self._cached_models: Dict[str, CacheRecord[AnyModel]] = {}
self._cache_stack: List[str] = []
self._lock = Lock()
# device to thread id
self._device_lock = threading.Lock()
self._execution_devices: Dict[torch.device, int] = {
x: 0 for x in execution_devices or self._get_execution_devices()
}
self._free_execution_device = BoundedSemaphore(len(self._execution_devices))
self.logger.info(
f"Using rendering device(s): {', '.join(sorted([str(x.device) for x in self._execution_devices]))}"
f"Using rendering device(s): {', '.join(sorted([str(x) for x in self._execution_devices.keys()]))}"
)
@property
@ -130,34 +112,61 @@ class ModelCache(ModelCacheBase[AnyModel]):
@property
def execution_devices(self) -> Set[torch.device]:
"""Return the set of available execution devices."""
return {x.device for x in self._execution_devices}
devices = self._execution_devices.keys()
return set(devices)
def acquire_execution_device(self, timeout: int = 0) -> torch.device:
"""Acquire and return an execution device (e.g. "cuda" for VRAM)."""
def get_execution_device(self) -> torch.device:
"""
Return an execution device that has been reserved for current thread.
Note that reservations are done using the current thread's TID.
It would be better to do this using the session ID, but that involves
too many detailed changes to model manager calls.
May generate a ValueError if no GPU has been reserved.
"""
current_thread = threading.current_thread().ident
assert current_thread is not None
assigned = [x for x, tid in self._execution_devices.items() if current_thread == tid]
if not assigned:
raise ValueError("No GPU has been reserved for the use of thread {current_thread}")
return assigned[0]
# first try to assign a device that is already executing on this thread
if claimed_devices := [x for x in self._execution_devices if x.thread_id == current_thread]:
claimed_devices[0].refcount += 1
return claimed_devices[0].device
else:
# this thread is not currently using any gpu. Wait for a free one
@contextmanager
def reserve_execution_device(self, timeout: Optional[int] = None) -> Generator[torch.device, None, None]:
"""Reserve an execution device (e.g. GPU) for exclusive use by a generation thread.
Note that the reservation is done using the current thread's TID.
It would be better to do this using the session ID, but that involves
too many detailed changes to model manager calls.
"""
device = None
with self._device_lock:
current_thread = threading.current_thread().ident
assert current_thread is not None
# look for a device that has already been assigned to this thread
assigned = [x for x, tid in self._execution_devices.items() if current_thread == tid]
if assigned:
device = assigned[0]
# no device already assigned. Get one.
if device is None:
self._free_execution_device.acquire(timeout=timeout)
unclaimed_devices = [x for x in self._execution_devices if x.refcount == 0]
unclaimed_devices[0].thread_id = current_thread
unclaimed_devices[0].refcount += 1
return unclaimed_devices[0].device
with self._device_lock:
free_device = [x for x, tid in self._execution_devices.items() if tid == 0]
print(f"DEBUG: execution devices = {self._execution_devices}")
self._execution_devices[free_device[0]] = current_thread
device = free_device[0]
def release_execution_device(self, device: torch.device) -> None:
"""Mark this execution device as unused."""
current_thread = threading.current_thread().ident
for x in self._execution_devices:
if x.thread_id == current_thread and x.device == device:
x.refcount -= 1
if x.refcount == 0:
x.thread_id = 0
self._free_execution_device.release()
# we are outside the lock region now
try:
yield device
finally:
with self._device_lock:
self._execution_devices[device] = 0
self._free_execution_device.release()
torch.cuda.empty_cache()
@property
def max_cache_size(self) -> float:
@ -203,7 +212,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
submodel_type: Optional[SubModelType] = None,
) -> None:
"""Store model under key and optional submodel_type."""
with self._lock:
with self._ram_lock:
key = self._make_cache_key(key, submodel_type)
if key in self._cached_models:
return
@ -228,7 +237,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
This may raise an IndexError if the model is not in the cache.
"""
with self._lock:
with self._ram_lock:
key = self._make_cache_key(key, submodel_type)
if key in self._cached_models:
if self.stats:
@ -395,7 +404,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
raise torch.cuda.OutOfMemoryError
@staticmethod
def _get_execution_devices(devices: Optional[Set[torch.device]] = None) -> Set[GPUDeviceStatus]:
def _get_execution_devices(devices: Optional[Set[torch.device]] = None) -> Set[torch.device]:
if not devices:
default_device = choose_torch_device()
if default_device != torch.device("cuda"):
@ -403,7 +412,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
else:
# we get here if the default device is cuda, and return each of the cuda devices.
devices = {torch.device(f"cuda:{x}") for x in range(0, torch.cuda.device_count())}
return {GPUDeviceStatus(device=x) for x in devices}
return devices
@staticmethod
def _device_name(device: torch.device) -> str:

View File

@ -56,8 +56,8 @@ class ModelLocker(ModelLockerBase):
self._cache_entry.lock()
try:
# We wait for a gpu to be free - may raise a TimeoutError
self._execution_device = self._cache.acquire_execution_device(MAX_GPU_WAIT)
# We wait for a gpu to be free - may raise a ValueError
self._execution_device = self._cache.get_execution_device()
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._execution_device}")
model_in_gpu = copy.deepcopy(self._cache_entry.model)
if hasattr(model_in_gpu, "to"):
@ -77,14 +77,5 @@ class ModelLocker(ModelLockerBase):
"""Call upon exit from context."""
if not hasattr(self.model, "to"):
return
self._cache_entry.unlock()
if self._execution_device:
self._cache.release_execution_device(self._execution_device)
try:
torch.cuda.empty_cache()
torch.mps.empty_cache()
except Exception:
pass
self._cache.print_cuda_stats()

View File

@ -15,7 +15,15 @@ MPS_DEVICE = torch.device("mps")
def choose_torch_device() -> torch.device:
"""Convenience routine for guessing which GPU device to run model on"""
"""Convenience routine for guessing which GPU device to run model on."""
# """Temporarily modified to use the model manager's get_execution_device()"""
# try:
# from invokeai.app.api.dependencies import ApiDependencies
# model_manager = ApiDependencies.invoker.services.model_manager
# device = model_manager.load.ram_cache.acquire_execution_device()
# print(f'DEBUG choose_torch_device returning {device}')
# return device
# except Exception:
config = get_config()
if config.device == "auto":
if torch.cuda.is_available():