mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
caching of subparts working
This commit is contained in:
parent
956ad6bcf5
commit
2e2abf6ea6
@ -18,6 +18,7 @@ import torch
|
||||
import transformers
|
||||
import warnings
|
||||
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
@ -32,8 +33,6 @@ from transformers import(
|
||||
logging as transformers_logging,
|
||||
)
|
||||
from huggingface_hub import scan_cache_dir
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
from picklescan.scanner import scan_file_path
|
||||
from typing import Sequence, Union
|
||||
|
||||
@ -48,7 +47,21 @@ from ..stable_diffusion.offloading import ModelGroup, FullyLoadedModelGroup
|
||||
from ..util import CUDA_DEVICE, ask_user, download_with_resume
|
||||
|
||||
MAX_MODELS_CACHED = 4
|
||||
|
||||
# This is the mapping from the stable diffusion submodel dict key to the class
|
||||
class SDModelType(Enum):
|
||||
diffusion_pipeline=StableDiffusionGeneratorPipeline # whole thing
|
||||
vae=AutoencoderKL # parts
|
||||
text_encoder=CLIPTextModel
|
||||
tokenizer=CLIPTokenizer
|
||||
unet=UNet2DConditionModel
|
||||
scheduler=SchedulerMixin
|
||||
safety_checker=StableDiffusionSafetyChecker
|
||||
feature_extractor=CLIPFeatureExtractor
|
||||
|
||||
# List the model classes we know how to fetch
|
||||
ModelClass = Union[tuple([x.value for x in SDModelType])]
|
||||
|
||||
class ModelCache(object):
|
||||
def __init__(
|
||||
self,
|
||||
@ -65,21 +78,27 @@ class ModelCache(object):
|
||||
self.max_models_cached: int=max_models_cached
|
||||
self.device: torch.device=execution_device
|
||||
|
||||
def get_submodel(
|
||||
self,
|
||||
repo_id_or_path: Union[str,Path],
|
||||
submodel: SDModelType=SDModelType.vae,
|
||||
subfolder: Path=None,
|
||||
revision: str=None,
|
||||
)->ModelClass:
|
||||
parent_model = self.get_model(
|
||||
repo_id_or_path=repo_id_or_path,
|
||||
subfolder=subfolder,
|
||||
revision=revision,
|
||||
)
|
||||
return getattr(parent_model, submodel.name)
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
repo_id_or_path: Union[str,Path],
|
||||
model_class: type=StableDiffusionGeneratorPipeline,
|
||||
model_type: SDModelType=SDModelType.diffusion_pipeline,
|
||||
subfolder: Path=None,
|
||||
revision: str=None,
|
||||
)->Union[
|
||||
AutoencoderKL,
|
||||
CLIPTokenizer,
|
||||
CLIPFeatureExtractor,
|
||||
CLIPTextModel,
|
||||
UNet2DConditionModel,
|
||||
StableDiffusionSafetyChecker,
|
||||
StableDiffusionGeneratorPipeline,
|
||||
]:
|
||||
)->ModelClass:
|
||||
'''
|
||||
Load and return a HuggingFace model, with RAM caching.
|
||||
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
|
||||
@ -87,7 +106,12 @@ class ModelCache(object):
|
||||
:param revision: model revision
|
||||
:param model_class: class of model to return
|
||||
'''
|
||||
key = self._model_key(repo_id_or_path,model_class,revision,subfolder) # internal unique identifier for the model
|
||||
key = self._model_key(
|
||||
repo_id_or_path,
|
||||
model_type.value,
|
||||
revision,
|
||||
subfolder
|
||||
) # internal unique identifier for the model
|
||||
if key in self.models: # cached - move to bottom of stack
|
||||
previous_key = self._current_model_key
|
||||
with contextlib.suppress(ValueError):
|
||||
@ -105,9 +129,9 @@ class ModelCache(object):
|
||||
print(f'DEBUG: loading {key} from disk/net')
|
||||
model = self._load_model_from_storage(
|
||||
repo_id_or_path=repo_id_or_path,
|
||||
model_class=model_type.value,
|
||||
subfolder=subfolder,
|
||||
revision=revision,
|
||||
model_class=model_class
|
||||
)
|
||||
if hasattr(model,'to'):
|
||||
self.model_group.install(model) # register with the model group
|
||||
@ -117,7 +141,7 @@ class ModelCache(object):
|
||||
|
||||
@staticmethod
|
||||
def _model_key(path,model_class,revision,subfolder)->str:
|
||||
return ':'.join([str(path),str(model_class),str(revision),str(subfolder)])
|
||||
return ':'.join([str(path),model_class.__name__,str(revision or ''),str(subfolder or '')])
|
||||
|
||||
def _make_cache_room(self):
|
||||
models_in_ram = len(self.models)
|
||||
@ -130,15 +154,7 @@ class ModelCache(object):
|
||||
gc.collect()
|
||||
|
||||
@property
|
||||
def current_model(self)->Union[
|
||||
AutoencoderKL,
|
||||
CLIPTokenizer,
|
||||
CLIPFeatureExtractor,
|
||||
CLIPTextModel,
|
||||
UNet2DConditionModel,
|
||||
StableDiffusionSafetyChecker,
|
||||
StableDiffusionGeneratorPipeline,
|
||||
]:
|
||||
def current_model(self)->ModelClass:
|
||||
'''
|
||||
Returns current model.
|
||||
'''
|
||||
@ -156,16 +172,8 @@ class ModelCache(object):
|
||||
repo_id_or_path: Union[str,Path],
|
||||
subfolder: Path=None,
|
||||
revision: str=None,
|
||||
model_class: type=StableDiffusionGeneratorPipeline,
|
||||
)->Union[
|
||||
AutoencoderKL,
|
||||
CLIPTokenizer,
|
||||
CLIPFeatureExtractor,
|
||||
CLIPTextModel,
|
||||
UNet2DConditionModel,
|
||||
StableDiffusionSafetyChecker,
|
||||
StableDiffusionGeneratorPipeline,
|
||||
]:
|
||||
model_class: ModelClass=StableDiffusionGeneratorPipeline,
|
||||
)->ModelClass:
|
||||
'''
|
||||
Load and return a HuggingFace model.
|
||||
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
|
||||
|
53
tests/test_model_cache.py
Normal file
53
tests/test_model_cache.py
Normal file
@ -0,0 +1,53 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_management.model_cache import ModelCache, SDModelType
|
||||
from invokeai.backend.stable_diffusion import StableDiffusionGeneratorPipeline
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
UNet2DConditionModel,
|
||||
SchedulerMixin,
|
||||
)
|
||||
from transformers import (
|
||||
CLIPTokenizer,
|
||||
CLIPFeatureExtractor,
|
||||
CLIPTextModel,
|
||||
)
|
||||
|
||||
|
||||
cache = ModelCache()
|
||||
|
||||
def test_pipeline_fetch():
|
||||
model0 = cache.get_model('stabilityai/sd-vae-ft-mse',SDModelType.vae)
|
||||
model1 = cache.get_model('stabilityai/stable-diffusion-2-1',SDModelType.diffusion_pipeline)
|
||||
model1_2 = cache.get_model('stabilityai/stable-diffusion-2-1')
|
||||
assert model1==model1_2
|
||||
assert model1.device==torch.device('cuda')
|
||||
model2 = cache.get_model('runwayml/stable-diffusion-v1-5')
|
||||
assert model2.device==torch.device('cuda')
|
||||
assert model1.device==torch.device('cpu')
|
||||
model1 = cache.get_model('stabilityai/stable-diffusion-2-1')
|
||||
assert model1.device==torch.device('cuda')
|
||||
|
||||
def test_submodel_fetch():
|
||||
model1_vae = cache.get_submodel('stabilityai/stable-diffusion-2-1',SDModelType.vae)
|
||||
assert isinstance(model1_vae,AutoencoderKL)
|
||||
model1 = cache.get_model('stabilityai/stable-diffusion-2-1',SDModelType.diffusion_pipeline)
|
||||
assert model1_vae == model1.vae
|
||||
model1_vae_2 = cache.get_submodel('stabilityai/stable-diffusion-2-1')
|
||||
assert model1_vae == model1_vae_2
|
||||
|
||||
def test_transformer_fetch():
|
||||
model4 = cache.get_model('openai/clip-vit-large-patch14',SDModelType.tokenizer)
|
||||
assert isinstance(model4,CLIPTokenizer)
|
||||
|
||||
model5 = cache.get_model('openai/clip-vit-large-patch14',SDModelType.text_encoder)
|
||||
assert isinstance(model5,CLIPTextModel)
|
||||
|
||||
def test_subfolder_fetch():
|
||||
model6 = cache.get_model('stabilityai/stable-diffusion-2',SDModelType.tokenizer,subfolder="tokenizer")
|
||||
assert isinstance(model6,CLIPTokenizer)
|
||||
|
||||
model7 = cache.get_model('stabilityai/stable-diffusion-2',SDModelType.text_encoder,subfolder="text_encoder")
|
||||
assert isinstance(model7,CLIPTextModel)
|
Loading…
Reference in New Issue
Block a user