add draft multi-gpu support

This commit is contained in:
Lincoln Stein 2024-02-19 15:21:55 -05:00
parent 74a51571a0
commit 6b991a5269
9 changed files with 94 additions and 20 deletions

View File

@ -38,3 +38,8 @@ class ModelLoadServiceBase(ABC):
@abstractmethod
def convert_cache(self) -> ModelConvertCacheBase:
"""Return the checkpoint convert cache used by this loader."""
@property
@abstractmethod
def gpu_count(self) -> int:
"""Return the number of GPUs we are configured to use."""

View File

@ -39,6 +39,7 @@ class ModelLoadService(ModelLoadServiceBase):
self._registry = registry
def start(self, invoker: Invoker) -> None:
"""Start the service."""
self._invoker = invoker
@property
@ -46,6 +47,11 @@ class ModelLoadService(ModelLoadServiceBase):
"""Return the RAM cache used by this loader."""
return self._ram_cache
@property
def gpu_count(self) -> int:
"""Return the number of GPUs available for our uses."""
return len(self._ram_cache.execution_devices)
@property
def convert_cache(self) -> ModelConvertCacheBase:
"""Return the checkpoint convert cache used by this loader."""

View File

@ -1,6 +1,7 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
from abc import ABC, abstractmethod
from typing import Optional, Set
import torch
from typing_extensions import Self
@ -31,7 +32,7 @@ class ModelManagerServiceBase(ABC):
model_record_service: ModelRecordServiceBase,
download_queue: DownloadQueueServiceBase,
events: EventServiceBase,
execution_device: torch.device,
execution_devices: Optional[Set[torch.device]] = None,
) -> Self:
"""
Construct the model manager service instance.

View File

@ -1,12 +1,13 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
"""Implementation of ModelManagerServiceBase."""
from typing import Optional, Set
import torch
from typing_extensions import Self
from invokeai.app.services.invoker import Invoker
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.logging import InvokeAILogger
from ..config import InvokeAIAppConfig
@ -67,7 +68,7 @@ class ModelManagerService(ModelManagerServiceBase):
model_record_service: ModelRecordServiceBase,
download_queue: DownloadQueueServiceBase,
events: EventServiceBase,
execution_device: torch.device = choose_torch_device(),
execution_devices: Optional[Set[torch.device]] = None,
) -> Self:
"""
Construct the model manager service instance.
@ -81,7 +82,7 @@ class ModelManagerService(ModelManagerServiceBase):
max_cache_size=app_config.ram,
max_vram_cache_size=app_config.vram,
logger=logger,
execution_device=execution_device,
execution_devices=execution_devices,
)
convert_cache = ModelConvertCache(cache_path=app_config.convert_cache_path, max_size=app_config.convert_cache)
loader = ModelLoadService(

View File

@ -21,7 +21,7 @@ from .session_processor_common import SessionProcessorStatus
class DefaultSessionProcessor(SessionProcessorBase):
def start(self, invoker: Invoker, thread_limit: int = 1, polling_interval: int = 1) -> None:
def start(self, invoker: Invoker, polling_interval: int = 1) -> None:
self._invoker: Invoker = invoker
self._queue_item: Optional[SessionQueueItem] = None
self._invocation: Optional[BaseInvocation] = None
@ -33,8 +33,8 @@ class DefaultSessionProcessor(SessionProcessorBase):
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_queue_event)
self._thread_limit = thread_limit
self._thread_semaphore = BoundedSemaphore(thread_limit)
self._thread_limit = self._invoker.services.model_manager.load.gpu_count
self._thread_semaphore = BoundedSemaphore(self._thread_limit)
self._polling_interval = polling_interval
# If profiling is enabled, create a profiler. The same profiler will be used for all sessions. Internally,

View File

@ -10,7 +10,7 @@ model will be cleared and (re)loaded from disk when next needed.
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from logging import Logger
from typing import Dict, Generic, Optional, TypeVar
from typing import Dict, Generic, Optional, Set, TypeVar
import torch
@ -89,8 +89,24 @@ class ModelCacheBase(ABC, Generic[T]):
@property
@abstractmethod
def execution_device(self) -> torch.device:
"""Return the exection device (e.g. "cuda" for VRAM)."""
def execution_devices(self) -> Set[torch.device]:
"""Return the set of available execution devices."""
pass
@abstractmethod
def acquire_execution_device(self, timeout: int = 0) -> torch.device:
"""
Pick the next available execution device.
If all devices are currently engaged (locked), then
block until timeout seconds have passed and raise a
TimeoutError if no devices are available.
"""
pass
@abstractmethod
def release_execution_device(self, device: torch.device) -> None:
"""Release a previously-acquired execution device."""
pass
@property

View File

@ -24,7 +24,8 @@ import sys
import time
from contextlib import suppress
from logging import Logger
from typing import Dict, List, Optional
from threading import BoundedSemaphore, Lock
from typing import Dict, List, Optional, Set
import torch
@ -60,8 +61,8 @@ class ModelCache(ModelCacheBase[AnyModel]):
self,
max_cache_size: float = DEFAULT_MAX_CACHE_SIZE,
max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE,
execution_device: torch.device = torch.device("cuda"),
storage_device: torch.device = torch.device("cpu"),
execution_devices: Optional[Set[torch.device]] = None,
precision: torch.dtype = torch.float16,
sequential_offload: bool = False,
lazy_offloading: bool = True,
@ -73,7 +74,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
Initialize the model RAM cache.
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
:param execution_device: Torch device to load active model into [torch.device('cuda')]
:param execution_devices: Set of torch device to load active model into [calculated]
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
:param precision: Precision for loaded models [torch.float16]
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
@ -88,7 +89,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
self._precision: torch.dtype = precision
self._max_cache_size: float = max_cache_size
self._max_vram_cache_size: float = max_vram_cache_size
self._execution_device: torch.device = execution_device
self._execution_devices: Set[torch.device] = execution_devices or self._get_execution_devices()
self._storage_device: torch.device = storage_device
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
self._log_memory_usage = log_memory_usage
@ -97,6 +98,12 @@ class ModelCache(ModelCacheBase[AnyModel]):
self._cached_models: Dict[str, CacheRecord[AnyModel]] = {}
self._cache_stack: List[str] = []
self._lock = Lock()
self._free_execution_device = BoundedSemaphore(len(self._execution_devices))
self._busy_execution_devices: Set[torch.device] = set()
self.logger.info(f"Using rendering device(s) {[self._device_name(x) for x in self._execution_devices]}")
@property
def logger(self) -> Logger:
"""Return the logger used by the cache."""
@ -113,9 +120,24 @@ class ModelCache(ModelCacheBase[AnyModel]):
return self._storage_device
@property
def execution_device(self) -> torch.device:
"""Return the exection device (e.g. "cuda" for VRAM)."""
return self._execution_device
def execution_devices(self) -> Set[torch.device]:
"""Return the set of available execution devices."""
return self._execution_devices
def acquire_execution_device(self, timeout: int = 0) -> torch.device:
"""Acquire and return an execution device (e.g. "cuda" for VRAM)."""
with self._lock:
self._free_execution_device.acquire(timeout=timeout)
free_devices = self.execution_devices - self._busy_execution_devices
chosen_device = list(free_devices)[0]
self._busy_execution_devices.add(chosen_device)
return chosen_device
def release_execution_device(self, device: torch.device) -> None:
"""Mark this execution device as unused."""
with self._lock:
self._free_execution_device.release()
self._busy_execution_devices.remove(device)
@property
def max_cache_size(self) -> float:
@ -422,3 +444,17 @@ class ModelCache(ModelCacheBase[AnyModel]):
free_mem, _ = torch.cuda.mem_get_info(torch.device(vram_device))
if needed_size > free_mem:
raise torch.cuda.OutOfMemoryError
@staticmethod
def _get_execution_devices() -> Set[torch.device]:
default_device = choose_torch_device()
if default_device != torch.device("cuda"):
return {default_device}
# we get here if the default device is cuda, and return each of the
# cuda devices.
return {torch.device(f"cuda:{x}") for x in range(0, torch.cuda.device_count())}
@staticmethod
def _device_name(device: torch.device) -> str:
return f"{device.type}:{device.index}"

View File

@ -2,12 +2,16 @@
Base class and implementation of a class that moves models in and out of VRAM.
"""
from typing import Optional
import torch
from invokeai.backend.model_manager import AnyModel
from .model_cache_base import CacheRecord, ModelCacheBase, ModelLockerBase
MAX_GPU_WAIT = 600 # wait up to 10 minutes for a GPU to become free
class ModelLocker(ModelLockerBase):
"""Internal class that mediates movement in and out of GPU."""
@ -21,6 +25,7 @@ class ModelLocker(ModelLockerBase):
"""
self._cache = cache
self._cache_entry = cache_entry
self._execution_device: Optional[torch.device] = None
@property
def model(self) -> AnyModel:
@ -39,10 +44,12 @@ class ModelLocker(ModelLockerBase):
if self._cache.lazy_offloading:
self._cache.offload_unlocked_models(self._cache_entry.size)
self._cache.move_model_to_device(self._cache_entry, self._cache.execution_device)
# We wait for a gpu to be free - may raise a TimeoutError
self._execution_device = self._cache.acquire_execution_device(MAX_GPU_WAIT)
self._cache.move_model_to_device(self._cache_entry, self._execution_device)
self._cache_entry.loaded = True
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}")
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._execution_device}")
self._cache.print_cuda_stats()
except torch.cuda.OutOfMemoryError:
self._cache.logger.warning("Insufficient GPU memory to load model. Aborting")
@ -59,6 +66,8 @@ class ModelLocker(ModelLockerBase):
return
self._cache_entry.unlock()
if self._execution_device:
self._cache.release_execution_device(self._execution_device)
if not self._cache.lazy_offloading:
self._cache.offload_unlocked_models(self._cache_entry.size)
self._cache.print_cuda_stats()

View File

@ -111,7 +111,7 @@ def mm2_download_queue(mm2_session: Session, request: FixtureRequest) -> Downloa
@pytest.fixture
def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordServiceBase) -> ModelLoadServiceBase:
def mm2_loader(mm2_app_config: InvokeAIAppConfig) -> ModelLoadServiceBase:
ram_cache = ModelCache(
logger=InvokeAILogger.get_logger(),
max_cache_size=mm2_app_config.ram,