InvokeAI/invokeai/backend/model_management/model_cache.py
2023-05-23 03:48:22 +03:00

797 lines
28 KiB
Python

"""
Manage a RAM cache of diffusion/transformer models for fast switching.
They are moved between GPU VRAM and CPU RAM as necessary. If the cache
grows larger than a preset maximum, then the least recently used
model will be cleared and (re)loaded from disk when next needed.
The cache returns context manager generators designed to load the
model into the GPU within the context, and unload outside the
context. Use like this:
cache = ModelCache(max_models_cached=6)
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1,
cache.get_model('stabilityai/stable-diffusion-2') as SD2:
do_something_in_GPU(SD1,SD2)
"""
import contextlib
import weakref
import gc
import os
import sys
import hashlib
import warnings
from contextlib import suppress
from enum import Enum
from pathlib import Path
from typing import Dict, Sequence, Union, types, Optional, List, Type, Any
import torch
from diffusers import DiffusionPipeline, SchedulerMixin, ConfigMixin
from diffusers import logging as diffusers_logging
from huggingface_hub import HfApi, scan_cache_dir
from transformers import logging as transformers_logging
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import get_invokeai_config
def get_model_path(repo_id_or_path: str):
globals = get_invokeai_config()
if os.path.exists(repo_id_or_path):
return repo_id_or_path
cache = scan_cache_dir(globals.cache_dir)
for repo in cache.repos:
if repo.repo_id != repo_id_or_path:
continue
for rev in repo.revisions:
if "main" in rev.refs:
return rev.snapshot_path
raise Exception(f"{repo_id_or_path} - not found")
def calc_model_size_by_fs(
repo_id_or_path: str,
subfolder: Optional[str] = None,
variant: Optional[str] = None
):
model_path = get_model_path(repo_id_or_path)
if subfolder is not None:
model_path = os.path.join(model_path, subfolder)
all_files = os.listdir(model_path)
all_files = [f for f in all_files if os.path.isfile(os.path.join(model_path, f))]
fp16_files = set([f for f in all_files if ".fp16." in f or ".fp16-" in f])
bit8_files = set([f for f in all_files if ".8bit." in f or ".8bit-" in f])
other_files = set(all_files) - fp16_files - bit8_files
if variant is None:
files = other_files
elif variant == "fp16":
files = fp16_files
elif variant == "8bit":
files = bit8_files
else:
raise NotImplementedError(f"Unknown variant: {variant}")
# try read from index if exists
index_postfix = ".index.json"
if variant is not None:
index_postfix = f".index.{variant}.json"
for file in files:
if not file.endswith(index_postfix):
continue
try:
with open(os.path.join(model_path, index_file), "r") as f:
index_data = json.loads(f.read())
return int(index_data["metadata"]["total_size"])
except:
pass
# calculate files size if there is no index file
formats = [
(".safetensors",), # safetensors
(".bin",), # torch
(".onnx", ".pb"), # onnx
(".msgpack",), # flax
(".ckpt",), # tf
(".h5",), # tf2
]
for file_format in formats:
model_files = [f for f in files if f.endswith(file_format)]
if len(model_files) == 0:
continue
model_size = 0
for model_file in model_files:
file_stats = os.stat(os.path.join(model_path, model_file))
model_size += file_stats.st_size
return model_size
#raise NotImplementedError(f"Unknown model structure! Files: {all_files}")
return 0 # scheduler/feature_extractor/tokenizer - models without loading to gpu
def calc_model_size_by_data(model) -> int:
if isinstance(model, DiffusionPipeline):
return _calc_pipeline_by_data(model)
elif isinstance(model, torch.nn.Module):
return _calc_model_by_data(model)
else:
return 0
def _calc_pipeline_by_data(pipeline) -> int:
res = 0
for submodel_key in pipeline.components.keys():
submodel = getattr(pipeline, submodel_key)
if submodel is not None and isinstance(submodel, torch.nn.Module):
res += _calc_model_by_data(submodel)
return res
def _calc_model_by_data(model) -> int:
mem_params = sum([param.nelement()*param.element_size() for param in model.parameters()])
mem_bufs = sum([buf.nelement()*buf.element_size() for buf in model.buffers()])
mem = mem_params + mem_bufs # in bytes
return mem
class SDModelType(str, Enum):
Diffusers = "diffusers"
Classifier = "classifier"
UNet = "unet"
TextEncoder = "text_encoder"
Tokenizer = "tokenizer"
Vae = "vae"
Scheduler = "scheduler"
class ModelInfoBase:
#model_path: str
#model_type: SDModelType
def __init__(self, repo_id_or_path: str, model_type: SDModelType):
self.repo_id_or_path = repo_id_or_path # TODO: or use allways path?
self.model_path = get_model_path(repo_id_or_path)
self.model_type = model_type
def _definition_to_type(self, subtypes: List[str]) -> Type:
if len(subtypes) < 2:
raise Exception("Invalid subfolder definition!")
if subtypes[0] in ["diffusers", "transformers"]:
res_type = sys.modules[subtypes[0]]
subtypes = subtypes[1:]
else:
res_type = sys.modules["diffusers"]
res_type = getattr(res_type, "pipelines")
for subtype in subtypes:
res_type = getattr(res_type, subtype)
return res_type
class DiffusersModelInfo(ModelInfoBase):
#child_types: Dict[str, Type]
#child_sizes: Dict[str, int]
def __init__(self, repo_id_or_path: str, model_type: SDModelType):
assert model_type == SDModelType.Diffusers
super().__init__(repo_id_or_path, model_type)
self.child_types: Dict[str, Type] = dict()
self.child_sizes: Dict[str, int] = dict()
try:
config_data = DiffusionPipeline.load_config(repo_id_or_path)
#config_data = json.loads(os.path.join(self.model_path, "model_index.json"))
except:
raise Exception("Invalid diffusers model! (model_index.json not found or invalid)")
config_data.pop("_ignore_files", None)
# retrieve all folder_names that contain relevant files
child_components = [k for k, v in config_data.items() if isinstance(v, list)]
for child_name in child_components:
child_type = self._definition_to_type(config_data[child_name])
self.child_types[child_name] = child_type
self.child_sizes[child_name] = calc_model_size_by_fs(repo_id_or_path, subfolder=child_name)
def get_size(self, child_type: Optional[SDModelType] = None):
if child_type is None:
return sum(self.child_sizes.values())
else:
return self.child_sizes[child_type]
def get_model(
self,
child_type: Optional[SDModelType] = None,
torch_dtype: Optional[torch.dtype] = None,
):
# return pipeline in different function to pass more arguments
if child_type is None:
raise Exception("Child model type can't be null on diffusers model")
if child_type not in self.child_types:
return None # TODO: or raise
# TODO:
for variant in ["fp16", "main", None]:
try:
model = self.child_types[child_type].from_pretrained(
self.repo_id_or_path,
subfolder=child_type.value,
cache_dir=get_invokeai_config().cache_dir,
torch_dtype=torch_dtype,
variant=variant,
)
break
except Exception as e:
print("====ERR LOAD====")
print(f"{variant}: {e}")
# calc more accurate size
self.child_sizes[child_type] = calc_model_size_by_data(model)
return model
def get_pipeline(self, **kwargs):
return DiffusionPipeline.from_pretrained(
self.repo_id_or_path,
**kwargs,
)
class EmptyConfigLoader(ConfigMixin):
@classmethod
def load_config(cls, *args, **kwargs):
cls.config_name = kwargs.pop("config_name")
return super().load_config(*args, **kwargs)
class ClassifierModelInfo(ModelInfoBase):
#child_types: Dict[str, Type]
#child_sizes: Dict[str, int]
def __init__(self, repo_id_or_path: str, model_type: SDModelType):
assert model_type == SDModelType.Classifier
super().__init__(repo_id_or_path, model_type)
self.child_types: Dict[str, Type] = dict()
self.child_sizes: Dict[str, int] = dict()
try:
main_config = EmptyConfigLoader.load_config(repo_id_or_path, config_name="config.json")
#main_config = json.loads(os.path.join(self.model_path, "config.json"))
except:
raise Exception("Invalid classifier model! (config.json not found or invalid)")
self._load_tokenizer(main_config)
self._load_text_encoder(main_config)
self._load_feature_extractor(main_config)
def _load_tokenizer(self, main_config: dict):
try:
tokenizer_config = EmptyConfigLoader.load_config(repo_id_or_path, config_name="tokenizer_config.json")
#tokenizer_config = json.loads(os.path.join(self.model_path, "tokenizer_config.json"))
except:
raise Exception("Invalid classifier model! (Failed to load tokenizer_config.json)")
if "tokenizer_class" in tokenizer_config:
tokenizer_class_name = tokenizer_config["tokenizer_class"]
elif "model_type" in main_config:
tokenizer_class_name = transformers.models.auto.tokenization_auto.TOKENIZER_MAPPING_NAMES[main_config["model_type"]]
else:
raise Exception("Invalid classifier model! (Failed to detect tokenizer type)")
self.child_types[SDModelType.Tokenizer] = self._definition_to_type(["transformers", tokenizer_class_name])
self.child_sizes[SDModelType.Tokenizer] = 0
def _load_text_encoder(self, main_config: dict):
if "architectures" in main_config and len(main_config["architectures"]) > 0:
text_encoder_class_name = main_config["architectures"][0]
elif "model_type" in main_config:
text_encoder_class_name = transformers.models.auto.modeling_auto.MODEL_FOR_PRETRAINING_MAPPING_NAMES[main_config["model_type"]]
else:
raise Exception("Invalid classifier model! (Failed to detect text_encoder type)")
self.child_types[SDModelType.TextEncoder] = self._definition_to_type(["transformers", text_encoder_class_name])
self.child_sizes[SDModelType.TextEncoder] = calc_model_size_by_fs(repo_id_or_path)
def _load_feature_extractor(self, main_config: dict):
self.child_sizes[SDModelType.FeatureExtractor] = 0
try:
feature_extractor_config = EmptyConfigLoader.load_config(repo_id_or_path, config_name="preprocessor_config.json")
except:
return # feature extractor not passed with t5
try:
feature_extractor_class_name = feature_extractor_config["feature_extractor_type"]
self.child_types[SDModelType.FeatureExtractor] = self._definition_to_type(["transformers", feature_extractor_class_name])
except:
raise Exception("Invalid classifier model! (Unknown feature_extrator type)")
def get_size(self, child_type: Optional[SDModelType] = None):
if child_type is None:
return sum(self.child_sizes.values())
else:
return self.child_sizes[child_type]
def get_model(
self,
child_type: Optional[SDModelType] = None,
torch_dtype: Optional[torch.dtype] = None,
):
if child_type is None:
raise Exception("Child model type can't be null on classififer model")
if child_type not in self.child_types:
return None # TODO: or raise
model = self.child_types[child_type].from_pretrained(
self.repo_id_or_path,
subfolder=child_type.value,
cache_dir=get_invokeai_config().cache_dir,
torch_dtype=torch_dtype,
)
# calc more accurate size
self.child_sizes[child_type] = calc_model_size_by_data(model)
return model
class VaeModelInfo(ModelInfoBase):
#vae_class: Type
#model_size: int
def __init__(self, repo_id_or_path: str, model_type: SDModelType):
assert model_type == SDModelType.Vae
super().__init__(repo_id_or_path, model_type)
try:
config = EmptyConfigLoader.load_config(repo_id_or_path, config_name="config.json")
#config = json.loads(os.path.join(self.model_path, "config.json"))
except:
raise Exception("Invalid vae model! (config.json not found or invalid)")
try:
vae_class_name = config.get("_class_name", "AutoencoderKL")
self.vae_class = self._definition_to_type(["diffusers", vae_class_name])
self.model_size = calc_model_size_by_fs(repo_id_or_path)
except:
raise Exception("Invalid vae model! (Unkown vae type)")
def get_size(self, child_type: Optional[SDModelType] = None):
if child_type is not None:
raise Exception("There is no child models in vae model")
return self.model_size
def get_model(
self,
child_type: Optional[SDModelType] = None,
torch_dtype: Optional[torch.dtype] = None,
):
if child_type is not None:
raise Exception("There is no child models in vae model")
model = self.vae_type.from_pretrained(
self.repo_id_or_path,
cache_dir=get_invokeai_config().cache_dir,
torch_dtype=torch_dtype,
)
# calc more accurate size
self.model_size = calc_model_size_by_data(model)
return model
MODEL_TYPES = {
SDModelType.Diffusers: DiffusersModelInfo,
SDModelType.Classifier: ClassifierModelInfo,
SDModelType.Vae: VaeModelInfo,
}
# Maximum size of the cache, in gigs
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
DEFAULT_MAX_CACHE_SIZE = 6.0
# actual size of a gig
GIG = 1073741824
# TODO:
class EmptyScheduler(SchedulerMixin, ConfigMixin):
pass
class ModelLocker(object):
"Forward declaration"
pass
class ModelCache(object):
"Forward declaration"
pass
class _CacheRecord:
key: str
size: int
cache: ModelCache
_locks: int
def __init__(self, cache, key: Any, size: int):
self.key = key
self.size = size
self.cache = cache
self._locks = 0
def lock(self):
self._locks += 1
def unlock(self):
self._locks -= 1
assert self._locks >= 0
@property
def locked(self):
return self._locks > 0
@property
def loaded(self):
model = self.cache._cached_models.get(self.key, None)
if model is not None and hasattr(model, "device"):
return model.device != self.cache.storage_device
else:
return False
class ModelCache(object):
def __init__(
self,
max_cache_size: float=DEFAULT_MAX_CACHE_SIZE,
execution_device: torch.device=torch.device('cuda'),
storage_device: torch.device=torch.device('cpu'),
precision: torch.dtype=torch.float16,
sequential_offload: bool=False,
lazy_offloading: bool=True,
sha_chunksize: int = 16777216,
logger: types.ModuleType = logger
):
'''
:param max_models: Maximum number of models to cache in CPU RAM [4]
:param execution_device: Torch device to load active model into [torch.device('cuda')]
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
:param precision: Precision for loaded models [torch.float16]
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
:param sha_chunksize: Chunksize to use when calculating sha256 model hash
'''
#max_cache_size = 9999
execution_device = torch.device('cuda')
self.model_infos: Dict[str, ModelInfoBase] = dict()
self.lazy_offloading = lazy_offloading
#self.sequential_offload: bool=sequential_offload
self.precision: torch.dtype=precision
self.max_cache_size: int=max_cache_size
self.execution_device: torch.device=execution_device
self.storage_device: torch.device=storage_device
self.sha_chunksize=sha_chunksize
self.logger = logger
self._cached_models = weakref.WeakValueDictionary()
self._cached_infos = weakref.WeakKeyDictionary()
self._cache_stack = list()
def get_key(
self,
model_path: str,
model_type: SDModelType,
revision: Optional[str] = None,
submodel_type: Optional[SDModelType] = None,
):
revision = revision or "main"
key = f"{model_path}:{model_type}:{revision}"
if submodel_type:
key += f":{submodel_type}"
return key
#def get_model(
# self,
# repo_id_or_path: Union[str, Path],
# model_type: SDModelType = SDModelType.Diffusers,
# subfolder: Path = None,
# submodel: SDModelType = None,
# revision: str = None,
# attach_model_part: Tuple[SDModelType, str] = (None, None),
# gpu_load: bool = True,
#) -> ModelLocker: # ?? what does it return
def _get_model_info(
self,
model_path: str,
model_type: SDModelType,
revision: str,
):
model_info_key = self.get_key(
model_path=model_path,
model_type=model_type,
revision=revision,
submodel_type=None,
)
if model_info_key not in self.model_infos:
if model_type not in MODEL_TYPES:
raise Exception(f"Unknown/unsupported model type: {model_type}")
self.model_infos[model_info_key] = MODEL_TYPES[model_type](
model_path,
model_type,
)
return self.model_infos[model_info_key]
def get_model(
self,
repo_id_or_path: Union[str, Path],
model_type: SDModelType = SDModelType.Diffusers,
submodel: Optional[SDModelType] = None,
revision: Optional[str] = None,
variant: Optional[str] = None,
gpu_load: bool = True,
) -> Any:
model_path = get_model_path(repo_id_or_path)
model_info = self._get_model_info(
model_path=model_path,
model_type=model_type,
revision=revision,
)
key = self.get_key(
model_path=model_path,
model_type=model_type,
revision=revision,
submodel_type=submodel,
)
# TODO: lock for no copies on simultaneous calls?
model = self._cached_models.get(key, None)
if model is None:
self.logger.info(f'Loading model {repo_id_or_path}, type {model_type}:{submodel}')
# this will remove older cached models until
# there is sufficient room to load the requested model
self._make_cache_room(model_info.get_size(submodel))
# clean memory to make MemoryUsage() more accurate
gc.collect()
model = model_info.get_model(submodel, torch_dtype=self.precision)
if mem_used := model_info.get_size(submodel):
self.logger.debug(f'CPU RAM used for load: {(mem_used/GIG):.2f} GB')
self._cached_models[key] = model
self._cached_infos[model] = _CacheRecord(self, key, mem_used)
with suppress(Exception):
self._cache_stack.remove(model)
self._cache_stack.append(model)
return self.ModelLocker(self, key, model, gpu_load)
class ModelLocker(object):
def __init__(self, cache, key, model, gpu_load):
self.gpu_load = gpu_load
self.cache = cache
self.key = key
self.model = model
def __enter__(self) -> Any:
if not hasattr(self.model, 'to'):
return self.model
cache_entry = self.cache._cached_infos[self.model]
# NOTE that the model has to have the to() method in order for this
# code to move it into GPU!
if self.gpu_load:
cache_entry.lock()
try:
if self.cache.lazy_offloading:
self.cache._offload_unlocked_models()
if self.model.device != self.cache.execution_device:
self.cache.logger.debug(f'Moving {self.key} into {self.cache.execution_device}')
with VRAMUsage() as mem:
self.model.to(self.cache.execution_device) # move into GPU
self.cache.logger.debug(f'GPU VRAM used for load: {(mem.vram_used/GIG):.2f} GB')
self.cache.logger.debug(f'Locking {self.key} in {self.cache.execution_device}')
self.cache._print_cuda_stats()
except:
cache_entry.unlock()
raise
# TODO: not fully understand
# in the event that the caller wants the model in RAM, we
# move it into CPU if it is in GPU and not locked
elif cache_entry.loaded and not cache_entry.locked:
self.model.to(self.cache.storage_device)
return self.model
def __exit__(self, type, value, traceback):
if not hasattr(self.model, 'to'):
return
cache_entry = self.cache._cached_infos[self.model]
cache_entry.unlock()
if not self.cache.lazy_offloading:
self.cache._offload_unlocked_models()
self.cache._print_cuda_stats()
def model_hash(
self,
repo_id_or_path: Union[str, Path],
revision: str = "main",
) -> str:
'''
Given the HF repo id or path to a model on disk, returns a unique
hash. Works for legacy checkpoint files, HF models on disk, and HF repo IDs
:param repo_id_or_path: repo_id string or Path to model file/directory on disk.
:param revision: optional revision string (if fetching a HF repo_id)
'''
revision = revision or "main"
if Path(repo_id_or_path).is_dir():
return self._local_model_hash(repo_id_or_path)
else:
return self._hf_commit_hash(repo_id_or_path,revision)
def cache_size(self) -> float:
"Return the current size of the cache, in GB"
current_cache_size = sum([m.size for m in self._cached_infos.values()])
return current_cache_size / GIG
def _has_cuda(self) -> bool:
return self.execution_device.type == 'cuda'
def _print_cuda_stats(self):
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
ram = "%4.2fG" % self.cache_size()
cached_models = 0
loaded_models = 0
locked_models = 0
for model_info in self._cached_infos.values():
cached_models += 1
if model_info.loaded:
loaded_models += 1
if model_info.locked:
locked_models += 1
self.logger.debug(f"Current VRAM/RAM usage: {vram}/{ram}; cached_models/loaded_models/locked_models/ = {cached_models}/{loaded_models}/{locked_models}")
def _make_cache_room(self, model_size):
# calculate how much memory this model will require
#multiplier = 2 if self.precision==torch.float32 else 1
bytes_needed = model_size
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
current_size = sum([m.size for m in self._cached_infos.values()])
if current_size + bytes_needed > maximum_size:
self.logger.debug(f'Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional {(bytes_needed/GIG):.2f} GB')
self.logger.debug(f"Before unloading: cached_models={len(self._cached_infos)}")
pos = 0
while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack):
model = self._cache_stack[pos]
model_info = self._cached_infos[model]
refs = sys.getrefcount(model)
device = model.device if hasattr(model, "device") else None
self.logger.debug(f"Model: {model_info.key}, locks: {model_info._locks}, device: {device}, loaded: {model_info.loaded}, refs: {refs}")
# 3 refs = 1 from _cache_stack, 1 from local variable, 1 from getrefcount function
if not model_info.locked and refs <= 3:
self.logger.debug(f'Unloading model {model_info.key} to free {(model_size/GIG):.2f} GB (-{(model_info.size/GIG):.2f} GB)')
current_size -= model_info.size
del self._cache_stack[pos]
del model
del model_info
else:
pos += 1
gc.collect()
torch.cuda.empty_cache()
self.logger.debug(f"After unloading: cached_models={len(self._cached_infos)}")
def _offload_unlocked_models(self):
for model, model_info in self._cached_infos.items():
if not model_info.locked and model_info.loaded:
self.logger.debug(f'Offloading {model_info.key} from {self.execution_device} into {self.storage_device}')
model.to(self.storage_device)
def _local_model_hash(self, model_path: Union[str, Path]) -> str:
sha = hashlib.sha256()
path = Path(model_path)
hashpath = path / "checksum.sha256"
if hashpath.exists() and path.stat().st_mtime <= hashpath.stat().st_mtime:
with open(hashpath) as f:
hash = f.read()
return hash
self.logger.debug(f'computing hash of model {path.name}')
for file in list(path.rglob("*.ckpt")) \
+ list(path.rglob("*.safetensors")) \
+ list(path.rglob("*.pth")):
with open(file, "rb") as f:
while chunk := f.read(self.sha_chunksize):
sha.update(chunk)
hash = sha.hexdigest()
with open(hashpath, "w") as f:
f.write(hash)
return hash
def _hf_commit_hash(self, repo_id: str, revision: str='main') -> str:
api = HfApi()
info = api.list_repo_refs(
repo_id=repo_id,
repo_type='model',
)
desired_revisions = [branch for branch in info.branches if branch.name==revision]
if not desired_revisions:
raise KeyError(f"Revision '{revision}' not found in {repo_id}")
return desired_revisions[0].target_commit
class SilenceWarnings(object):
def __init__(self):
self.transformers_verbosity = transformers_logging.get_verbosity()
self.diffusers_verbosity = diffusers_logging.get_verbosity()
def __enter__(self):
transformers_logging.set_verbosity_error()
diffusers_logging.set_verbosity_error()
warnings.simplefilter('ignore')
def __exit__(self,type,value,traceback):
transformers_logging.set_verbosity(self.transformers_verbosity)
diffusers_logging.set_verbosity(self.diffusers_verbosity)
warnings.simplefilter('default')
class VRAMUsage(object):
def __init__(self):
self.vram = None
self.vram_used = 0
def __enter__(self):
self.vram = torch.cuda.memory_allocated()
return self
def __exit__(self, *args):
self.vram_used = torch.cuda.memory_allocated() - self.vram