mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add draft multi-gpu support
This commit is contained in:
parent
74a51571a0
commit
6b991a5269
@ -38,3 +38,8 @@ class ModelLoadServiceBase(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def convert_cache(self) -> ModelConvertCacheBase:
|
def convert_cache(self) -> ModelConvertCacheBase:
|
||||||
"""Return the checkpoint convert cache used by this loader."""
|
"""Return the checkpoint convert cache used by this loader."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def gpu_count(self) -> int:
|
||||||
|
"""Return the number of GPUs we are configured to use."""
|
||||||
|
@ -39,6 +39,7 @@ class ModelLoadService(ModelLoadServiceBase):
|
|||||||
self._registry = registry
|
self._registry = registry
|
||||||
|
|
||||||
def start(self, invoker: Invoker) -> None:
|
def start(self, invoker: Invoker) -> None:
|
||||||
|
"""Start the service."""
|
||||||
self._invoker = invoker
|
self._invoker = invoker
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -46,6 +47,11 @@ class ModelLoadService(ModelLoadServiceBase):
|
|||||||
"""Return the RAM cache used by this loader."""
|
"""Return the RAM cache used by this loader."""
|
||||||
return self._ram_cache
|
return self._ram_cache
|
||||||
|
|
||||||
|
@property
|
||||||
|
def gpu_count(self) -> int:
|
||||||
|
"""Return the number of GPUs available for our uses."""
|
||||||
|
return len(self._ram_cache.execution_devices)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def convert_cache(self) -> ModelConvertCacheBase:
|
def convert_cache(self) -> ModelConvertCacheBase:
|
||||||
"""Return the checkpoint convert cache used by this loader."""
|
"""Return the checkpoint convert cache used by this loader."""
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional, Set
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
@ -31,7 +32,7 @@ class ModelManagerServiceBase(ABC):
|
|||||||
model_record_service: ModelRecordServiceBase,
|
model_record_service: ModelRecordServiceBase,
|
||||||
download_queue: DownloadQueueServiceBase,
|
download_queue: DownloadQueueServiceBase,
|
||||||
events: EventServiceBase,
|
events: EventServiceBase,
|
||||||
execution_device: torch.device,
|
execution_devices: Optional[Set[torch.device]] = None,
|
||||||
) -> Self:
|
) -> Self:
|
||||||
"""
|
"""
|
||||||
Construct the model manager service instance.
|
Construct the model manager service instance.
|
||||||
|
@ -1,12 +1,13 @@
|
|||||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||||
"""Implementation of ModelManagerServiceBase."""
|
"""Implementation of ModelManagerServiceBase."""
|
||||||
|
|
||||||
|
from typing import Optional, Set
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
|
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
|
||||||
from invokeai.backend.util.devices import choose_torch_device
|
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
from ..config import InvokeAIAppConfig
|
from ..config import InvokeAIAppConfig
|
||||||
@ -67,7 +68,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
model_record_service: ModelRecordServiceBase,
|
model_record_service: ModelRecordServiceBase,
|
||||||
download_queue: DownloadQueueServiceBase,
|
download_queue: DownloadQueueServiceBase,
|
||||||
events: EventServiceBase,
|
events: EventServiceBase,
|
||||||
execution_device: torch.device = choose_torch_device(),
|
execution_devices: Optional[Set[torch.device]] = None,
|
||||||
) -> Self:
|
) -> Self:
|
||||||
"""
|
"""
|
||||||
Construct the model manager service instance.
|
Construct the model manager service instance.
|
||||||
@ -81,7 +82,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
max_cache_size=app_config.ram,
|
max_cache_size=app_config.ram,
|
||||||
max_vram_cache_size=app_config.vram,
|
max_vram_cache_size=app_config.vram,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
execution_device=execution_device,
|
execution_devices=execution_devices,
|
||||||
)
|
)
|
||||||
convert_cache = ModelConvertCache(cache_path=app_config.convert_cache_path, max_size=app_config.convert_cache)
|
convert_cache = ModelConvertCache(cache_path=app_config.convert_cache_path, max_size=app_config.convert_cache)
|
||||||
loader = ModelLoadService(
|
loader = ModelLoadService(
|
||||||
|
@ -21,7 +21,7 @@ from .session_processor_common import SessionProcessorStatus
|
|||||||
|
|
||||||
|
|
||||||
class DefaultSessionProcessor(SessionProcessorBase):
|
class DefaultSessionProcessor(SessionProcessorBase):
|
||||||
def start(self, invoker: Invoker, thread_limit: int = 1, polling_interval: int = 1) -> None:
|
def start(self, invoker: Invoker, polling_interval: int = 1) -> None:
|
||||||
self._invoker: Invoker = invoker
|
self._invoker: Invoker = invoker
|
||||||
self._queue_item: Optional[SessionQueueItem] = None
|
self._queue_item: Optional[SessionQueueItem] = None
|
||||||
self._invocation: Optional[BaseInvocation] = None
|
self._invocation: Optional[BaseInvocation] = None
|
||||||
@ -33,8 +33,8 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
|
|
||||||
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_queue_event)
|
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_queue_event)
|
||||||
|
|
||||||
self._thread_limit = thread_limit
|
self._thread_limit = self._invoker.services.model_manager.load.gpu_count
|
||||||
self._thread_semaphore = BoundedSemaphore(thread_limit)
|
self._thread_semaphore = BoundedSemaphore(self._thread_limit)
|
||||||
self._polling_interval = polling_interval
|
self._polling_interval = polling_interval
|
||||||
|
|
||||||
# If profiling is enabled, create a profiler. The same profiler will be used for all sessions. Internally,
|
# If profiling is enabled, create a profiler. The same profiler will be used for all sessions. Internally,
|
||||||
|
@ -10,7 +10,7 @@ model will be cleared and (re)loaded from disk when next needed.
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
from typing import Dict, Generic, Optional, TypeVar
|
from typing import Dict, Generic, Optional, Set, TypeVar
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -89,8 +89,24 @@ class ModelCacheBase(ABC, Generic[T]):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def execution_device(self) -> torch.device:
|
def execution_devices(self) -> Set[torch.device]:
|
||||||
"""Return the exection device (e.g. "cuda" for VRAM)."""
|
"""Return the set of available execution devices."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def acquire_execution_device(self, timeout: int = 0) -> torch.device:
|
||||||
|
"""
|
||||||
|
Pick the next available execution device.
|
||||||
|
|
||||||
|
If all devices are currently engaged (locked), then
|
||||||
|
block until timeout seconds have passed and raise a
|
||||||
|
TimeoutError if no devices are available.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def release_execution_device(self, device: torch.device) -> None:
|
||||||
|
"""Release a previously-acquired execution device."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -24,7 +24,8 @@ import sys
|
|||||||
import time
|
import time
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
from typing import Dict, List, Optional
|
from threading import BoundedSemaphore, Lock
|
||||||
|
from typing import Dict, List, Optional, Set
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -60,8 +61,8 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
self,
|
self,
|
||||||
max_cache_size: float = DEFAULT_MAX_CACHE_SIZE,
|
max_cache_size: float = DEFAULT_MAX_CACHE_SIZE,
|
||||||
max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE,
|
max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE,
|
||||||
execution_device: torch.device = torch.device("cuda"),
|
|
||||||
storage_device: torch.device = torch.device("cpu"),
|
storage_device: torch.device = torch.device("cpu"),
|
||||||
|
execution_devices: Optional[Set[torch.device]] = None,
|
||||||
precision: torch.dtype = torch.float16,
|
precision: torch.dtype = torch.float16,
|
||||||
sequential_offload: bool = False,
|
sequential_offload: bool = False,
|
||||||
lazy_offloading: bool = True,
|
lazy_offloading: bool = True,
|
||||||
@ -73,7 +74,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
Initialize the model RAM cache.
|
Initialize the model RAM cache.
|
||||||
|
|
||||||
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
|
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
|
||||||
:param execution_device: Torch device to load active model into [torch.device('cuda')]
|
:param execution_devices: Set of torch device to load active model into [calculated]
|
||||||
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
|
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
|
||||||
:param precision: Precision for loaded models [torch.float16]
|
:param precision: Precision for loaded models [torch.float16]
|
||||||
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
|
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
|
||||||
@ -88,7 +89,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
self._precision: torch.dtype = precision
|
self._precision: torch.dtype = precision
|
||||||
self._max_cache_size: float = max_cache_size
|
self._max_cache_size: float = max_cache_size
|
||||||
self._max_vram_cache_size: float = max_vram_cache_size
|
self._max_vram_cache_size: float = max_vram_cache_size
|
||||||
self._execution_device: torch.device = execution_device
|
self._execution_devices: Set[torch.device] = execution_devices or self._get_execution_devices()
|
||||||
self._storage_device: torch.device = storage_device
|
self._storage_device: torch.device = storage_device
|
||||||
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
|
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
|
||||||
self._log_memory_usage = log_memory_usage
|
self._log_memory_usage = log_memory_usage
|
||||||
@ -97,6 +98,12 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
self._cached_models: Dict[str, CacheRecord[AnyModel]] = {}
|
self._cached_models: Dict[str, CacheRecord[AnyModel]] = {}
|
||||||
self._cache_stack: List[str] = []
|
self._cache_stack: List[str] = []
|
||||||
|
|
||||||
|
self._lock = Lock()
|
||||||
|
self._free_execution_device = BoundedSemaphore(len(self._execution_devices))
|
||||||
|
self._busy_execution_devices: Set[torch.device] = set()
|
||||||
|
|
||||||
|
self.logger.info(f"Using rendering device(s) {[self._device_name(x) for x in self._execution_devices]}")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def logger(self) -> Logger:
|
def logger(self) -> Logger:
|
||||||
"""Return the logger used by the cache."""
|
"""Return the logger used by the cache."""
|
||||||
@ -113,9 +120,24 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
return self._storage_device
|
return self._storage_device
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def execution_device(self) -> torch.device:
|
def execution_devices(self) -> Set[torch.device]:
|
||||||
"""Return the exection device (e.g. "cuda" for VRAM)."""
|
"""Return the set of available execution devices."""
|
||||||
return self._execution_device
|
return self._execution_devices
|
||||||
|
|
||||||
|
def acquire_execution_device(self, timeout: int = 0) -> torch.device:
|
||||||
|
"""Acquire and return an execution device (e.g. "cuda" for VRAM)."""
|
||||||
|
with self._lock:
|
||||||
|
self._free_execution_device.acquire(timeout=timeout)
|
||||||
|
free_devices = self.execution_devices - self._busy_execution_devices
|
||||||
|
chosen_device = list(free_devices)[0]
|
||||||
|
self._busy_execution_devices.add(chosen_device)
|
||||||
|
return chosen_device
|
||||||
|
|
||||||
|
def release_execution_device(self, device: torch.device) -> None:
|
||||||
|
"""Mark this execution device as unused."""
|
||||||
|
with self._lock:
|
||||||
|
self._free_execution_device.release()
|
||||||
|
self._busy_execution_devices.remove(device)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def max_cache_size(self) -> float:
|
def max_cache_size(self) -> float:
|
||||||
@ -422,3 +444,17 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
free_mem, _ = torch.cuda.mem_get_info(torch.device(vram_device))
|
free_mem, _ = torch.cuda.mem_get_info(torch.device(vram_device))
|
||||||
if needed_size > free_mem:
|
if needed_size > free_mem:
|
||||||
raise torch.cuda.OutOfMemoryError
|
raise torch.cuda.OutOfMemoryError
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_execution_devices() -> Set[torch.device]:
|
||||||
|
default_device = choose_torch_device()
|
||||||
|
if default_device != torch.device("cuda"):
|
||||||
|
return {default_device}
|
||||||
|
|
||||||
|
# we get here if the default device is cuda, and return each of the
|
||||||
|
# cuda devices.
|
||||||
|
return {torch.device(f"cuda:{x}") for x in range(0, torch.cuda.device_count())}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _device_name(device: torch.device) -> str:
|
||||||
|
return f"{device.type}:{device.index}"
|
||||||
|
@ -2,12 +2,16 @@
|
|||||||
Base class and implementation of a class that moves models in and out of VRAM.
|
Base class and implementation of a class that moves models in and out of VRAM.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from invokeai.backend.model_manager import AnyModel
|
from invokeai.backend.model_manager import AnyModel
|
||||||
|
|
||||||
from .model_cache_base import CacheRecord, ModelCacheBase, ModelLockerBase
|
from .model_cache_base import CacheRecord, ModelCacheBase, ModelLockerBase
|
||||||
|
|
||||||
|
MAX_GPU_WAIT = 600 # wait up to 10 minutes for a GPU to become free
|
||||||
|
|
||||||
|
|
||||||
class ModelLocker(ModelLockerBase):
|
class ModelLocker(ModelLockerBase):
|
||||||
"""Internal class that mediates movement in and out of GPU."""
|
"""Internal class that mediates movement in and out of GPU."""
|
||||||
@ -21,6 +25,7 @@ class ModelLocker(ModelLockerBase):
|
|||||||
"""
|
"""
|
||||||
self._cache = cache
|
self._cache = cache
|
||||||
self._cache_entry = cache_entry
|
self._cache_entry = cache_entry
|
||||||
|
self._execution_device: Optional[torch.device] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def model(self) -> AnyModel:
|
def model(self) -> AnyModel:
|
||||||
@ -39,10 +44,12 @@ class ModelLocker(ModelLockerBase):
|
|||||||
if self._cache.lazy_offloading:
|
if self._cache.lazy_offloading:
|
||||||
self._cache.offload_unlocked_models(self._cache_entry.size)
|
self._cache.offload_unlocked_models(self._cache_entry.size)
|
||||||
|
|
||||||
self._cache.move_model_to_device(self._cache_entry, self._cache.execution_device)
|
# We wait for a gpu to be free - may raise a TimeoutError
|
||||||
|
self._execution_device = self._cache.acquire_execution_device(MAX_GPU_WAIT)
|
||||||
|
self._cache.move_model_to_device(self._cache_entry, self._execution_device)
|
||||||
self._cache_entry.loaded = True
|
self._cache_entry.loaded = True
|
||||||
|
|
||||||
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}")
|
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._execution_device}")
|
||||||
self._cache.print_cuda_stats()
|
self._cache.print_cuda_stats()
|
||||||
except torch.cuda.OutOfMemoryError:
|
except torch.cuda.OutOfMemoryError:
|
||||||
self._cache.logger.warning("Insufficient GPU memory to load model. Aborting")
|
self._cache.logger.warning("Insufficient GPU memory to load model. Aborting")
|
||||||
@ -59,6 +66,8 @@ class ModelLocker(ModelLockerBase):
|
|||||||
return
|
return
|
||||||
|
|
||||||
self._cache_entry.unlock()
|
self._cache_entry.unlock()
|
||||||
|
if self._execution_device:
|
||||||
|
self._cache.release_execution_device(self._execution_device)
|
||||||
if not self._cache.lazy_offloading:
|
if not self._cache.lazy_offloading:
|
||||||
self._cache.offload_unlocked_models(self._cache_entry.size)
|
self._cache.offload_unlocked_models(self._cache_entry.size)
|
||||||
self._cache.print_cuda_stats()
|
self._cache.print_cuda_stats()
|
||||||
|
@ -111,7 +111,7 @@ def mm2_download_queue(mm2_session: Session, request: FixtureRequest) -> Downloa
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordServiceBase) -> ModelLoadServiceBase:
|
def mm2_loader(mm2_app_config: InvokeAIAppConfig) -> ModelLoadServiceBase:
|
||||||
ram_cache = ModelCache(
|
ram_cache = ModelCache(
|
||||||
logger=InvokeAILogger.get_logger(),
|
logger=InvokeAILogger.get_logger(),
|
||||||
max_cache_size=mm2_app_config.ram,
|
max_cache_size=mm2_app_config.ram,
|
||||||
|
Loading…
Reference in New Issue
Block a user