add support for multi-gpu rendering

This commit adds speculative support for parallel rendering across
multiple GPUs. The parallelism is at the level of a session. Each
session is given access to a different GPU. When all GPUs are busy,
execution of the session will block until a GPU becomes available.

The code is untested at the current time, and is being posted for
comment.
This commit is contained in:
Lincoln Stein 2024-02-19 15:21:55 -05:00
parent b06d63fb34
commit b85f2bc87d
10 changed files with 96 additions and 16 deletions

View File

@ -24,8 +24,10 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
__threadLimit: BoundedSemaphore __threadLimit: BoundedSemaphore
def start(self, invoker: Invoker) -> None: def start(self, invoker: Invoker) -> None:
# if we do want multithreading at some point, we could make this configurable # LS - this will probably break
self.__threadLimit = BoundedSemaphore(1) # but the idea is to enable multithreading up to the number of available
# GPUs. Nodes will block on model loading if no GPU is free.
self.__threadLimit = BoundedSemaphore(invoker.services.model_manager.gpu_count)
self.__invoker = invoker self.__invoker = invoker
self.__stop_event = Event() self.__stop_event = Event()
self.__invoker_thread = Thread( self.__invoker_thread = Thread(

View File

@ -38,3 +38,8 @@ class ModelLoadServiceBase(ABC):
@abstractmethod @abstractmethod
def convert_cache(self) -> ModelConvertCacheBase: def convert_cache(self) -> ModelConvertCacheBase:
"""Return the checkpoint convert cache used by this loader.""" """Return the checkpoint convert cache used by this loader."""
@property
@abstractmethod
def gpu_count(self) -> int:
"""Return the number of GPUs we are configured to use."""

View File

@ -40,6 +40,7 @@ class ModelLoadService(ModelLoadServiceBase):
self._registry = registry self._registry = registry
def start(self, invoker: Invoker) -> None: def start(self, invoker: Invoker) -> None:
"""Start the service."""
self._invoker = invoker self._invoker = invoker
@property @property
@ -47,6 +48,11 @@ class ModelLoadService(ModelLoadServiceBase):
"""Return the RAM cache used by this loader.""" """Return the RAM cache used by this loader."""
return self._ram_cache return self._ram_cache
@property
def gpu_count(self) -> int:
"""Return the number of GPUs available for our uses."""
return len(self._ram_cache.execution_devices)
@property @property
def convert_cache(self) -> ModelConvertCacheBase: def convert_cache(self) -> ModelConvertCacheBase:
"""Return the checkpoint convert cache used by this loader.""" """Return the checkpoint convert cache used by this loader."""

View File

@ -98,3 +98,8 @@ class ModelManagerServiceBase(ABC):
context_data: Optional[InvocationContextData] = None, context_data: Optional[InvocationContextData] = None,
) -> LoadedModel: ) -> LoadedModel:
pass pass
@property
@abstractmethod
def gpu_count(self) -> int:
"""Return the number of GPUs we are configured to use."""

View File

@ -112,6 +112,11 @@ class ModelManagerService(ModelManagerServiceBase):
else: else:
return self.load.load_model(configs[0], submodel, context_data) return self.load.load_model(configs[0], submodel, context_data)
@property
def gpu_count(self) -> int:
"""Return the number of GPUs we are using."""
return self.load.gpu_count
@classmethod @classmethod
def build_model_manager( def build_model_manager(
cls, cls,

View File

@ -10,7 +10,7 @@ model will be cleared and (re)loaded from disk when next needed.
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
from logging import Logger from logging import Logger
from typing import Dict, Generic, Optional, TypeVar from typing import Dict, Generic, Optional, Set, TypeVar
import torch import torch
@ -89,8 +89,24 @@ class ModelCacheBase(ABC, Generic[T]):
@property @property
@abstractmethod @abstractmethod
def execution_device(self) -> torch.device: def execution_devices(self) -> Set[torch.device]:
"""Return the exection device (e.g. "cuda" for VRAM).""" """Return the set of available execution devices."""
pass
@abstractmethod
def acquire_execution_device(self, timeout: int = 0) -> torch.device:
"""
Pick the next available execution device.
If all devices are currently engaged (locked), then
block until timeout seconds have passed and raise a
TimeoutError if no devices are available.
"""
pass
@abstractmethod
def release_execution_device(self, device: torch.device) -> None:
"""Release a previously-acquired execution device."""
pass pass
@property @property

View File

@ -25,7 +25,8 @@ import sys
import time import time
from contextlib import suppress from contextlib import suppress
from logging import Logger from logging import Logger
from typing import Dict, List, Optional from threading import BoundedSemaphore, Lock
from typing import Dict, List, Optional, Set
import torch import torch
@ -61,8 +62,8 @@ class ModelCache(ModelCacheBase[AnyModel]):
self, self,
max_cache_size: float = DEFAULT_MAX_CACHE_SIZE, max_cache_size: float = DEFAULT_MAX_CACHE_SIZE,
max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE, max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE,
execution_device: torch.device = torch.device("cuda"),
storage_device: torch.device = torch.device("cpu"), storage_device: torch.device = torch.device("cpu"),
execution_devices: Optional[Set[torch.device]] = None,
precision: torch.dtype = torch.float16, precision: torch.dtype = torch.float16,
sequential_offload: bool = False, sequential_offload: bool = False,
lazy_offloading: bool = True, lazy_offloading: bool = True,
@ -74,7 +75,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
Initialize the model RAM cache. Initialize the model RAM cache.
:param max_cache_size: Maximum size of the RAM cache [6.0 GB] :param max_cache_size: Maximum size of the RAM cache [6.0 GB]
:param execution_device: Torch device to load active model into [torch.device('cuda')] :param execution_devices: Set of torch device to load active model into [calculated]
:param storage_device: Torch device to save inactive model in [torch.device('cpu')] :param storage_device: Torch device to save inactive model in [torch.device('cpu')]
:param precision: Precision for loaded models [torch.float16] :param precision: Precision for loaded models [torch.float16]
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded :param lazy_offloading: Keep model in VRAM until another model needs to be loaded
@ -89,7 +90,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
self._precision: torch.dtype = precision self._precision: torch.dtype = precision
self._max_cache_size: float = max_cache_size self._max_cache_size: float = max_cache_size
self._max_vram_cache_size: float = max_vram_cache_size self._max_vram_cache_size: float = max_vram_cache_size
self._execution_device: torch.device = execution_device self._execution_devices: Set[torch.device] = execution_devices or self._get_execution_devices()
self._storage_device: torch.device = storage_device self._storage_device: torch.device = storage_device
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__) self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
self._log_memory_usage = log_memory_usage or self._logger.level == logging.DEBUG self._log_memory_usage = log_memory_usage or self._logger.level == logging.DEBUG
@ -99,6 +100,10 @@ class ModelCache(ModelCacheBase[AnyModel]):
self._cached_models: Dict[str, CacheRecord[AnyModel]] = {} self._cached_models: Dict[str, CacheRecord[AnyModel]] = {}
self._cache_stack: List[str] = [] self._cache_stack: List[str] = []
self._lock = Lock()
self._free_execution_device = BoundedSemaphore(len(self._execution_devices))
self._busy_execution_devices: Set[torch.device] = set()
@property @property
def logger(self) -> Logger: def logger(self) -> Logger:
"""Return the logger used by the cache.""" """Return the logger used by the cache."""
@ -115,9 +120,24 @@ class ModelCache(ModelCacheBase[AnyModel]):
return self._storage_device return self._storage_device
@property @property
def execution_device(self) -> torch.device: def execution_devices(self) -> Set[torch.device]:
"""Return the exection device (e.g. "cuda" for VRAM).""" """Return the set of available execution devices."""
return self._execution_device return self._execution_devices
def acquire_execution_device(self, timeout: int = 0) -> torch.device:
"""Acquire and return an execution device (e.g. "cuda" for VRAM)."""
with self._lock:
self._free_execution_device.acquire(timeout=timeout)
free_devices = self.execution_devices - self._busy_execution_devices
chosen_device = list(free_devices)[0]
self._busy_execution_devices.add(chosen_device)
return chosen_device
def release_execution_device(self, device: torch.device) -> None:
"""Mark this execution device as unused."""
with self._lock:
self._free_execution_device.release()
self._busy_execution_devices.remove(device)
@property @property
def max_cache_size(self) -> float: def max_cache_size(self) -> float:
@ -405,3 +425,13 @@ class ModelCache(ModelCacheBase[AnyModel]):
mps.empty_cache() mps.empty_cache()
self.logger.debug(f"After making room: cached_models={len(self._cached_models)}") self.logger.debug(f"After making room: cached_models={len(self._cached_models)}")
@staticmethod
def _get_execution_devices() -> Set[torch.device]:
default_device = choose_torch_device()
if default_device != torch.device("cuda"):
return {default_device}
# we get here if the default device is cuda, and return each of the
# cuda devices.
return {torch.device(f"cuda:{x}") for x in range(0, torch.cuda.device_count())}

View File

@ -2,10 +2,16 @@
Base class and implementation of a class that moves models in and out of VRAM. Base class and implementation of a class that moves models in and out of VRAM.
""" """
from typing import Optional
import torch
from invokeai.backend.model_manager import AnyModel from invokeai.backend.model_manager import AnyModel
from .model_cache_base import CacheRecord, ModelCacheBase, ModelLockerBase from .model_cache_base import CacheRecord, ModelCacheBase, ModelLockerBase
MAX_GPU_WAIT = 600 # wait up to 10 minutes for a GPU to become free
class ModelLocker(ModelLockerBase): class ModelLocker(ModelLockerBase):
"""Internal class that mediates movement in and out of GPU.""" """Internal class that mediates movement in and out of GPU."""
@ -19,6 +25,7 @@ class ModelLocker(ModelLockerBase):
""" """
self._cache = cache self._cache = cache
self._cache_entry = cache_entry self._cache_entry = cache_entry
self._execution_device: Optional[torch.device] = None
@property @property
def model(self) -> AnyModel: def model(self) -> AnyModel:
@ -37,10 +44,12 @@ class ModelLocker(ModelLockerBase):
if self._cache.lazy_offloading: if self._cache.lazy_offloading:
self._cache.offload_unlocked_models(self._cache_entry.size) self._cache.offload_unlocked_models(self._cache_entry.size)
self._cache.move_model_to_device(self._cache_entry, self._cache.execution_device) # We wait for a gpu to be free - may raise a TimeoutError
self._execution_device = self._cache.acquire_execution_device(MAX_GPU_WAIT)
self._cache.move_model_to_device(self._cache_entry, self._execution_device)
self._cache_entry.loaded = True self._cache_entry.loaded = True
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}") self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._execution_device}")
self._cache.print_cuda_stats() self._cache.print_cuda_stats()
except Exception: except Exception:
@ -54,6 +63,8 @@ class ModelLocker(ModelLockerBase):
return return
self._cache_entry.unlock() self._cache_entry.unlock()
if self._execution_device:
self._cache.release_execution_device(self._execution_device)
if not self._cache.lazy_offloading: if not self._cache.lazy_offloading:
self._cache.offload_unlocked_models(self._cache_entry.size) self._cache.offload_unlocked_models(self._cache_entry.size)
self._cache.print_cuda_stats() self._cache.print_cuda_stats()

View File

@ -65,7 +65,7 @@ def mock_services() -> InvocationServices:
images=None, # type: ignore images=None, # type: ignore
invocation_cache=MemoryInvocationCache(max_cache_size=0), invocation_cache=MemoryInvocationCache(max_cache_size=0),
logger=logging, # type: ignore logger=logging, # type: ignore
model_manager=Mock(), # type: ignore model_manager=Mock(gpu_count=1), # type: ignore
download_queue=None, # type: ignore download_queue=None, # type: ignore
names=None, # type: ignore names=None, # type: ignore
performance_statistics=InvocationStatsService(), performance_statistics=InvocationStatsService(),

View File

@ -112,7 +112,7 @@ def mm2_metadata_store(mm2_record_store: ModelRecordServiceSQL) -> ModelMetadata
@pytest.fixture @pytest.fixture
def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordServiceBase) -> ModelLoadServiceBase: def mm2_loader(mm2_app_config: InvokeAIAppConfig) -> ModelLoadServiceBase:
ram_cache = ModelCache( ram_cache = ModelCache(
logger=InvokeAILogger.get_logger(), logger=InvokeAILogger.get_logger(),
max_cache_size=mm2_app_config.ram_cache_size, max_cache_size=mm2_app_config.ram_cache_size,