mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
552 lines
21 KiB
Python
552 lines
21 KiB
Python
"""
|
|
Manage a RAM cache of diffusion/transformer models for fast switching.
|
|
They are moved between GPU VRAM and CPU RAM as necessary. If the cache
|
|
grows larger than a preset maximum, then the least recently used
|
|
model will be cleared and (re)loaded from disk when next needed.
|
|
|
|
The cache returns context manager generators designed to load the
|
|
model into the GPU within the context, and unload outside the
|
|
context. Use like this:
|
|
|
|
cache = ModelCache(max_cache_size=7.5)
|
|
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1,
|
|
cache.get_model('stabilityai/stable-diffusion-2') as SD2:
|
|
do_something_in_GPU(SD1,SD2)
|
|
|
|
|
|
"""
|
|
|
|
import gc
|
|
import hashlib
|
|
import math
|
|
import os
|
|
import sys
|
|
import time
|
|
from contextlib import suppress
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import Any, Dict, Optional, Type, Union, types
|
|
|
|
import torch
|
|
|
|
import invokeai.backend.util.logging as logger
|
|
from invokeai.backend.model_management.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
|
from invokeai.backend.model_management.model_load_optimizations import skip_torch_weight_init
|
|
|
|
from ..util.devices import choose_torch_device
|
|
from .models import BaseModelType, ModelBase, ModelType, SubModelType
|
|
|
|
if choose_torch_device() == torch.device("mps"):
|
|
from torch import mps
|
|
|
|
# Maximum size of the cache, in gigs
|
|
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
|
|
DEFAULT_MAX_CACHE_SIZE = 6.0
|
|
|
|
# amount of GPU memory to hold in reserve for use by generations (GB)
|
|
DEFAULT_MAX_VRAM_CACHE_SIZE = 2.75
|
|
|
|
# actual size of a gig
|
|
GIG = 1073741824
|
|
# Size of a MB in bytes.
|
|
MB = 2**20
|
|
|
|
|
|
@dataclass
|
|
class CacheStats(object):
|
|
hits: int = 0 # cache hits
|
|
misses: int = 0 # cache misses
|
|
high_watermark: int = 0 # amount of cache used
|
|
in_cache: int = 0 # number of models in cache
|
|
cleared: int = 0 # number of models cleared to make space
|
|
cache_size: int = 0 # total size of cache
|
|
# {submodel_key => size}
|
|
loaded_model_sizes: Dict[str, int] = field(default_factory=dict)
|
|
|
|
|
|
class ModelLocker(object):
|
|
"Forward declaration"
|
|
pass
|
|
|
|
|
|
class ModelCache(object):
|
|
"Forward declaration"
|
|
pass
|
|
|
|
|
|
class _CacheRecord:
|
|
size: int
|
|
model: Any
|
|
cache: ModelCache
|
|
_locks: int
|
|
|
|
def __init__(self, cache, model: Any, size: int):
|
|
self.size = size
|
|
self.model = model
|
|
self.cache = cache
|
|
self._locks = 0
|
|
|
|
def lock(self):
|
|
self._locks += 1
|
|
|
|
def unlock(self):
|
|
self._locks -= 1
|
|
assert self._locks >= 0
|
|
|
|
@property
|
|
def locked(self):
|
|
return self._locks > 0
|
|
|
|
@property
|
|
def loaded(self):
|
|
if self.model is not None and hasattr(self.model, "device"):
|
|
return self.model.device != self.cache.storage_device
|
|
else:
|
|
return False
|
|
|
|
|
|
class ModelCache(object):
|
|
def __init__(
|
|
self,
|
|
max_cache_size: float = DEFAULT_MAX_CACHE_SIZE,
|
|
max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE,
|
|
execution_device: torch.device = torch.device("cuda"),
|
|
storage_device: torch.device = torch.device("cpu"),
|
|
precision: torch.dtype = torch.float16,
|
|
sequential_offload: bool = False,
|
|
lazy_offloading: bool = True,
|
|
sha_chunksize: int = 16777216,
|
|
logger: types.ModuleType = logger,
|
|
log_memory_usage: bool = False,
|
|
):
|
|
"""
|
|
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
|
|
:param execution_device: Torch device to load active model into [torch.device('cuda')]
|
|
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
|
|
:param precision: Precision for loaded models [torch.float16]
|
|
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
|
|
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
|
|
:param sha_chunksize: Chunksize to use when calculating sha256 model hash
|
|
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
|
|
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
|
|
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
|
|
behaviour.
|
|
"""
|
|
self.model_infos: Dict[str, ModelBase] = dict()
|
|
# allow lazy offloading only when vram cache enabled
|
|
self.lazy_offloading = lazy_offloading and max_vram_cache_size > 0
|
|
self.precision: torch.dtype = precision
|
|
self.max_cache_size: float = max_cache_size
|
|
self.max_vram_cache_size: float = max_vram_cache_size
|
|
self.execution_device: torch.device = execution_device
|
|
self.storage_device: torch.device = storage_device
|
|
self.sha_chunksize = sha_chunksize
|
|
self.logger = logger
|
|
self._log_memory_usage = log_memory_usage
|
|
|
|
# used for stats collection
|
|
self.stats = None
|
|
|
|
self._cached_models = dict()
|
|
self._cache_stack = list()
|
|
|
|
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
|
|
if self._log_memory_usage:
|
|
return MemorySnapshot.capture()
|
|
return None
|
|
|
|
def get_key(
|
|
self,
|
|
model_path: str,
|
|
base_model: BaseModelType,
|
|
model_type: ModelType,
|
|
submodel_type: Optional[SubModelType] = None,
|
|
):
|
|
key = f"{model_path}:{base_model}:{model_type}"
|
|
if submodel_type:
|
|
key += f":{submodel_type}"
|
|
return key
|
|
|
|
def _get_model_info(
|
|
self,
|
|
model_path: str,
|
|
model_class: Type[ModelBase],
|
|
base_model: BaseModelType,
|
|
model_type: ModelType,
|
|
):
|
|
model_info_key = self.get_key(
|
|
model_path=model_path,
|
|
base_model=base_model,
|
|
model_type=model_type,
|
|
submodel_type=None,
|
|
)
|
|
|
|
if model_info_key not in self.model_infos:
|
|
self.model_infos[model_info_key] = model_class(
|
|
model_path,
|
|
base_model,
|
|
model_type,
|
|
)
|
|
|
|
return self.model_infos[model_info_key]
|
|
|
|
# TODO: args
|
|
def get_model(
|
|
self,
|
|
model_path: Union[str, Path],
|
|
model_class: Type[ModelBase],
|
|
base_model: BaseModelType,
|
|
model_type: ModelType,
|
|
submodel: Optional[SubModelType] = None,
|
|
gpu_load: bool = True,
|
|
) -> Any:
|
|
if not isinstance(model_path, Path):
|
|
model_path = Path(model_path)
|
|
|
|
if not os.path.exists(model_path):
|
|
raise Exception(f"Model not found: {model_path}")
|
|
|
|
model_info = self._get_model_info(
|
|
model_path=model_path,
|
|
model_class=model_class,
|
|
base_model=base_model,
|
|
model_type=model_type,
|
|
)
|
|
key = self.get_key(
|
|
model_path=model_path,
|
|
base_model=base_model,
|
|
model_type=model_type,
|
|
submodel_type=submodel,
|
|
)
|
|
# TODO: lock for no copies on simultaneous calls?
|
|
cache_entry = self._cached_models.get(key, None)
|
|
if cache_entry is None:
|
|
self.logger.info(
|
|
f"Loading model {model_path}, type"
|
|
f" {base_model.value}:{model_type.value}{':'+submodel.value if submodel else ''}"
|
|
)
|
|
if self.stats:
|
|
self.stats.misses += 1
|
|
|
|
self_reported_model_size_before_load = model_info.get_size(submodel)
|
|
# Remove old models from the cache to make room for the new model.
|
|
self._make_cache_room(self_reported_model_size_before_load)
|
|
|
|
# Load the model from disk and capture a memory snapshot before/after.
|
|
start_load_time = time.time()
|
|
snapshot_before = self._capture_memory_snapshot()
|
|
with skip_torch_weight_init():
|
|
model = model_info.get_model(child_type=submodel, torch_dtype=self.precision)
|
|
snapshot_after = self._capture_memory_snapshot()
|
|
end_load_time = time.time()
|
|
|
|
self_reported_model_size_after_load = model_info.get_size(submodel)
|
|
|
|
self.logger.debug(
|
|
f"Moved model '{key}' from disk to cpu in {(end_load_time-start_load_time):.2f}s.\n"
|
|
f"Self-reported size before/after load: {(self_reported_model_size_before_load/GIG):.3f}GB /"
|
|
f" {(self_reported_model_size_after_load/GIG):.3f}GB.\n"
|
|
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
|
)
|
|
|
|
if abs(self_reported_model_size_after_load - self_reported_model_size_before_load) > 10 * MB:
|
|
self.logger.debug(
|
|
f"Model '{key}' mis-reported its size before load. Self-reported size before/after load:"
|
|
f" {(self_reported_model_size_before_load/GIG):.2f}GB /"
|
|
f" {(self_reported_model_size_after_load/GIG):.2f}GB."
|
|
)
|
|
|
|
cache_entry = _CacheRecord(self, model, self_reported_model_size_after_load)
|
|
self._cached_models[key] = cache_entry
|
|
else:
|
|
if self.stats:
|
|
self.stats.hits += 1
|
|
|
|
if self.stats:
|
|
self.stats.cache_size = self.max_cache_size * GIG
|
|
self.stats.high_watermark = max(self.stats.high_watermark, self._cache_size())
|
|
self.stats.in_cache = len(self._cached_models)
|
|
self.stats.loaded_model_sizes[key] = max(
|
|
self.stats.loaded_model_sizes.get(key, 0), model_info.get_size(submodel)
|
|
)
|
|
|
|
with suppress(Exception):
|
|
self._cache_stack.remove(key)
|
|
self._cache_stack.append(key)
|
|
|
|
return self.ModelLocker(self, key, cache_entry.model, gpu_load, cache_entry.size)
|
|
|
|
def _move_model_to_device(self, key: str, target_device: torch.device):
|
|
cache_entry = self._cached_models[key]
|
|
|
|
source_device = cache_entry.model.device
|
|
# Note: We compare device types only so that 'cuda' == 'cuda:0'. This would need to be revised to support
|
|
# multi-GPU.
|
|
if torch.device(source_device).type == torch.device(target_device).type:
|
|
return
|
|
|
|
start_model_to_time = time.time()
|
|
snapshot_before = self._capture_memory_snapshot()
|
|
cache_entry.model.to(target_device)
|
|
snapshot_after = self._capture_memory_snapshot()
|
|
end_model_to_time = time.time()
|
|
self.logger.debug(
|
|
f"Moved model '{key}' from {source_device} to"
|
|
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s.\n"
|
|
f"Estimated model size: {(cache_entry.size/GIG):.3f} GB.\n"
|
|
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
|
)
|
|
|
|
if (
|
|
snapshot_before is not None
|
|
and snapshot_after is not None
|
|
and snapshot_before.vram is not None
|
|
and snapshot_after.vram is not None
|
|
):
|
|
vram_change = abs(snapshot_before.vram - snapshot_after.vram)
|
|
|
|
# If the estimated model size does not match the change in VRAM, log a warning.
|
|
if not math.isclose(
|
|
vram_change,
|
|
cache_entry.size,
|
|
rel_tol=0.1,
|
|
abs_tol=10 * MB,
|
|
):
|
|
self.logger.debug(
|
|
f"Moving model '{key}' from {source_device} to"
|
|
f" {target_device} caused an unexpected change in VRAM usage. The model's"
|
|
" estimated size may be incorrect. Estimated model size:"
|
|
f" {(cache_entry.size/GIG):.3f} GB.\n"
|
|
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
|
)
|
|
|
|
class ModelLocker(object):
|
|
def __init__(self, cache, key, model, gpu_load, size_needed):
|
|
"""
|
|
:param cache: The model_cache object
|
|
:param key: The key of the model to lock in GPU
|
|
:param model: The model to lock
|
|
:param gpu_load: True if load into gpu
|
|
:param size_needed: Size of the model to load
|
|
"""
|
|
self.gpu_load = gpu_load
|
|
self.cache = cache
|
|
self.key = key
|
|
self.model = model
|
|
self.size_needed = size_needed
|
|
self.cache_entry = self.cache._cached_models[self.key]
|
|
|
|
def __enter__(self) -> Any:
|
|
if not hasattr(self.model, "to"):
|
|
return self.model
|
|
|
|
# NOTE that the model has to have the to() method in order for this
|
|
# code to move it into GPU!
|
|
if self.gpu_load:
|
|
self.cache_entry.lock()
|
|
|
|
try:
|
|
if self.cache.lazy_offloading:
|
|
self.cache._offload_unlocked_models(self.size_needed)
|
|
|
|
self.cache._move_model_to_device(self.key, self.cache.execution_device)
|
|
|
|
self.cache.logger.debug(f"Locking {self.key} in {self.cache.execution_device}")
|
|
self.cache._print_cuda_stats()
|
|
|
|
except Exception:
|
|
self.cache_entry.unlock()
|
|
raise
|
|
|
|
# TODO: not fully understand
|
|
# in the event that the caller wants the model in RAM, we
|
|
# move it into CPU if it is in GPU and not locked
|
|
elif self.cache_entry.loaded and not self.cache_entry.locked:
|
|
self.cache._move_model_to_device(self.key, self.cache.storage_device)
|
|
|
|
return self.model
|
|
|
|
def __exit__(self, type, value, traceback):
|
|
if not hasattr(self.model, "to"):
|
|
return
|
|
|
|
self.cache_entry.unlock()
|
|
if not self.cache.lazy_offloading:
|
|
self.cache._offload_unlocked_models()
|
|
self.cache._print_cuda_stats()
|
|
|
|
# TODO: should it be called untrack_model?
|
|
def uncache_model(self, cache_id: str):
|
|
with suppress(ValueError):
|
|
self._cache_stack.remove(cache_id)
|
|
self._cached_models.pop(cache_id, None)
|
|
|
|
def model_hash(
|
|
self,
|
|
model_path: Union[str, Path],
|
|
) -> str:
|
|
"""
|
|
Given the HF repo id or path to a model on disk, returns a unique
|
|
hash. Works for legacy checkpoint files, HF models on disk, and HF repo IDs
|
|
|
|
:param model_path: Path to model file/directory on disk.
|
|
"""
|
|
return self._local_model_hash(model_path)
|
|
|
|
def cache_size(self) -> float:
|
|
"""Return the current size of the cache, in GB."""
|
|
return self._cache_size() / GIG
|
|
|
|
def _has_cuda(self) -> bool:
|
|
return self.execution_device.type == "cuda"
|
|
|
|
def _print_cuda_stats(self):
|
|
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
|
|
ram = "%4.2fG" % self.cache_size()
|
|
|
|
cached_models = 0
|
|
loaded_models = 0
|
|
locked_models = 0
|
|
for model_info in self._cached_models.values():
|
|
cached_models += 1
|
|
if model_info.loaded:
|
|
loaded_models += 1
|
|
if model_info.locked:
|
|
locked_models += 1
|
|
|
|
self.logger.debug(
|
|
f"Current VRAM/RAM usage: {vram}/{ram}; cached_models/loaded_models/locked_models/ ="
|
|
f" {cached_models}/{loaded_models}/{locked_models}"
|
|
)
|
|
|
|
def _cache_size(self) -> int:
|
|
return sum([m.size for m in self._cached_models.values()])
|
|
|
|
def _make_cache_room(self, model_size):
|
|
# calculate how much memory this model will require
|
|
# multiplier = 2 if self.precision==torch.float32 else 1
|
|
bytes_needed = model_size
|
|
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
|
|
current_size = self._cache_size()
|
|
|
|
if current_size + bytes_needed > maximum_size:
|
|
self.logger.debug(
|
|
f"Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional"
|
|
f" {(bytes_needed/GIG):.2f} GB"
|
|
)
|
|
|
|
self.logger.debug(f"Before unloading: cached_models={len(self._cached_models)}")
|
|
|
|
pos = 0
|
|
models_cleared = 0
|
|
while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack):
|
|
model_key = self._cache_stack[pos]
|
|
cache_entry = self._cached_models[model_key]
|
|
|
|
refs = sys.getrefcount(cache_entry.model)
|
|
|
|
# HACK: This is a workaround for a memory-management issue that we haven't tracked down yet. We are directly
|
|
# going against the advice in the Python docs by using `gc.get_referrers(...)` in this way:
|
|
# https://docs.python.org/3/library/gc.html#gc.get_referrers
|
|
|
|
# manualy clear local variable references of just finished function calls
|
|
# for some reason python don't want to collect it even by gc.collect() immidiately
|
|
if refs > 2:
|
|
while True:
|
|
cleared = False
|
|
for referrer in gc.get_referrers(cache_entry.model):
|
|
if type(referrer).__name__ == "frame":
|
|
# RuntimeError: cannot clear an executing frame
|
|
with suppress(RuntimeError):
|
|
referrer.clear()
|
|
cleared = True
|
|
# break
|
|
|
|
# repeat if referrers changes(due to frame clear), else exit loop
|
|
if cleared:
|
|
gc.collect()
|
|
else:
|
|
break
|
|
|
|
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
|
|
self.logger.debug(
|
|
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded},"
|
|
f" refs: {refs}"
|
|
)
|
|
|
|
# 2 refs:
|
|
# 1 from cache_entry
|
|
# 1 from getrefcount function
|
|
# 1 from onnx runtime object
|
|
if not cache_entry.locked and refs <= 3 if "onnx" in model_key else 2:
|
|
self.logger.debug(
|
|
f"Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
|
|
)
|
|
current_size -= cache_entry.size
|
|
models_cleared += 1
|
|
if self.stats:
|
|
self.stats.cleared += 1
|
|
del self._cache_stack[pos]
|
|
del self._cached_models[model_key]
|
|
del cache_entry
|
|
|
|
else:
|
|
pos += 1
|
|
|
|
if models_cleared < 0:
|
|
# There would likely be some 'garbage' to be collected regardless of whether a model was cleared or not, but
|
|
# there is a significant time cost to calling `gc.collect()`, so we want to use it sparingly. (The time cost
|
|
# is high even if no garbage gets collected.)
|
|
#
|
|
# Calling gc.collect(...) when a model is cleared seems like a good middle-ground:
|
|
# - If models had to be cleared, it's a signal that we are close to our memory limit.
|
|
# - If models were cleared, there's a good chance that there's a significant amount of garbage to be
|
|
# collected.
|
|
#
|
|
# Keep in mind that gc is only responsible for handling reference cycles. Most objects should be cleaned up
|
|
# immediately when their reference count hits 0.
|
|
gc.collect()
|
|
|
|
torch.cuda.empty_cache()
|
|
if choose_torch_device() == torch.device("mps"):
|
|
mps.empty_cache()
|
|
|
|
self.logger.debug(f"After unloading: cached_models={len(self._cached_models)}")
|
|
|
|
def _offload_unlocked_models(self, size_needed: int = 0):
|
|
reserved = self.max_vram_cache_size * GIG
|
|
vram_in_use = torch.cuda.memory_allocated()
|
|
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB")
|
|
for model_key, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
|
|
if vram_in_use <= reserved:
|
|
break
|
|
if not cache_entry.locked and cache_entry.loaded:
|
|
self._move_model_to_device(model_key, self.storage_device)
|
|
|
|
vram_in_use = torch.cuda.memory_allocated()
|
|
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB")
|
|
|
|
torch.cuda.empty_cache()
|
|
if choose_torch_device() == torch.device("mps"):
|
|
mps.empty_cache()
|
|
|
|
def _local_model_hash(self, model_path: Union[str, Path]) -> str:
|
|
sha = hashlib.sha256()
|
|
path = Path(model_path)
|
|
|
|
hashpath = path / "checksum.sha256"
|
|
if hashpath.exists() and path.stat().st_mtime <= hashpath.stat().st_mtime:
|
|
with open(hashpath) as f:
|
|
hash = f.read()
|
|
return hash
|
|
|
|
self.logger.debug(f"computing hash of model {path.name}")
|
|
for file in list(path.rglob("*.ckpt")) + list(path.rglob("*.safetensors")) + list(path.rglob("*.pth")):
|
|
with open(file, "rb") as f:
|
|
while chunk := f.read(self.sha_chunksize):
|
|
sha.update(chunk)
|
|
hash = sha.hexdigest()
|
|
with open(hashpath, "w") as f:
|
|
f.write(hash)
|
|
return hash
|