2023-04-28 04:41:52 +00:00
|
|
|
"""
|
2023-05-03 16:38:18 +00:00
|
|
|
Manage a RAM cache of diffusion/transformer models for fast switching.
|
|
|
|
They are moved between GPU VRAM and CPU RAM as necessary. If the cache
|
2023-04-28 04:41:52 +00:00
|
|
|
grows larger than a preset maximum, then the least recently used
|
|
|
|
model will be cleared and (re)loaded from disk when next needed.
|
2023-05-03 16:38:18 +00:00
|
|
|
|
|
|
|
The cache returns context manager generators designed to load the
|
|
|
|
model into the GPU within the context, and unload outside the
|
|
|
|
context. Use like this:
|
|
|
|
|
|
|
|
cache = ModelCache(max_models_cached=6)
|
|
|
|
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1,
|
|
|
|
cache.get_model('stabilityai/stable-diffusion-2') as SD2:
|
|
|
|
do_something_in_GPU(SD1,SD2)
|
|
|
|
|
|
|
|
|
2023-04-28 04:41:52 +00:00
|
|
|
"""
|
|
|
|
|
|
|
|
import contextlib
|
|
|
|
import gc
|
2023-05-03 16:38:18 +00:00
|
|
|
import hashlib
|
2023-04-28 04:41:52 +00:00
|
|
|
import warnings
|
2023-05-03 16:38:18 +00:00
|
|
|
from collections.abc import Generator
|
2023-05-02 02:57:30 +00:00
|
|
|
from enum import Enum
|
2023-04-28 04:41:52 +00:00
|
|
|
from pathlib import Path
|
|
|
|
from typing import Sequence, Union
|
|
|
|
|
2023-05-03 16:38:18 +00:00
|
|
|
import torch
|
|
|
|
from diffusers import AutoencoderKL, SchedulerMixin, UNet2DConditionModel
|
|
|
|
from diffusers import logging as diffusers_logging
|
|
|
|
from diffusers.pipelines.stable_diffusion.safety_checker import \
|
|
|
|
StableDiffusionSafetyChecker
|
|
|
|
from huggingface_hub import HfApi
|
|
|
|
from picklescan.scanner import scan_file_path
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
|
|
|
from transformers import logging as transformers_logging
|
|
|
|
|
|
|
|
from ..globals import global_cache_dir
|
|
|
|
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
2023-05-02 20:52:27 +00:00
|
|
|
from . import load_pipeline_from_original_stable_diffusion_ckpt
|
2023-04-28 04:41:52 +00:00
|
|
|
|
2023-05-03 16:38:18 +00:00
|
|
|
MAX_MODELS = 4
|
2023-05-02 02:57:30 +00:00
|
|
|
|
|
|
|
# This is the mapping from the stable diffusion submodel dict key to the class
|
|
|
|
class SDModelType(Enum):
|
|
|
|
diffusion_pipeline=StableDiffusionGeneratorPipeline # whole thing
|
|
|
|
vae=AutoencoderKL # parts
|
|
|
|
text_encoder=CLIPTextModel
|
|
|
|
tokenizer=CLIPTokenizer
|
|
|
|
unet=UNet2DConditionModel
|
|
|
|
scheduler=SchedulerMixin
|
|
|
|
safety_checker=StableDiffusionSafetyChecker
|
|
|
|
feature_extractor=CLIPFeatureExtractor
|
2023-04-28 04:41:52 +00:00
|
|
|
|
2023-05-02 20:52:27 +00:00
|
|
|
# The list of model classes we know how to fetch, for typechecking
|
2023-05-02 02:57:30 +00:00
|
|
|
ModelClass = Union[tuple([x.value for x in SDModelType])]
|
|
|
|
|
2023-05-02 20:52:27 +00:00
|
|
|
# Legacy information needed to load a legacy checkpoint file
|
|
|
|
class LegacyInfo(BaseModel):
|
|
|
|
config_file: Path
|
2023-05-03 16:38:18 +00:00
|
|
|
vae_file: Path = None
|
|
|
|
|
|
|
|
class UnsafeModelException(Exception):
|
|
|
|
"Raised when a legacy model file fails the picklescan test"
|
|
|
|
pass
|
|
|
|
|
|
|
|
class UnscannableModelException(Exception):
|
|
|
|
"Raised when picklescan is unable to scan a legacy model file"
|
|
|
|
pass
|
2023-05-02 20:52:27 +00:00
|
|
|
|
2023-04-28 04:41:52 +00:00
|
|
|
class ModelCache(object):
|
|
|
|
def __init__(
|
|
|
|
self,
|
2023-05-03 16:38:18 +00:00
|
|
|
max_models: int=MAX_MODELS,
|
2023-04-28 04:41:52 +00:00
|
|
|
execution_device: torch.device=torch.device('cuda'),
|
2023-05-03 16:38:18 +00:00
|
|
|
storage_device: torch.device=torch.device('cpu'),
|
2023-04-28 04:41:52 +00:00
|
|
|
precision: torch.dtype=torch.float16,
|
|
|
|
sequential_offload: bool=False,
|
2023-05-02 20:52:27 +00:00
|
|
|
sha_chunksize: int = 16777216,
|
2023-04-28 04:41:52 +00:00
|
|
|
):
|
2023-05-02 20:52:27 +00:00
|
|
|
'''
|
2023-05-03 16:38:18 +00:00
|
|
|
:param max_models: Maximum number of models to cache in CPU RAM [4]
|
2023-05-02 20:52:27 +00:00
|
|
|
:param execution_device: Torch device to load active model into [torch.device('cuda')]
|
2023-05-03 16:38:18 +00:00
|
|
|
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
|
2023-05-02 20:52:27 +00:00
|
|
|
:param precision: Precision for loaded models [torch.float16]
|
|
|
|
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
|
|
|
|
:param sha_chunksize: Chunksize to use when calculating sha256 model hash
|
|
|
|
'''
|
2023-04-28 04:41:52 +00:00
|
|
|
self.models: dict = dict()
|
|
|
|
self.stack: Sequence = list()
|
|
|
|
self.sequential_offload: bool=sequential_offload
|
|
|
|
self.precision: torch.dtype=precision
|
2023-05-03 16:38:18 +00:00
|
|
|
self.max_models: int=max_models
|
|
|
|
self.execution_device: torch.device=execution_device
|
|
|
|
self.storage_device: torch.device=storage_device
|
2023-05-02 20:52:27 +00:00
|
|
|
self.sha_chunksize=sha_chunksize
|
2023-04-28 04:41:52 +00:00
|
|
|
|
2023-05-03 16:38:18 +00:00
|
|
|
@contextlib.contextmanager
|
2023-04-28 04:41:52 +00:00
|
|
|
def get_model(
|
|
|
|
self,
|
|
|
|
repo_id_or_path: Union[str,Path],
|
2023-05-02 02:57:30 +00:00
|
|
|
model_type: SDModelType=SDModelType.diffusion_pipeline,
|
2023-04-28 04:41:52 +00:00
|
|
|
subfolder: Path=None,
|
2023-05-03 16:38:18 +00:00
|
|
|
submodel: SDModelType=None,
|
2023-04-28 04:41:52 +00:00
|
|
|
revision: str=None,
|
2023-05-02 20:52:27 +00:00
|
|
|
legacy_info: LegacyInfo=None,
|
2023-05-03 16:38:18 +00:00
|
|
|
gpu_load: bool=True,
|
|
|
|
)->Generator[ModelClass, None, None]:
|
2023-04-28 04:41:52 +00:00
|
|
|
'''
|
2023-05-03 16:38:18 +00:00
|
|
|
Load and return a HuggingFace model wrapped in a context manager generator, with RAM caching.
|
|
|
|
Use like this:
|
|
|
|
|
|
|
|
cache = ModelCache()
|
|
|
|
with cache.get_model('stabilityai/stable-diffusion-2') as SD2:
|
|
|
|
do_something_with_the_model(SD2)
|
|
|
|
|
|
|
|
The model will be locked into GPU VRAM for the duration of the context.
|
2023-04-28 04:41:52 +00:00
|
|
|
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
|
|
|
|
:param subfolder: name of a subfolder in which the model can be found, e.g. "vae"
|
2023-05-03 16:38:18 +00:00
|
|
|
:param submodel: an SDModelType enum indicating the model part to return, e.g. SDModelType.vae
|
2023-04-28 04:41:52 +00:00
|
|
|
:param revision: model revision
|
|
|
|
:param model_class: class of model to return
|
2023-05-03 16:38:18 +00:00
|
|
|
:param gpu_load: load the model into GPU [default True]
|
2023-05-02 20:52:27 +00:00
|
|
|
:param legacy_info: a LegacyInfo object containing additional info needed to load a legacy ckpt
|
2023-04-28 04:41:52 +00:00
|
|
|
'''
|
2023-05-02 20:52:27 +00:00
|
|
|
key = self._model_key( # internal unique identifier for the model
|
2023-05-02 02:57:30 +00:00
|
|
|
repo_id_or_path,
|
|
|
|
model_type.value,
|
|
|
|
revision,
|
|
|
|
subfolder
|
2023-05-02 20:52:27 +00:00
|
|
|
)
|
2023-04-28 04:41:52 +00:00
|
|
|
if key in self.models: # cached - move to bottom of stack
|
|
|
|
with contextlib.suppress(ValueError):
|
|
|
|
self.stack.remove(key)
|
|
|
|
self.stack.append(key)
|
2023-05-03 16:38:18 +00:00
|
|
|
model = self.models[key]
|
2023-04-28 04:41:52 +00:00
|
|
|
else: # not cached -load
|
|
|
|
self._make_cache_room()
|
|
|
|
model = self._load_model_from_storage(
|
|
|
|
repo_id_or_path=repo_id_or_path,
|
2023-05-02 02:57:30 +00:00
|
|
|
model_class=model_type.value,
|
2023-04-28 04:41:52 +00:00
|
|
|
subfolder=subfolder,
|
|
|
|
revision=revision,
|
2023-05-02 20:52:27 +00:00
|
|
|
legacy_info=legacy_info,
|
2023-04-28 04:41:52 +00:00
|
|
|
)
|
|
|
|
self.stack.append(key) # add to LRU cache
|
|
|
|
self.models[key]=model # keep copy of model in dict
|
2023-05-03 16:38:18 +00:00
|
|
|
|
|
|
|
if submodel:
|
|
|
|
model = getattr(model, submodel.name)
|
|
|
|
debugging_name = f'{submodel.name} submodel of {repo_id_or_path}'
|
|
|
|
else:
|
|
|
|
debugging_name = repo_id_or_path
|
|
|
|
|
|
|
|
try:
|
|
|
|
if gpu_load and hasattr(model,'to'):
|
|
|
|
print(f' | Loading {debugging_name} into GPU')
|
|
|
|
model.to(self.execution_device) # move into GPU
|
|
|
|
self._print_cuda_stats()
|
|
|
|
yield model
|
|
|
|
finally:
|
|
|
|
if gpu_load and hasattr(model,'to'):
|
|
|
|
print(f' | Unloading {debugging_name} from GPU')
|
|
|
|
model.to(self.storage_device)
|
|
|
|
self._print_cuda_stats()
|
2023-04-28 04:41:52 +00:00
|
|
|
|
2023-05-03 16:38:18 +00:00
|
|
|
def model_hash(self,
|
|
|
|
repo_id_or_path: Union[str,Path],
|
2023-05-02 20:52:27 +00:00
|
|
|
revision: str=None)->str:
|
|
|
|
'''
|
|
|
|
Given the HF repo id or path to a model on disk, returns a unique
|
|
|
|
hash. Works for legacy checkpoint files, HF models on disk, and HF repo IDs
|
|
|
|
:param repo_id_or_path: repo_id string or Path to model file/directory on disk.
|
|
|
|
:param revision: optional revision string (if fetching a HF repo_id)
|
|
|
|
'''
|
|
|
|
if self.is_legacy_ckpt(repo_id_or_path):
|
|
|
|
return self._legacy_model_hash(repo_id_or_path)
|
|
|
|
elif Path(repo_id_or_path).is_dir():
|
|
|
|
return self._local_model_hash(repo_id_or_path)
|
|
|
|
else:
|
|
|
|
return self._hf_commit_hash(repo_id_or_path,revision)
|
|
|
|
|
2023-05-03 16:38:18 +00:00
|
|
|
def cache_size(self)->int:
|
|
|
|
"Return the current number of models cached."
|
|
|
|
return len(self.models)
|
|
|
|
|
2023-04-28 04:41:52 +00:00
|
|
|
@staticmethod
|
|
|
|
def _model_key(path,model_class,revision,subfolder)->str:
|
2023-05-02 02:57:30 +00:00
|
|
|
return ':'.join([str(path),model_class.__name__,str(revision or ''),str(subfolder or '')])
|
2023-04-28 04:41:52 +00:00
|
|
|
|
2023-05-03 16:38:18 +00:00
|
|
|
def _has_cuda(self)->bool:
|
|
|
|
return self.execution_device.type == 'cuda'
|
|
|
|
|
|
|
|
def _print_cuda_stats(self):
|
|
|
|
print(
|
|
|
|
" | Current VRAM usage:",
|
|
|
|
"%4.2fG" % (torch.cuda.memory_allocated() / 1e9),
|
|
|
|
)
|
|
|
|
|
2023-04-28 04:41:52 +00:00
|
|
|
def _make_cache_room(self):
|
|
|
|
models_in_ram = len(self.models)
|
2023-05-03 16:38:18 +00:00
|
|
|
while models_in_ram >= self.max_models:
|
2023-04-28 04:41:52 +00:00
|
|
|
if least_recently_used_key := self.stack.pop(0):
|
2023-05-03 16:38:18 +00:00
|
|
|
print(f' | Maximum cache size reached: cache_size={models_in_ram}; unloading model {least_recently_used_key}')
|
2023-04-28 04:41:52 +00:00
|
|
|
del self.models[least_recently_used_key]
|
|
|
|
models_in_ram = len(self.models)
|
|
|
|
gc.collect()
|
|
|
|
|
|
|
|
@property
|
2023-05-02 02:57:30 +00:00
|
|
|
def current_model(self)->ModelClass:
|
2023-04-28 04:41:52 +00:00
|
|
|
'''
|
|
|
|
Returns current model.
|
|
|
|
'''
|
|
|
|
return self.models[self._current_model_key]
|
|
|
|
|
|
|
|
@property
|
|
|
|
def _current_model_key(self)->str:
|
|
|
|
'''
|
|
|
|
Returns key of currently loaded model.
|
|
|
|
'''
|
|
|
|
return self.stack[-1]
|
|
|
|
|
|
|
|
def _load_model_from_storage(
|
|
|
|
self,
|
|
|
|
repo_id_or_path: Union[str,Path],
|
|
|
|
subfolder: Path=None,
|
|
|
|
revision: str=None,
|
2023-05-02 02:57:30 +00:00
|
|
|
model_class: ModelClass=StableDiffusionGeneratorPipeline,
|
2023-05-02 20:52:27 +00:00
|
|
|
legacy_info: LegacyInfo=None,
|
2023-05-02 02:57:30 +00:00
|
|
|
)->ModelClass:
|
2023-04-28 04:41:52 +00:00
|
|
|
'''
|
|
|
|
Load and return a HuggingFace model.
|
|
|
|
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
|
|
|
|
:param subfolder: name of a subfolder in which the model can be found, e.g. "vae"
|
|
|
|
:param revision: model revision
|
2023-05-02 20:52:27 +00:00
|
|
|
:param model_class: class of model to return, defaults to StableDiffusionGeneratorPIpeline
|
|
|
|
:param legacy_info: a LegacyInfo object containing additional info needed to load a legacy ckpt
|
2023-04-28 04:41:52 +00:00
|
|
|
'''
|
|
|
|
# silence transformer and diffuser warnings
|
|
|
|
with SilenceWarnings():
|
2023-05-02 20:52:27 +00:00
|
|
|
if self.is_legacy_ckpt(repo_id_or_path):
|
|
|
|
model = self._load_ckpt_from_storage(repo_id_or_path, legacy_info)
|
|
|
|
else:
|
|
|
|
model = self._load_diffusers_from_storage(
|
|
|
|
repo_id_or_path,
|
|
|
|
subfolder,
|
|
|
|
revision,
|
|
|
|
model_class,
|
|
|
|
)
|
2023-04-28 04:41:52 +00:00
|
|
|
if self.sequential_offload and isinstance(model,StableDiffusionGeneratorPipeline):
|
2023-05-03 16:38:18 +00:00
|
|
|
model.enable_offload_submodels(self.execution_device)
|
2023-04-28 04:41:52 +00:00
|
|
|
elif hasattr(model,'to'):
|
2023-05-03 16:38:18 +00:00
|
|
|
model.to(self.execution_device)
|
2023-04-28 04:41:52 +00:00
|
|
|
return model
|
|
|
|
|
2023-05-02 20:52:27 +00:00
|
|
|
def _load_diffusers_from_storage(
|
|
|
|
self,
|
|
|
|
repo_id_or_path: Union[str,Path],
|
|
|
|
subfolder: Path=None,
|
|
|
|
revision: str=None,
|
|
|
|
model_class: ModelClass=StableDiffusionGeneratorPipeline,
|
|
|
|
)->ModelClass:
|
|
|
|
'''
|
|
|
|
Load and return a HuggingFace model using from_pretrained().
|
|
|
|
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
|
|
|
|
:param subfolder: name of a subfolder in which the model can be found, e.g. "vae"
|
|
|
|
:param revision: model revision
|
|
|
|
:param model_class: class of model to return, defaults to StableDiffusionGeneratorPIpeline
|
|
|
|
'''
|
|
|
|
return model_class.from_pretrained(
|
|
|
|
repo_id_or_path,
|
|
|
|
revision=revision,
|
|
|
|
subfolder=subfolder or '.',
|
|
|
|
cache_dir=global_cache_dir('hub'),
|
|
|
|
)
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def is_legacy_ckpt(cls, repo_id_or_path: Union[str,Path])->bool:
|
|
|
|
'''
|
|
|
|
Return true if the indicated path is a legacy checkpoint
|
|
|
|
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
|
|
|
|
'''
|
|
|
|
path = Path(repo_id_or_path)
|
|
|
|
return path.is_file() and path.suffix in [".ckpt",".safetensors"]
|
|
|
|
|
2023-05-03 16:38:18 +00:00
|
|
|
@classmethod
|
|
|
|
def scan_model(cls, model_name, checkpoint):
|
|
|
|
"""
|
|
|
|
Apply picklescanner to the indicated checkpoint and issue a warning
|
|
|
|
and option to exit if an infected file is identified.
|
|
|
|
"""
|
|
|
|
# scan model
|
|
|
|
print(f" | Scanning Model: {model_name}")
|
|
|
|
scan_result = scan_file_path(checkpoint)
|
|
|
|
if scan_result.infected_files != 0:
|
|
|
|
if scan_result.infected_files == 1:
|
|
|
|
raise UnsafeModelException("The legacy model you are trying to load may contain malware. Aborting.")
|
|
|
|
else:
|
|
|
|
raise UnscannableModelException("InvokeAI was unable to scan the legacy model you requested. Aborting")
|
|
|
|
else:
|
|
|
|
print(" | Model scanned ok")
|
|
|
|
|
2023-05-02 20:52:27 +00:00
|
|
|
def _load_ckpt_from_storage(self,
|
|
|
|
ckpt_path: Union[str,Path],
|
|
|
|
legacy_info:LegacyInfo)->StableDiffusionGeneratorPipeline:
|
|
|
|
'''
|
|
|
|
Load a legacy checkpoint, convert it, and return a StableDiffusionGeneratorPipeline.
|
|
|
|
:param ckpt_path: string or Path pointing to the weights file (.ckpt or .safetensors)
|
|
|
|
:param legacy_info: LegacyInfo object containing paths to legacy config file and alternate vae if required
|
|
|
|
'''
|
|
|
|
assert legacy_info is not None
|
|
|
|
pipeline = load_pipeline_from_original_stable_diffusion_ckpt(
|
|
|
|
checkpoint_path=ckpt_path,
|
|
|
|
original_config_file=legacy_info.config_file,
|
|
|
|
vae_path=legacy_info.vae_file,
|
|
|
|
return_generator_pipeline=True,
|
|
|
|
precision=self.precision,
|
|
|
|
)
|
|
|
|
return pipeline
|
|
|
|
|
|
|
|
def _legacy_model_hash(self, checkpoint_path: Union[str,Path])->str:
|
|
|
|
sha = hashlib.sha256()
|
|
|
|
path = Path(checkpoint_path)
|
|
|
|
assert path.is_file()
|
|
|
|
|
|
|
|
hashpath = path.parent / f"{path.name}.sha256"
|
|
|
|
if hashpath.exists() and path.stat().st_mtime <= hashpath.stat().st_mtime:
|
|
|
|
with open(hashpath) as f:
|
|
|
|
hash = f.read()
|
|
|
|
return hash
|
|
|
|
|
|
|
|
print(f' | computing hash of model {path.name}')
|
|
|
|
with open(path, "rb") as f:
|
|
|
|
while chunk := f.read(self.sha_chunksize):
|
|
|
|
sha.update(chunk)
|
|
|
|
hash = sha.hexdigest()
|
|
|
|
|
|
|
|
with open(hashpath, "w") as f:
|
|
|
|
f.write(hash)
|
|
|
|
return hash
|
|
|
|
|
|
|
|
def _local_model_hash(self, model_path: Union[str,Path])->str:
|
|
|
|
sha = hashlib.sha256()
|
|
|
|
path = Path(model_path)
|
|
|
|
|
|
|
|
hashpath = path / "checksum.sha256"
|
|
|
|
if hashpath.exists() and path.stat().st_mtime <= hashpath.stat().st_mtime:
|
|
|
|
with open(hashpath) as f:
|
|
|
|
hash = f.read()
|
|
|
|
return hash
|
|
|
|
|
|
|
|
print(f' | computing hash of model {path.name}')
|
|
|
|
for file in list(path.rglob("*.ckpt")) \
|
|
|
|
+ list(path.rglob("*.safetensors")) \
|
|
|
|
+ list(path.rglob("*.pth")):
|
|
|
|
with open(file, "rb") as f:
|
|
|
|
while chunk := f.read(self.sha_chunksize):
|
|
|
|
sha.update(chunk)
|
|
|
|
hash = sha.hexdigest()
|
|
|
|
with open(hashpath, "w") as f:
|
|
|
|
f.write(hash)
|
|
|
|
return hash
|
|
|
|
|
|
|
|
def _hf_commit_hash(self, repo_id: str, revision: str='main')->str:
|
|
|
|
api = HfApi()
|
|
|
|
info = api.list_repo_refs(
|
|
|
|
repo_id=repo_id,
|
|
|
|
repo_type='model',
|
|
|
|
)
|
|
|
|
desired_revisions = [branch for branch in info.branches if branch.name==revision]
|
|
|
|
if not desired_revisions:
|
|
|
|
raise KeyError(f"Revision '{revision}' not found in {repo_id}")
|
|
|
|
return desired_revisions[0].target_commit
|
|
|
|
|
2023-04-28 04:41:52 +00:00
|
|
|
class SilenceWarnings(object):
|
|
|
|
def __init__(self):
|
|
|
|
self.transformers_verbosity = transformers_logging.get_verbosity()
|
|
|
|
self.diffusers_verbosity = diffusers_logging.get_verbosity()
|
|
|
|
|
|
|
|
def __enter__(self):
|
|
|
|
transformers_logging.set_verbosity_error()
|
|
|
|
diffusers_logging.set_verbosity_error()
|
|
|
|
warnings.simplefilter('ignore')
|
|
|
|
|
|
|
|
def __exit__(self,type,value,traceback):
|
|
|
|
transformers_logging.set_verbosity(self.transformers_verbosity)
|
|
|
|
diffusers_logging.set_verbosity(self.diffusers_verbosity)
|
|
|
|
warnings.simplefilter('default')
|