mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fixup code broken by merge with main
This commit is contained in:
parent
0df018bd4e
commit
6932f27b43
@ -60,4 +60,3 @@ class ModelLoadServiceBase(ABC):
|
||||
Returns:
|
||||
A LoadedModel object.
|
||||
"""
|
||||
|
||||
|
@ -76,6 +76,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
|
||||
ram_cache = ModelCache(
|
||||
max_cache_size=app_config.ram,
|
||||
max_vram_cache_size=app_config.vram,
|
||||
logger=logger,
|
||||
)
|
||||
convert_cache = ModelConvertCache(cache_path=app_config.convert_cache_path, max_size=app_config.convert_cache)
|
||||
|
@ -19,8 +19,10 @@ context. Use like this:
|
||||
"""
|
||||
|
||||
import gc
|
||||
import math
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from contextlib import contextmanager, suppress
|
||||
from logging import Logger
|
||||
from threading import BoundedSemaphore
|
||||
@ -40,6 +42,7 @@ from .model_locker import ModelLocker
|
||||
# Maximum size of the cache, in gigs
|
||||
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
|
||||
DEFAULT_MAX_CACHE_SIZE = 6.0
|
||||
DEFAULT_MAX_VRAM_CACHE_SIZE = 0.25
|
||||
|
||||
# actual size of a gig
|
||||
GIG = 1073741824
|
||||
@ -54,6 +57,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
def __init__(
|
||||
self,
|
||||
max_cache_size: float = DEFAULT_MAX_CACHE_SIZE,
|
||||
max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE,
|
||||
storage_device: torch.device = torch.device("cpu"),
|
||||
execution_devices: Optional[Set[torch.device]] = None,
|
||||
precision: torch.dtype = torch.float16,
|
||||
@ -76,6 +80,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
"""
|
||||
self._precision: torch.dtype = precision
|
||||
self._max_cache_size: float = max_cache_size
|
||||
self._max_vram_cache_size: float = max_vram_cache_size
|
||||
self._storage_device: torch.device = storage_device
|
||||
self._ram_lock = threading.Lock()
|
||||
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
|
||||
@ -281,14 +286,17 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
|
||||
def offload_unlocked_models(self, size_required: int) -> None:
|
||||
"""Move any unused models from VRAM."""
|
||||
device = self.get_execution_device()
|
||||
reserved = self._max_vram_cache_size * GIG
|
||||
vram_in_use = torch.cuda.memory_allocated() + size_required
|
||||
vram_in_use = torch.cuda.memory_allocated(device) + size_required
|
||||
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM needed for models; max allowed={(reserved/GIG):.2f}GB")
|
||||
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
|
||||
if vram_in_use <= reserved:
|
||||
break
|
||||
if not cache_entry.loaded:
|
||||
continue
|
||||
if cache_entry.device is not device:
|
||||
continue
|
||||
if not cache_entry.locked:
|
||||
self.move_model_to_device(cache_entry, self.storage_device)
|
||||
cache_entry.loaded = False
|
||||
|
@ -39,11 +39,11 @@ class ModelLocker(ModelLockerBase):
|
||||
"""Move the model into the execution device (GPU) and lock it."""
|
||||
self._cache_entry.lock()
|
||||
try:
|
||||
if self._cache.lazy_offloading:
|
||||
self._cache.offload_unlocked_models(self._cache_entry.size)
|
||||
self._cache.move_model_to_device(self._cache_entry, self._cache.get_execution_device())
|
||||
device = self._cache.get_execution_device()
|
||||
self._cache.offload_unlocked_models(self._cache_entry.size)
|
||||
self._cache.move_model_to_device(self._cache_entry, device)
|
||||
self._cache_entry.loaded = True
|
||||
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}")
|
||||
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {device}")
|
||||
self._cache.print_cuda_stats()
|
||||
except torch.cuda.OutOfMemoryError:
|
||||
self._cache.logger.warning("Insufficient GPU memory to load model. Aborting")
|
||||
|
@ -14,13 +14,14 @@ def test_loading(mm2_model_manager: ModelManagerServiceBase, embedding_file: Pat
|
||||
matches = store.search_by_attr(model_name="test_embedding")
|
||||
assert len(matches) == 0
|
||||
key = mm2_model_manager.install.register_path(embedding_file)
|
||||
loaded_model = mm2_model_manager.load.load_model(store.get_model(key))
|
||||
assert loaded_model is not None
|
||||
assert loaded_model.config.key == key
|
||||
with loaded_model as model:
|
||||
assert isinstance(model, TextualInversionModelRaw)
|
||||
with mm2_model_manager.load.ram_cache.reserve_execution_device():
|
||||
loaded_model = mm2_model_manager.load.load_model(store.get_model(key))
|
||||
assert loaded_model is not None
|
||||
assert loaded_model.config.key == key
|
||||
with loaded_model as model:
|
||||
assert isinstance(model, TextualInversionModelRaw)
|
||||
|
||||
config = mm2_model_manager.store.get_model(key)
|
||||
loaded_model_2 = mm2_model_manager.load.load_model(config)
|
||||
config = mm2_model_manager.store.get_model(key)
|
||||
loaded_model_2 = mm2_model_manager.load.load_model(config)
|
||||
|
||||
assert loaded_model.config.key == loaded_model_2.config.key
|
||||
assert loaded_model.config.key == loaded_model_2.config.key
|
||||
|
@ -10,6 +10,7 @@ import torch
|
||||
from invokeai.app.services.config import get_config
|
||||
from invokeai.backend.model_manager.load import ModelCache
|
||||
from invokeai.backend.util.devices import TorchDevice, choose_precision, choose_torch_device, torch_dtype
|
||||
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
||||
|
||||
devices = ["cpu", "cuda:0", "cuda:1", "mps"]
|
||||
device_types_cpu = [("cpu", torch.float32), ("cuda:0", torch.float32), ("mps", torch.float32)]
|
||||
@ -21,6 +22,7 @@ device_types_mps = [("cpu", torch.float32), ("cuda:0", torch.float32), ("mps", t
|
||||
def test_device_choice(device_name):
|
||||
config = get_config()
|
||||
config.device = device_name
|
||||
TorchDevice.set_model_cache(None) # disable dynamic selection of GPU device
|
||||
torch_device = TorchDevice.choose_torch_device()
|
||||
assert torch_device == torch.device(device_name)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user