mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fixup code broken by merge with main
This commit is contained in:
parent
0df018bd4e
commit
6932f27b43
@ -60,4 +60,3 @@ class ModelLoadServiceBase(ABC):
|
|||||||
Returns:
|
Returns:
|
||||||
A LoadedModel object.
|
A LoadedModel object.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -76,6 +76,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
|
|
||||||
ram_cache = ModelCache(
|
ram_cache = ModelCache(
|
||||||
max_cache_size=app_config.ram,
|
max_cache_size=app_config.ram,
|
||||||
|
max_vram_cache_size=app_config.vram,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
)
|
)
|
||||||
convert_cache = ModelConvertCache(cache_path=app_config.convert_cache_path, max_size=app_config.convert_cache)
|
convert_cache = ModelConvertCache(cache_path=app_config.convert_cache_path, max_size=app_config.convert_cache)
|
||||||
|
@ -19,8 +19,10 @@ context. Use like this:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
|
import math
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
|
import time
|
||||||
from contextlib import contextmanager, suppress
|
from contextlib import contextmanager, suppress
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
from threading import BoundedSemaphore
|
from threading import BoundedSemaphore
|
||||||
@ -40,6 +42,7 @@ from .model_locker import ModelLocker
|
|||||||
# Maximum size of the cache, in gigs
|
# Maximum size of the cache, in gigs
|
||||||
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
|
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
|
||||||
DEFAULT_MAX_CACHE_SIZE = 6.0
|
DEFAULT_MAX_CACHE_SIZE = 6.0
|
||||||
|
DEFAULT_MAX_VRAM_CACHE_SIZE = 0.25
|
||||||
|
|
||||||
# actual size of a gig
|
# actual size of a gig
|
||||||
GIG = 1073741824
|
GIG = 1073741824
|
||||||
@ -54,6 +57,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
max_cache_size: float = DEFAULT_MAX_CACHE_SIZE,
|
max_cache_size: float = DEFAULT_MAX_CACHE_SIZE,
|
||||||
|
max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE,
|
||||||
storage_device: torch.device = torch.device("cpu"),
|
storage_device: torch.device = torch.device("cpu"),
|
||||||
execution_devices: Optional[Set[torch.device]] = None,
|
execution_devices: Optional[Set[torch.device]] = None,
|
||||||
precision: torch.dtype = torch.float16,
|
precision: torch.dtype = torch.float16,
|
||||||
@ -76,6 +80,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
"""
|
"""
|
||||||
self._precision: torch.dtype = precision
|
self._precision: torch.dtype = precision
|
||||||
self._max_cache_size: float = max_cache_size
|
self._max_cache_size: float = max_cache_size
|
||||||
|
self._max_vram_cache_size: float = max_vram_cache_size
|
||||||
self._storage_device: torch.device = storage_device
|
self._storage_device: torch.device = storage_device
|
||||||
self._ram_lock = threading.Lock()
|
self._ram_lock = threading.Lock()
|
||||||
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
|
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
|
||||||
@ -281,14 +286,17 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
|
|
||||||
def offload_unlocked_models(self, size_required: int) -> None:
|
def offload_unlocked_models(self, size_required: int) -> None:
|
||||||
"""Move any unused models from VRAM."""
|
"""Move any unused models from VRAM."""
|
||||||
|
device = self.get_execution_device()
|
||||||
reserved = self._max_vram_cache_size * GIG
|
reserved = self._max_vram_cache_size * GIG
|
||||||
vram_in_use = torch.cuda.memory_allocated() + size_required
|
vram_in_use = torch.cuda.memory_allocated(device) + size_required
|
||||||
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM needed for models; max allowed={(reserved/GIG):.2f}GB")
|
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM needed for models; max allowed={(reserved/GIG):.2f}GB")
|
||||||
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
|
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
|
||||||
if vram_in_use <= reserved:
|
if vram_in_use <= reserved:
|
||||||
break
|
break
|
||||||
if not cache_entry.loaded:
|
if not cache_entry.loaded:
|
||||||
continue
|
continue
|
||||||
|
if cache_entry.device is not device:
|
||||||
|
continue
|
||||||
if not cache_entry.locked:
|
if not cache_entry.locked:
|
||||||
self.move_model_to_device(cache_entry, self.storage_device)
|
self.move_model_to_device(cache_entry, self.storage_device)
|
||||||
cache_entry.loaded = False
|
cache_entry.loaded = False
|
||||||
|
@ -39,11 +39,11 @@ class ModelLocker(ModelLockerBase):
|
|||||||
"""Move the model into the execution device (GPU) and lock it."""
|
"""Move the model into the execution device (GPU) and lock it."""
|
||||||
self._cache_entry.lock()
|
self._cache_entry.lock()
|
||||||
try:
|
try:
|
||||||
if self._cache.lazy_offloading:
|
device = self._cache.get_execution_device()
|
||||||
self._cache.offload_unlocked_models(self._cache_entry.size)
|
self._cache.offload_unlocked_models(self._cache_entry.size)
|
||||||
self._cache.move_model_to_device(self._cache_entry, self._cache.get_execution_device())
|
self._cache.move_model_to_device(self._cache_entry, device)
|
||||||
self._cache_entry.loaded = True
|
self._cache_entry.loaded = True
|
||||||
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}")
|
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {device}")
|
||||||
self._cache.print_cuda_stats()
|
self._cache.print_cuda_stats()
|
||||||
except torch.cuda.OutOfMemoryError:
|
except torch.cuda.OutOfMemoryError:
|
||||||
self._cache.logger.warning("Insufficient GPU memory to load model. Aborting")
|
self._cache.logger.warning("Insufficient GPU memory to load model. Aborting")
|
||||||
|
@ -14,6 +14,7 @@ def test_loading(mm2_model_manager: ModelManagerServiceBase, embedding_file: Pat
|
|||||||
matches = store.search_by_attr(model_name="test_embedding")
|
matches = store.search_by_attr(model_name="test_embedding")
|
||||||
assert len(matches) == 0
|
assert len(matches) == 0
|
||||||
key = mm2_model_manager.install.register_path(embedding_file)
|
key = mm2_model_manager.install.register_path(embedding_file)
|
||||||
|
with mm2_model_manager.load.ram_cache.reserve_execution_device():
|
||||||
loaded_model = mm2_model_manager.load.load_model(store.get_model(key))
|
loaded_model = mm2_model_manager.load.load_model(store.get_model(key))
|
||||||
assert loaded_model is not None
|
assert loaded_model is not None
|
||||||
assert loaded_model.config.key == key
|
assert loaded_model.config.key == key
|
||||||
|
@ -10,6 +10,7 @@ import torch
|
|||||||
from invokeai.app.services.config import get_config
|
from invokeai.app.services.config import get_config
|
||||||
from invokeai.backend.model_manager.load import ModelCache
|
from invokeai.backend.model_manager.load import ModelCache
|
||||||
from invokeai.backend.util.devices import TorchDevice, choose_precision, choose_torch_device, torch_dtype
|
from invokeai.backend.util.devices import TorchDevice, choose_precision, choose_torch_device, torch_dtype
|
||||||
|
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
||||||
|
|
||||||
devices = ["cpu", "cuda:0", "cuda:1", "mps"]
|
devices = ["cpu", "cuda:0", "cuda:1", "mps"]
|
||||||
device_types_cpu = [("cpu", torch.float32), ("cuda:0", torch.float32), ("mps", torch.float32)]
|
device_types_cpu = [("cpu", torch.float32), ("cuda:0", torch.float32), ("mps", torch.float32)]
|
||||||
@ -21,6 +22,7 @@ device_types_mps = [("cpu", torch.float32), ("cuda:0", torch.float32), ("mps", t
|
|||||||
def test_device_choice(device_name):
|
def test_device_choice(device_name):
|
||||||
config = get_config()
|
config = get_config()
|
||||||
config.device = device_name
|
config.device = device_name
|
||||||
|
TorchDevice.set_model_cache(None) # disable dynamic selection of GPU device
|
||||||
torch_device = TorchDevice.choose_torch_device()
|
torch_device = TorchDevice.choose_torch_device()
|
||||||
assert torch_device == torch.device(device_name)
|
assert torch_device == torch.device(device_name)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user