mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
implement session-level reservation of gpus
This commit is contained in:
parent
eca29c41d0
commit
3d69372785
@ -100,7 +100,8 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
ram: Maximum memory amount used by memory model cache for rapid switching (GB).
|
||||
convert_cache: Maximum size of on-disk converted models cache (GB).
|
||||
log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
|
||||
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
|
||||
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda:0`, `cuda:1`, `cuda:2`, `cuda:3`, `cuda:4`, `cuda:5`, `mps`
|
||||
devices: List of execution devices; will override default device selected.
|
||||
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`, `autocast`
|
||||
sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.
|
||||
attention_type: Attention type.<br>Valid values: `auto`, `normal`, `xformers`, `sliced`, `torch-sdp`
|
||||
@ -108,6 +109,7 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
force_tiled_decode: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).
|
||||
pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.
|
||||
max_queue_size: Maximum number of items in the session queue.
|
||||
max_threads: Maximum number of session queue execution threads.
|
||||
allow_nodes: List of nodes to allow. Omit to allow all.
|
||||
deny_nodes: List of nodes to deny. Omit to deny none.
|
||||
node_cache_size: How many cached nodes to keep in memory.
|
||||
|
@ -181,47 +181,51 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
if profiler is not None:
|
||||
profiler.start(profile_id=session.session_id)
|
||||
|
||||
# Prepare invocations and take the first
|
||||
with self._process_lock:
|
||||
invocation = session.session.next()
|
||||
# reserve a GPU for this session - may block
|
||||
with self._invoker.services.model_manager.load.ram_cache.reserve_execution_device() as gpu:
|
||||
print(f"DEBUG: session {session.item_id} has reserved gpu {gpu}")
|
||||
|
||||
# Loop over invocations until the session is complete or canceled
|
||||
while invocation is not None:
|
||||
if self._stop_event.is_set():
|
||||
break
|
||||
self._resume_event.wait()
|
||||
# Prepare invocations and take the first
|
||||
with self._process_lock:
|
||||
invocation = session.session.next()
|
||||
|
||||
self._process_next_invocation(session, invocation, stats_service)
|
||||
# Loop over invocations until the session is complete or canceled
|
||||
while invocation is not None:
|
||||
if self._stop_event.is_set():
|
||||
break
|
||||
self._resume_event.wait()
|
||||
|
||||
# The session is complete if all invocations are complete or there was an error
|
||||
if session.session.is_complete():
|
||||
# Send complete event
|
||||
self._invoker.services.events.emit_graph_execution_complete(
|
||||
queue_batch_id=session.batch_id,
|
||||
queue_item_id=session.item_id,
|
||||
queue_id=session.queue_id,
|
||||
graph_execution_state_id=session.session.id,
|
||||
)
|
||||
# Log stats
|
||||
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
|
||||
# we don't care about that - suppress the error.
|
||||
with suppress(GESStatsNotFoundError):
|
||||
stats_service.log_stats(session.session.id)
|
||||
stats_service.reset_stats()
|
||||
self._process_next_invocation(session, invocation, stats_service)
|
||||
|
||||
# If we are profiling, stop the profiler and dump the profile & stats
|
||||
if self._profiler:
|
||||
profile_path = self._profiler.stop()
|
||||
stats_path = profile_path.with_suffix(".json")
|
||||
stats_service.dump_stats(
|
||||
graph_execution_state_id=session.session.id, output_path=stats_path
|
||||
# The session is complete if all invocations are complete or there was an error
|
||||
if session.session.is_complete():
|
||||
# Send complete event
|
||||
self._invoker.services.events.emit_graph_execution_complete(
|
||||
queue_batch_id=session.batch_id,
|
||||
queue_item_id=session.item_id,
|
||||
queue_id=session.queue_id,
|
||||
graph_execution_state_id=session.session.id,
|
||||
)
|
||||
self._queue_items.remove(session.item_id)
|
||||
invocation = None
|
||||
else:
|
||||
# Prepare the next invocation
|
||||
with self._process_lock:
|
||||
invocation = session.session.next()
|
||||
# Log stats
|
||||
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
|
||||
# we don't care about that - suppress the error.
|
||||
with suppress(GESStatsNotFoundError):
|
||||
stats_service.log_stats(session.session.id)
|
||||
stats_service.reset_stats()
|
||||
|
||||
# If we are profiling, stop the profiler and dump the profile & stats
|
||||
if self._profiler:
|
||||
profile_path = self._profiler.stop()
|
||||
stats_path = profile_path.with_suffix(".json")
|
||||
stats_service.dump_stats(
|
||||
graph_execution_state_id=session.session.id, output_path=stats_path
|
||||
)
|
||||
self._queue_items.remove(session.item_id)
|
||||
invocation = None
|
||||
else:
|
||||
# Prepare the next invocation
|
||||
with self._process_lock:
|
||||
invocation = session.session.next()
|
||||
|
||||
except Exception:
|
||||
# Non-fatal error in processor
|
||||
|
@ -8,9 +8,10 @@ model will be cleared and (re)loaded from disk when next needed.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from logging import Logger
|
||||
from typing import Dict, Generic, Optional, Set, TypeVar
|
||||
from typing import Dict, Generator, Generic, Optional, Set, TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
@ -93,20 +94,23 @@ class ModelCacheBase(ABC, Generic[T]):
|
||||
"""Return the set of available execution devices."""
|
||||
pass
|
||||
|
||||
@contextmanager
|
||||
@abstractmethod
|
||||
def acquire_execution_device(self, timeout: int = 0) -> torch.device:
|
||||
"""
|
||||
Pick the next available execution device.
|
||||
|
||||
If all devices are currently engaged (locked), then
|
||||
block until timeout seconds have passed and raise a
|
||||
TimeoutError if no devices are available.
|
||||
"""
|
||||
def reserve_execution_device(self, timeout: int = 0) -> Generator[torch.device, None, None]:
|
||||
"""Reserve an execution device (GPU) under the current thread id."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def release_execution_device(self, device: torch.device) -> None:
|
||||
"""Release a previously-acquired execution device."""
|
||||
def get_execution_device(self) -> torch.device:
|
||||
"""
|
||||
Return an execution device that has been reserved for current thread.
|
||||
|
||||
Note that reservations are done using the current thread's TID.
|
||||
It would be better to do this using the session ID, but that involves
|
||||
too many detailed changes to model manager calls.
|
||||
|
||||
May generate a ValueError if no GPU has been reserved.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
|
@ -21,13 +21,12 @@ context. Use like this:
|
||||
import gc
|
||||
import sys
|
||||
import threading
|
||||
from contextlib import suppress
|
||||
from contextlib import contextmanager, suppress
|
||||
from logging import Logger
|
||||
from threading import BoundedSemaphore, Lock
|
||||
from typing import Dict, List, Optional, Set
|
||||
from threading import BoundedSemaphore
|
||||
from typing import Dict, Generator, List, Optional, Set
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel
|
||||
|
||||
from invokeai.backend.model_manager import AnyModel, SubModelType
|
||||
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot
|
||||
@ -51,26 +50,6 @@ GIG = 1073741824
|
||||
MB = 2**20
|
||||
|
||||
|
||||
# GPU device can only be used by one thread at a time.
|
||||
# The refcount indicates the number of models stored
|
||||
# in it.
|
||||
class GPUDeviceStatus(BaseModel):
|
||||
"""Track of which threads are using the GPU(s) on this system."""
|
||||
|
||||
device: torch.device
|
||||
thread_id: int = 0
|
||||
refcount: int = 0
|
||||
|
||||
class Config:
|
||||
"""Configure the base model."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""Allow to be added to a set."""
|
||||
return hash(str(torch.device))
|
||||
|
||||
|
||||
class ModelCache(ModelCacheBase[AnyModel]):
|
||||
"""Implementation of ModelCacheBase."""
|
||||
|
||||
@ -100,9 +79,8 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
"""
|
||||
self._precision: torch.dtype = precision
|
||||
self._max_cache_size: float = max_cache_size
|
||||
self._execution_devices: Set[GPUDeviceStatus] = self._get_execution_devices(execution_devices)
|
||||
self._storage_device: torch.device = storage_device
|
||||
self._lock = threading.Lock()
|
||||
self._ram_lock = threading.Lock()
|
||||
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
|
||||
self._log_memory_usage = log_memory_usage
|
||||
self._stats: Optional[CacheStats] = None
|
||||
@ -110,11 +88,15 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
self._cached_models: Dict[str, CacheRecord[AnyModel]] = {}
|
||||
self._cache_stack: List[str] = []
|
||||
|
||||
self._lock = Lock()
|
||||
# device to thread id
|
||||
self._device_lock = threading.Lock()
|
||||
self._execution_devices: Dict[torch.device, int] = {
|
||||
x: 0 for x in execution_devices or self._get_execution_devices()
|
||||
}
|
||||
self._free_execution_device = BoundedSemaphore(len(self._execution_devices))
|
||||
|
||||
self.logger.info(
|
||||
f"Using rendering device(s): {', '.join(sorted([str(x.device) for x in self._execution_devices]))}"
|
||||
f"Using rendering device(s): {', '.join(sorted([str(x) for x in self._execution_devices.keys()]))}"
|
||||
)
|
||||
|
||||
@property
|
||||
@ -130,34 +112,61 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
@property
|
||||
def execution_devices(self) -> Set[torch.device]:
|
||||
"""Return the set of available execution devices."""
|
||||
return {x.device for x in self._execution_devices}
|
||||
devices = self._execution_devices.keys()
|
||||
return set(devices)
|
||||
|
||||
def acquire_execution_device(self, timeout: int = 0) -> torch.device:
|
||||
"""Acquire and return an execution device (e.g. "cuda" for VRAM)."""
|
||||
def get_execution_device(self) -> torch.device:
|
||||
"""
|
||||
Return an execution device that has been reserved for current thread.
|
||||
|
||||
Note that reservations are done using the current thread's TID.
|
||||
It would be better to do this using the session ID, but that involves
|
||||
too many detailed changes to model manager calls.
|
||||
|
||||
May generate a ValueError if no GPU has been reserved.
|
||||
"""
|
||||
current_thread = threading.current_thread().ident
|
||||
assert current_thread is not None
|
||||
assigned = [x for x, tid in self._execution_devices.items() if current_thread == tid]
|
||||
if not assigned:
|
||||
raise ValueError("No GPU has been reserved for the use of thread {current_thread}")
|
||||
return assigned[0]
|
||||
|
||||
# first try to assign a device that is already executing on this thread
|
||||
if claimed_devices := [x for x in self._execution_devices if x.thread_id == current_thread]:
|
||||
claimed_devices[0].refcount += 1
|
||||
return claimed_devices[0].device
|
||||
else:
|
||||
# this thread is not currently using any gpu. Wait for a free one
|
||||
@contextmanager
|
||||
def reserve_execution_device(self, timeout: Optional[int] = None) -> Generator[torch.device, None, None]:
|
||||
"""Reserve an execution device (e.g. GPU) for exclusive use by a generation thread.
|
||||
|
||||
Note that the reservation is done using the current thread's TID.
|
||||
It would be better to do this using the session ID, but that involves
|
||||
too many detailed changes to model manager calls.
|
||||
"""
|
||||
device = None
|
||||
with self._device_lock:
|
||||
current_thread = threading.current_thread().ident
|
||||
assert current_thread is not None
|
||||
|
||||
# look for a device that has already been assigned to this thread
|
||||
assigned = [x for x, tid in self._execution_devices.items() if current_thread == tid]
|
||||
if assigned:
|
||||
device = assigned[0]
|
||||
|
||||
# no device already assigned. Get one.
|
||||
if device is None:
|
||||
self._free_execution_device.acquire(timeout=timeout)
|
||||
unclaimed_devices = [x for x in self._execution_devices if x.refcount == 0]
|
||||
unclaimed_devices[0].thread_id = current_thread
|
||||
unclaimed_devices[0].refcount += 1
|
||||
return unclaimed_devices[0].device
|
||||
with self._device_lock:
|
||||
free_device = [x for x, tid in self._execution_devices.items() if tid == 0]
|
||||
print(f"DEBUG: execution devices = {self._execution_devices}")
|
||||
self._execution_devices[free_device[0]] = current_thread
|
||||
device = free_device[0]
|
||||
|
||||
def release_execution_device(self, device: torch.device) -> None:
|
||||
"""Mark this execution device as unused."""
|
||||
current_thread = threading.current_thread().ident
|
||||
for x in self._execution_devices:
|
||||
if x.thread_id == current_thread and x.device == device:
|
||||
x.refcount -= 1
|
||||
if x.refcount == 0:
|
||||
x.thread_id = 0
|
||||
self._free_execution_device.release()
|
||||
# we are outside the lock region now
|
||||
try:
|
||||
yield device
|
||||
finally:
|
||||
with self._device_lock:
|
||||
self._execution_devices[device] = 0
|
||||
self._free_execution_device.release()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@property
|
||||
def max_cache_size(self) -> float:
|
||||
@ -203,7 +212,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> None:
|
||||
"""Store model under key and optional submodel_type."""
|
||||
with self._lock:
|
||||
with self._ram_lock:
|
||||
key = self._make_cache_key(key, submodel_type)
|
||||
if key in self._cached_models:
|
||||
return
|
||||
@ -228,7 +237,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
|
||||
This may raise an IndexError if the model is not in the cache.
|
||||
"""
|
||||
with self._lock:
|
||||
with self._ram_lock:
|
||||
key = self._make_cache_key(key, submodel_type)
|
||||
if key in self._cached_models:
|
||||
if self.stats:
|
||||
@ -395,7 +404,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
raise torch.cuda.OutOfMemoryError
|
||||
|
||||
@staticmethod
|
||||
def _get_execution_devices(devices: Optional[Set[torch.device]] = None) -> Set[GPUDeviceStatus]:
|
||||
def _get_execution_devices(devices: Optional[Set[torch.device]] = None) -> Set[torch.device]:
|
||||
if not devices:
|
||||
default_device = choose_torch_device()
|
||||
if default_device != torch.device("cuda"):
|
||||
@ -403,7 +412,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
else:
|
||||
# we get here if the default device is cuda, and return each of the cuda devices.
|
||||
devices = {torch.device(f"cuda:{x}") for x in range(0, torch.cuda.device_count())}
|
||||
return {GPUDeviceStatus(device=x) for x in devices}
|
||||
return devices
|
||||
|
||||
@staticmethod
|
||||
def _device_name(device: torch.device) -> str:
|
||||
|
@ -56,8 +56,8 @@ class ModelLocker(ModelLockerBase):
|
||||
self._cache_entry.lock()
|
||||
|
||||
try:
|
||||
# We wait for a gpu to be free - may raise a TimeoutError
|
||||
self._execution_device = self._cache.acquire_execution_device(MAX_GPU_WAIT)
|
||||
# We wait for a gpu to be free - may raise a ValueError
|
||||
self._execution_device = self._cache.get_execution_device()
|
||||
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._execution_device}")
|
||||
model_in_gpu = copy.deepcopy(self._cache_entry.model)
|
||||
if hasattr(model_in_gpu, "to"):
|
||||
@ -77,14 +77,5 @@ class ModelLocker(ModelLockerBase):
|
||||
"""Call upon exit from context."""
|
||||
if not hasattr(self.model, "to"):
|
||||
return
|
||||
|
||||
self._cache_entry.unlock()
|
||||
if self._execution_device:
|
||||
self._cache.release_execution_device(self._execution_device)
|
||||
|
||||
try:
|
||||
torch.cuda.empty_cache()
|
||||
torch.mps.empty_cache()
|
||||
except Exception:
|
||||
pass
|
||||
self._cache.print_cuda_stats()
|
||||
|
@ -15,7 +15,15 @@ MPS_DEVICE = torch.device("mps")
|
||||
|
||||
|
||||
def choose_torch_device() -> torch.device:
|
||||
"""Convenience routine for guessing which GPU device to run model on"""
|
||||
"""Convenience routine for guessing which GPU device to run model on."""
|
||||
# """Temporarily modified to use the model manager's get_execution_device()"""
|
||||
# try:
|
||||
# from invokeai.app.api.dependencies import ApiDependencies
|
||||
# model_manager = ApiDependencies.invoker.services.model_manager
|
||||
# device = model_manager.load.ram_cache.acquire_execution_device()
|
||||
# print(f'DEBUG choose_torch_device returning {device}')
|
||||
# return device
|
||||
# except Exception:
|
||||
config = get_config()
|
||||
if config.device == "auto":
|
||||
if torch.cuda.is_available():
|
||||
|
Loading…
Reference in New Issue
Block a user