mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
caching of subparts working
This commit is contained in:
parent
956ad6bcf5
commit
2e2abf6ea6
@ -18,6 +18,7 @@ import torch
|
|||||||
import transformers
|
import transformers
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
AutoencoderKL,
|
AutoencoderKL,
|
||||||
@ -32,8 +33,6 @@ from transformers import(
|
|||||||
logging as transformers_logging,
|
logging as transformers_logging,
|
||||||
)
|
)
|
||||||
from huggingface_hub import scan_cache_dir
|
from huggingface_hub import scan_cache_dir
|
||||||
from omegaconf import OmegaConf
|
|
||||||
from omegaconf.dictconfig import DictConfig
|
|
||||||
from picklescan.scanner import scan_file_path
|
from picklescan.scanner import scan_file_path
|
||||||
from typing import Sequence, Union
|
from typing import Sequence, Union
|
||||||
|
|
||||||
@ -48,7 +47,21 @@ from ..stable_diffusion.offloading import ModelGroup, FullyLoadedModelGroup
|
|||||||
from ..util import CUDA_DEVICE, ask_user, download_with_resume
|
from ..util import CUDA_DEVICE, ask_user, download_with_resume
|
||||||
|
|
||||||
MAX_MODELS_CACHED = 4
|
MAX_MODELS_CACHED = 4
|
||||||
|
|
||||||
|
# This is the mapping from the stable diffusion submodel dict key to the class
|
||||||
|
class SDModelType(Enum):
|
||||||
|
diffusion_pipeline=StableDiffusionGeneratorPipeline # whole thing
|
||||||
|
vae=AutoencoderKL # parts
|
||||||
|
text_encoder=CLIPTextModel
|
||||||
|
tokenizer=CLIPTokenizer
|
||||||
|
unet=UNet2DConditionModel
|
||||||
|
scheduler=SchedulerMixin
|
||||||
|
safety_checker=StableDiffusionSafetyChecker
|
||||||
|
feature_extractor=CLIPFeatureExtractor
|
||||||
|
|
||||||
|
# List the model classes we know how to fetch
|
||||||
|
ModelClass = Union[tuple([x.value for x in SDModelType])]
|
||||||
|
|
||||||
class ModelCache(object):
|
class ModelCache(object):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -65,21 +78,27 @@ class ModelCache(object):
|
|||||||
self.max_models_cached: int=max_models_cached
|
self.max_models_cached: int=max_models_cached
|
||||||
self.device: torch.device=execution_device
|
self.device: torch.device=execution_device
|
||||||
|
|
||||||
|
def get_submodel(
|
||||||
|
self,
|
||||||
|
repo_id_or_path: Union[str,Path],
|
||||||
|
submodel: SDModelType=SDModelType.vae,
|
||||||
|
subfolder: Path=None,
|
||||||
|
revision: str=None,
|
||||||
|
)->ModelClass:
|
||||||
|
parent_model = self.get_model(
|
||||||
|
repo_id_or_path=repo_id_or_path,
|
||||||
|
subfolder=subfolder,
|
||||||
|
revision=revision,
|
||||||
|
)
|
||||||
|
return getattr(parent_model, submodel.name)
|
||||||
|
|
||||||
def get_model(
|
def get_model(
|
||||||
self,
|
self,
|
||||||
repo_id_or_path: Union[str,Path],
|
repo_id_or_path: Union[str,Path],
|
||||||
model_class: type=StableDiffusionGeneratorPipeline,
|
model_type: SDModelType=SDModelType.diffusion_pipeline,
|
||||||
subfolder: Path=None,
|
subfolder: Path=None,
|
||||||
revision: str=None,
|
revision: str=None,
|
||||||
)->Union[
|
)->ModelClass:
|
||||||
AutoencoderKL,
|
|
||||||
CLIPTokenizer,
|
|
||||||
CLIPFeatureExtractor,
|
|
||||||
CLIPTextModel,
|
|
||||||
UNet2DConditionModel,
|
|
||||||
StableDiffusionSafetyChecker,
|
|
||||||
StableDiffusionGeneratorPipeline,
|
|
||||||
]:
|
|
||||||
'''
|
'''
|
||||||
Load and return a HuggingFace model, with RAM caching.
|
Load and return a HuggingFace model, with RAM caching.
|
||||||
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
|
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
|
||||||
@ -87,7 +106,12 @@ class ModelCache(object):
|
|||||||
:param revision: model revision
|
:param revision: model revision
|
||||||
:param model_class: class of model to return
|
:param model_class: class of model to return
|
||||||
'''
|
'''
|
||||||
key = self._model_key(repo_id_or_path,model_class,revision,subfolder) # internal unique identifier for the model
|
key = self._model_key(
|
||||||
|
repo_id_or_path,
|
||||||
|
model_type.value,
|
||||||
|
revision,
|
||||||
|
subfolder
|
||||||
|
) # internal unique identifier for the model
|
||||||
if key in self.models: # cached - move to bottom of stack
|
if key in self.models: # cached - move to bottom of stack
|
||||||
previous_key = self._current_model_key
|
previous_key = self._current_model_key
|
||||||
with contextlib.suppress(ValueError):
|
with contextlib.suppress(ValueError):
|
||||||
@ -105,9 +129,9 @@ class ModelCache(object):
|
|||||||
print(f'DEBUG: loading {key} from disk/net')
|
print(f'DEBUG: loading {key} from disk/net')
|
||||||
model = self._load_model_from_storage(
|
model = self._load_model_from_storage(
|
||||||
repo_id_or_path=repo_id_or_path,
|
repo_id_or_path=repo_id_or_path,
|
||||||
|
model_class=model_type.value,
|
||||||
subfolder=subfolder,
|
subfolder=subfolder,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
model_class=model_class
|
|
||||||
)
|
)
|
||||||
if hasattr(model,'to'):
|
if hasattr(model,'to'):
|
||||||
self.model_group.install(model) # register with the model group
|
self.model_group.install(model) # register with the model group
|
||||||
@ -117,7 +141,7 @@ class ModelCache(object):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _model_key(path,model_class,revision,subfolder)->str:
|
def _model_key(path,model_class,revision,subfolder)->str:
|
||||||
return ':'.join([str(path),str(model_class),str(revision),str(subfolder)])
|
return ':'.join([str(path),model_class.__name__,str(revision or ''),str(subfolder or '')])
|
||||||
|
|
||||||
def _make_cache_room(self):
|
def _make_cache_room(self):
|
||||||
models_in_ram = len(self.models)
|
models_in_ram = len(self.models)
|
||||||
@ -130,15 +154,7 @@ class ModelCache(object):
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def current_model(self)->Union[
|
def current_model(self)->ModelClass:
|
||||||
AutoencoderKL,
|
|
||||||
CLIPTokenizer,
|
|
||||||
CLIPFeatureExtractor,
|
|
||||||
CLIPTextModel,
|
|
||||||
UNet2DConditionModel,
|
|
||||||
StableDiffusionSafetyChecker,
|
|
||||||
StableDiffusionGeneratorPipeline,
|
|
||||||
]:
|
|
||||||
'''
|
'''
|
||||||
Returns current model.
|
Returns current model.
|
||||||
'''
|
'''
|
||||||
@ -156,16 +172,8 @@ class ModelCache(object):
|
|||||||
repo_id_or_path: Union[str,Path],
|
repo_id_or_path: Union[str,Path],
|
||||||
subfolder: Path=None,
|
subfolder: Path=None,
|
||||||
revision: str=None,
|
revision: str=None,
|
||||||
model_class: type=StableDiffusionGeneratorPipeline,
|
model_class: ModelClass=StableDiffusionGeneratorPipeline,
|
||||||
)->Union[
|
)->ModelClass:
|
||||||
AutoencoderKL,
|
|
||||||
CLIPTokenizer,
|
|
||||||
CLIPFeatureExtractor,
|
|
||||||
CLIPTextModel,
|
|
||||||
UNet2DConditionModel,
|
|
||||||
StableDiffusionSafetyChecker,
|
|
||||||
StableDiffusionGeneratorPipeline,
|
|
||||||
]:
|
|
||||||
'''
|
'''
|
||||||
Load and return a HuggingFace model.
|
Load and return a HuggingFace model.
|
||||||
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
|
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
|
||||||
|
53
tests/test_model_cache.py
Normal file
53
tests/test_model_cache.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from invokeai.backend.model_management.model_cache import ModelCache, SDModelType
|
||||||
|
from invokeai.backend.stable_diffusion import StableDiffusionGeneratorPipeline
|
||||||
|
|
||||||
|
from diffusers import (
|
||||||
|
AutoencoderKL,
|
||||||
|
UNet2DConditionModel,
|
||||||
|
SchedulerMixin,
|
||||||
|
)
|
||||||
|
from transformers import (
|
||||||
|
CLIPTokenizer,
|
||||||
|
CLIPFeatureExtractor,
|
||||||
|
CLIPTextModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
cache = ModelCache()
|
||||||
|
|
||||||
|
def test_pipeline_fetch():
|
||||||
|
model0 = cache.get_model('stabilityai/sd-vae-ft-mse',SDModelType.vae)
|
||||||
|
model1 = cache.get_model('stabilityai/stable-diffusion-2-1',SDModelType.diffusion_pipeline)
|
||||||
|
model1_2 = cache.get_model('stabilityai/stable-diffusion-2-1')
|
||||||
|
assert model1==model1_2
|
||||||
|
assert model1.device==torch.device('cuda')
|
||||||
|
model2 = cache.get_model('runwayml/stable-diffusion-v1-5')
|
||||||
|
assert model2.device==torch.device('cuda')
|
||||||
|
assert model1.device==torch.device('cpu')
|
||||||
|
model1 = cache.get_model('stabilityai/stable-diffusion-2-1')
|
||||||
|
assert model1.device==torch.device('cuda')
|
||||||
|
|
||||||
|
def test_submodel_fetch():
|
||||||
|
model1_vae = cache.get_submodel('stabilityai/stable-diffusion-2-1',SDModelType.vae)
|
||||||
|
assert isinstance(model1_vae,AutoencoderKL)
|
||||||
|
model1 = cache.get_model('stabilityai/stable-diffusion-2-1',SDModelType.diffusion_pipeline)
|
||||||
|
assert model1_vae == model1.vae
|
||||||
|
model1_vae_2 = cache.get_submodel('stabilityai/stable-diffusion-2-1')
|
||||||
|
assert model1_vae == model1_vae_2
|
||||||
|
|
||||||
|
def test_transformer_fetch():
|
||||||
|
model4 = cache.get_model('openai/clip-vit-large-patch14',SDModelType.tokenizer)
|
||||||
|
assert isinstance(model4,CLIPTokenizer)
|
||||||
|
|
||||||
|
model5 = cache.get_model('openai/clip-vit-large-patch14',SDModelType.text_encoder)
|
||||||
|
assert isinstance(model5,CLIPTextModel)
|
||||||
|
|
||||||
|
def test_subfolder_fetch():
|
||||||
|
model6 = cache.get_model('stabilityai/stable-diffusion-2',SDModelType.tokenizer,subfolder="tokenizer")
|
||||||
|
assert isinstance(model6,CLIPTokenizer)
|
||||||
|
|
||||||
|
model7 = cache.get_model('stabilityai/stable-diffusion-2',SDModelType.text_encoder,subfolder="text_encoder")
|
||||||
|
assert isinstance(model7,CLIPTextModel)
|
Loading…
Reference in New Issue
Block a user