add redesigned model cache for diffusers & transformers

This commit is contained in:
Lincoln Stein 2023-04-28 00:41:52 -04:00
parent 4a924c9b54
commit 956ad6bcf5
4 changed files with 211 additions and 10 deletions

View File

@ -0,0 +1,203 @@
"""
Manage a cache of Stable Diffusion model files for fast switching.
They are moved between GPU and CPU as necessary. If the cache
grows larger than a preset maximum, then the least recently used
model will be cleared and (re)loaded from disk when next needed.
"""
import contextlib
import hashlib
import gc
import time
import os
import psutil
import safetensors
import safetensors.torch
import torch
import transformers
import warnings
from pathlib import Path
from diffusers import (
AutoencoderKL,
UNet2DConditionModel,
SchedulerMixin,
logging as diffusers_logging,
)
from transformers import(
CLIPTokenizer,
CLIPFeatureExtractor,
CLIPTextModel,
logging as transformers_logging,
)
from huggingface_hub import scan_cache_dir
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
from picklescan.scanner import scan_file_path
from typing import Sequence, Union
from invokeai.backend.globals import Globals, global_cache_dir
from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
)
from ..stable_diffusion import (
StableDiffusionGeneratorPipeline,
)
from ..stable_diffusion.offloading import ModelGroup, FullyLoadedModelGroup
from ..util import CUDA_DEVICE, ask_user, download_with_resume
MAX_MODELS_CACHED = 4
class ModelCache(object):
def __init__(
self,
max_models_cached: int=MAX_MODELS_CACHED,
execution_device: torch.device=torch.device('cuda'),
precision: torch.dtype=torch.float16,
sequential_offload: bool=False,
):
self.model_group: ModelGroup=FullyLoadedModelGroup(execution_device)
self.models: dict = dict()
self.stack: Sequence = list()
self.sequential_offload: bool=sequential_offload
self.precision: torch.dtype=precision
self.max_models_cached: int=max_models_cached
self.device: torch.device=execution_device
def get_model(
self,
repo_id_or_path: Union[str,Path],
model_class: type=StableDiffusionGeneratorPipeline,
subfolder: Path=None,
revision: str=None,
)->Union[
AutoencoderKL,
CLIPTokenizer,
CLIPFeatureExtractor,
CLIPTextModel,
UNet2DConditionModel,
StableDiffusionSafetyChecker,
StableDiffusionGeneratorPipeline,
]:
'''
Load and return a HuggingFace model, with RAM caching.
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
:param subfolder: name of a subfolder in which the model can be found, e.g. "vae"
:param revision: model revision
:param model_class: class of model to return
'''
key = self._model_key(repo_id_or_path,model_class,revision,subfolder) # internal unique identifier for the model
if key in self.models: # cached - move to bottom of stack
previous_key = self._current_model_key
with contextlib.suppress(ValueError):
self.stack.remove(key)
self.stack.append(key)
if previous_key != key:
if hasattr(self.current_model,'to'):
print(f'DEBUG: loading {key} into GPU')
self.model_group.offload_current()
self.model_group.load(self.models[key])
else: # not cached -load
self._make_cache_room()
self.model_group.offload_current()
print(f'DEBUG: loading {key} from disk/net')
model = self._load_model_from_storage(
repo_id_or_path=repo_id_or_path,
subfolder=subfolder,
revision=revision,
model_class=model_class
)
if hasattr(model,'to'):
self.model_group.install(model) # register with the model group
self.stack.append(key) # add to LRU cache
self.models[key]=model # keep copy of model in dict
return self.models[key]
@staticmethod
def _model_key(path,model_class,revision,subfolder)->str:
return ':'.join([str(path),str(model_class),str(revision),str(subfolder)])
def _make_cache_room(self):
models_in_ram = len(self.models)
while models_in_ram >= self.max_models_cached:
if least_recently_used_key := self.stack.pop(0):
print(f'DEBUG: maximum cache size reached: cache_size={models_in_ram}; unloading model {least_recently_used_key}')
self.model_group.uninstall(self.models[least_recently_used_key])
del self.models[least_recently_used_key]
models_in_ram = len(self.models)
gc.collect()
@property
def current_model(self)->Union[
AutoencoderKL,
CLIPTokenizer,
CLIPFeatureExtractor,
CLIPTextModel,
UNet2DConditionModel,
StableDiffusionSafetyChecker,
StableDiffusionGeneratorPipeline,
]:
'''
Returns current model.
'''
return self.models[self._current_model_key]
@property
def _current_model_key(self)->str:
'''
Returns key of currently loaded model.
'''
return self.stack[-1]
def _load_model_from_storage(
self,
repo_id_or_path: Union[str,Path],
subfolder: Path=None,
revision: str=None,
model_class: type=StableDiffusionGeneratorPipeline,
)->Union[
AutoencoderKL,
CLIPTokenizer,
CLIPFeatureExtractor,
CLIPTextModel,
UNet2DConditionModel,
StableDiffusionSafetyChecker,
StableDiffusionGeneratorPipeline,
]:
'''
Load and return a HuggingFace model.
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
:param subfolder: name of a subfolder in which the model can be found, e.g. "vae"
:param revision: model revision
:param model_class: class of model to return
'''
# silence transformer and diffuser warnings
with SilenceWarnings():
model = model_class.from_pretrained(
repo_id_or_path,
revision=revision,
subfolder=subfolder or '.',
cache_dir=global_cache_dir('hub'),
)
if self.sequential_offload and isinstance(model,StableDiffusionGeneratorPipeline):
model.enable_offload_submodels(self.device)
elif hasattr(model,'to'):
model.to(self.device)
return model
class SilenceWarnings(object):
def __init__(self):
self.transformers_verbosity = transformers_logging.get_verbosity()
self.diffusers_verbosity = diffusers_logging.get_verbosity()
def __enter__(self):
transformers_logging.set_verbosity_error()
diffusers_logging.set_verbosity_error()
warnings.simplefilter('ignore')
def __exit__(self,type,value,traceback):
transformers_logging.set_verbosity(self.transformers_verbosity)
diffusers_logging.set_verbosity(self.diffusers_verbosity)
warnings.simplefilter('default')

View File

@ -1,4 +1,4 @@
"""enum
"""
Manage a cache of Stable Diffusion model files for fast switching.
They are moved between GPU and CPU as necessary. If CPU memory falls
below a preset minimum, the least recently used model will be
@ -1108,11 +1108,8 @@ class ModelManager(object):
>> invokeai/models/diffusers to invokeai/models/hub due to a change introduced by
>> diffusers version 0.14. InvokeAI will now move all models from the "diffusers" directory
>> into "hub" and then remove the diffusers directory. This is a quick, safe, one-time
>> operation. However if you have customized either of these directories and need to
>> make adjustments, please press ctrl-C now to abort and relaunch InvokeAI when you are ready.
>> Otherwise press <enter> to continue."""
>> operation."""
)
input("continue> ")
# transformer files get moved into the hub directory
if cls._is_huggingface_hub_directory_present():

View File

@ -157,7 +157,7 @@ class LazilyLoadedModelGroup(ModelGroup):
def offload_current(self):
module = self._current_model_ref()
if module is not NO_MODEL:
module.to(device=OFFLOAD_DEVICE)
module.to(OFFLOAD_DEVICE)
self.clear_current_model()
def _load(self, module: torch.nn.Module) -> torch.nn.Module:
@ -228,7 +228,7 @@ class FullyLoadedModelGroup(ModelGroup):
def install(self, *models: torch.nn.Module):
for model in models:
self._models.add(model)
model.to(device=self.execution_device)
model.to(self.execution_device)
def uninstall(self, *models: torch.nn.Module):
for model in models:
@ -238,11 +238,11 @@ class FullyLoadedModelGroup(ModelGroup):
self.uninstall(*self._models)
def load(self, model):
model.to(device=self.execution_device)
model.to(self.execution_device)
def offload_current(self):
for model in self._models:
model.to(device=OFFLOAD_DEVICE)
model.to(OFFLOAD_DEVICE)
def ready(self):
for model in self._models:
@ -252,7 +252,7 @@ class FullyLoadedModelGroup(ModelGroup):
self.execution_device = device
for model in self._models:
if model.device != OFFLOAD_DEVICE:
model.to(device=device)
model.to(device)
def device_for(self, model):
if model not in self:

View File

@ -61,6 +61,7 @@ dependencies = [
"picklescan",
"pillow",
"prompt-toolkit",
"pympler==1.0.1",
"pypatchmatch",
"pyreadline3",
"python-multipart==0.0.6",