caching of subparts working

This commit is contained in:
Lincoln Stein 2023-05-01 22:57:30 -04:00
parent 956ad6bcf5
commit 2e2abf6ea6
2 changed files with 95 additions and 34 deletions

View File

@ -18,6 +18,7 @@ import torch
import transformers
import warnings
from enum import Enum
from pathlib import Path
from diffusers import (
AutoencoderKL,
@ -32,8 +33,6 @@ from transformers import(
logging as transformers_logging,
)
from huggingface_hub import scan_cache_dir
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
from picklescan.scanner import scan_file_path
from typing import Sequence, Union
@ -48,7 +47,21 @@ from ..stable_diffusion.offloading import ModelGroup, FullyLoadedModelGroup
from ..util import CUDA_DEVICE, ask_user, download_with_resume
MAX_MODELS_CACHED = 4
# This is the mapping from the stable diffusion submodel dict key to the class
class SDModelType(Enum):
diffusion_pipeline=StableDiffusionGeneratorPipeline # whole thing
vae=AutoencoderKL # parts
text_encoder=CLIPTextModel
tokenizer=CLIPTokenizer
unet=UNet2DConditionModel
scheduler=SchedulerMixin
safety_checker=StableDiffusionSafetyChecker
feature_extractor=CLIPFeatureExtractor
# List the model classes we know how to fetch
ModelClass = Union[tuple([x.value for x in SDModelType])]
class ModelCache(object):
def __init__(
self,
@ -65,21 +78,27 @@ class ModelCache(object):
self.max_models_cached: int=max_models_cached
self.device: torch.device=execution_device
def get_submodel(
self,
repo_id_or_path: Union[str,Path],
submodel: SDModelType=SDModelType.vae,
subfolder: Path=None,
revision: str=None,
)->ModelClass:
parent_model = self.get_model(
repo_id_or_path=repo_id_or_path,
subfolder=subfolder,
revision=revision,
)
return getattr(parent_model, submodel.name)
def get_model(
self,
repo_id_or_path: Union[str,Path],
model_class: type=StableDiffusionGeneratorPipeline,
model_type: SDModelType=SDModelType.diffusion_pipeline,
subfolder: Path=None,
revision: str=None,
)->Union[
AutoencoderKL,
CLIPTokenizer,
CLIPFeatureExtractor,
CLIPTextModel,
UNet2DConditionModel,
StableDiffusionSafetyChecker,
StableDiffusionGeneratorPipeline,
]:
)->ModelClass:
'''
Load and return a HuggingFace model, with RAM caching.
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
@ -87,7 +106,12 @@ class ModelCache(object):
:param revision: model revision
:param model_class: class of model to return
'''
key = self._model_key(repo_id_or_path,model_class,revision,subfolder) # internal unique identifier for the model
key = self._model_key(
repo_id_or_path,
model_type.value,
revision,
subfolder
) # internal unique identifier for the model
if key in self.models: # cached - move to bottom of stack
previous_key = self._current_model_key
with contextlib.suppress(ValueError):
@ -105,9 +129,9 @@ class ModelCache(object):
print(f'DEBUG: loading {key} from disk/net')
model = self._load_model_from_storage(
repo_id_or_path=repo_id_or_path,
model_class=model_type.value,
subfolder=subfolder,
revision=revision,
model_class=model_class
)
if hasattr(model,'to'):
self.model_group.install(model) # register with the model group
@ -117,7 +141,7 @@ class ModelCache(object):
@staticmethod
def _model_key(path,model_class,revision,subfolder)->str:
return ':'.join([str(path),str(model_class),str(revision),str(subfolder)])
return ':'.join([str(path),model_class.__name__,str(revision or ''),str(subfolder or '')])
def _make_cache_room(self):
models_in_ram = len(self.models)
@ -130,15 +154,7 @@ class ModelCache(object):
gc.collect()
@property
def current_model(self)->Union[
AutoencoderKL,
CLIPTokenizer,
CLIPFeatureExtractor,
CLIPTextModel,
UNet2DConditionModel,
StableDiffusionSafetyChecker,
StableDiffusionGeneratorPipeline,
]:
def current_model(self)->ModelClass:
'''
Returns current model.
'''
@ -156,16 +172,8 @@ class ModelCache(object):
repo_id_or_path: Union[str,Path],
subfolder: Path=None,
revision: str=None,
model_class: type=StableDiffusionGeneratorPipeline,
)->Union[
AutoencoderKL,
CLIPTokenizer,
CLIPFeatureExtractor,
CLIPTextModel,
UNet2DConditionModel,
StableDiffusionSafetyChecker,
StableDiffusionGeneratorPipeline,
]:
model_class: ModelClass=StableDiffusionGeneratorPipeline,
)->ModelClass:
'''
Load and return a HuggingFace model.
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model

53
tests/test_model_cache.py Normal file
View File

@ -0,0 +1,53 @@
import pytest
import torch
from invokeai.backend.model_management.model_cache import ModelCache, SDModelType
from invokeai.backend.stable_diffusion import StableDiffusionGeneratorPipeline
from diffusers import (
AutoencoderKL,
UNet2DConditionModel,
SchedulerMixin,
)
from transformers import (
CLIPTokenizer,
CLIPFeatureExtractor,
CLIPTextModel,
)
cache = ModelCache()
def test_pipeline_fetch():
model0 = cache.get_model('stabilityai/sd-vae-ft-mse',SDModelType.vae)
model1 = cache.get_model('stabilityai/stable-diffusion-2-1',SDModelType.diffusion_pipeline)
model1_2 = cache.get_model('stabilityai/stable-diffusion-2-1')
assert model1==model1_2
assert model1.device==torch.device('cuda')
model2 = cache.get_model('runwayml/stable-diffusion-v1-5')
assert model2.device==torch.device('cuda')
assert model1.device==torch.device('cpu')
model1 = cache.get_model('stabilityai/stable-diffusion-2-1')
assert model1.device==torch.device('cuda')
def test_submodel_fetch():
model1_vae = cache.get_submodel('stabilityai/stable-diffusion-2-1',SDModelType.vae)
assert isinstance(model1_vae,AutoencoderKL)
model1 = cache.get_model('stabilityai/stable-diffusion-2-1',SDModelType.diffusion_pipeline)
assert model1_vae == model1.vae
model1_vae_2 = cache.get_submodel('stabilityai/stable-diffusion-2-1')
assert model1_vae == model1_vae_2
def test_transformer_fetch():
model4 = cache.get_model('openai/clip-vit-large-patch14',SDModelType.tokenizer)
assert isinstance(model4,CLIPTokenizer)
model5 = cache.get_model('openai/clip-vit-large-patch14',SDModelType.text_encoder)
assert isinstance(model5,CLIPTextModel)
def test_subfolder_fetch():
model6 = cache.get_model('stabilityai/stable-diffusion-2',SDModelType.tokenizer,subfolder="tokenizer")
assert isinstance(model6,CLIPTokenizer)
model7 = cache.get_model('stabilityai/stable-diffusion-2',SDModelType.text_encoder,subfolder="text_encoder")
assert isinstance(model7,CLIPTextModel)