fixup code broken by merge with main

This commit is contained in:
Lincoln Stein 2024-06-23 12:17:16 -04:00
parent 0df018bd4e
commit 6932f27b43
6 changed files with 25 additions and 14 deletions

View File

@ -60,4 +60,3 @@ class ModelLoadServiceBase(ABC):
Returns:
A LoadedModel object.
"""

View File

@ -76,6 +76,7 @@ class ModelManagerService(ModelManagerServiceBase):
ram_cache = ModelCache(
max_cache_size=app_config.ram,
max_vram_cache_size=app_config.vram,
logger=logger,
)
convert_cache = ModelConvertCache(cache_path=app_config.convert_cache_path, max_size=app_config.convert_cache)

View File

@ -19,8 +19,10 @@ context. Use like this:
"""
import gc
import math
import sys
import threading
import time
from contextlib import contextmanager, suppress
from logging import Logger
from threading import BoundedSemaphore
@ -40,6 +42,7 @@ from .model_locker import ModelLocker
# Maximum size of the cache, in gigs
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
DEFAULT_MAX_CACHE_SIZE = 6.0
DEFAULT_MAX_VRAM_CACHE_SIZE = 0.25
# actual size of a gig
GIG = 1073741824
@ -54,6 +57,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
def __init__(
self,
max_cache_size: float = DEFAULT_MAX_CACHE_SIZE,
max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE,
storage_device: torch.device = torch.device("cpu"),
execution_devices: Optional[Set[torch.device]] = None,
precision: torch.dtype = torch.float16,
@ -76,6 +80,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
"""
self._precision: torch.dtype = precision
self._max_cache_size: float = max_cache_size
self._max_vram_cache_size: float = max_vram_cache_size
self._storage_device: torch.device = storage_device
self._ram_lock = threading.Lock()
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
@ -281,14 +286,17 @@ class ModelCache(ModelCacheBase[AnyModel]):
def offload_unlocked_models(self, size_required: int) -> None:
"""Move any unused models from VRAM."""
device = self.get_execution_device()
reserved = self._max_vram_cache_size * GIG
vram_in_use = torch.cuda.memory_allocated() + size_required
vram_in_use = torch.cuda.memory_allocated(device) + size_required
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM needed for models; max allowed={(reserved/GIG):.2f}GB")
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
if vram_in_use <= reserved:
break
if not cache_entry.loaded:
continue
if cache_entry.device is not device:
continue
if not cache_entry.locked:
self.move_model_to_device(cache_entry, self.storage_device)
cache_entry.loaded = False

View File

@ -39,11 +39,11 @@ class ModelLocker(ModelLockerBase):
"""Move the model into the execution device (GPU) and lock it."""
self._cache_entry.lock()
try:
if self._cache.lazy_offloading:
device = self._cache.get_execution_device()
self._cache.offload_unlocked_models(self._cache_entry.size)
self._cache.move_model_to_device(self._cache_entry, self._cache.get_execution_device())
self._cache.move_model_to_device(self._cache_entry, device)
self._cache_entry.loaded = True
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}")
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {device}")
self._cache.print_cuda_stats()
except torch.cuda.OutOfMemoryError:
self._cache.logger.warning("Insufficient GPU memory to load model. Aborting")

View File

@ -14,6 +14,7 @@ def test_loading(mm2_model_manager: ModelManagerServiceBase, embedding_file: Pat
matches = store.search_by_attr(model_name="test_embedding")
assert len(matches) == 0
key = mm2_model_manager.install.register_path(embedding_file)
with mm2_model_manager.load.ram_cache.reserve_execution_device():
loaded_model = mm2_model_manager.load.load_model(store.get_model(key))
assert loaded_model is not None
assert loaded_model.config.key == key

View File

@ -10,6 +10,7 @@ import torch
from invokeai.app.services.config import get_config
from invokeai.backend.model_manager.load import ModelCache
from invokeai.backend.util.devices import TorchDevice, choose_precision, choose_torch_device, torch_dtype
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
devices = ["cpu", "cuda:0", "cuda:1", "mps"]
device_types_cpu = [("cpu", torch.float32), ("cuda:0", torch.float32), ("mps", torch.float32)]
@ -21,6 +22,7 @@ device_types_mps = [("cpu", torch.float32), ("cuda:0", torch.float32), ("mps", t
def test_device_choice(device_name):
config = get_config()
config.device = device_name
TorchDevice.set_model_cache(None) # disable dynamic selection of GPU device
torch_device = TorchDevice.choose_torch_device()
assert torch_device == torch.device(device_name)