2023-04-28 04:41:52 +00:00
|
|
|
"""
|
2023-05-03 16:38:18 +00:00
|
|
|
Manage a RAM cache of diffusion/transformer models for fast switching.
|
|
|
|
They are moved between GPU VRAM and CPU RAM as necessary. If the cache
|
2023-04-28 04:41:52 +00:00
|
|
|
grows larger than a preset maximum, then the least recently used
|
|
|
|
model will be cleared and (re)loaded from disk when next needed.
|
2023-05-03 16:38:18 +00:00
|
|
|
|
|
|
|
The cache returns context manager generators designed to load the
|
|
|
|
model into the GPU within the context, and unload outside the
|
|
|
|
context. Use like this:
|
|
|
|
|
|
|
|
cache = ModelCache(max_models_cached=6)
|
|
|
|
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1,
|
|
|
|
cache.get_model('stabilityai/stable-diffusion-2') as SD2:
|
|
|
|
do_something_in_GPU(SD1,SD2)
|
|
|
|
|
|
|
|
|
2023-04-28 04:41:52 +00:00
|
|
|
"""
|
|
|
|
|
|
|
|
import contextlib
|
|
|
|
import gc
|
2023-05-03 16:38:18 +00:00
|
|
|
import hashlib
|
2023-04-28 04:41:52 +00:00
|
|
|
import warnings
|
2023-05-05 03:15:32 +00:00
|
|
|
from collections import Counter
|
2023-05-14 22:09:38 +00:00
|
|
|
from contextlib import suppress
|
2023-05-09 01:47:03 +00:00
|
|
|
from enum import Enum
|
2023-04-28 04:41:52 +00:00
|
|
|
from pathlib import Path
|
2023-05-14 22:09:38 +00:00
|
|
|
from typing import Dict, Sequence, Union, Set, Tuple, types, Optional
|
2023-04-28 04:41:52 +00:00
|
|
|
|
2023-05-03 16:38:18 +00:00
|
|
|
import torch
|
2023-05-06 19:58:44 +00:00
|
|
|
import safetensors.torch
|
2023-05-14 22:09:38 +00:00
|
|
|
|
2023-05-14 00:06:26 +00:00
|
|
|
from diffusers import DiffusionPipeline, StableDiffusionPipeline, AutoencoderKL, SchedulerMixin, UNet2DConditionModel, ConfigMixin
|
2023-05-03 16:38:18 +00:00
|
|
|
from diffusers import logging as diffusers_logging
|
|
|
|
from diffusers.pipelines.stable_diffusion.safety_checker import \
|
|
|
|
StableDiffusionSafetyChecker
|
|
|
|
from huggingface_hub import HfApi
|
|
|
|
from picklescan.scanner import scan_file_path
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
|
|
|
from transformers import logging as transformers_logging
|
|
|
|
|
2023-05-05 03:15:32 +00:00
|
|
|
import invokeai.backend.util.logging as logger
|
2023-05-03 16:38:18 +00:00
|
|
|
from ..globals import global_cache_dir
|
|
|
|
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
2023-04-28 04:41:52 +00:00
|
|
|
|
2023-05-07 22:07:28 +00:00
|
|
|
# Maximum size of the cache, in gigs
|
|
|
|
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
|
|
|
|
DEFAULT_MAX_CACHE_SIZE = 6.0
|
|
|
|
|
|
|
|
# actual size of a gig
|
|
|
|
GIG = 1073741824
|
2023-05-02 02:57:30 +00:00
|
|
|
|
|
|
|
# This is the mapping from the stable diffusion submodel dict key to the class
|
2023-05-08 03:18:17 +00:00
|
|
|
class LoraType(dict):
|
|
|
|
pass
|
|
|
|
class TIType(dict):
|
|
|
|
pass
|
|
|
|
|
2023-05-14 00:06:26 +00:00
|
|
|
class SDModelType(str, Enum):
|
|
|
|
Diffusers="diffusers" # whole pipeline
|
|
|
|
Vae="vae" # diffusers parts
|
|
|
|
TextEncoder="text_encoder"
|
|
|
|
Tokenizer="tokenizer"
|
|
|
|
UNet="unet"
|
|
|
|
Scheduler="scheduler"
|
|
|
|
SafetyChecker="safety_checker"
|
|
|
|
FeatureExtractor="feature_extractor"
|
2023-05-08 03:18:17 +00:00
|
|
|
# These are all loaded as dicts of tensors, and we
|
|
|
|
# distinguish them by class
|
2023-05-14 00:06:26 +00:00
|
|
|
Lora="lora"
|
|
|
|
TextualInversion="textual_inversion"
|
|
|
|
|
|
|
|
# TODO:
|
|
|
|
class EmptyScheduler(SchedulerMixin, ConfigMixin):
|
|
|
|
pass
|
|
|
|
|
|
|
|
MODEL_CLASSES = {
|
|
|
|
SDModelType.Diffusers: StableDiffusionGeneratorPipeline,
|
|
|
|
SDModelType.Vae: AutoencoderKL,
|
|
|
|
SDModelType.TextEncoder: CLIPTextModel, # TODO: t5
|
|
|
|
SDModelType.Tokenizer: CLIPTokenizer, # TODO: t5
|
|
|
|
SDModelType.UNet: UNet2DConditionModel,
|
|
|
|
SDModelType.Scheduler: EmptyScheduler,
|
|
|
|
SDModelType.SafetyChecker: StableDiffusionSafetyChecker,
|
|
|
|
SDModelType.FeatureExtractor: CLIPFeatureExtractor,
|
|
|
|
|
|
|
|
SDModelType.Lora: LoraType,
|
|
|
|
SDModelType.TextualInversion: TIType,
|
|
|
|
}
|
|
|
|
|
2023-05-14 22:09:38 +00:00
|
|
|
DIFFUSERS_PARTS = {
|
|
|
|
SDModelType.Vae,
|
|
|
|
SDModelType.TextEncoder,
|
|
|
|
SDModelType.Tokenizer,
|
|
|
|
SDModelType.UNet,
|
|
|
|
SDModelType.Scheduler,
|
|
|
|
SDModelType.SafetyChecker,
|
|
|
|
SDModelType.FeatureExtractor,
|
|
|
|
}
|
|
|
|
|
2023-05-05 23:32:28 +00:00
|
|
|
class ModelStatus(Enum):
|
|
|
|
unknown='unknown'
|
|
|
|
not_loaded='not loaded'
|
|
|
|
in_ram='cached'
|
|
|
|
in_vram='in gpu'
|
|
|
|
active='locked in gpu'
|
2023-05-07 22:07:28 +00:00
|
|
|
|
|
|
|
# This is used to guesstimate the size of a model before we load it.
|
|
|
|
# After loading, we will know it exactly.
|
|
|
|
# Sizes are in Gigs, estimated for float16; double for float32
|
|
|
|
SIZE_GUESSTIMATE = {
|
2023-05-14 00:06:26 +00:00
|
|
|
SDModelType.Diffusers: 2.2,
|
|
|
|
SDModelType.Vae: 0.35,
|
|
|
|
SDModelType.TextEncoder: 0.5,
|
|
|
|
SDModelType.Tokenizer: 0.001,
|
|
|
|
SDModelType.UNet: 3.4,
|
|
|
|
SDModelType.Scheduler: 0.001,
|
|
|
|
SDModelType.SafetyChecker: 1.2,
|
|
|
|
SDModelType.FeatureExtractor: 0.001,
|
|
|
|
SDModelType.Lora: 0.1,
|
|
|
|
SDModelType.TextualInversion: 0.001,
|
2023-05-07 22:07:28 +00:00
|
|
|
}
|
2023-04-28 04:41:52 +00:00
|
|
|
|
2023-05-02 20:52:27 +00:00
|
|
|
# The list of model classes we know how to fetch, for typechecking
|
2023-05-14 00:06:26 +00:00
|
|
|
ModelClass = Union[tuple([x for x in MODEL_CLASSES.values()])]
|
2023-05-14 00:46:13 +00:00
|
|
|
DiffusionClasses = (StableDiffusionGeneratorPipeline, AutoencoderKL, EmptyScheduler, UNet2DConditionModel, CLIPTextModel)
|
2023-05-02 02:57:30 +00:00
|
|
|
|
2023-05-03 16:38:18 +00:00
|
|
|
class UnsafeModelException(Exception):
|
|
|
|
"Raised when a legacy model file fails the picklescan test"
|
|
|
|
pass
|
|
|
|
|
|
|
|
class UnscannableModelException(Exception):
|
|
|
|
"Raised when picklescan is unable to scan a legacy model file"
|
|
|
|
pass
|
2023-05-02 20:52:27 +00:00
|
|
|
|
2023-05-06 04:44:12 +00:00
|
|
|
class ModelLocker(object):
|
|
|
|
"Forward declaration"
|
|
|
|
pass
|
|
|
|
|
2023-04-28 04:41:52 +00:00
|
|
|
class ModelCache(object):
|
|
|
|
def __init__(
|
2023-05-14 00:06:26 +00:00
|
|
|
self,
|
|
|
|
max_cache_size: float=DEFAULT_MAX_CACHE_SIZE,
|
|
|
|
execution_device: torch.device=torch.device('cuda'),
|
|
|
|
storage_device: torch.device=torch.device('cpu'),
|
|
|
|
precision: torch.dtype=torch.float16,
|
|
|
|
sequential_offload: bool=False,
|
|
|
|
lazy_offloading: bool=True,
|
|
|
|
sha_chunksize: int = 16777216,
|
|
|
|
logger: types.ModuleType = logger
|
2023-04-28 04:41:52 +00:00
|
|
|
):
|
2023-05-02 20:52:27 +00:00
|
|
|
'''
|
2023-05-03 16:38:18 +00:00
|
|
|
:param max_models: Maximum number of models to cache in CPU RAM [4]
|
2023-05-02 20:52:27 +00:00
|
|
|
:param execution_device: Torch device to load active model into [torch.device('cuda')]
|
2023-05-03 16:38:18 +00:00
|
|
|
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
|
2023-05-02 20:52:27 +00:00
|
|
|
:param precision: Precision for loaded models [torch.float16]
|
2023-05-05 03:15:32 +00:00
|
|
|
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
|
2023-05-02 20:52:27 +00:00
|
|
|
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
|
|
|
|
:param sha_chunksize: Chunksize to use when calculating sha256 model hash
|
|
|
|
'''
|
2023-04-28 04:41:52 +00:00
|
|
|
self.models: dict = dict()
|
|
|
|
self.stack: Sequence = list()
|
2023-05-05 03:15:32 +00:00
|
|
|
self.lazy_offloading = lazy_offloading
|
2023-04-28 04:41:52 +00:00
|
|
|
self.sequential_offload: bool=sequential_offload
|
|
|
|
self.precision: torch.dtype=precision
|
2023-05-07 22:07:28 +00:00
|
|
|
self.current_cache_size: int=0
|
|
|
|
self.max_cache_size: int=max_cache_size
|
2023-05-03 16:38:18 +00:00
|
|
|
self.execution_device: torch.device=execution_device
|
|
|
|
self.storage_device: torch.device=storage_device
|
2023-05-02 20:52:27 +00:00
|
|
|
self.sha_chunksize=sha_chunksize
|
2023-05-05 23:32:28 +00:00
|
|
|
self.logger = logger
|
2023-05-05 03:15:32 +00:00
|
|
|
self.loaded_models: set = set() # set of model keys loaded in GPU
|
|
|
|
self.locked_models: Counter = Counter() # set of model keys locked in GPU
|
2023-05-07 22:07:28 +00:00
|
|
|
self.model_sizes: Dict[str,int] = dict()
|
2023-04-28 04:41:52 +00:00
|
|
|
|
|
|
|
def get_model(
|
2023-05-14 00:06:26 +00:00
|
|
|
self,
|
|
|
|
repo_id_or_path: Union[str, Path],
|
|
|
|
model_type: SDModelType = SDModelType.Diffusers,
|
|
|
|
subfolder: Path = None,
|
|
|
|
submodel: SDModelType = None,
|
|
|
|
revision: str = None,
|
2023-05-14 22:09:38 +00:00
|
|
|
attach_model_parts: Optional[Set[Tuple[SDModelType, str]]] = None,
|
2023-05-14 00:06:26 +00:00
|
|
|
gpu_load: bool = True,
|
|
|
|
) -> ModelLocker: # ?? what does it return
|
2023-04-28 04:41:52 +00:00
|
|
|
'''
|
2023-05-03 16:38:18 +00:00
|
|
|
Load and return a HuggingFace model wrapped in a context manager generator, with RAM caching.
|
|
|
|
Use like this:
|
|
|
|
|
|
|
|
cache = ModelCache()
|
2023-05-08 01:39:11 +00:00
|
|
|
with cache.get_model('stabilityai/stable-diffusion-2') as model:
|
|
|
|
do_something_with_the_model(model)
|
|
|
|
|
|
|
|
While in context, model will be locked into GPU. If you want to do something
|
|
|
|
with the model while it is in RAM, just use the context's `model` attribute:
|
|
|
|
|
|
|
|
context = cache.get_model('stabilityai/stable-diffusion-2')
|
|
|
|
context.model.device
|
|
|
|
# device(type='cpu')
|
|
|
|
|
|
|
|
with context as model:
|
|
|
|
model.device
|
|
|
|
# device(type='cuda')
|
2023-05-03 16:38:18 +00:00
|
|
|
|
2023-05-05 23:32:28 +00:00
|
|
|
You can fetch an individual part of a diffusers model by passing the submodel
|
|
|
|
argument:
|
|
|
|
|
|
|
|
vae_context = cache.get_model(
|
|
|
|
'stabilityai/sd-stable-diffusion-2',
|
2023-05-14 00:06:26 +00:00
|
|
|
submodel=SDModelType.Vae
|
2023-05-05 23:32:28 +00:00
|
|
|
)
|
|
|
|
|
2023-05-08 01:39:11 +00:00
|
|
|
This is equivalent to:
|
|
|
|
|
|
|
|
vae_context = cache.get_model(
|
|
|
|
'stabilityai/sd-stable-diffusion-2',
|
2023-05-14 00:06:26 +00:00
|
|
|
model_type = SDModelType.Vae,
|
2023-05-08 01:39:11 +00:00
|
|
|
subfolder='vae'
|
|
|
|
)
|
|
|
|
|
2023-05-05 23:32:28 +00:00
|
|
|
Vice versa, you can load and attach an external submodel to a diffusers model
|
|
|
|
before returning it by passing the attach_submodel argument. This only works with
|
|
|
|
diffusers models:
|
|
|
|
|
|
|
|
pipeline_context = cache.get_model(
|
|
|
|
'runwayml/stable-diffusion-v1-5',
|
2023-05-14 22:09:38 +00:00
|
|
|
attach_model_parts=set(
|
|
|
|
[SDModelType.Vae,'stabilityai/sd-vae-ft-mse']
|
|
|
|
[SDModelType.UNet,'runwayml/stable-diffusion-1.5','unet'] #type, ID, subfolder
|
2023-05-05 23:32:28 +00:00
|
|
|
)
|
2023-05-14 22:09:38 +00:00
|
|
|
)
|
2023-05-05 23:32:28 +00:00
|
|
|
|
2023-05-03 16:38:18 +00:00
|
|
|
The model will be locked into GPU VRAM for the duration of the context.
|
2023-04-28 04:41:52 +00:00
|
|
|
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
|
2023-05-10 02:44:58 +00:00
|
|
|
:param model_type: An SDModelType enum indicating the type of the (parent) model
|
2023-04-28 04:41:52 +00:00
|
|
|
:param subfolder: name of a subfolder in which the model can be found, e.g. "vae"
|
2023-05-14 00:06:26 +00:00
|
|
|
:param submodel: an SDModelType enum indicating the model part to return, e.g. SDModelType.Vae
|
2023-05-14 22:09:38 +00:00
|
|
|
:param attach_model_parts: load and attach a diffusers model component. Pass a set of tuple of format (SDModelType,repo_id_or_path,subfolder)
|
2023-04-28 04:41:52 +00:00
|
|
|
:param revision: model revision
|
2023-05-03 16:38:18 +00:00
|
|
|
:param gpu_load: load the model into GPU [default True]
|
2023-04-28 04:41:52 +00:00
|
|
|
'''
|
2023-05-02 20:52:27 +00:00
|
|
|
key = self._model_key( # internal unique identifier for the model
|
2023-05-02 02:57:30 +00:00
|
|
|
repo_id_or_path,
|
|
|
|
revision,
|
2023-05-08 01:39:11 +00:00
|
|
|
subfolder,
|
2023-05-14 00:06:26 +00:00
|
|
|
model_type,
|
2023-05-07 22:07:28 +00:00
|
|
|
)
|
2023-05-08 01:39:11 +00:00
|
|
|
|
|
|
|
# optimization: if caller is asking to load a submodel of a diffusers pipeline, then
|
|
|
|
# check whether it is already cached in RAM and return it instead of loading from disk again
|
|
|
|
if subfolder and not submodel:
|
|
|
|
possible_parent_key = self._model_key(
|
|
|
|
repo_id_or_path,
|
|
|
|
revision,
|
2023-05-14 01:11:06 +00:00
|
|
|
None,
|
2023-05-14 00:06:26 +00:00
|
|
|
SDModelType.Diffusers
|
2023-05-08 01:39:11 +00:00
|
|
|
)
|
|
|
|
if possible_parent_key in self.models:
|
|
|
|
key = possible_parent_key
|
2023-05-14 00:06:26 +00:00
|
|
|
submodel = model_type
|
2023-05-08 01:39:11 +00:00
|
|
|
|
|
|
|
# Look for the model in the cache RAM
|
|
|
|
if key in self.models: # cached - move to bottom of stack (most recently used)
|
2023-04-28 04:41:52 +00:00
|
|
|
with contextlib.suppress(ValueError):
|
|
|
|
self.stack.remove(key)
|
|
|
|
self.stack.append(key)
|
2023-05-03 16:38:18 +00:00
|
|
|
model = self.models[key]
|
2023-05-07 22:07:28 +00:00
|
|
|
|
2023-04-28 04:41:52 +00:00
|
|
|
else: # not cached -load
|
2023-05-07 22:07:28 +00:00
|
|
|
self.logger.info(f'Loading model {repo_id_or_path}, type {model_type}')
|
|
|
|
|
|
|
|
# this will remove older cached models until
|
|
|
|
# there is sufficient room to load the requested model
|
|
|
|
self._make_cache_room(key, model_type)
|
2023-05-07 23:06:49 +00:00
|
|
|
|
|
|
|
# clean memory to make MemoryUsage() more accurate
|
|
|
|
gc.collect()
|
2023-05-09 01:47:03 +00:00
|
|
|
model = self._load_model_from_storage(
|
|
|
|
repo_id_or_path=repo_id_or_path,
|
2023-05-10 02:44:58 +00:00
|
|
|
model_type=model_type,
|
2023-05-09 01:47:03 +00:00
|
|
|
subfolder=subfolder,
|
|
|
|
revision=revision,
|
|
|
|
)
|
2023-05-14 20:45:40 +00:00
|
|
|
|
2023-05-09 01:47:03 +00:00
|
|
|
if mem_used := self.calc_model_size(model):
|
|
|
|
logger.debug(f'CPU RAM used for load: {(mem_used/GIG):.2f} GB')
|
|
|
|
self.model_sizes[key] = mem_used # remember size of this model for cache cleansing
|
|
|
|
self.current_cache_size += mem_used # increment size of the cache
|
2023-05-14 20:45:40 +00:00
|
|
|
|
2023-05-08 01:39:11 +00:00
|
|
|
# this is a bit of legacy work needed to support the old-style "load this diffuser with custom VAE"
|
2023-05-14 22:09:38 +00:00
|
|
|
if model_type == SDModelType.Diffusers and attach_model_parts:
|
|
|
|
for attach_model_part in attach_model_parts:
|
|
|
|
self.attach_part(model, *attach_model_part)
|
2023-05-14 20:45:40 +00:00
|
|
|
|
2023-04-28 04:41:52 +00:00
|
|
|
self.stack.append(key) # add to LRU cache
|
2023-05-14 00:06:26 +00:00
|
|
|
self.models[key] = model # keep copy of model in dict
|
2023-05-03 16:38:18 +00:00
|
|
|
|
|
|
|
if submodel:
|
2023-05-14 00:06:26 +00:00
|
|
|
model = getattr(model, submodel)
|
2023-05-05 03:15:32 +00:00
|
|
|
|
2023-05-06 04:44:12 +00:00
|
|
|
return self.ModelLocker(self, key, model, gpu_load)
|
|
|
|
|
2023-05-06 19:58:44 +00:00
|
|
|
def uncache_model(self, key: str):
|
|
|
|
'''Remove corresponding model from the cache'''
|
|
|
|
if key is not None and key in self.models:
|
2023-05-14 00:06:26 +00:00
|
|
|
self.models.pop(key, None)
|
|
|
|
self.locked_models.pop(key, None)
|
|
|
|
self.loaded_models.discard(key)
|
|
|
|
with contextlib.suppress(ValueError):
|
2023-05-07 22:07:28 +00:00
|
|
|
self.stack.remove(key)
|
2023-05-06 19:58:44 +00:00
|
|
|
|
2023-05-06 04:44:12 +00:00
|
|
|
class ModelLocker(object):
|
|
|
|
def __init__(self, cache, key, model, gpu_load):
|
|
|
|
self.gpu_load = gpu_load
|
|
|
|
self.cache = cache
|
|
|
|
self.key = key
|
|
|
|
# This will keep a copy of the model in RAM until the locker
|
|
|
|
# is garbage collected. Needs testing!
|
|
|
|
self.model = model
|
|
|
|
|
|
|
|
def __enter__(self)->ModelClass:
|
|
|
|
cache = self.cache
|
|
|
|
key = self.key
|
|
|
|
model = self.model
|
2023-05-10 03:46:59 +00:00
|
|
|
|
|
|
|
# NOTE that the model has to have the to() method in order for this
|
|
|
|
# code to move it into GPU!
|
2023-05-06 04:44:12 +00:00
|
|
|
if self.gpu_load and hasattr(model,'to'):
|
|
|
|
cache.loaded_models.add(key)
|
|
|
|
cache.locked_models[key] += 1
|
2023-05-10 03:46:59 +00:00
|
|
|
|
2023-05-06 04:44:12 +00:00
|
|
|
if cache.lazy_offloading:
|
|
|
|
cache._offload_unlocked_models()
|
2023-05-10 03:46:59 +00:00
|
|
|
|
2023-05-14 22:29:55 +00:00
|
|
|
if model.device != cache.execution_device and \
|
|
|
|
not (self.cache.sequential_offload \
|
|
|
|
and isinstance(model, StableDiffusionGeneratorPipeline)
|
|
|
|
):
|
|
|
|
|
2023-05-08 03:18:17 +00:00
|
|
|
cache.logger.debug(f'Moving {key} into {cache.execution_device}')
|
2023-05-09 01:47:03 +00:00
|
|
|
with VRAMUsage() as mem:
|
2023-05-14 22:09:38 +00:00
|
|
|
self._to(model,cache.execution_device)
|
|
|
|
|
2023-05-14 22:29:55 +00:00
|
|
|
self.cache.logger.debug(f'Locked {key} in {cache.execution_device}')
|
2023-05-09 01:47:03 +00:00
|
|
|
cache.logger.debug(f'GPU VRAM used for load: {(mem.vram_used/GIG):.2f} GB')
|
|
|
|
cache.model_sizes[key] = mem.vram_used # more accurate size
|
2023-05-10 03:46:59 +00:00
|
|
|
|
2023-05-06 04:44:12 +00:00
|
|
|
cache._print_cuda_stats()
|
2023-05-10 03:46:59 +00:00
|
|
|
|
2023-05-06 04:44:12 +00:00
|
|
|
else:
|
|
|
|
# in the event that the caller wants the model in RAM, we
|
|
|
|
# move it into CPU if it is in GPU and not locked
|
2023-05-14 00:06:26 +00:00
|
|
|
if hasattr(model, 'to') and (key in cache.loaded_models
|
2023-05-06 04:44:12 +00:00
|
|
|
and cache.locked_models[key] == 0):
|
2023-05-14 22:29:55 +00:00
|
|
|
self._to(model,cache.storage_device)
|
2023-05-14 22:09:38 +00:00
|
|
|
# model.to(cache.storage_device)
|
2023-05-06 04:44:12 +00:00
|
|
|
cache.loaded_models.remove(key)
|
|
|
|
return model
|
|
|
|
|
|
|
|
def __exit__(self, type, value, traceback):
|
2023-05-14 00:06:26 +00:00
|
|
|
if not hasattr(self.model, 'to'):
|
|
|
|
return
|
|
|
|
|
2023-05-06 04:44:12 +00:00
|
|
|
key = self.key
|
|
|
|
cache = self.cache
|
|
|
|
cache.locked_models[key] -= 1
|
|
|
|
if not cache.lazy_offloading:
|
|
|
|
cache._offload_unlocked_models()
|
|
|
|
cache._print_cuda_stats()
|
2023-05-05 03:15:32 +00:00
|
|
|
|
2023-05-14 22:09:38 +00:00
|
|
|
def _to(self, model, device):
|
|
|
|
model.to(device)
|
|
|
|
if isinstance(model,MODEL_CLASSES[SDModelType.Diffusers]):
|
|
|
|
for part in DIFFUSERS_PARTS:
|
|
|
|
with suppress(Exception):
|
|
|
|
getattr(model,part).to(device)
|
|
|
|
|
2023-05-14 00:06:26 +00:00
|
|
|
def attach_part(
|
|
|
|
self,
|
|
|
|
diffusers_model: StableDiffusionPipeline,
|
|
|
|
part_type: SDModelType,
|
|
|
|
part_id: str,
|
2023-05-14 20:45:40 +00:00
|
|
|
subfolder: Optional[str] = None
|
2023-05-14 00:06:26 +00:00
|
|
|
):
|
2023-05-05 23:32:28 +00:00
|
|
|
'''
|
|
|
|
Attach a diffusers model part to a diffusers model. This can be
|
|
|
|
used to replace the VAE, tokenizer, textencoder, unet, etc.
|
|
|
|
:param diffuser_model: The diffusers model to attach the part to.
|
|
|
|
:param part_type: An SD ModelType indicating the part
|
|
|
|
:param part_id: A HF repo_id for the part
|
|
|
|
'''
|
|
|
|
part = self._load_diffusers_from_storage(
|
|
|
|
part_id,
|
2023-05-14 20:45:40 +00:00
|
|
|
model_type=part_type,
|
|
|
|
subfolder=subfolder,
|
2023-05-05 23:32:28 +00:00
|
|
|
)
|
2023-05-14 22:09:38 +00:00
|
|
|
if hasattr(part,'to'):
|
|
|
|
part.to(diffusers_model.device)
|
2023-05-14 00:06:26 +00:00
|
|
|
setattr(diffusers_model, part_type, part)
|
|
|
|
self.logger.debug(f'Attached {part_type} {part_id}')
|
|
|
|
|
|
|
|
def status(
|
|
|
|
self,
|
|
|
|
repo_id_or_path: Union[str, Path],
|
|
|
|
model_type: SDModelType = SDModelType.Diffusers,
|
|
|
|
revision: str = None,
|
|
|
|
subfolder: Path = None,
|
|
|
|
) -> ModelStatus:
|
2023-05-05 23:32:28 +00:00
|
|
|
key = self._model_key(
|
|
|
|
repo_id_or_path,
|
|
|
|
revision,
|
2023-05-09 03:39:44 +00:00
|
|
|
subfolder,
|
2023-05-14 00:06:26 +00:00
|
|
|
model_type,
|
2023-05-09 03:39:44 +00:00
|
|
|
)
|
2023-05-05 23:32:28 +00:00
|
|
|
if key not in self.models:
|
|
|
|
return ModelStatus.not_loaded
|
|
|
|
if key in self.loaded_models:
|
|
|
|
if self.locked_models[key] > 0:
|
|
|
|
return ModelStatus.active
|
|
|
|
else:
|
|
|
|
return ModelStatus.in_vram
|
|
|
|
else:
|
|
|
|
return ModelStatus.in_ram
|
2023-04-28 04:41:52 +00:00
|
|
|
|
2023-05-14 00:06:26 +00:00
|
|
|
def model_hash(
|
|
|
|
self,
|
|
|
|
repo_id_or_path: Union[str, Path],
|
|
|
|
revision: str = "main",
|
|
|
|
) -> str:
|
2023-05-02 20:52:27 +00:00
|
|
|
'''
|
|
|
|
Given the HF repo id or path to a model on disk, returns a unique
|
|
|
|
hash. Works for legacy checkpoint files, HF models on disk, and HF repo IDs
|
|
|
|
:param repo_id_or_path: repo_id string or Path to model file/directory on disk.
|
|
|
|
:param revision: optional revision string (if fetching a HF repo_id)
|
|
|
|
'''
|
2023-05-05 23:32:28 +00:00
|
|
|
revision = revision or "main"
|
2023-05-09 03:39:44 +00:00
|
|
|
if Path(repo_id_or_path).is_dir():
|
2023-05-02 20:52:27 +00:00
|
|
|
return self._local_model_hash(repo_id_or_path)
|
|
|
|
else:
|
|
|
|
return self._hf_commit_hash(repo_id_or_path,revision)
|
|
|
|
|
2023-05-14 00:06:26 +00:00
|
|
|
def cache_size(self) -> float:
|
2023-05-07 22:07:28 +00:00
|
|
|
"Return the current size of the cache, in GB"
|
|
|
|
return self.current_cache_size / GIG
|
2023-05-03 16:38:18 +00:00
|
|
|
|
2023-05-05 23:32:28 +00:00
|
|
|
@classmethod
|
|
|
|
def scan_model(cls, model_name, checkpoint):
|
|
|
|
"""
|
|
|
|
Apply picklescanner to the indicated checkpoint and issue a warning
|
|
|
|
and option to exit if an infected file is identified.
|
|
|
|
"""
|
|
|
|
# scan model
|
|
|
|
logger.debug(f"Scanning Model: {model_name}")
|
|
|
|
scan_result = scan_file_path(checkpoint)
|
|
|
|
if scan_result.infected_files != 0:
|
|
|
|
if scan_result.infected_files == 1:
|
|
|
|
raise UnsafeModelException("The legacy model you are trying to load may contain malware. Aborting.")
|
|
|
|
else:
|
|
|
|
raise UnscannableModelException("InvokeAI was unable to scan the legacy model you requested. Aborting")
|
|
|
|
else:
|
|
|
|
logger.debug("Model scanned ok")
|
|
|
|
|
2023-04-28 04:41:52 +00:00
|
|
|
@staticmethod
|
2023-05-14 00:06:26 +00:00
|
|
|
def _model_key(path, revision, subfolder, model_class) -> str:
|
|
|
|
return ':'.join([
|
|
|
|
str(path),
|
|
|
|
str(revision or ''),
|
|
|
|
str(subfolder or ''),
|
|
|
|
model_class,
|
|
|
|
])
|
|
|
|
|
|
|
|
def _has_cuda(self) -> bool:
|
2023-05-03 16:38:18 +00:00
|
|
|
return self.execution_device.type == 'cuda'
|
|
|
|
|
|
|
|
def _print_cuda_stats(self):
|
2023-05-07 22:07:28 +00:00
|
|
|
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
|
|
|
|
ram = "%4.2fG" % (self.current_cache_size / GIG)
|
2023-05-14 22:29:55 +00:00
|
|
|
cached_models = len(self.models)
|
2023-05-05 03:15:32 +00:00
|
|
|
loaded_models = len(self.loaded_models)
|
|
|
|
locked_models = len([x for x in self.locked_models if self.locked_models[x]>0])
|
2023-05-14 22:29:55 +00:00
|
|
|
logger.debug(f"Current VRAM/RAM usage: {vram}/{ram}; cached_models/loaded_models/locked_models = {cached_models}/{loaded_models}/{locked_models}")
|
2023-05-07 22:07:28 +00:00
|
|
|
|
|
|
|
def _make_cache_room(self, key, model_type):
|
|
|
|
# calculate how much memory this model will require
|
|
|
|
multiplier = 2 if self.precision==torch.float32 else 1
|
2023-05-08 03:18:17 +00:00
|
|
|
bytes_needed = int(self.model_sizes.get(key,0) or SIZE_GUESSTIMATE.get(model_type,0.5)*GIG*multiplier)
|
2023-05-07 22:07:28 +00:00
|
|
|
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
|
|
|
|
current_size = self.current_cache_size
|
|
|
|
|
|
|
|
adjective = 'guesstimated' if key not in self.model_sizes else 'known from previous load'
|
|
|
|
logger.debug(f'{(bytes_needed/GIG):.2f} GB needed to load this model ({adjective})')
|
|
|
|
while current_size+bytes_needed > maximum_size:
|
2023-04-28 04:41:52 +00:00
|
|
|
if least_recently_used_key := self.stack.pop(0):
|
2023-05-07 22:07:28 +00:00
|
|
|
model_size = self.model_sizes.get(least_recently_used_key,0)
|
|
|
|
logger.debug(f'Max cache size exceeded: cache_size={(current_size/GIG):.2f} GB, need an additional {(bytes_needed/GIG):.2f} GB')
|
|
|
|
logger.debug(f'Unloading model {least_recently_used_key} to free {(model_size/GIG):.2f} GB')
|
|
|
|
self.uncache_model(least_recently_used_key)
|
|
|
|
current_size -= model_size
|
|
|
|
self.current_cache_size = current_size
|
2023-04-28 04:41:52 +00:00
|
|
|
gc.collect()
|
|
|
|
|
2023-05-05 23:32:28 +00:00
|
|
|
def _offload_unlocked_models(self):
|
|
|
|
to_offload = set()
|
|
|
|
for key in self.loaded_models:
|
|
|
|
if key not in self.locked_models or self.locked_models[key] == 0:
|
|
|
|
self.logger.debug(f'Offloading {key} from {self.execution_device} into {self.storage_device}')
|
|
|
|
to_offload.add(key)
|
|
|
|
for key in to_offload:
|
|
|
|
self.models[key].to(self.storage_device)
|
|
|
|
self.loaded_models.remove(key)
|
2023-04-28 04:41:52 +00:00
|
|
|
|
|
|
|
def _load_model_from_storage(
|
2023-05-14 00:06:26 +00:00
|
|
|
self,
|
|
|
|
repo_id_or_path: Union[str, Path],
|
|
|
|
subfolder: Optional[Path] = None,
|
|
|
|
revision: Optional[str] = None,
|
|
|
|
model_type: SDModelType = SDModelType.Diffusers,
|
|
|
|
) -> ModelClass:
|
2023-04-28 04:41:52 +00:00
|
|
|
'''
|
|
|
|
Load and return a HuggingFace model.
|
|
|
|
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
|
|
|
|
:param subfolder: name of a subfolder in which the model can be found, e.g. "vae"
|
|
|
|
:param revision: model revision
|
2023-05-14 00:06:26 +00:00
|
|
|
:param model_type: type of model to return, defaults to SDModelType.Diffusers
|
2023-04-28 04:41:52 +00:00
|
|
|
'''
|
2023-05-06 19:58:44 +00:00
|
|
|
# silence transformer and diffuser warnings
|
|
|
|
with SilenceWarnings():
|
2023-05-14 00:06:26 +00:00
|
|
|
if model_type==SDModelType.Lora:
|
2023-05-10 02:44:58 +00:00
|
|
|
model = self._load_lora_from_storage(repo_id_or_path)
|
2023-05-14 00:06:26 +00:00
|
|
|
elif model_type==SDModelType.TextualInversion:
|
2023-05-10 02:44:58 +00:00
|
|
|
model = self._load_ti_from_storage(repo_id_or_path)
|
|
|
|
else:
|
|
|
|
model = self._load_diffusers_from_storage(
|
|
|
|
repo_id_or_path,
|
|
|
|
subfolder,
|
|
|
|
revision,
|
2023-05-14 00:06:26 +00:00
|
|
|
model_type,
|
2023-05-10 02:44:58 +00:00
|
|
|
)
|
2023-05-14 00:06:26 +00:00
|
|
|
if self.sequential_offload and isinstance(model, StableDiffusionGeneratorPipeline):
|
2023-05-10 02:44:58 +00:00
|
|
|
model.enable_offload_submodels(self.execution_device)
|
2023-04-28 04:41:52 +00:00
|
|
|
return model
|
|
|
|
|
2023-05-02 20:52:27 +00:00
|
|
|
def _load_diffusers_from_storage(
|
2023-05-14 00:06:26 +00:00
|
|
|
self,
|
|
|
|
repo_id_or_path: Union[str, Path],
|
|
|
|
subfolder: Optional[Path] = None,
|
|
|
|
revision: Optional[str] = None,
|
|
|
|
model_type: ModelClass = StableDiffusionGeneratorPipeline,
|
|
|
|
) -> ModelClass:
|
2023-05-02 20:52:27 +00:00
|
|
|
'''
|
|
|
|
Load and return a HuggingFace model using from_pretrained().
|
|
|
|
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
|
|
|
|
:param subfolder: name of a subfolder in which the model can be found, e.g. "vae"
|
|
|
|
:param revision: model revision
|
|
|
|
:param model_class: class of model to return, defaults to StableDiffusionGeneratorPIpeline
|
|
|
|
'''
|
2023-05-14 00:06:26 +00:00
|
|
|
|
|
|
|
model_class = MODEL_CLASSES[model_type]
|
|
|
|
|
|
|
|
if revision is not None:
|
|
|
|
revisions = [revision]
|
|
|
|
elif self.precision == torch.float16:
|
|
|
|
revisions = ['fp16', 'main']
|
|
|
|
else:
|
|
|
|
revisions = ['main']
|
|
|
|
|
|
|
|
extra_args = dict()
|
|
|
|
if model_class in DiffusionClasses:
|
2023-05-14 00:46:13 +00:00
|
|
|
extra_args.update(
|
2023-05-14 00:06:26 +00:00
|
|
|
torch_dtype=self.precision,
|
2023-05-14 00:46:13 +00:00
|
|
|
)
|
|
|
|
if model_class == StableDiffusionGeneratorPipeline:
|
|
|
|
extra_args.update(
|
2023-05-14 00:06:26 +00:00
|
|
|
safety_checker=None,
|
|
|
|
)
|
2023-05-14 00:46:13 +00:00
|
|
|
|
2023-05-06 19:58:44 +00:00
|
|
|
for rev in revisions:
|
|
|
|
try:
|
2023-05-14 00:06:26 +00:00
|
|
|
model = model_class.from_pretrained(
|
2023-05-06 19:58:44 +00:00
|
|
|
repo_id_or_path,
|
|
|
|
revision=rev,
|
|
|
|
subfolder=subfolder or '.',
|
|
|
|
cache_dir=global_cache_dir('hub'),
|
|
|
|
**extra_args,
|
|
|
|
)
|
|
|
|
self.logger.debug(f'Found revision {rev}')
|
|
|
|
break
|
|
|
|
except OSError:
|
|
|
|
pass
|
2023-05-05 23:32:28 +00:00
|
|
|
return model
|
2023-05-02 20:52:27 +00:00
|
|
|
|
2023-05-14 00:06:26 +00:00
|
|
|
def _load_lora_from_storage(self, lora_path: Path) -> LoraType:
|
|
|
|
assert False, "_load_lora_from_storage() is not yet implemented"
|
2023-05-10 02:44:58 +00:00
|
|
|
|
2023-05-14 00:06:26 +00:00
|
|
|
def _load_ti_from_storage(self, lora_path: Path) -> TIType:
|
|
|
|
assert False, "_load_ti_from_storage() is not yet implemented"
|
2023-05-02 20:52:27 +00:00
|
|
|
|
2023-05-14 00:06:26 +00:00
|
|
|
def _legacy_model_hash(self, checkpoint_path: Union[str, Path]) -> str:
|
2023-05-02 20:52:27 +00:00
|
|
|
sha = hashlib.sha256()
|
|
|
|
path = Path(checkpoint_path)
|
2023-05-06 19:58:44 +00:00
|
|
|
assert path.is_file(),f"File {checkpoint_path} not found"
|
2023-05-02 20:52:27 +00:00
|
|
|
|
|
|
|
hashpath = path.parent / f"{path.name}.sha256"
|
|
|
|
if hashpath.exists() and path.stat().st_mtime <= hashpath.stat().st_mtime:
|
|
|
|
with open(hashpath) as f:
|
|
|
|
hash = f.read()
|
|
|
|
return hash
|
|
|
|
|
2023-05-05 03:15:32 +00:00
|
|
|
logger.debug(f'computing hash of model {path.name}')
|
2023-05-02 20:52:27 +00:00
|
|
|
with open(path, "rb") as f:
|
|
|
|
while chunk := f.read(self.sha_chunksize):
|
|
|
|
sha.update(chunk)
|
|
|
|
hash = sha.hexdigest()
|
|
|
|
|
|
|
|
with open(hashpath, "w") as f:
|
|
|
|
f.write(hash)
|
|
|
|
return hash
|
|
|
|
|
2023-05-14 00:06:26 +00:00
|
|
|
def _local_model_hash(self, model_path: Union[str, Path]) -> str:
|
2023-05-02 20:52:27 +00:00
|
|
|
sha = hashlib.sha256()
|
|
|
|
path = Path(model_path)
|
|
|
|
|
|
|
|
hashpath = path / "checksum.sha256"
|
|
|
|
if hashpath.exists() and path.stat().st_mtime <= hashpath.stat().st_mtime:
|
|
|
|
with open(hashpath) as f:
|
|
|
|
hash = f.read()
|
|
|
|
return hash
|
|
|
|
|
2023-05-05 03:15:32 +00:00
|
|
|
logger.debug(f'computing hash of model {path.name}')
|
2023-05-02 20:52:27 +00:00
|
|
|
for file in list(path.rglob("*.ckpt")) \
|
|
|
|
+ list(path.rglob("*.safetensors")) \
|
|
|
|
+ list(path.rglob("*.pth")):
|
|
|
|
with open(file, "rb") as f:
|
|
|
|
while chunk := f.read(self.sha_chunksize):
|
|
|
|
sha.update(chunk)
|
|
|
|
hash = sha.hexdigest()
|
|
|
|
with open(hashpath, "w") as f:
|
|
|
|
f.write(hash)
|
|
|
|
return hash
|
|
|
|
|
2023-05-14 00:06:26 +00:00
|
|
|
def _hf_commit_hash(self, repo_id: str, revision: str='main') -> str:
|
2023-05-02 20:52:27 +00:00
|
|
|
api = HfApi()
|
|
|
|
info = api.list_repo_refs(
|
|
|
|
repo_id=repo_id,
|
|
|
|
repo_type='model',
|
|
|
|
)
|
|
|
|
desired_revisions = [branch for branch in info.branches if branch.name==revision]
|
|
|
|
if not desired_revisions:
|
|
|
|
raise KeyError(f"Revision '{revision}' not found in {repo_id}")
|
|
|
|
return desired_revisions[0].target_commit
|
|
|
|
|
2023-05-09 01:47:03 +00:00
|
|
|
@staticmethod
|
2023-05-14 00:06:26 +00:00
|
|
|
def calc_model_size(model) -> int:
|
2023-05-09 01:47:03 +00:00
|
|
|
if isinstance(model,DiffusionPipeline):
|
|
|
|
return ModelCache._calc_pipeline(model)
|
|
|
|
elif isinstance(model,torch.nn.Module):
|
|
|
|
return ModelCache._calc_model(model)
|
|
|
|
else:
|
|
|
|
return None
|
|
|
|
|
|
|
|
@staticmethod
|
2023-05-14 00:06:26 +00:00
|
|
|
def _calc_pipeline(pipeline) -> int:
|
2023-05-09 01:47:03 +00:00
|
|
|
res = 0
|
|
|
|
for submodel_key in pipeline.components.keys():
|
|
|
|
submodel = getattr(pipeline, submodel_key)
|
|
|
|
if submodel is not None and isinstance(submodel, torch.nn.Module):
|
|
|
|
res += ModelCache._calc_model(submodel)
|
|
|
|
return res
|
|
|
|
|
|
|
|
@staticmethod
|
2023-05-14 00:06:26 +00:00
|
|
|
def _calc_model(model) -> int:
|
2023-05-09 01:47:03 +00:00
|
|
|
mem_params = sum([param.nelement()*param.element_size() for param in model.parameters()])
|
|
|
|
mem_bufs = sum([buf.nelement()*buf.element_size() for buf in model.buffers()])
|
|
|
|
mem = mem_params + mem_bufs # in bytes
|
|
|
|
return mem
|
2023-05-05 23:32:28 +00:00
|
|
|
|
2023-04-28 04:41:52 +00:00
|
|
|
class SilenceWarnings(object):
|
|
|
|
def __init__(self):
|
|
|
|
self.transformers_verbosity = transformers_logging.get_verbosity()
|
|
|
|
self.diffusers_verbosity = diffusers_logging.get_verbosity()
|
|
|
|
|
|
|
|
def __enter__(self):
|
|
|
|
transformers_logging.set_verbosity_error()
|
|
|
|
diffusers_logging.set_verbosity_error()
|
|
|
|
warnings.simplefilter('ignore')
|
|
|
|
|
|
|
|
def __exit__(self,type,value,traceback):
|
|
|
|
transformers_logging.set_verbosity(self.transformers_verbosity)
|
|
|
|
diffusers_logging.set_verbosity(self.diffusers_verbosity)
|
|
|
|
warnings.simplefilter('default')
|
2023-05-07 22:07:28 +00:00
|
|
|
|
2023-05-09 01:47:03 +00:00
|
|
|
class VRAMUsage(object):
|
2023-05-07 22:07:28 +00:00
|
|
|
def __init__(self):
|
2023-05-09 01:47:03 +00:00
|
|
|
self.vram = None
|
|
|
|
self.vram_used = 0
|
2023-05-07 22:07:28 +00:00
|
|
|
|
|
|
|
def __enter__(self):
|
2023-05-09 01:47:03 +00:00
|
|
|
self.vram = torch.cuda.memory_allocated()
|
2023-05-07 22:07:28 +00:00
|
|
|
return self
|
|
|
|
|
|
|
|
def __exit__(self, *args):
|
2023-05-09 01:47:03 +00:00
|
|
|
self.vram_used = torch.cuda.memory_allocated() - self.vram
|