mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add draft multi-gpu support
This commit is contained in:
parent
74a51571a0
commit
6b991a5269
@ -38,3 +38,8 @@ class ModelLoadServiceBase(ABC):
|
||||
@abstractmethod
|
||||
def convert_cache(self) -> ModelConvertCacheBase:
|
||||
"""Return the checkpoint convert cache used by this loader."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def gpu_count(self) -> int:
|
||||
"""Return the number of GPUs we are configured to use."""
|
||||
|
@ -39,6 +39,7 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
self._registry = registry
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
"""Start the service."""
|
||||
self._invoker = invoker
|
||||
|
||||
@property
|
||||
@ -46,6 +47,11 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
"""Return the RAM cache used by this loader."""
|
||||
return self._ram_cache
|
||||
|
||||
@property
|
||||
def gpu_count(self) -> int:
|
||||
"""Return the number of GPUs available for our uses."""
|
||||
return len(self._ram_cache.execution_devices)
|
||||
|
||||
@property
|
||||
def convert_cache(self) -> ModelConvertCacheBase:
|
||||
"""Return the checkpoint convert cache used by this loader."""
|
||||
|
@ -1,6 +1,7 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Set
|
||||
|
||||
import torch
|
||||
from typing_extensions import Self
|
||||
@ -31,7 +32,7 @@ class ModelManagerServiceBase(ABC):
|
||||
model_record_service: ModelRecordServiceBase,
|
||||
download_queue: DownloadQueueServiceBase,
|
||||
events: EventServiceBase,
|
||||
execution_device: torch.device,
|
||||
execution_devices: Optional[Set[torch.device]] = None,
|
||||
) -> Self:
|
||||
"""
|
||||
Construct the model manager service instance.
|
||||
|
@ -1,12 +1,13 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||
"""Implementation of ModelManagerServiceBase."""
|
||||
|
||||
from typing import Optional, Set
|
||||
|
||||
import torch
|
||||
from typing_extensions import Self
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
from ..config import InvokeAIAppConfig
|
||||
@ -67,7 +68,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
model_record_service: ModelRecordServiceBase,
|
||||
download_queue: DownloadQueueServiceBase,
|
||||
events: EventServiceBase,
|
||||
execution_device: torch.device = choose_torch_device(),
|
||||
execution_devices: Optional[Set[torch.device]] = None,
|
||||
) -> Self:
|
||||
"""
|
||||
Construct the model manager service instance.
|
||||
@ -81,7 +82,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
max_cache_size=app_config.ram,
|
||||
max_vram_cache_size=app_config.vram,
|
||||
logger=logger,
|
||||
execution_device=execution_device,
|
||||
execution_devices=execution_devices,
|
||||
)
|
||||
convert_cache = ModelConvertCache(cache_path=app_config.convert_cache_path, max_size=app_config.convert_cache)
|
||||
loader = ModelLoadService(
|
||||
|
@ -21,7 +21,7 @@ from .session_processor_common import SessionProcessorStatus
|
||||
|
||||
|
||||
class DefaultSessionProcessor(SessionProcessorBase):
|
||||
def start(self, invoker: Invoker, thread_limit: int = 1, polling_interval: int = 1) -> None:
|
||||
def start(self, invoker: Invoker, polling_interval: int = 1) -> None:
|
||||
self._invoker: Invoker = invoker
|
||||
self._queue_item: Optional[SessionQueueItem] = None
|
||||
self._invocation: Optional[BaseInvocation] = None
|
||||
@ -33,8 +33,8 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
|
||||
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_queue_event)
|
||||
|
||||
self._thread_limit = thread_limit
|
||||
self._thread_semaphore = BoundedSemaphore(thread_limit)
|
||||
self._thread_limit = self._invoker.services.model_manager.load.gpu_count
|
||||
self._thread_semaphore = BoundedSemaphore(self._thread_limit)
|
||||
self._polling_interval = polling_interval
|
||||
|
||||
# If profiling is enabled, create a profiler. The same profiler will be used for all sessions. Internally,
|
||||
|
@ -10,7 +10,7 @@ model will be cleared and (re)loaded from disk when next needed.
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from logging import Logger
|
||||
from typing import Dict, Generic, Optional, TypeVar
|
||||
from typing import Dict, Generic, Optional, Set, TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
@ -89,8 +89,24 @@ class ModelCacheBase(ABC, Generic[T]):
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def execution_device(self) -> torch.device:
|
||||
"""Return the exection device (e.g. "cuda" for VRAM)."""
|
||||
def execution_devices(self) -> Set[torch.device]:
|
||||
"""Return the set of available execution devices."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def acquire_execution_device(self, timeout: int = 0) -> torch.device:
|
||||
"""
|
||||
Pick the next available execution device.
|
||||
|
||||
If all devices are currently engaged (locked), then
|
||||
block until timeout seconds have passed and raise a
|
||||
TimeoutError if no devices are available.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def release_execution_device(self, device: torch.device) -> None:
|
||||
"""Release a previously-acquired execution device."""
|
||||
pass
|
||||
|
||||
@property
|
||||
|
@ -24,7 +24,8 @@ import sys
|
||||
import time
|
||||
from contextlib import suppress
|
||||
from logging import Logger
|
||||
from typing import Dict, List, Optional
|
||||
from threading import BoundedSemaphore, Lock
|
||||
from typing import Dict, List, Optional, Set
|
||||
|
||||
import torch
|
||||
|
||||
@ -60,8 +61,8 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
self,
|
||||
max_cache_size: float = DEFAULT_MAX_CACHE_SIZE,
|
||||
max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE,
|
||||
execution_device: torch.device = torch.device("cuda"),
|
||||
storage_device: torch.device = torch.device("cpu"),
|
||||
execution_devices: Optional[Set[torch.device]] = None,
|
||||
precision: torch.dtype = torch.float16,
|
||||
sequential_offload: bool = False,
|
||||
lazy_offloading: bool = True,
|
||||
@ -73,7 +74,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
Initialize the model RAM cache.
|
||||
|
||||
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
|
||||
:param execution_device: Torch device to load active model into [torch.device('cuda')]
|
||||
:param execution_devices: Set of torch device to load active model into [calculated]
|
||||
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
|
||||
:param precision: Precision for loaded models [torch.float16]
|
||||
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
|
||||
@ -88,7 +89,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
self._precision: torch.dtype = precision
|
||||
self._max_cache_size: float = max_cache_size
|
||||
self._max_vram_cache_size: float = max_vram_cache_size
|
||||
self._execution_device: torch.device = execution_device
|
||||
self._execution_devices: Set[torch.device] = execution_devices or self._get_execution_devices()
|
||||
self._storage_device: torch.device = storage_device
|
||||
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
|
||||
self._log_memory_usage = log_memory_usage
|
||||
@ -97,6 +98,12 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
self._cached_models: Dict[str, CacheRecord[AnyModel]] = {}
|
||||
self._cache_stack: List[str] = []
|
||||
|
||||
self._lock = Lock()
|
||||
self._free_execution_device = BoundedSemaphore(len(self._execution_devices))
|
||||
self._busy_execution_devices: Set[torch.device] = set()
|
||||
|
||||
self.logger.info(f"Using rendering device(s) {[self._device_name(x) for x in self._execution_devices]}")
|
||||
|
||||
@property
|
||||
def logger(self) -> Logger:
|
||||
"""Return the logger used by the cache."""
|
||||
@ -113,9 +120,24 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
return self._storage_device
|
||||
|
||||
@property
|
||||
def execution_device(self) -> torch.device:
|
||||
"""Return the exection device (e.g. "cuda" for VRAM)."""
|
||||
return self._execution_device
|
||||
def execution_devices(self) -> Set[torch.device]:
|
||||
"""Return the set of available execution devices."""
|
||||
return self._execution_devices
|
||||
|
||||
def acquire_execution_device(self, timeout: int = 0) -> torch.device:
|
||||
"""Acquire and return an execution device (e.g. "cuda" for VRAM)."""
|
||||
with self._lock:
|
||||
self._free_execution_device.acquire(timeout=timeout)
|
||||
free_devices = self.execution_devices - self._busy_execution_devices
|
||||
chosen_device = list(free_devices)[0]
|
||||
self._busy_execution_devices.add(chosen_device)
|
||||
return chosen_device
|
||||
|
||||
def release_execution_device(self, device: torch.device) -> None:
|
||||
"""Mark this execution device as unused."""
|
||||
with self._lock:
|
||||
self._free_execution_device.release()
|
||||
self._busy_execution_devices.remove(device)
|
||||
|
||||
@property
|
||||
def max_cache_size(self) -> float:
|
||||
@ -422,3 +444,17 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
free_mem, _ = torch.cuda.mem_get_info(torch.device(vram_device))
|
||||
if needed_size > free_mem:
|
||||
raise torch.cuda.OutOfMemoryError
|
||||
|
||||
@staticmethod
|
||||
def _get_execution_devices() -> Set[torch.device]:
|
||||
default_device = choose_torch_device()
|
||||
if default_device != torch.device("cuda"):
|
||||
return {default_device}
|
||||
|
||||
# we get here if the default device is cuda, and return each of the
|
||||
# cuda devices.
|
||||
return {torch.device(f"cuda:{x}") for x in range(0, torch.cuda.device_count())}
|
||||
|
||||
@staticmethod
|
||||
def _device_name(device: torch.device) -> str:
|
||||
return f"{device.type}:{device.index}"
|
||||
|
@ -2,12 +2,16 @@
|
||||
Base class and implementation of a class that moves models in and out of VRAM.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager import AnyModel
|
||||
|
||||
from .model_cache_base import CacheRecord, ModelCacheBase, ModelLockerBase
|
||||
|
||||
MAX_GPU_WAIT = 600 # wait up to 10 minutes for a GPU to become free
|
||||
|
||||
|
||||
class ModelLocker(ModelLockerBase):
|
||||
"""Internal class that mediates movement in and out of GPU."""
|
||||
@ -21,6 +25,7 @@ class ModelLocker(ModelLockerBase):
|
||||
"""
|
||||
self._cache = cache
|
||||
self._cache_entry = cache_entry
|
||||
self._execution_device: Optional[torch.device] = None
|
||||
|
||||
@property
|
||||
def model(self) -> AnyModel:
|
||||
@ -39,10 +44,12 @@ class ModelLocker(ModelLockerBase):
|
||||
if self._cache.lazy_offloading:
|
||||
self._cache.offload_unlocked_models(self._cache_entry.size)
|
||||
|
||||
self._cache.move_model_to_device(self._cache_entry, self._cache.execution_device)
|
||||
# We wait for a gpu to be free - may raise a TimeoutError
|
||||
self._execution_device = self._cache.acquire_execution_device(MAX_GPU_WAIT)
|
||||
self._cache.move_model_to_device(self._cache_entry, self._execution_device)
|
||||
self._cache_entry.loaded = True
|
||||
|
||||
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}")
|
||||
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._execution_device}")
|
||||
self._cache.print_cuda_stats()
|
||||
except torch.cuda.OutOfMemoryError:
|
||||
self._cache.logger.warning("Insufficient GPU memory to load model. Aborting")
|
||||
@ -59,6 +66,8 @@ class ModelLocker(ModelLockerBase):
|
||||
return
|
||||
|
||||
self._cache_entry.unlock()
|
||||
if self._execution_device:
|
||||
self._cache.release_execution_device(self._execution_device)
|
||||
if not self._cache.lazy_offloading:
|
||||
self._cache.offload_unlocked_models(self._cache_entry.size)
|
||||
self._cache.print_cuda_stats()
|
||||
|
@ -111,7 +111,7 @@ def mm2_download_queue(mm2_session: Session, request: FixtureRequest) -> Downloa
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordServiceBase) -> ModelLoadServiceBase:
|
||||
def mm2_loader(mm2_app_config: InvokeAIAppConfig) -> ModelLoadServiceBase:
|
||||
ram_cache = ModelCache(
|
||||
logger=InvokeAILogger.get_logger(),
|
||||
max_cache_size=mm2_app_config.ram,
|
||||
|
Loading…
Reference in New Issue
Block a user