InvokeAI/invokeai/backend/model_management/model_cache.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

393 lines
14 KiB
Python
Raw Normal View History

"""
Manage a RAM cache of diffusion/transformer models for fast switching.
They are moved between GPU VRAM and CPU RAM as necessary. If the cache
grows larger than a preset maximum, then the least recently used
model will be cleared and (re)loaded from disk when next needed.
The cache returns context manager generators designed to load the
model into the GPU within the context, and unload outside the
context. Use like this:
cache = ModelCache(max_models_cached=6)
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1,
cache.get_model('stabilityai/stable-diffusion-2') as SD2:
do_something_in_GPU(SD1,SD2)
"""
import gc
2023-05-18 00:56:52 +00:00
import os
import sys
import hashlib
from contextlib import suppress
from pathlib import Path
from typing import Dict, Union, types, Optional, Type, Any
import torch
import logging
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import get_invokeai_config
2023-05-30 23:12:27 +00:00
from .lora import LoRAModel, TextualInversionModel
2023-06-11 03:12:21 +00:00
from .models import BaseModelType, ModelType, SubModelType, ModelBase
2023-05-18 00:56:52 +00:00
# Maximum size of the cache, in gigs
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
DEFAULT_MAX_CACHE_SIZE = 6.0
# actual size of a gig
GIG = 1073741824
class ModelLocker(object):
"Forward declaration"
pass
2023-05-18 00:56:52 +00:00
class ModelCache(object):
"Forward declaration"
pass
class SDModelInfo(object):
"""Forward declaration"""
pass
2023-05-18 00:56:52 +00:00
class _CacheRecord:
size: int
model: Any
2023-05-23 00:48:22 +00:00
cache: ModelCache
2023-05-18 00:56:52 +00:00
_locks: int
def __init__(self, cache, model: Any, size: int):
2023-05-18 00:56:52 +00:00
self.size = size
self.model = model
2023-05-23 00:48:22 +00:00
self.cache = cache
2023-05-18 00:56:52 +00:00
self._locks = 0
def lock(self):
self._locks += 1
def unlock(self):
self._locks -= 1
assert self._locks >= 0
@property
def locked(self):
return self._locks > 0
@property
def loaded(self):
if self.model is not None and hasattr(self.model, "device"):
return self.model.device != self.cache.storage_device
2023-05-18 00:56:52 +00:00
else:
return False
class ModelCache(object):
def __init__(
self,
max_cache_size: float=DEFAULT_MAX_CACHE_SIZE,
execution_device: torch.device=torch.device('cuda'),
storage_device: torch.device=torch.device('cpu'),
precision: torch.dtype=torch.float16,
sequential_offload: bool=False,
lazy_offloading: bool=True,
sha_chunksize: int = 16777216,
logger: types.ModuleType = logger
):
'''
:param max_models: Maximum number of models to cache in CPU RAM [4]
:param execution_device: Torch device to load active model into [torch.device('cuda')]
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
:param precision: Precision for loaded models [torch.float16]
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
:param sha_chunksize: Chunksize to use when calculating sha256 model hash
'''
2023-05-23 00:48:22 +00:00
#max_cache_size = 9999
2023-05-18 00:56:52 +00:00
execution_device = torch.device('cuda')
self.model_infos: Dict[str, SDModelInfo] = dict()
self.lazy_offloading = lazy_offloading
2023-05-23 00:48:22 +00:00
#self.sequential_offload: bool=sequential_offload
self.precision: torch.dtype=precision
self.max_cache_size: int=max_cache_size
self.execution_device: torch.device=execution_device
self.storage_device: torch.device=storage_device
self.sha_chunksize=sha_chunksize
self.logger = logger
2023-05-18 00:56:52 +00:00
self._cached_models = dict()
2023-05-23 00:48:22 +00:00
self._cache_stack = list()
2023-05-18 00:56:52 +00:00
def get_key(
self,
model_path: str,
2023-06-11 01:49:09 +00:00
base_model: BaseModelType,
model_type: ModelType,
2023-06-11 01:49:09 +00:00
submodel_type: Optional[SubModelType] = None,
2023-05-18 00:56:52 +00:00
):
2023-06-11 01:49:09 +00:00
key = f"{model_path}:{base_model}:{model_type}"
2023-05-18 00:56:52 +00:00
if submodel_type:
key += f":{submodel_type}"
return key
#def get_model(
# self,
# repo_id_or_path: Union[str, Path],
# model_type: ModelType = ModelType.Diffusers,
2023-05-18 00:56:52 +00:00
# subfolder: Path = None,
# submodel: ModelType = None,
2023-05-18 00:56:52 +00:00
# revision: str = None,
# attach_model_part: Tuple[ModelType, str] = (None, None),
2023-05-18 00:56:52 +00:00
# gpu_load: bool = True,
#) -> ModelLocker: # ?? what does it return
def _get_model_info(
self,
model_path: str,
2023-06-10 00:14:10 +00:00
model_class: Type[ModelBase],
2023-06-11 01:49:09 +00:00
base_model: BaseModelType,
model_type: ModelType,
2023-05-18 00:56:52 +00:00
):
model_info_key = self.get_key(
model_path=model_path,
2023-06-11 01:49:09 +00:00
base_model=base_model,
2023-05-18 00:56:52 +00:00
model_type=model_type,
submodel_type=None,
)
if model_info_key not in self.model_infos:
2023-06-10 00:14:10 +00:00
self.model_infos[model_info_key] = model_class(
2023-05-18 00:56:52 +00:00
model_path,
2023-06-12 13:14:09 +00:00
base_model,
model_type,
2023-05-18 00:56:52 +00:00
)
return self.model_infos[model_info_key]
2023-06-10 00:14:10 +00:00
# TODO: args
def get_model(
self,
2023-06-10 00:14:10 +00:00
model_path: Union[str, Path],
model_class: Type[ModelBase],
2023-06-11 01:49:09 +00:00
base_model: BaseModelType,
model_type: ModelType,
2023-06-10 00:14:10 +00:00
submodel: Optional[SubModelType] = None,
gpu_load: bool = True,
2023-05-18 00:56:52 +00:00
) -> Any:
2023-06-10 00:14:10 +00:00
if not isinstance(model_path, Path):
model_path = Path(model_path)
if not os.path.exists(model_path):
raise Exception(f"Model not found: {model_path}")
2023-05-18 00:56:52 +00:00
model_info = self._get_model_info(
model_path=model_path,
2023-06-10 00:14:10 +00:00
model_class=model_class,
2023-06-11 01:49:09 +00:00
base_model=base_model,
model_type=model_type,
2023-05-18 00:56:52 +00:00
)
key = self.get_key(
model_path=model_path,
2023-06-11 01:49:09 +00:00
base_model=base_model,
model_type=model_type,
2023-05-18 00:56:52 +00:00
submodel_type=submodel,
)
2023-05-23 00:48:22 +00:00
# TODO: lock for no copies on simultaneous calls?
cache_entry = self._cached_models.get(key, None)
if cache_entry is None:
2023-06-11 01:49:09 +00:00
self.logger.info(f'Loading model {model_path}, type {base_model}:{model_type}:{submodel}')
# this will remove older cached models until
# there is sufficient room to load the requested model
2023-05-18 00:56:52 +00:00
self._make_cache_room(model_info.get_size(submodel))
# clean memory to make MemoryUsage() more accurate
gc.collect()
2023-06-12 13:14:09 +00:00
model = model_info.get_model(child_type=submodel, torch_dtype=self.precision)
2023-05-18 00:56:52 +00:00
if mem_used := model_info.get_size(submodel):
2023-05-23 00:48:22 +00:00
self.logger.debug(f'CPU RAM used for load: {(mem_used/GIG):.2f} GB')
cache_entry = _CacheRecord(self, model, mem_used)
self._cached_models[key] = cache_entry
2023-05-18 00:56:52 +00:00
with suppress(Exception):
self._cache_stack.remove(key)
self._cache_stack.append(key)
return self.ModelLocker(self, key, cache_entry.model, gpu_load)
class ModelLocker(object):
def __init__(self, cache, key, model, gpu_load):
self.gpu_load = gpu_load
self.cache = cache
self.key = key
self.model = model
2023-05-18 00:56:52 +00:00
def __enter__(self) -> Any:
if not hasattr(self.model, 'to'):
return self.model
cache_entry = self.cache._cached_models[self.key]
2023-05-18 00:56:52 +00:00
2023-05-10 03:46:59 +00:00
# NOTE that the model has to have the to() method in order for this
# code to move it into GPU!
2023-05-18 00:56:52 +00:00
if self.gpu_load:
cache_entry.lock()
2023-05-23 00:48:22 +00:00
try:
if self.cache.lazy_offloading:
self.cache._offload_unlocked_models()
if self.model.device != self.cache.execution_device:
self.cache.logger.debug(f'Moving {self.key} into {self.cache.execution_device}')
with VRAMUsage() as mem:
self.model.to(self.cache.execution_device) # move into GPU
self.cache.logger.debug(f'GPU VRAM used for load: {(mem.vram_used/GIG):.2f} GB')
self.cache.logger.debug(f'Locking {self.key} in {self.cache.execution_device}')
self.cache._print_cuda_stats()
except:
cache_entry.unlock()
raise
2023-05-18 00:56:52 +00:00
# TODO: not fully understand
# in the event that the caller wants the model in RAM, we
# move it into CPU if it is in GPU and not locked
elif cache_entry.loaded and not cache_entry.locked:
self.model.to(self.cache.storage_device)
return self.model
def __exit__(self, type, value, traceback):
if not hasattr(self.model, 'to'):
return
cache_entry = self.cache._cached_models[self.key]
2023-05-23 00:48:22 +00:00
cache_entry.unlock()
2023-05-18 00:56:52 +00:00
if not self.cache.lazy_offloading:
self.cache._offload_unlocked_models()
self.cache._print_cuda_stats()
2023-05-23 00:48:22 +00:00
def model_hash(
self,
2023-06-10 00:14:10 +00:00
model_path: Union[str, Path],
) -> str:
'''
Given the HF repo id or path to a model on disk, returns a unique
hash. Works for legacy checkpoint files, HF models on disk, and HF repo IDs
2023-06-10 00:14:10 +00:00
:param model_path: Path to model file/directory on disk.
'''
2023-06-10 00:14:10 +00:00
return self._local_model_hash(model_path)
def cache_size(self) -> float:
"Return the current size of the cache, in GB"
current_cache_size = sum([m.size for m in self._cached_models.values()])
2023-05-23 00:48:22 +00:00
return current_cache_size / GIG
def _has_cuda(self) -> bool:
return self.execution_device.type == 'cuda'
def _print_cuda_stats(self):
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
2023-05-23 00:48:22 +00:00
ram = "%4.2fG" % self.cache_size()
2023-05-23 00:48:22 +00:00
cached_models = 0
2023-05-18 00:56:52 +00:00
loaded_models = 0
locked_models = 0
for model_info in self._cached_models.values():
2023-05-23 00:48:22 +00:00
cached_models += 1
if model_info.loaded:
2023-05-18 00:56:52 +00:00
loaded_models += 1
2023-05-23 00:48:22 +00:00
if model_info.locked:
2023-05-18 00:56:52 +00:00
locked_models += 1
2023-05-23 00:48:22 +00:00
self.logger.debug(f"Current VRAM/RAM usage: {vram}/{ram}; cached_models/loaded_models/locked_models/ = {cached_models}/{loaded_models}/{locked_models}")
2023-05-18 00:56:52 +00:00
def _make_cache_room(self, model_size):
# calculate how much memory this model will require
2023-05-18 00:56:52 +00:00
#multiplier = 2 if self.precision==torch.float32 else 1
bytes_needed = model_size
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
current_size = sum([m.size for m in self._cached_models.values()])
2023-05-18 00:56:52 +00:00
if current_size + bytes_needed > maximum_size:
2023-05-23 00:48:22 +00:00
self.logger.debug(f'Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional {(bytes_needed/GIG):.2f} GB')
self.logger.debug(f"Before unloading: cached_models={len(self._cached_models)}")
2023-05-18 00:56:52 +00:00
pos = 0
2023-05-23 00:48:22 +00:00
while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack):
model_key = self._cache_stack[pos]
cache_entry = self._cached_models[model_key]
2023-05-23 00:48:22 +00:00
refs = sys.getrefcount(cache_entry.model)
2023-05-23 00:48:22 +00:00
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
self.logger.debug(f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}, refs: {refs}")
2023-05-23 00:48:22 +00:00
# 2 refs:
# 1 from cache_entry
# 1 from getrefcount function
if not cache_entry.locked and refs <= 2:
self.logger.debug(f'Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)')
current_size -= cache_entry.size
2023-05-23 00:48:22 +00:00
del self._cache_stack[pos]
del self._cached_models[model_key]
del cache_entry
2023-05-23 00:48:22 +00:00
2023-05-18 00:56:52 +00:00
else:
pos += 1
gc.collect()
2023-05-23 00:48:22 +00:00
torch.cuda.empty_cache()
self.logger.debug(f"After unloading: cached_models={len(self._cached_models)}")
2023-05-23 00:48:22 +00:00
def _offload_unlocked_models(self):
for model_key, cache_entry in self._cached_models.items():
if not cache_entry.locked and cache_entry.loaded:
self.logger.debug(f'Offloading {model_key} from {self.execution_device} into {self.storage_device}')
cache_entry.model.to(self.storage_device)
def _local_model_hash(self, model_path: Union[str, Path]) -> str:
sha = hashlib.sha256()
path = Path(model_path)
hashpath = path / "checksum.sha256"
if hashpath.exists() and path.stat().st_mtime <= hashpath.stat().st_mtime:
with open(hashpath) as f:
hash = f.read()
return hash
2023-05-23 00:48:22 +00:00
self.logger.debug(f'computing hash of model {path.name}')
for file in list(path.rglob("*.ckpt")) \
+ list(path.rglob("*.safetensors")) \
+ list(path.rglob("*.pth")):
with open(file, "rb") as f:
while chunk := f.read(self.sha_chunksize):
sha.update(chunk)
hash = sha.hexdigest()
with open(hashpath, "w") as f:
f.write(hash)
return hash
class VRAMUsage(object):
def __init__(self):
self.vram = None
self.vram_used = 0
def __enter__(self):
self.vram = torch.cuda.memory_allocated()
return self
def __exit__(self, *args):
self.vram_used = torch.cuda.memory_allocated() - self.vram