|
|
|
@ -1,8 +1,5 @@
|
|
|
|
|
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team
|
|
|
|
|
# TODO: Add Stalker's proper name to copyright
|
|
|
|
|
|
|
|
|
|
import gc
|
|
|
|
|
import math
|
|
|
|
|
import logging
|
|
|
|
|
import time
|
|
|
|
|
from logging import Logger
|
|
|
|
|
from typing import Dict, List, Optional
|
|
|
|
@ -10,9 +7,15 @@ from typing import Dict, List, Optional
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
from invokeai.backend.model_manager import AnyModel, SubModelType
|
|
|
|
|
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
|
|
|
|
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot
|
|
|
|
|
from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord
|
|
|
|
|
from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats
|
|
|
|
|
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_only_full_load import (
|
|
|
|
|
CachedModelOnlyFullLoad,
|
|
|
|
|
)
|
|
|
|
|
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import (
|
|
|
|
|
CachedModelWithPartialLoad,
|
|
|
|
|
)
|
|
|
|
|
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import (
|
|
|
|
|
apply_custom_layers_to_model,
|
|
|
|
|
)
|
|
|
|
@ -29,6 +32,7 @@ MB = 2**20
|
|
|
|
|
|
|
|
|
|
# TODO(ryand): Where should this go? The ModelCache shouldn't be concerned with submodels.
|
|
|
|
|
def get_model_cache_key(model_key: str, submodel_type: Optional[SubModelType] = None) -> str:
|
|
|
|
|
"""Get the cache key for a model based on the optional submodel type."""
|
|
|
|
|
if submodel_type:
|
|
|
|
|
return f"{model_key}:{submodel_type.value}"
|
|
|
|
|
else:
|
|
|
|
@ -70,34 +74,35 @@ class ModelCache:
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
max_cache_size: float,
|
|
|
|
|
max_vram_cache_size: float,
|
|
|
|
|
execution_device: torch.device = torch.device("cuda"),
|
|
|
|
|
storage_device: torch.device = torch.device("cpu"),
|
|
|
|
|
lazy_offloading: bool = True,
|
|
|
|
|
max_ram_cache_size_gb: float,
|
|
|
|
|
max_vram_cache_size_gb: float,
|
|
|
|
|
execution_device: torch.device | str = "cuda",
|
|
|
|
|
storage_device: torch.device | str = "cpu",
|
|
|
|
|
log_memory_usage: bool = False,
|
|
|
|
|
logger: Optional[Logger] = None,
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
Initialize the model RAM cache.
|
|
|
|
|
"""Initialize the model RAM cache.
|
|
|
|
|
|
|
|
|
|
:param max_cache_size: Maximum size of the storage_device cache in GBs.
|
|
|
|
|
:param max_vram_cache_size: Maximum size of the execution_device cache in GBs.
|
|
|
|
|
:param max_ram_cache_size_gb: The maximum amount of CPU RAM to use for model caching in GB. This parameter is
|
|
|
|
|
kept to maintain compatibility with previous versions of the model cache, but should be deprecated in the
|
|
|
|
|
future. If set, this parameter overrides the default cache size logic.
|
|
|
|
|
:param max_vram_cache_size_gb: The amount of VRAM to use for model caching in GB. This parameter is kept to
|
|
|
|
|
maintain compatibility with previous versions of the model cache, but should be deprecated in the future.
|
|
|
|
|
If set, this parameter overrides the default cache size logic.
|
|
|
|
|
:param execution_device: Torch device to load active model into [torch.device('cuda')]
|
|
|
|
|
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
|
|
|
|
|
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
|
|
|
|
|
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
|
|
|
|
|
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
|
|
|
|
|
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
|
|
|
|
|
behaviour.
|
|
|
|
|
:param logger: InvokeAILogger to use (otherwise creates one)
|
|
|
|
|
"""
|
|
|
|
|
# allow lazy offloading only when vram cache enabled
|
|
|
|
|
self._lazy_offloading = lazy_offloading and max_vram_cache_size > 0
|
|
|
|
|
self._max_cache_size: float = max_cache_size
|
|
|
|
|
self._max_vram_cache_size: float = max_vram_cache_size
|
|
|
|
|
self._execution_device: torch.device = execution_device
|
|
|
|
|
self._storage_device: torch.device = storage_device
|
|
|
|
|
self._execution_device: torch.device = torch.device(execution_device)
|
|
|
|
|
self._storage_device: torch.device = torch.device(storage_device)
|
|
|
|
|
|
|
|
|
|
self._max_ram_cache_size_gb = max_ram_cache_size_gb
|
|
|
|
|
self._max_vram_cache_size_gb = max_vram_cache_size_gb
|
|
|
|
|
|
|
|
|
|
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
|
|
|
|
|
self._log_memory_usage = log_memory_usage
|
|
|
|
|
self._stats: Optional[CacheStats] = None
|
|
|
|
@ -105,26 +110,6 @@ class ModelCache:
|
|
|
|
|
self._cached_models: Dict[str, CacheRecord] = {}
|
|
|
|
|
self._cache_stack: List[str] = []
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def max_cache_size(self) -> float:
|
|
|
|
|
"""Return the cap on cache size."""
|
|
|
|
|
return self._max_cache_size
|
|
|
|
|
|
|
|
|
|
@max_cache_size.setter
|
|
|
|
|
def max_cache_size(self, value: float) -> None:
|
|
|
|
|
"""Set the cap on cache size."""
|
|
|
|
|
self._max_cache_size = value
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def max_vram_cache_size(self) -> float:
|
|
|
|
|
"""Return the cap on vram cache size."""
|
|
|
|
|
return self._max_vram_cache_size
|
|
|
|
|
|
|
|
|
|
@max_vram_cache_size.setter
|
|
|
|
|
def max_vram_cache_size(self, value: float) -> None:
|
|
|
|
|
"""Set the cap on vram cache size."""
|
|
|
|
|
self._max_vram_cache_size = value
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def stats(self) -> Optional[CacheStats]:
|
|
|
|
|
"""Return collected CacheStats object."""
|
|
|
|
@ -132,17 +117,17 @@ class ModelCache:
|
|
|
|
|
|
|
|
|
|
@stats.setter
|
|
|
|
|
def stats(self, stats: CacheStats) -> None:
|
|
|
|
|
"""Set the CacheStats object for collectin cache statistics."""
|
|
|
|
|
"""Set the CacheStats object for collecting cache statistics."""
|
|
|
|
|
self._stats = stats
|
|
|
|
|
|
|
|
|
|
def put(
|
|
|
|
|
self,
|
|
|
|
|
key: str,
|
|
|
|
|
model: AnyModel,
|
|
|
|
|
) -> None:
|
|
|
|
|
"""Insert model into the cache."""
|
|
|
|
|
def put(self, key: str, model: AnyModel) -> None:
|
|
|
|
|
"""Add a model to the cache."""
|
|
|
|
|
if key in self._cached_models:
|
|
|
|
|
self._logger.debug(
|
|
|
|
|
f"Attempted to add model {key} ({model.__class__.__name__}), but it already exists in the cache. No action necessary."
|
|
|
|
|
)
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
size = calc_model_size_by_data(self._logger, model)
|
|
|
|
|
self.make_room(size)
|
|
|
|
|
|
|
|
|
@ -150,17 +135,26 @@ class ModelCache:
|
|
|
|
|
if isinstance(model, torch.nn.Module):
|
|
|
|
|
apply_custom_layers_to_model(model)
|
|
|
|
|
|
|
|
|
|
running_on_cpu = self._execution_device == torch.device("cpu")
|
|
|
|
|
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) and not running_on_cpu else None
|
|
|
|
|
cache_record = CacheRecord(key=key, model=model, device=self._storage_device, state_dict=state_dict, size=size)
|
|
|
|
|
# Partial loading only makes sense on CUDA.
|
|
|
|
|
# - When running on CPU, there is no 'loading' to do.
|
|
|
|
|
# - When running on MPS, memory is shared with the CPU, so the default OS memory management already handles this
|
|
|
|
|
# well.
|
|
|
|
|
running_with_cuda = self._execution_device.type == "cuda"
|
|
|
|
|
|
|
|
|
|
# Wrap model.
|
|
|
|
|
if isinstance(model, torch.nn.Module) and running_with_cuda:
|
|
|
|
|
wrapped_model = CachedModelWithPartialLoad(model, self._execution_device)
|
|
|
|
|
else:
|
|
|
|
|
wrapped_model = CachedModelOnlyFullLoad(model, self._execution_device, size)
|
|
|
|
|
|
|
|
|
|
cache_record = CacheRecord(key=key, cached_model=wrapped_model)
|
|
|
|
|
self._cached_models[key] = cache_record
|
|
|
|
|
self._cache_stack.append(key)
|
|
|
|
|
self._logger.debug(
|
|
|
|
|
f"Added model {key} (Type: {model.__class__.__name__}, Wrap mode: {wrapped_model.__class__.__name__}, Model size: {size/MB:.2f}MB)"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def get(
|
|
|
|
|
self,
|
|
|
|
|
key: str,
|
|
|
|
|
stats_name: Optional[str] = None,
|
|
|
|
|
) -> CacheRecord:
|
|
|
|
|
def get(self, key: str, stats_name: Optional[str] = None) -> CacheRecord:
|
|
|
|
|
"""Retrieve a model from the cache.
|
|
|
|
|
|
|
|
|
|
:param key: Model key
|
|
|
|
@ -174,6 +168,7 @@ class ModelCache:
|
|
|
|
|
else:
|
|
|
|
|
if self.stats:
|
|
|
|
|
self.stats.misses += 1
|
|
|
|
|
self._logger.debug(f"Cache miss: {key}")
|
|
|
|
|
raise IndexError(f"The model with key {key} is not in the cache.")
|
|
|
|
|
|
|
|
|
|
cache_entry = self._cached_models[key]
|
|
|
|
@ -181,37 +176,44 @@ class ModelCache:
|
|
|
|
|
# more stats
|
|
|
|
|
if self.stats:
|
|
|
|
|
stats_name = stats_name or key
|
|
|
|
|
self.stats.cache_size = int(self._max_cache_size * GB)
|
|
|
|
|
self.stats.high_watermark = max(self.stats.high_watermark, self._get_cache_size())
|
|
|
|
|
self.stats.high_watermark = max(self.stats.high_watermark, self._get_ram_in_use())
|
|
|
|
|
self.stats.in_cache = len(self._cached_models)
|
|
|
|
|
self.stats.loaded_model_sizes[stats_name] = max(
|
|
|
|
|
self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size
|
|
|
|
|
self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.cached_model.total_bytes()
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# this moves the entry to the top (right end) of the stack
|
|
|
|
|
# This moves the entry to the top (right end) of the stack.
|
|
|
|
|
self._cache_stack = [k for k in self._cache_stack if k != key]
|
|
|
|
|
self._cache_stack.append(key)
|
|
|
|
|
|
|
|
|
|
self._logger.debug(f"Cache hit: {key} (Type: {cache_entry.cached_model.model.__class__.__name__})")
|
|
|
|
|
return cache_entry
|
|
|
|
|
|
|
|
|
|
def lock(self, cache_entry: CacheRecord) -> None:
|
|
|
|
|
"""Lock a model for use and move it into VRAM."""
|
|
|
|
|
if cache_entry.key not in self._cached_models:
|
|
|
|
|
self._logger.info(
|
|
|
|
|
f"Locking model cache entry {cache_entry.key} ({cache_entry.model.__class__.__name__}), but it has "
|
|
|
|
|
"already been dropped from the RAM cache. This is a sign that the model loading order is non-optimal "
|
|
|
|
|
"in the invocation code (See https://github.com/invoke-ai/InvokeAI/issues/7513)."
|
|
|
|
|
f"Locking model cache entry {cache_entry.key} "
|
|
|
|
|
f"(Type: {cache_entry.cached_model.model.__class__.__name__}), but it has already been dropped from "
|
|
|
|
|
"the RAM cache. This is a sign that the model loading order is non-optimal in the invocation code "
|
|
|
|
|
"(See https://github.com/invoke-ai/InvokeAI/issues/7513)."
|
|
|
|
|
)
|
|
|
|
|
# cache_entry = self._cached_models[key]
|
|
|
|
|
cache_entry.lock()
|
|
|
|
|
|
|
|
|
|
self._logger.debug(
|
|
|
|
|
f"Locking model {cache_entry.key} (Type: {cache_entry.cached_model.model.__class__.__name__})"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if self._execution_device.type == "cpu":
|
|
|
|
|
# Models don't need to be loaded into VRAM if we're running on CPU.
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
if self._lazy_offloading:
|
|
|
|
|
self._offload_unlocked_models(cache_entry.size)
|
|
|
|
|
self._move_model_to_device(cache_entry, self._execution_device)
|
|
|
|
|
cache_entry.loaded = True
|
|
|
|
|
self._logger.debug(f"Locking {cache_entry.key} in {self._execution_device}")
|
|
|
|
|
self._print_cuda_stats()
|
|
|
|
|
self._load_locked_model(cache_entry)
|
|
|
|
|
self._logger.debug(
|
|
|
|
|
f"Finished locking model {cache_entry.key} (Type: {cache_entry.cached_model.model.__class__.__name__})"
|
|
|
|
|
)
|
|
|
|
|
except torch.cuda.OutOfMemoryError:
|
|
|
|
|
self._logger.warning("Insufficient GPU memory to load model. Aborting")
|
|
|
|
|
cache_entry.unlock()
|
|
|
|
@ -220,201 +222,258 @@ class ModelCache:
|
|
|
|
|
cache_entry.unlock()
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
self._log_cache_state()
|
|
|
|
|
|
|
|
|
|
def unlock(self, cache_entry: CacheRecord) -> None:
|
|
|
|
|
"""Unlock a model."""
|
|
|
|
|
if cache_entry.key not in self._cached_models:
|
|
|
|
|
self._logger.info(
|
|
|
|
|
f"Unlocking model cache entry {cache_entry.key} ({cache_entry.model.__class__.__name__}), but it has "
|
|
|
|
|
"already been dropped from the RAM cache. This is a sign that the model loading order is non-optimal "
|
|
|
|
|
"in the invocation code (See https://github.com/invoke-ai/InvokeAI/issues/7513)."
|
|
|
|
|
f"Unlocking model cache entry {cache_entry.key} "
|
|
|
|
|
f"(Type: {cache_entry.cached_model.model.__class__.__name__}), but it has already been dropped from "
|
|
|
|
|
"the RAM cache. This is a sign that the model loading order is non-optimal in the invocation code "
|
|
|
|
|
"(See https://github.com/invoke-ai/InvokeAI/issues/7513)."
|
|
|
|
|
)
|
|
|
|
|
# cache_entry = self._cached_models[key]
|
|
|
|
|
cache_entry.unlock()
|
|
|
|
|
if not self._lazy_offloading:
|
|
|
|
|
self._offload_unlocked_models(0)
|
|
|
|
|
self._print_cuda_stats()
|
|
|
|
|
self._logger.debug(
|
|
|
|
|
f"Unlocked model {cache_entry.key} (Type: {cache_entry.cached_model.model.__class__.__name__})"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _get_cache_size(self) -> int:
|
|
|
|
|
"""Get the total size of the models currently cached."""
|
|
|
|
|
total = 0
|
|
|
|
|
for cache_record in self._cached_models.values():
|
|
|
|
|
total += cache_record.size
|
|
|
|
|
return total
|
|
|
|
|
def _load_locked_model(self, cache_entry: CacheRecord) -> None:
|
|
|
|
|
"""Helper function for self.lock(). Loads a locked model into VRAM."""
|
|
|
|
|
start_time = time.time()
|
|
|
|
|
vram_available = self._get_vram_available()
|
|
|
|
|
|
|
|
|
|
# Calculate model_vram_needed, the amount of additional VRAM that will be used if we fully load the model into
|
|
|
|
|
# VRAM.
|
|
|
|
|
model_cur_vram_bytes = cache_entry.cached_model.cur_vram_bytes()
|
|
|
|
|
model_total_bytes = cache_entry.cached_model.total_bytes()
|
|
|
|
|
model_vram_needed = model_total_bytes - model_cur_vram_bytes
|
|
|
|
|
|
|
|
|
|
# The amount of VRAM that must be freed to make room for model_vram_needed.
|
|
|
|
|
vram_bytes_to_free = max(0, model_vram_needed - vram_available)
|
|
|
|
|
|
|
|
|
|
self._logger.debug(
|
|
|
|
|
f"Before unloading: {self._get_vram_state_str(model_cur_vram_bytes, model_total_bytes, vram_available)}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Make room for the model in VRAM.
|
|
|
|
|
# 1. If the model can fit entirely in VRAM, then make enough room for it to be loaded fully.
|
|
|
|
|
# 2. If the model can't fit fully into VRAM, then unload all other models and load as much of the model as
|
|
|
|
|
# possible.
|
|
|
|
|
vram_bytes_freed = self._offload_unlocked_models(vram_bytes_to_free)
|
|
|
|
|
self._logger.debug(f"Unloaded models (if necessary): vram_bytes_freed={(vram_bytes_freed/MB):.2f}MB")
|
|
|
|
|
|
|
|
|
|
# Check the updated vram_available after offloading.
|
|
|
|
|
vram_available = self._get_vram_available()
|
|
|
|
|
self._logger.debug(
|
|
|
|
|
f"After unloading: {self._get_vram_state_str(model_cur_vram_bytes, model_total_bytes, vram_available)}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Move as much of the model as possible into VRAM.
|
|
|
|
|
# For testing, only allow 10% of the model to be loaded into VRAM.
|
|
|
|
|
# vram_available = int(model_vram_needed * 0.1)
|
|
|
|
|
model_bytes_loaded = self._move_model_to_vram(cache_entry, vram_available)
|
|
|
|
|
|
|
|
|
|
model_cur_vram_bytes = cache_entry.cached_model.cur_vram_bytes()
|
|
|
|
|
vram_available = self._get_vram_available()
|
|
|
|
|
loaded_percent = model_cur_vram_bytes / model_total_bytes if model_total_bytes > 0 else 0
|
|
|
|
|
self._logger.info(
|
|
|
|
|
f"Loaded model '{cache_entry.key}' ({cache_entry.cached_model.model.__class__.__name__}) onto "
|
|
|
|
|
f"{self._execution_device.type} device in {(time.time() - start_time):.2f}s. "
|
|
|
|
|
f"Total model size: {model_total_bytes/MB:.2f}MB, "
|
|
|
|
|
f"VRAM: {model_cur_vram_bytes/MB:.2f}MB ({loaded_percent:.1%})"
|
|
|
|
|
)
|
|
|
|
|
self._logger.debug(f"Loaded model onto execution device: model_bytes_loaded={(model_bytes_loaded/MB):.2f}MB, ")
|
|
|
|
|
self._logger.debug(
|
|
|
|
|
f"After loading: {self._get_vram_state_str(model_cur_vram_bytes, model_total_bytes, vram_available)}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _move_model_to_vram(self, cache_entry: CacheRecord, vram_available: int) -> int:
|
|
|
|
|
try:
|
|
|
|
|
if isinstance(cache_entry.cached_model, CachedModelWithPartialLoad):
|
|
|
|
|
return cache_entry.cached_model.partial_load_to_vram(vram_available)
|
|
|
|
|
elif isinstance(cache_entry.cached_model, CachedModelOnlyFullLoad): # type: ignore
|
|
|
|
|
# Partial load is not supported, so we have no choice but to try and fit it all into VRAM.
|
|
|
|
|
return cache_entry.cached_model.full_load_to_vram()
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"Unsupported cached model type: {type(cache_entry.cached_model)}")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
if isinstance(e, torch.cuda.OutOfMemoryError):
|
|
|
|
|
self._logger.warning("Insufficient GPU memory to load model. Aborting")
|
|
|
|
|
# If an exception occurs, the model could be left in a bad state, so we delete it from the cache entirely.
|
|
|
|
|
self._delete_cache_entry(cache_entry)
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
def _move_model_to_ram(self, cache_entry: CacheRecord, vram_bytes_to_free: int) -> int:
|
|
|
|
|
try:
|
|
|
|
|
if isinstance(cache_entry.cached_model, CachedModelWithPartialLoad):
|
|
|
|
|
return cache_entry.cached_model.partial_unload_from_vram(vram_bytes_to_free)
|
|
|
|
|
elif isinstance(cache_entry.cached_model, CachedModelOnlyFullLoad): # type: ignore
|
|
|
|
|
return cache_entry.cached_model.full_unload_from_vram()
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"Unsupported cached model type: {type(cache_entry.cached_model)}")
|
|
|
|
|
except Exception:
|
|
|
|
|
# If an exception occurs, the model could be left in a bad state, so we delete it from the cache entirely.
|
|
|
|
|
self._delete_cache_entry(cache_entry)
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
def _get_vram_available(self) -> int:
|
|
|
|
|
"""Calculate the amount of additional VRAM available for the cache to use."""
|
|
|
|
|
vram_total_available_to_cache = int(self._max_vram_cache_size_gb * GB)
|
|
|
|
|
return vram_total_available_to_cache - self._get_vram_in_use()
|
|
|
|
|
|
|
|
|
|
def _get_vram_in_use(self) -> int:
|
|
|
|
|
"""Get the amount of VRAM currently in use by the cache."""
|
|
|
|
|
return sum(ce.cached_model.cur_vram_bytes() for ce in self._cached_models.values())
|
|
|
|
|
|
|
|
|
|
def _get_ram_available(self) -> int:
|
|
|
|
|
"""Get the amount of RAM available for the cache to use, while keeping memory pressure under control."""
|
|
|
|
|
|
|
|
|
|
ram_total_available_to_cache = int(self._max_ram_cache_size_gb * GB)
|
|
|
|
|
return ram_total_available_to_cache - self._get_ram_in_use()
|
|
|
|
|
|
|
|
|
|
def _get_ram_in_use(self) -> int:
|
|
|
|
|
"""Get the amount of RAM currently in use."""
|
|
|
|
|
return sum(ce.cached_model.total_bytes() for ce in self._cached_models.values())
|
|
|
|
|
|
|
|
|
|
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
|
|
|
|
|
if self._log_memory_usage:
|
|
|
|
|
return MemorySnapshot.capture()
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
def _make_cache_key(self, model_key: str, submodel_type: Optional[SubModelType] = None) -> str:
|
|
|
|
|
if submodel_type:
|
|
|
|
|
return f"{model_key}:{submodel_type.value}"
|
|
|
|
|
else:
|
|
|
|
|
return model_key
|
|
|
|
|
|
|
|
|
|
def _offload_unlocked_models(self, size_required: int) -> None:
|
|
|
|
|
"""Offload models from the execution_device to make room for size_required.
|
|
|
|
|
|
|
|
|
|
:param size_required: The amount of space to clear in the execution_device cache, in bytes.
|
|
|
|
|
"""
|
|
|
|
|
reserved = self._max_vram_cache_size * GB
|
|
|
|
|
vram_in_use = torch.cuda.memory_allocated() + size_required
|
|
|
|
|
self._logger.debug(f"{(vram_in_use/GB):.2f}GB VRAM needed for models; max allowed={(reserved/GB):.2f}GB")
|
|
|
|
|
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
|
|
|
|
|
if vram_in_use <= reserved:
|
|
|
|
|
break
|
|
|
|
|
if not cache_entry.loaded:
|
|
|
|
|
continue
|
|
|
|
|
if not cache_entry.locked:
|
|
|
|
|
self._move_model_to_device(cache_entry, self._storage_device)
|
|
|
|
|
cache_entry.loaded = False
|
|
|
|
|
vram_in_use = torch.cuda.memory_allocated() + size_required
|
|
|
|
|
self._logger.debug(
|
|
|
|
|
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GB):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GB):.2f}GB"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
TorchDevice.empty_cache()
|
|
|
|
|
|
|
|
|
|
def _move_model_to_device(self, cache_entry: CacheRecord, target_device: torch.device) -> None:
|
|
|
|
|
"""Move model into the indicated device.
|
|
|
|
|
|
|
|
|
|
:param cache_entry: The CacheRecord for the model
|
|
|
|
|
:param target_device: The torch.device to move the model into
|
|
|
|
|
|
|
|
|
|
May raise a torch.cuda.OutOfMemoryError
|
|
|
|
|
"""
|
|
|
|
|
self._logger.debug(f"Called to move {cache_entry.key} to {target_device}")
|
|
|
|
|
source_device = cache_entry.device
|
|
|
|
|
|
|
|
|
|
# Note: We compare device types only so that 'cuda' == 'cuda:0'.
|
|
|
|
|
# This would need to be revised to support multi-GPU.
|
|
|
|
|
if torch.device(source_device).type == torch.device(target_device).type:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# Some models don't have a `to` method, in which case they run in RAM/CPU.
|
|
|
|
|
if not hasattr(cache_entry.model, "to"):
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# This roundabout method for moving the model around is done to avoid
|
|
|
|
|
# the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
|
|
|
|
|
# When moving to VRAM, we copy (not move) each element of the state dict from
|
|
|
|
|
# RAM to a new state dict in VRAM, and then inject it into the model.
|
|
|
|
|
# This operation is slightly faster than running `to()` on the whole model.
|
|
|
|
|
#
|
|
|
|
|
# When the model needs to be removed from VRAM we simply delete the copy
|
|
|
|
|
# of the state dict in VRAM, and reinject the state dict that is cached
|
|
|
|
|
# in RAM into the model. So this operation is very fast.
|
|
|
|
|
start_model_to_time = time.time()
|
|
|
|
|
snapshot_before = self._capture_memory_snapshot()
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
if cache_entry.state_dict is not None:
|
|
|
|
|
assert hasattr(cache_entry.model, "load_state_dict")
|
|
|
|
|
if target_device == self._storage_device:
|
|
|
|
|
cache_entry.model.load_state_dict(cache_entry.state_dict, assign=True)
|
|
|
|
|
else:
|
|
|
|
|
new_dict: Dict[str, torch.Tensor] = {}
|
|
|
|
|
for k, v in cache_entry.state_dict.items():
|
|
|
|
|
new_dict[k] = v.to(target_device, copy=True)
|
|
|
|
|
cache_entry.model.load_state_dict(new_dict, assign=True)
|
|
|
|
|
cache_entry.model.to(target_device)
|
|
|
|
|
cache_entry.device = target_device
|
|
|
|
|
except Exception as e: # blow away cache entry
|
|
|
|
|
self._delete_cache_entry(cache_entry)
|
|
|
|
|
raise e
|
|
|
|
|
|
|
|
|
|
snapshot_after = self._capture_memory_snapshot()
|
|
|
|
|
end_model_to_time = time.time()
|
|
|
|
|
self._logger.debug(
|
|
|
|
|
f"Moved model '{cache_entry.key}' from {source_device} to"
|
|
|
|
|
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s."
|
|
|
|
|
f"Estimated model size: {(cache_entry.size/GB):.3f} GB."
|
|
|
|
|
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
|
|
|
|
def _get_vram_state_str(self, model_cur_vram_bytes: int, model_total_bytes: int, vram_available: int) -> str:
|
|
|
|
|
"""Helper function for preparing a VRAM state log string."""
|
|
|
|
|
model_cur_vram_bytes_percent = model_cur_vram_bytes / model_total_bytes if model_total_bytes > 0 else 0
|
|
|
|
|
return (
|
|
|
|
|
f"model_total={model_total_bytes/MB:.0f} MB, "
|
|
|
|
|
+ f"model_vram={model_cur_vram_bytes/MB:.0f} MB ({model_cur_vram_bytes_percent:.1%} %), "
|
|
|
|
|
# + f"vram_total={int(self._max_vram_cache_size * GB)/MB:.0f} MB, "
|
|
|
|
|
+ f"vram_available={(vram_available/MB):.0f} MB, "
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if (
|
|
|
|
|
snapshot_before is not None
|
|
|
|
|
and snapshot_after is not None
|
|
|
|
|
and snapshot_before.vram is not None
|
|
|
|
|
and snapshot_after.vram is not None
|
|
|
|
|
):
|
|
|
|
|
vram_change = abs(snapshot_before.vram - snapshot_after.vram)
|
|
|
|
|
def _offload_unlocked_models(self, vram_bytes_to_free: int) -> int:
|
|
|
|
|
"""Offload models from the execution_device until vram_bytes_to_free bytes are freed, or all models are
|
|
|
|
|
offloaded. Of course, locked models are not offloaded.
|
|
|
|
|
|
|
|
|
|
# If the estimated model size does not match the change in VRAM, log a warning.
|
|
|
|
|
if not math.isclose(
|
|
|
|
|
vram_change,
|
|
|
|
|
cache_entry.size,
|
|
|
|
|
rel_tol=0.1,
|
|
|
|
|
abs_tol=10 * MB,
|
|
|
|
|
):
|
|
|
|
|
Returns:
|
|
|
|
|
int: The number of bytes freed.
|
|
|
|
|
"""
|
|
|
|
|
self._logger.debug(f"Offloading unlocked models with goal of freeing {vram_bytes_to_free/MB:.2f}MB of VRAM.")
|
|
|
|
|
vram_bytes_freed = 0
|
|
|
|
|
# TODO(ryand): Give more thought to the offloading policy used here.
|
|
|
|
|
cache_entries_increasing_size = sorted(self._cached_models.values(), key=lambda x: x.cached_model.total_bytes())
|
|
|
|
|
for cache_entry in cache_entries_increasing_size:
|
|
|
|
|
if vram_bytes_freed >= vram_bytes_to_free:
|
|
|
|
|
break
|
|
|
|
|
if cache_entry.is_locked:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
cache_entry_bytes_freed = self._move_model_to_ram(cache_entry, vram_bytes_to_free - vram_bytes_freed)
|
|
|
|
|
if cache_entry_bytes_freed > 0:
|
|
|
|
|
self._logger.debug(
|
|
|
|
|
f"Moving model '{cache_entry.key}' from {source_device} to"
|
|
|
|
|
f" {target_device} caused an unexpected change in VRAM usage. The model's"
|
|
|
|
|
" estimated size may be incorrect. Estimated model size:"
|
|
|
|
|
f" {(cache_entry.size/GB):.3f} GB.\n"
|
|
|
|
|
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
|
|
|
|
f"Unloaded {cache_entry.key} from VRAM to free {(cache_entry_bytes_freed/MB):.0f} MB."
|
|
|
|
|
)
|
|
|
|
|
vram_bytes_freed += cache_entry_bytes_freed
|
|
|
|
|
|
|
|
|
|
TorchDevice.empty_cache()
|
|
|
|
|
return vram_bytes_freed
|
|
|
|
|
|
|
|
|
|
def _log_cache_state(self, title: str = "Model cache state:", include_entry_details: bool = True):
|
|
|
|
|
if self._logger.getEffectiveLevel() > logging.DEBUG:
|
|
|
|
|
# Short circuit if the logger is not set to debug. Some of the data lookups could take a non-negligible
|
|
|
|
|
# amount of time.
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
log = f"{title}\n"
|
|
|
|
|
|
|
|
|
|
log_format = " {:<30} Limit: {:>7.1f} MB, Used: {:>7.1f} MB ({:>5.1%}), Available: {:>7.1f} MB ({:>5.1%})\n"
|
|
|
|
|
|
|
|
|
|
ram_in_use_bytes = self._get_ram_in_use()
|
|
|
|
|
ram_available_bytes = self._get_ram_available()
|
|
|
|
|
ram_size_bytes = ram_in_use_bytes + ram_available_bytes
|
|
|
|
|
ram_in_use_bytes_percent = ram_in_use_bytes / ram_size_bytes if ram_size_bytes > 0 else 0
|
|
|
|
|
ram_available_bytes_percent = ram_available_bytes / ram_size_bytes if ram_size_bytes > 0 else 0
|
|
|
|
|
log += log_format.format(
|
|
|
|
|
f"Storage Device ({self._storage_device.type})",
|
|
|
|
|
ram_size_bytes / MB,
|
|
|
|
|
ram_in_use_bytes / MB,
|
|
|
|
|
ram_in_use_bytes_percent,
|
|
|
|
|
ram_available_bytes / MB,
|
|
|
|
|
ram_available_bytes_percent,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if self._execution_device.type != "cpu":
|
|
|
|
|
vram_in_use_bytes = self._get_vram_in_use()
|
|
|
|
|
vram_available_bytes = self._get_vram_available()
|
|
|
|
|
vram_size_bytes = vram_in_use_bytes + vram_available_bytes
|
|
|
|
|
vram_in_use_bytes_percent = vram_in_use_bytes / vram_size_bytes if vram_size_bytes > 0 else 0
|
|
|
|
|
vram_available_bytes_percent = vram_available_bytes / vram_size_bytes if vram_size_bytes > 0 else 0
|
|
|
|
|
log += log_format.format(
|
|
|
|
|
f"Compute Device ({self._execution_device.type})",
|
|
|
|
|
vram_size_bytes / MB,
|
|
|
|
|
vram_in_use_bytes / MB,
|
|
|
|
|
vram_in_use_bytes_percent,
|
|
|
|
|
vram_available_bytes / MB,
|
|
|
|
|
vram_available_bytes_percent,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if torch.cuda.is_available():
|
|
|
|
|
log += " {:<30} {:.1f} MB\n".format("CUDA Memory Allocated:", torch.cuda.memory_allocated() / MB)
|
|
|
|
|
log += " {:<30} {}\n".format("Total models:", len(self._cached_models))
|
|
|
|
|
|
|
|
|
|
if include_entry_details and len(self._cached_models) > 0:
|
|
|
|
|
log += " Models:\n"
|
|
|
|
|
log_format = (
|
|
|
|
|
" {:<80} total={:>7.1f} MB, vram={:>7.1f} MB ({:>5.1%}), ram={:>7.1f} MB ({:>5.1%}), locked={}\n"
|
|
|
|
|
)
|
|
|
|
|
for cache_record in self._cached_models.values():
|
|
|
|
|
total_bytes = cache_record.cached_model.total_bytes()
|
|
|
|
|
cur_vram_bytes = cache_record.cached_model.cur_vram_bytes()
|
|
|
|
|
cur_vram_bytes_percent = cur_vram_bytes / total_bytes if total_bytes > 0 else 0
|
|
|
|
|
cur_ram_bytes = total_bytes - cur_vram_bytes
|
|
|
|
|
cur_ram_bytes_percent = cur_ram_bytes / total_bytes if total_bytes > 0 else 0
|
|
|
|
|
|
|
|
|
|
log += log_format.format(
|
|
|
|
|
f"{cache_record.key} ({cache_record.cached_model.model.__class__.__name__}):",
|
|
|
|
|
total_bytes / MB,
|
|
|
|
|
cur_vram_bytes / MB,
|
|
|
|
|
cur_vram_bytes_percent,
|
|
|
|
|
cur_ram_bytes / MB,
|
|
|
|
|
cur_ram_bytes_percent,
|
|
|
|
|
cache_record.is_locked,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _print_cuda_stats(self) -> None:
|
|
|
|
|
"""Log CUDA diagnostics."""
|
|
|
|
|
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GB)
|
|
|
|
|
ram = "%4.2fG" % (self._get_cache_size() / GB)
|
|
|
|
|
self._logger.debug(log)
|
|
|
|
|
|
|
|
|
|
in_ram_models = 0
|
|
|
|
|
in_vram_models = 0
|
|
|
|
|
locked_in_vram_models = 0
|
|
|
|
|
for cache_record in self._cached_models.values():
|
|
|
|
|
if hasattr(cache_record.model, "device"):
|
|
|
|
|
if cache_record.model.device == self._storage_device:
|
|
|
|
|
in_ram_models += 1
|
|
|
|
|
else:
|
|
|
|
|
in_vram_models += 1
|
|
|
|
|
if cache_record.locked:
|
|
|
|
|
locked_in_vram_models += 1
|
|
|
|
|
|
|
|
|
|
self._logger.debug(
|
|
|
|
|
f"Current VRAM/RAM usage: {vram}/{ram}; models_in_ram/models_in_vram(locked) ="
|
|
|
|
|
f" {in_ram_models}/{in_vram_models}({locked_in_vram_models})"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def make_room(self, size: int) -> None:
|
|
|
|
|
def make_room(self, bytes_needed: int) -> None:
|
|
|
|
|
"""Make enough room in the cache to accommodate a new model of indicated size.
|
|
|
|
|
|
|
|
|
|
Note: This function deletes all of the cache's internal references to a model in order to free it. If there are
|
|
|
|
|
external references to the model, there's nothing that the cache can do about it, and those models will not be
|
|
|
|
|
garbage-collected.
|
|
|
|
|
"""
|
|
|
|
|
bytes_needed = size
|
|
|
|
|
maximum_size = self._max_cache_size * GB # stored in GB, convert to bytes
|
|
|
|
|
current_size = self._get_cache_size()
|
|
|
|
|
self._logger.debug(f"Making room for {bytes_needed/MB:.2f}MB of RAM.")
|
|
|
|
|
self._log_cache_state(title="Before dropping models:")
|
|
|
|
|
|
|
|
|
|
if current_size + bytes_needed > maximum_size:
|
|
|
|
|
self._logger.debug(
|
|
|
|
|
f"Max cache size exceeded: {(current_size/GB):.2f}/{self.max_cache_size:.2f} GB, need an additional"
|
|
|
|
|
f" {(bytes_needed/GB):.2f} GB"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self._logger.debug(f"Before making_room: cached_models={len(self._cached_models)}")
|
|
|
|
|
ram_bytes_available = self._get_ram_available()
|
|
|
|
|
ram_bytes_to_free = max(0, bytes_needed - ram_bytes_available)
|
|
|
|
|
|
|
|
|
|
ram_bytes_freed = 0
|
|
|
|
|
pos = 0
|
|
|
|
|
models_cleared = 0
|
|
|
|
|
while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack):
|
|
|
|
|
while ram_bytes_freed < ram_bytes_to_free and pos < len(self._cache_stack):
|
|
|
|
|
model_key = self._cache_stack[pos]
|
|
|
|
|
cache_entry = self._cached_models[model_key]
|
|
|
|
|
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
|
|
|
|
|
self._logger.debug(
|
|
|
|
|
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if not cache_entry.locked:
|
|
|
|
|
if not cache_entry.is_locked:
|
|
|
|
|
ram_bytes_freed += cache_entry.cached_model.total_bytes()
|
|
|
|
|
self._logger.debug(
|
|
|
|
|
f"Removing {model_key} from RAM cache to free at least {(size/GB):.2f} GB (-{(cache_entry.size/GB):.2f} GB)"
|
|
|
|
|
f"Dropping {model_key} from RAM cache to free {(cache_entry.cached_model.total_bytes()/MB):.2f}MB."
|
|
|
|
|
)
|
|
|
|
|
current_size -= cache_entry.size
|
|
|
|
|
models_cleared += 1
|
|
|
|
|
self._delete_cache_entry(cache_entry)
|
|
|
|
|
del cache_entry
|
|
|
|
|
|
|
|
|
|
models_cleared += 1
|
|
|
|
|
else:
|
|
|
|
|
pos += 1
|
|
|
|
|
|
|
|
|
@ -435,8 +494,10 @@ class ModelCache:
|
|
|
|
|
gc.collect()
|
|
|
|
|
|
|
|
|
|
TorchDevice.empty_cache()
|
|
|
|
|
self._logger.debug(f"After making room: cached_models={len(self._cached_models)}")
|
|
|
|
|
self._logger.debug(f"Dropped {models_cleared} models to free {ram_bytes_freed/MB:.2f}MB of RAM.")
|
|
|
|
|
self._log_cache_state(title="After dropping models:")
|
|
|
|
|
|
|
|
|
|
def _delete_cache_entry(self, cache_entry: CacheRecord) -> None:
|
|
|
|
|
self._cache_stack.remove(cache_entry.key)
|
|
|
|
|
del self._cached_models[cache_entry.key]
|
|
|
|
|
"""Delete cache_entry from the cache if it exists. No exception is thrown if it doesn't exist."""
|
|
|
|
|
self._cache_stack = [key for key in self._cache_stack if key != cache_entry.key]
|
|
|
|
|
self._cached_models.pop(cache_entry.key, None)
|
|
|
|
|