InvokeAI/invokeai/backend/model_management/model_cache.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

573 lines
21 KiB
Python
Raw Normal View History

"""
Manage a RAM cache of diffusion/transformer models for fast switching.
They are moved between GPU VRAM and CPU RAM as necessary. If the cache
grows larger than a preset maximum, then the least recently used
model will be cleared and (re)loaded from disk when next needed.
The cache returns context manager generators designed to load the
model into the GPU within the context, and unload outside the
context. Use like this:
cache = ModelCache(max_cache_size=7.5)
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1,
cache.get_model('stabilityai/stable-diffusion-2') as SD2:
do_something_in_GPU(SD1,SD2)
"""
import gc
2023-08-18 15:18:46 +00:00
import hashlib
import math
2023-05-18 00:56:52 +00:00
import os
import sys
import time
from contextlib import suppress
2023-08-16 01:56:19 +00:00
from dataclasses import dataclass, field
from pathlib import Path
2023-08-18 15:18:46 +00:00
from typing import Any, Dict, Optional, Type, Union, types
import psutil
import torch
import invokeai.backend.util.logging as logger
2023-08-18 15:18:46 +00:00
from ..util.devices import choose_torch_device
2023-08-18 15:18:46 +00:00
from .models import BaseModelType, ModelBase, ModelType, SubModelType
2023-05-18 00:56:52 +00:00
if choose_torch_device() == torch.device("mps"):
from torch import mps
2023-05-18 00:56:52 +00:00
# Maximum size of the cache, in gigs
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
DEFAULT_MAX_CACHE_SIZE = 6.0
# amount of GPU memory to hold in reserve for use by generations (GB)
2023-07-28 13:46:44 +00:00
DEFAULT_MAX_VRAM_CACHE_SIZE = 2.75
2023-05-18 00:56:52 +00:00
# actual size of a gig
GIG = 1073741824
# Size of a MB in bytes.
MB = 2**20
2023-05-18 00:56:52 +00:00
2023-07-28 13:46:44 +00:00
2023-08-16 01:00:30 +00:00
@dataclass
class CacheStats(object):
2023-08-16 01:56:19 +00:00
hits: int = 0 # cache hits
misses: int = 0 # cache misses
high_watermark: int = 0 # amount of cache used
in_cache: int = 0 # number of models in cache
cleared: int = 0 # number of models cleared to make space
cache_size: int = 0 # total size of cache
# {submodel_key => size}
loaded_model_sizes: Dict[str, int] = field(default_factory=dict)
2023-08-16 01:00:30 +00:00
class ModelLocker(object):
"Forward declaration"
pass
2023-07-28 13:46:44 +00:00
2023-05-18 00:56:52 +00:00
class ModelCache(object):
"Forward declaration"
pass
2023-07-28 13:46:44 +00:00
2023-05-18 00:56:52 +00:00
class _CacheRecord:
size: int
model: Any
2023-05-23 00:48:22 +00:00
cache: ModelCache
2023-05-18 00:56:52 +00:00
_locks: int
def __init__(self, cache, model: Any, size: int):
2023-05-18 00:56:52 +00:00
self.size = size
self.model = model
2023-05-23 00:48:22 +00:00
self.cache = cache
2023-05-18 00:56:52 +00:00
self._locks = 0
def lock(self):
self._locks += 1
def unlock(self):
self._locks -= 1
assert self._locks >= 0
@property
def locked(self):
return self._locks > 0
@property
def loaded(self):
if self.model is not None and hasattr(self.model, "device"):
return self.model.device != self.cache.storage_device
2023-05-18 00:56:52 +00:00
else:
return False
2023-07-28 13:46:44 +00:00
2023-05-18 00:56:52 +00:00
class ModelCache(object):
def __init__(
self,
2023-07-28 13:46:44 +00:00
max_cache_size: float = DEFAULT_MAX_CACHE_SIZE,
max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE,
execution_device: torch.device = torch.device("cuda"),
storage_device: torch.device = torch.device("cpu"),
precision: torch.dtype = torch.float16,
sequential_offload: bool = False,
lazy_offloading: bool = True,
sha_chunksize: int = 16777216,
2023-07-28 13:46:44 +00:00
logger: types.ModuleType = logger,
):
2023-07-28 13:46:44 +00:00
"""
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
:param execution_device: Torch device to load active model into [torch.device('cuda')]
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
:param precision: Precision for loaded models [torch.float16]
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
:param sha_chunksize: Chunksize to use when calculating sha256 model hash
2023-07-28 13:46:44 +00:00
"""
self.model_infos: Dict[str, ModelBase] = dict()
# allow lazy offloading only when vram cache enabled
self.lazy_offloading = lazy_offloading and max_vram_cache_size > 0
2023-07-28 13:46:44 +00:00
self.precision: torch.dtype = precision
self.max_cache_size: float = max_cache_size
self.max_vram_cache_size: float = max_vram_cache_size
self.execution_device: torch.device = execution_device
self.storage_device: torch.device = storage_device
self.sha_chunksize = sha_chunksize
self.logger = logger
2023-05-18 00:56:52 +00:00
2023-08-16 01:00:30 +00:00
# used for stats collection
self.stats = None
self._cached_models = dict()
2023-05-23 00:48:22 +00:00
self._cache_stack = list()
2023-05-18 00:56:52 +00:00
def get_key(
self,
model_path: str,
2023-06-11 01:49:09 +00:00
base_model: BaseModelType,
model_type: ModelType,
2023-06-11 01:49:09 +00:00
submodel_type: Optional[SubModelType] = None,
2023-05-18 00:56:52 +00:00
):
2023-06-11 01:49:09 +00:00
key = f"{model_path}:{base_model}:{model_type}"
2023-05-18 00:56:52 +00:00
if submodel_type:
key += f":{submodel_type}"
return key
def _get_model_info(
self,
model_path: str,
2023-06-10 00:14:10 +00:00
model_class: Type[ModelBase],
2023-06-11 01:49:09 +00:00
base_model: BaseModelType,
model_type: ModelType,
2023-05-18 00:56:52 +00:00
):
model_info_key = self.get_key(
model_path=model_path,
2023-06-11 01:49:09 +00:00
base_model=base_model,
2023-05-18 00:56:52 +00:00
model_type=model_type,
submodel_type=None,
)
if model_info_key not in self.model_infos:
2023-06-10 00:14:10 +00:00
self.model_infos[model_info_key] = model_class(
2023-05-18 00:56:52 +00:00
model_path,
2023-06-12 13:14:09 +00:00
base_model,
model_type,
2023-05-18 00:56:52 +00:00
)
return self.model_infos[model_info_key]
2023-06-10 00:14:10 +00:00
# TODO: args
def get_model(
self,
2023-06-10 00:14:10 +00:00
model_path: Union[str, Path],
model_class: Type[ModelBase],
2023-06-11 01:49:09 +00:00
base_model: BaseModelType,
model_type: ModelType,
2023-06-10 00:14:10 +00:00
submodel: Optional[SubModelType] = None,
gpu_load: bool = True,
2023-05-18 00:56:52 +00:00
) -> Any:
2023-06-10 00:14:10 +00:00
if not isinstance(model_path, Path):
model_path = Path(model_path)
if not os.path.exists(model_path):
raise Exception(f"Model not found: {model_path}")
2023-05-18 00:56:52 +00:00
model_info = self._get_model_info(
model_path=model_path,
2023-06-10 00:14:10 +00:00
model_class=model_class,
2023-06-11 01:49:09 +00:00
base_model=base_model,
model_type=model_type,
2023-05-18 00:56:52 +00:00
)
key = self.get_key(
model_path=model_path,
2023-06-11 01:49:09 +00:00
base_model=base_model,
model_type=model_type,
2023-05-18 00:56:52 +00:00
submodel_type=submodel,
)
2023-05-23 00:48:22 +00:00
# TODO: lock for no copies on simultaneous calls?
cache_entry = self._cached_models.get(key, None)
if cache_entry is None:
2023-07-30 12:17:10 +00:00
self.logger.info(
2023-09-28 18:14:03 +00:00
f"Loading model {model_path}, type"
f" {base_model.value}:{model_type.value}{':'+submodel.value if submodel else ''}"
2023-07-30 12:17:10 +00:00
)
2023-08-16 01:00:30 +00:00
if self.stats:
self.stats.misses += 1
self_reported_model_size_before_load = model_info.get_size(submodel)
# Remove old models from the cache to make room for the new model.
self._make_cache_room(self_reported_model_size_before_load)
# Load the model from disk and capture a memory snapshot before/after.
start_load_time = time.time()
snapshot_before = MemorySnapshot.capture()
2023-06-12 13:14:09 +00:00
model = model_info.get_model(child_type=submodel, torch_dtype=self.precision)
snapshot_after = MemorySnapshot.capture()
end_load_time = time.time()
self_reported_model_size_after_load = model_info.get_size(submodel)
self.logger.debug(
f"Moved model '{key}' from disk to cpu in {(end_load_time-start_load_time):.2f}s. Self-reported size"
f" before/after load: {(self_reported_model_size_before_load/GIG):.2f}GB /"
f" {(self_reported_model_size_after_load/GIG):.2f}GB."
f" {get_pretty_snapshot_diff(snapshot_before, snapshot_after)}."
)
if not math.isclose(
self_reported_model_size_before_load, self_reported_model_size_after_load, abs_tol=10 * MB
):
self.logger.warning(
f"Model '{key}' mis-reported its size before load. Self-reported size before/after load:"
f" {(self_reported_model_size_before_load/GIG):.2f}GB /"
f" {(self_reported_model_size_after_load/GIG):.2f}GB."
)
cache_entry = _CacheRecord(self, model, self_reported_model_size_after_load)
self._cached_models[key] = cache_entry
2023-08-16 01:00:30 +00:00
else:
if self.stats:
self.stats.hits += 1
if self.stats:
2023-08-16 01:56:19 +00:00
self.stats.cache_size = self.max_cache_size * GIG
2023-08-16 01:00:30 +00:00
self.stats.high_watermark = max(self.stats.high_watermark, self._cache_size())
self.stats.in_cache = len(self._cached_models)
2023-08-16 01:56:19 +00:00
self.stats.loaded_model_sizes[key] = max(
self.stats.loaded_model_sizes.get(key, 0), model_info.get_size(submodel)
2023-08-16 01:56:19 +00:00
)
2023-05-18 00:56:52 +00:00
with suppress(Exception):
self._cache_stack.remove(key)
self._cache_stack.append(key)
return self.ModelLocker(self, key, cache_entry.model, gpu_load, cache_entry.size)
class ModelLocker(object):
def __init__(self, cache, key, model, gpu_load, size_needed):
2023-07-28 13:46:44 +00:00
"""
:param cache: The model_cache object
:param key: The key of the model to lock in GPU
:param model: The model to lock
:param gpu_load: True if load into gpu
:param size_needed: Size of the model to load
2023-07-28 13:46:44 +00:00
"""
self.gpu_load = gpu_load
self.cache = cache
self.key = key
self.model = model
self.size_needed = size_needed
self.cache_entry = self.cache._cached_models[self.key]
2023-05-18 00:56:52 +00:00
def __enter__(self) -> Any:
2023-07-28 13:46:44 +00:00
if not hasattr(self.model, "to"):
2023-05-18 00:56:52 +00:00
return self.model
2023-05-10 03:46:59 +00:00
# NOTE that the model has to have the to() method in order for this
# code to move it into GPU!
2023-05-18 00:56:52 +00:00
if self.gpu_load:
self.cache_entry.lock()
2023-05-23 00:48:22 +00:00
try:
if self.cache.lazy_offloading:
2023-07-28 13:46:44 +00:00
self.cache._offload_unlocked_models(self.size_needed)
2023-05-23 00:48:22 +00:00
if self.model.device != self.cache.execution_device:
start_model_to_time = time.time()
snapshot_before = MemorySnapshot.capture()
self.model.to(self.cache.execution_device) # move into GPU
snapshot_after = MemorySnapshot.capture()
end_model_to_time = time.time()
self.cache.logger.debug(
f"Moved model '{self.key}' from {self.cache.storage_device} to"
f" {self.cache.execution_device} in {(end_model_to_time-start_model_to_time):.2f}s."
f" Estimated model size: {(self.cache_entry.size/GIG):.2f} GB."
f" {get_pretty_snapshot_diff(snapshot_before, snapshot_after)}."
)
2023-07-28 13:46:44 +00:00
if not math.isclose(
abs((snapshot_before.vram or 0) - (snapshot_after.vram or 0)),
self.cache_entry.size,
rel_tol=0.1,
abs_tol=10 * MB,
):
self.cache.logger.warning(
f"Moving '{self.key}' from {self.cache.storage_device} to"
f" {self.cache.execution_device} caused an unexpected change in VRAM usage. The model's"
" estimated size may be incorrect. Estimated model size:"
f" {(self.cache_entry.size/GIG):.2f} GB."
f" {get_pretty_snapshot_diff(snapshot_before, snapshot_after)}."
)
2023-07-28 13:46:44 +00:00
self.cache.logger.debug(f"Locking {self.key} in {self.cache.execution_device}")
2023-05-23 00:48:22 +00:00
self.cache._print_cuda_stats()
2023-08-17 22:45:25 +00:00
except Exception:
self.cache_entry.unlock()
2023-05-23 00:48:22 +00:00
raise
2023-05-18 00:56:52 +00:00
# TODO: not fully understand
# in the event that the caller wants the model in RAM, we
# move it into CPU if it is in GPU and not locked
elif self.cache_entry.loaded and not self.cache_entry.locked:
2023-05-18 00:56:52 +00:00
self.model.to(self.cache.storage_device)
return self.model
def __exit__(self, type, value, traceback):
2023-07-28 13:46:44 +00:00
if not hasattr(self.model, "to"):
return
self.cache_entry.unlock()
2023-05-18 00:56:52 +00:00
if not self.cache.lazy_offloading:
self.cache._offload_unlocked_models()
self.cache._print_cuda_stats()
# TODO: should it be called untrack_model?
def uncache_model(self, cache_id: str):
with suppress(ValueError):
self._cache_stack.remove(cache_id)
self._cached_models.pop(cache_id, None)
2023-05-23 00:48:22 +00:00
def model_hash(
self,
2023-06-10 00:14:10 +00:00
model_path: Union[str, Path],
) -> str:
2023-07-28 13:46:44 +00:00
"""
Given the HF repo id or path to a model on disk, returns a unique
hash. Works for legacy checkpoint files, HF models on disk, and HF repo IDs
2023-08-16 01:00:30 +00:00
2023-06-10 00:14:10 +00:00
:param model_path: Path to model file/directory on disk.
2023-07-28 13:46:44 +00:00
"""
2023-06-10 00:14:10 +00:00
return self._local_model_hash(model_path)
def cache_size(self) -> float:
2023-08-16 01:00:30 +00:00
"""Return the current size of the cache, in GB."""
return self._cache_size() / GIG
def _has_cuda(self) -> bool:
2023-07-28 13:46:44 +00:00
return self.execution_device.type == "cuda"
def _print_cuda_stats(self):
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
2023-05-23 00:48:22 +00:00
ram = "%4.2fG" % self.cache_size()
2023-05-23 00:48:22 +00:00
cached_models = 0
2023-05-18 00:56:52 +00:00
loaded_models = 0
locked_models = 0
for model_info in self._cached_models.values():
2023-05-23 00:48:22 +00:00
cached_models += 1
if model_info.loaded:
2023-05-18 00:56:52 +00:00
loaded_models += 1
2023-05-23 00:48:22 +00:00
if model_info.locked:
2023-05-18 00:56:52 +00:00
locked_models += 1
2023-07-28 13:46:44 +00:00
self.logger.debug(
2023-09-28 18:14:03 +00:00
f"Current VRAM/RAM usage: {vram}/{ram}; cached_models/loaded_models/locked_models/ ="
f" {cached_models}/{loaded_models}/{locked_models}"
2023-07-28 13:46:44 +00:00
)
2023-05-18 00:56:52 +00:00
2023-08-16 01:00:30 +00:00
def _cache_size(self) -> int:
return sum([m.size for m in self._cached_models.values()])
2023-05-18 00:56:52 +00:00
def _make_cache_room(self, model_size):
# calculate how much memory this model will require
2023-07-28 13:46:44 +00:00
# multiplier = 2 if self.precision==torch.float32 else 1
2023-05-18 00:56:52 +00:00
bytes_needed = model_size
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
2023-08-16 01:00:30 +00:00
current_size = self._cache_size()
2023-05-18 00:56:52 +00:00
if current_size + bytes_needed > maximum_size:
2023-07-28 13:46:44 +00:00
self.logger.debug(
2023-09-28 18:14:03 +00:00
f"Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional"
f" {(bytes_needed/GIG):.2f} GB"
2023-07-28 13:46:44 +00:00
)
2023-05-23 00:48:22 +00:00
self.logger.debug(f"Before unloading: cached_models={len(self._cached_models)}")
2023-05-18 00:56:52 +00:00
pos = 0
2023-05-23 00:48:22 +00:00
while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack):
model_key = self._cache_stack[pos]
cache_entry = self._cached_models[model_key]
2023-05-23 00:48:22 +00:00
refs = sys.getrefcount(cache_entry.model)
2023-05-23 00:48:22 +00:00
# manualy clear local variable references of just finished function calls
# for some reason python don't want to collect it even by gc.collect() immidiately
if refs > 2:
while True:
cleared = False
for referrer in gc.get_referrers(cache_entry.model):
if type(referrer).__name__ == "frame":
# RuntimeError: cannot clear an executing frame
with suppress(RuntimeError):
referrer.clear()
cleared = True
2023-07-28 13:46:44 +00:00
# break
# repeat if referrers changes(due to frame clear), else exit loop
if cleared:
gc.collect()
else:
break
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
2023-07-28 13:46:44 +00:00
self.logger.debug(
2023-09-28 18:14:03 +00:00
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded},"
f" refs: {refs}"
2023-07-28 13:46:44 +00:00
)
2023-05-23 00:48:22 +00:00
# 2 refs:
# 1 from cache_entry
# 1 from getrefcount function
2023-07-28 13:59:35 +00:00
# 1 from onnx runtime object
2023-07-28 13:46:44 +00:00
if not cache_entry.locked and refs <= 3 if "onnx" in model_key else 2:
self.logger.debug(
f"Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
)
current_size -= cache_entry.size
2023-08-16 01:00:30 +00:00
if self.stats:
self.stats.cleared += 1
2023-05-23 00:48:22 +00:00
del self._cache_stack[pos]
del self._cached_models[model_key]
del cache_entry
2023-05-23 00:48:22 +00:00
2023-05-18 00:56:52 +00:00
else:
pos += 1
gc.collect()
2023-05-23 00:48:22 +00:00
torch.cuda.empty_cache()
if choose_torch_device() == torch.device("mps"):
mps.empty_cache()
2023-05-23 00:48:22 +00:00
self.logger.debug(f"After unloading: cached_models={len(self._cached_models)}")
2023-05-23 00:48:22 +00:00
2023-07-28 13:46:44 +00:00
def _offload_unlocked_models(self, size_needed: int = 0):
reserved = self.max_vram_cache_size * GIG
vram_in_use = torch.cuda.memory_allocated()
2023-07-28 13:46:44 +00:00
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB")
for model_key, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
if vram_in_use <= reserved:
break
if not cache_entry.locked and cache_entry.loaded:
start_model_to_time = time.time()
snapshot_before = MemorySnapshot.capture()
cache_entry.model.to(self.storage_device)
snapshot_after = MemorySnapshot.capture()
end_model_to_time = time.time()
self.logger.debug(
f"Moved model '{model_key}' from {self.execution_device} to {self.storage_device} in"
f" {(end_model_to_time-start_model_to_time):.2f}s. Estimated model size:"
f" {(cache_entry.size/GIG):.2f} GB."
f" {get_pretty_snapshot_diff(snapshot_before, snapshot_after)}."
)
vram_in_use = snapshot_after.vram or 0
2023-07-28 13:46:44 +00:00
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB")
gc.collect()
torch.cuda.empty_cache()
if choose_torch_device() == torch.device("mps"):
mps.empty_cache()
2023-07-28 13:46:44 +00:00
def _local_model_hash(self, model_path: Union[str, Path]) -> str:
sha = hashlib.sha256()
path = Path(model_path)
2023-07-28 13:46:44 +00:00
hashpath = path / "checksum.sha256"
if hashpath.exists() and path.stat().st_mtime <= hashpath.stat().st_mtime:
with open(hashpath) as f:
hash = f.read()
return hash
2023-07-28 13:46:44 +00:00
self.logger.debug(f"computing hash of model {path.name}")
for file in list(path.rglob("*.ckpt")) + list(path.rglob("*.safetensors")) + list(path.rglob("*.pth")):
with open(file, "rb") as f:
while chunk := f.read(self.sha_chunksize):
sha.update(chunk)
hash = sha.hexdigest()
with open(hashpath, "w") as f:
f.write(hash)
return hash
2023-07-28 13:46:44 +00:00
class MemorySnapshot:
"""A snapshot of RAM and VRAM usage. All values are in bytes."""
def __init__(self, process_ram: int, vram: Optional[int]):
"""Initialize a MemorySnapshot.
Most of the time, `MemorySnapshot` will be constructed with `MemorySnapshot.capture()`.
Args:
process_ram (int): CPU RAM used by the current process.
vram (Optional[int]): VRAM used by torch.
"""
self.process_ram = process_ram
self.vram = vram
@classmethod
def capture(cls, run_garbage_collector: bool = True):
"""Capture and return a MemorySnapshot.
Note: This function has significant overhead, particularly if `run_garbage_collector == True`.
Args:
run_garbage_collector (bool, optional): If true, gc.collect() will be run before checking the process RAM
usage. Defaults to True.
Returns:
MemorySnapshot
"""
if run_garbage_collector:
gc.collect()
# According to the psutil docs (https://psutil.readthedocs.io/en/latest/#psutil.Process.memory_info), rss is
# supported on all platforms.
process_ram = psutil.Process().memory_info().rss
if choose_torch_device() == torch.device("cuda"):
vram = torch.cuda.memory_allocated()
else:
# TODO: We could add support for mps.current_allocated_memory() as well. Leaving out for now until we have
# time to test it properly.
vram = None
return cls(process_ram, vram)
def get_pretty_snapshot_diff(snapshot_1: MemorySnapshot, snapshot_2: MemorySnapshot) -> str:
"""Get a pretty string describing the difference between two `MemorySnapshot`s."""
ram_diff = snapshot_2.process_ram - snapshot_1.process_ram
msg = f"RAM ({(ram_diff/GIG):+.2f}): {(snapshot_1.process_ram/GIG):.2f}GB -> {(snapshot_2.process_ram/GIG):.2f}GB"
vram_diff = None
if snapshot_1.vram is not None and snapshot_2.vram is not None:
vram_diff = snapshot_2.vram - snapshot_1.vram
2023-07-28 13:46:44 +00:00
msg += f", VRAM ({(vram_diff/GIG):+.2f}): {(snapshot_1.vram/GIG):.2f}GB -> {(snapshot_2.vram/GIG):.2f}GB"
return msg