caching of subparts working

This commit is contained in:
Lincoln Stein
2023-05-01 22:57:30 -04:00
parent 956ad6bcf5
commit 2e2abf6ea6
2 changed files with 95 additions and 34 deletions

View File

@ -18,6 +18,7 @@ import torch
import transformers
import warnings
from enum import Enum
from pathlib import Path
from diffusers import (
AutoencoderKL,
@ -32,8 +33,6 @@ from transformers import(
logging as transformers_logging,
)
from huggingface_hub import scan_cache_dir
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
from picklescan.scanner import scan_file_path
from typing import Sequence, Union
@ -48,7 +47,21 @@ from ..stable_diffusion.offloading import ModelGroup, FullyLoadedModelGroup
from ..util import CUDA_DEVICE, ask_user, download_with_resume
MAX_MODELS_CACHED = 4
# This is the mapping from the stable diffusion submodel dict key to the class
class SDModelType(Enum):
diffusion_pipeline=StableDiffusionGeneratorPipeline # whole thing
vae=AutoencoderKL # parts
text_encoder=CLIPTextModel
tokenizer=CLIPTokenizer
unet=UNet2DConditionModel
scheduler=SchedulerMixin
safety_checker=StableDiffusionSafetyChecker
feature_extractor=CLIPFeatureExtractor
# List the model classes we know how to fetch
ModelClass = Union[tuple([x.value for x in SDModelType])]
class ModelCache(object):
def __init__(
self,
@ -65,21 +78,27 @@ class ModelCache(object):
self.max_models_cached: int=max_models_cached
self.device: torch.device=execution_device
def get_submodel(
self,
repo_id_or_path: Union[str,Path],
submodel: SDModelType=SDModelType.vae,
subfolder: Path=None,
revision: str=None,
)->ModelClass:
parent_model = self.get_model(
repo_id_or_path=repo_id_or_path,
subfolder=subfolder,
revision=revision,
)
return getattr(parent_model, submodel.name)
def get_model(
self,
repo_id_or_path: Union[str,Path],
model_class: type=StableDiffusionGeneratorPipeline,
model_type: SDModelType=SDModelType.diffusion_pipeline,
subfolder: Path=None,
revision: str=None,
)->Union[
AutoencoderKL,
CLIPTokenizer,
CLIPFeatureExtractor,
CLIPTextModel,
UNet2DConditionModel,
StableDiffusionSafetyChecker,
StableDiffusionGeneratorPipeline,
]:
)->ModelClass:
'''
Load and return a HuggingFace model, with RAM caching.
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
@ -87,7 +106,12 @@ class ModelCache(object):
:param revision: model revision
:param model_class: class of model to return
'''
key = self._model_key(repo_id_or_path,model_class,revision,subfolder) # internal unique identifier for the model
key = self._model_key(
repo_id_or_path,
model_type.value,
revision,
subfolder
) # internal unique identifier for the model
if key in self.models: # cached - move to bottom of stack
previous_key = self._current_model_key
with contextlib.suppress(ValueError):
@ -105,9 +129,9 @@ class ModelCache(object):
print(f'DEBUG: loading {key} from disk/net')
model = self._load_model_from_storage(
repo_id_or_path=repo_id_or_path,
model_class=model_type.value,
subfolder=subfolder,
revision=revision,
model_class=model_class
)
if hasattr(model,'to'):
self.model_group.install(model) # register with the model group
@ -117,7 +141,7 @@ class ModelCache(object):
@staticmethod
def _model_key(path,model_class,revision,subfolder)->str:
return ':'.join([str(path),str(model_class),str(revision),str(subfolder)])
return ':'.join([str(path),model_class.__name__,str(revision or ''),str(subfolder or '')])
def _make_cache_room(self):
models_in_ram = len(self.models)
@ -130,15 +154,7 @@ class ModelCache(object):
gc.collect()
@property
def current_model(self)->Union[
AutoencoderKL,
CLIPTokenizer,
CLIPFeatureExtractor,
CLIPTextModel,
UNet2DConditionModel,
StableDiffusionSafetyChecker,
StableDiffusionGeneratorPipeline,
]:
def current_model(self)->ModelClass:
'''
Returns current model.
'''
@ -156,16 +172,8 @@ class ModelCache(object):
repo_id_or_path: Union[str,Path],
subfolder: Path=None,
revision: str=None,
model_class: type=StableDiffusionGeneratorPipeline,
)->Union[
AutoencoderKL,
CLIPTokenizer,
CLIPFeatureExtractor,
CLIPTextModel,
UNet2DConditionModel,
StableDiffusionSafetyChecker,
StableDiffusionGeneratorPipeline,
]:
model_class: ModelClass=StableDiffusionGeneratorPipeline,
)->ModelClass:
'''
Load and return a HuggingFace model.
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model