add draft multi-gpu support

This commit is contained in:
Lincoln Stein 2024-02-19 15:21:55 -05:00
parent 74a51571a0
commit 6b991a5269
9 changed files with 94 additions and 20 deletions

View File

@ -38,3 +38,8 @@ class ModelLoadServiceBase(ABC):
@abstractmethod @abstractmethod
def convert_cache(self) -> ModelConvertCacheBase: def convert_cache(self) -> ModelConvertCacheBase:
"""Return the checkpoint convert cache used by this loader.""" """Return the checkpoint convert cache used by this loader."""
@property
@abstractmethod
def gpu_count(self) -> int:
"""Return the number of GPUs we are configured to use."""

View File

@ -39,6 +39,7 @@ class ModelLoadService(ModelLoadServiceBase):
self._registry = registry self._registry = registry
def start(self, invoker: Invoker) -> None: def start(self, invoker: Invoker) -> None:
"""Start the service."""
self._invoker = invoker self._invoker = invoker
@property @property
@ -46,6 +47,11 @@ class ModelLoadService(ModelLoadServiceBase):
"""Return the RAM cache used by this loader.""" """Return the RAM cache used by this loader."""
return self._ram_cache return self._ram_cache
@property
def gpu_count(self) -> int:
"""Return the number of GPUs available for our uses."""
return len(self._ram_cache.execution_devices)
@property @property
def convert_cache(self) -> ModelConvertCacheBase: def convert_cache(self) -> ModelConvertCacheBase:
"""Return the checkpoint convert cache used by this loader.""" """Return the checkpoint convert cache used by this loader."""

View File

@ -1,6 +1,7 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team # Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Optional, Set
import torch import torch
from typing_extensions import Self from typing_extensions import Self
@ -31,7 +32,7 @@ class ModelManagerServiceBase(ABC):
model_record_service: ModelRecordServiceBase, model_record_service: ModelRecordServiceBase,
download_queue: DownloadQueueServiceBase, download_queue: DownloadQueueServiceBase,
events: EventServiceBase, events: EventServiceBase,
execution_device: torch.device, execution_devices: Optional[Set[torch.device]] = None,
) -> Self: ) -> Self:
""" """
Construct the model manager service instance. Construct the model manager service instance.

View File

@ -1,12 +1,13 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team # Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
"""Implementation of ModelManagerServiceBase.""" """Implementation of ModelManagerServiceBase."""
from typing import Optional, Set
import torch import torch
from typing_extensions import Self from typing_extensions import Self
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
from ..config import InvokeAIAppConfig from ..config import InvokeAIAppConfig
@ -67,7 +68,7 @@ class ModelManagerService(ModelManagerServiceBase):
model_record_service: ModelRecordServiceBase, model_record_service: ModelRecordServiceBase,
download_queue: DownloadQueueServiceBase, download_queue: DownloadQueueServiceBase,
events: EventServiceBase, events: EventServiceBase,
execution_device: torch.device = choose_torch_device(), execution_devices: Optional[Set[torch.device]] = None,
) -> Self: ) -> Self:
""" """
Construct the model manager service instance. Construct the model manager service instance.
@ -81,7 +82,7 @@ class ModelManagerService(ModelManagerServiceBase):
max_cache_size=app_config.ram, max_cache_size=app_config.ram,
max_vram_cache_size=app_config.vram, max_vram_cache_size=app_config.vram,
logger=logger, logger=logger,
execution_device=execution_device, execution_devices=execution_devices,
) )
convert_cache = ModelConvertCache(cache_path=app_config.convert_cache_path, max_size=app_config.convert_cache) convert_cache = ModelConvertCache(cache_path=app_config.convert_cache_path, max_size=app_config.convert_cache)
loader = ModelLoadService( loader = ModelLoadService(

View File

@ -21,7 +21,7 @@ from .session_processor_common import SessionProcessorStatus
class DefaultSessionProcessor(SessionProcessorBase): class DefaultSessionProcessor(SessionProcessorBase):
def start(self, invoker: Invoker, thread_limit: int = 1, polling_interval: int = 1) -> None: def start(self, invoker: Invoker, polling_interval: int = 1) -> None:
self._invoker: Invoker = invoker self._invoker: Invoker = invoker
self._queue_item: Optional[SessionQueueItem] = None self._queue_item: Optional[SessionQueueItem] = None
self._invocation: Optional[BaseInvocation] = None self._invocation: Optional[BaseInvocation] = None
@ -33,8 +33,8 @@ class DefaultSessionProcessor(SessionProcessorBase):
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_queue_event) local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_queue_event)
self._thread_limit = thread_limit self._thread_limit = self._invoker.services.model_manager.load.gpu_count
self._thread_semaphore = BoundedSemaphore(thread_limit) self._thread_semaphore = BoundedSemaphore(self._thread_limit)
self._polling_interval = polling_interval self._polling_interval = polling_interval
# If profiling is enabled, create a profiler. The same profiler will be used for all sessions. Internally, # If profiling is enabled, create a profiler. The same profiler will be used for all sessions. Internally,

View File

@ -10,7 +10,7 @@ model will be cleared and (re)loaded from disk when next needed.
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
from logging import Logger from logging import Logger
from typing import Dict, Generic, Optional, TypeVar from typing import Dict, Generic, Optional, Set, TypeVar
import torch import torch
@ -89,8 +89,24 @@ class ModelCacheBase(ABC, Generic[T]):
@property @property
@abstractmethod @abstractmethod
def execution_device(self) -> torch.device: def execution_devices(self) -> Set[torch.device]:
"""Return the exection device (e.g. "cuda" for VRAM).""" """Return the set of available execution devices."""
pass
@abstractmethod
def acquire_execution_device(self, timeout: int = 0) -> torch.device:
"""
Pick the next available execution device.
If all devices are currently engaged (locked), then
block until timeout seconds have passed and raise a
TimeoutError if no devices are available.
"""
pass
@abstractmethod
def release_execution_device(self, device: torch.device) -> None:
"""Release a previously-acquired execution device."""
pass pass
@property @property

View File

@ -24,7 +24,8 @@ import sys
import time import time
from contextlib import suppress from contextlib import suppress
from logging import Logger from logging import Logger
from typing import Dict, List, Optional from threading import BoundedSemaphore, Lock
from typing import Dict, List, Optional, Set
import torch import torch
@ -60,8 +61,8 @@ class ModelCache(ModelCacheBase[AnyModel]):
self, self,
max_cache_size: float = DEFAULT_MAX_CACHE_SIZE, max_cache_size: float = DEFAULT_MAX_CACHE_SIZE,
max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE, max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE,
execution_device: torch.device = torch.device("cuda"),
storage_device: torch.device = torch.device("cpu"), storage_device: torch.device = torch.device("cpu"),
execution_devices: Optional[Set[torch.device]] = None,
precision: torch.dtype = torch.float16, precision: torch.dtype = torch.float16,
sequential_offload: bool = False, sequential_offload: bool = False,
lazy_offloading: bool = True, lazy_offloading: bool = True,
@ -73,7 +74,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
Initialize the model RAM cache. Initialize the model RAM cache.
:param max_cache_size: Maximum size of the RAM cache [6.0 GB] :param max_cache_size: Maximum size of the RAM cache [6.0 GB]
:param execution_device: Torch device to load active model into [torch.device('cuda')] :param execution_devices: Set of torch device to load active model into [calculated]
:param storage_device: Torch device to save inactive model in [torch.device('cpu')] :param storage_device: Torch device to save inactive model in [torch.device('cpu')]
:param precision: Precision for loaded models [torch.float16] :param precision: Precision for loaded models [torch.float16]
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded :param lazy_offloading: Keep model in VRAM until another model needs to be loaded
@ -88,7 +89,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
self._precision: torch.dtype = precision self._precision: torch.dtype = precision
self._max_cache_size: float = max_cache_size self._max_cache_size: float = max_cache_size
self._max_vram_cache_size: float = max_vram_cache_size self._max_vram_cache_size: float = max_vram_cache_size
self._execution_device: torch.device = execution_device self._execution_devices: Set[torch.device] = execution_devices or self._get_execution_devices()
self._storage_device: torch.device = storage_device self._storage_device: torch.device = storage_device
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__) self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
self._log_memory_usage = log_memory_usage self._log_memory_usage = log_memory_usage
@ -97,6 +98,12 @@ class ModelCache(ModelCacheBase[AnyModel]):
self._cached_models: Dict[str, CacheRecord[AnyModel]] = {} self._cached_models: Dict[str, CacheRecord[AnyModel]] = {}
self._cache_stack: List[str] = [] self._cache_stack: List[str] = []
self._lock = Lock()
self._free_execution_device = BoundedSemaphore(len(self._execution_devices))
self._busy_execution_devices: Set[torch.device] = set()
self.logger.info(f"Using rendering device(s) {[self._device_name(x) for x in self._execution_devices]}")
@property @property
def logger(self) -> Logger: def logger(self) -> Logger:
"""Return the logger used by the cache.""" """Return the logger used by the cache."""
@ -113,9 +120,24 @@ class ModelCache(ModelCacheBase[AnyModel]):
return self._storage_device return self._storage_device
@property @property
def execution_device(self) -> torch.device: def execution_devices(self) -> Set[torch.device]:
"""Return the exection device (e.g. "cuda" for VRAM).""" """Return the set of available execution devices."""
return self._execution_device return self._execution_devices
def acquire_execution_device(self, timeout: int = 0) -> torch.device:
"""Acquire and return an execution device (e.g. "cuda" for VRAM)."""
with self._lock:
self._free_execution_device.acquire(timeout=timeout)
free_devices = self.execution_devices - self._busy_execution_devices
chosen_device = list(free_devices)[0]
self._busy_execution_devices.add(chosen_device)
return chosen_device
def release_execution_device(self, device: torch.device) -> None:
"""Mark this execution device as unused."""
with self._lock:
self._free_execution_device.release()
self._busy_execution_devices.remove(device)
@property @property
def max_cache_size(self) -> float: def max_cache_size(self) -> float:
@ -422,3 +444,17 @@ class ModelCache(ModelCacheBase[AnyModel]):
free_mem, _ = torch.cuda.mem_get_info(torch.device(vram_device)) free_mem, _ = torch.cuda.mem_get_info(torch.device(vram_device))
if needed_size > free_mem: if needed_size > free_mem:
raise torch.cuda.OutOfMemoryError raise torch.cuda.OutOfMemoryError
@staticmethod
def _get_execution_devices() -> Set[torch.device]:
default_device = choose_torch_device()
if default_device != torch.device("cuda"):
return {default_device}
# we get here if the default device is cuda, and return each of the
# cuda devices.
return {torch.device(f"cuda:{x}") for x in range(0, torch.cuda.device_count())}
@staticmethod
def _device_name(device: torch.device) -> str:
return f"{device.type}:{device.index}"

View File

@ -2,12 +2,16 @@
Base class and implementation of a class that moves models in and out of VRAM. Base class and implementation of a class that moves models in and out of VRAM.
""" """
from typing import Optional
import torch import torch
from invokeai.backend.model_manager import AnyModel from invokeai.backend.model_manager import AnyModel
from .model_cache_base import CacheRecord, ModelCacheBase, ModelLockerBase from .model_cache_base import CacheRecord, ModelCacheBase, ModelLockerBase
MAX_GPU_WAIT = 600 # wait up to 10 minutes for a GPU to become free
class ModelLocker(ModelLockerBase): class ModelLocker(ModelLockerBase):
"""Internal class that mediates movement in and out of GPU.""" """Internal class that mediates movement in and out of GPU."""
@ -21,6 +25,7 @@ class ModelLocker(ModelLockerBase):
""" """
self._cache = cache self._cache = cache
self._cache_entry = cache_entry self._cache_entry = cache_entry
self._execution_device: Optional[torch.device] = None
@property @property
def model(self) -> AnyModel: def model(self) -> AnyModel:
@ -39,10 +44,12 @@ class ModelLocker(ModelLockerBase):
if self._cache.lazy_offloading: if self._cache.lazy_offloading:
self._cache.offload_unlocked_models(self._cache_entry.size) self._cache.offload_unlocked_models(self._cache_entry.size)
self._cache.move_model_to_device(self._cache_entry, self._cache.execution_device) # We wait for a gpu to be free - may raise a TimeoutError
self._execution_device = self._cache.acquire_execution_device(MAX_GPU_WAIT)
self._cache.move_model_to_device(self._cache_entry, self._execution_device)
self._cache_entry.loaded = True self._cache_entry.loaded = True
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}") self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._execution_device}")
self._cache.print_cuda_stats() self._cache.print_cuda_stats()
except torch.cuda.OutOfMemoryError: except torch.cuda.OutOfMemoryError:
self._cache.logger.warning("Insufficient GPU memory to load model. Aborting") self._cache.logger.warning("Insufficient GPU memory to load model. Aborting")
@ -59,6 +66,8 @@ class ModelLocker(ModelLockerBase):
return return
self._cache_entry.unlock() self._cache_entry.unlock()
if self._execution_device:
self._cache.release_execution_device(self._execution_device)
if not self._cache.lazy_offloading: if not self._cache.lazy_offloading:
self._cache.offload_unlocked_models(self._cache_entry.size) self._cache.offload_unlocked_models(self._cache_entry.size)
self._cache.print_cuda_stats() self._cache.print_cuda_stats()

View File

@ -111,7 +111,7 @@ def mm2_download_queue(mm2_session: Session, request: FixtureRequest) -> Downloa
@pytest.fixture @pytest.fixture
def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordServiceBase) -> ModelLoadServiceBase: def mm2_loader(mm2_app_config: InvokeAIAppConfig) -> ModelLoadServiceBase:
ram_cache = ModelCache( ram_cache = ModelCache(
logger=InvokeAILogger.get_logger(), logger=InvokeAILogger.get_logger(),
max_cache_size=mm2_app_config.ram, max_cache_size=mm2_app_config.ram,