mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
work on model cache and its regression test finished
This commit is contained in:
parent
bb959448c1
commit
e1fed52c66
@ -1,55 +1,46 @@
|
||||
"""
|
||||
Manage a cache of Stable Diffusion model files for fast switching.
|
||||
They are moved between GPU and CPU as necessary. If the cache
|
||||
Manage a RAM cache of diffusion/transformer models for fast switching.
|
||||
They are moved between GPU VRAM and CPU RAM as necessary. If the cache
|
||||
grows larger than a preset maximum, then the least recently used
|
||||
model will be cleared and (re)loaded from disk when next needed.
|
||||
|
||||
The cache returns context manager generators designed to load the
|
||||
model into the GPU within the context, and unload outside the
|
||||
context. Use like this:
|
||||
|
||||
cache = ModelCache(max_models_cached=6)
|
||||
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1,
|
||||
cache.get_model('stabilityai/stable-diffusion-2') as SD2:
|
||||
do_something_in_GPU(SD1,SD2)
|
||||
|
||||
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import hashlib
|
||||
import gc
|
||||
import time
|
||||
import os
|
||||
import psutil
|
||||
|
||||
import safetensors
|
||||
import safetensors.torch
|
||||
import torch
|
||||
import transformers
|
||||
import hashlib
|
||||
import warnings
|
||||
|
||||
from collections.abc import Generator
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from pydantic import BaseModel
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
UNet2DConditionModel,
|
||||
SchedulerMixin,
|
||||
logging as diffusers_logging,
|
||||
)
|
||||
from huggingface_hub import list_repo_refs,HfApi
|
||||
from transformers import(
|
||||
CLIPTokenizer,
|
||||
CLIPFeatureExtractor,
|
||||
CLIPTextModel,
|
||||
logging as transformers_logging,
|
||||
)
|
||||
from huggingface_hub import scan_cache_dir
|
||||
from picklescan.scanner import scan_file_path
|
||||
from typing import Sequence, Union
|
||||
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import (
|
||||
StableDiffusionSafetyChecker,
|
||||
)
|
||||
from . import load_pipeline_from_original_stable_diffusion_ckpt
|
||||
from ..globals import Globals, global_cache_dir
|
||||
from ..stable_diffusion import (
|
||||
StableDiffusionGeneratorPipeline,
|
||||
)
|
||||
from ..stable_diffusion.offloading import ModelGroup, FullyLoadedModelGroup
|
||||
from ..util import CUDA_DEVICE, ask_user, download_with_resume
|
||||
import torch
|
||||
from diffusers import AutoencoderKL, SchedulerMixin, UNet2DConditionModel
|
||||
from diffusers import logging as diffusers_logging
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import \
|
||||
StableDiffusionSafetyChecker
|
||||
from huggingface_hub import HfApi
|
||||
from picklescan.scanner import scan_file_path
|
||||
from pydantic import BaseModel
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
from transformers import logging as transformers_logging
|
||||
|
||||
MAX_MODELS_CACHED = 4
|
||||
from ..globals import global_cache_dir
|
||||
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
||||
from . import load_pipeline_from_original_stable_diffusion_ckpt
|
||||
|
||||
MAX_MODELS = 4
|
||||
|
||||
# This is the mapping from the stable diffusion submodel dict key to the class
|
||||
class SDModelType(Enum):
|
||||
@ -68,70 +59,69 @@ ModelClass = Union[tuple([x.value for x in SDModelType])]
|
||||
# Legacy information needed to load a legacy checkpoint file
|
||||
class LegacyInfo(BaseModel):
|
||||
config_file: Path
|
||||
vae_file: Path
|
||||
vae_file: Path = None
|
||||
|
||||
class UnsafeModelException(Exception):
|
||||
"Raised when a legacy model file fails the picklescan test"
|
||||
pass
|
||||
|
||||
class UnscannableModelException(Exception):
|
||||
"Raised when picklescan is unable to scan a legacy model file"
|
||||
pass
|
||||
|
||||
class ModelCache(object):
|
||||
def __init__(
|
||||
self,
|
||||
max_models_cached: int=MAX_MODELS_CACHED,
|
||||
max_models: int=MAX_MODELS,
|
||||
execution_device: torch.device=torch.device('cuda'),
|
||||
storage_device: torch.device=torch.device('cpu'),
|
||||
precision: torch.dtype=torch.float16,
|
||||
sequential_offload: bool=False,
|
||||
sha_chunksize: int = 16777216,
|
||||
):
|
||||
'''
|
||||
:param max_models_cached: Maximum number of models to cache in CPU RAM [4]
|
||||
:param max_models: Maximum number of models to cache in CPU RAM [4]
|
||||
:param execution_device: Torch device to load active model into [torch.device('cuda')]
|
||||
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
|
||||
:param precision: Precision for loaded models [torch.float16]
|
||||
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
|
||||
:param sha_chunksize: Chunksize to use when calculating sha256 model hash
|
||||
'''
|
||||
self.model_group: ModelGroup=FullyLoadedModelGroup(execution_device)
|
||||
self.models: dict = dict()
|
||||
self.stack: Sequence = list()
|
||||
self.sequential_offload: bool=sequential_offload
|
||||
self.precision: torch.dtype=precision
|
||||
self.max_models_cached: int=max_models_cached
|
||||
self.device: torch.device=execution_device
|
||||
self.max_models: int=max_models
|
||||
self.execution_device: torch.device=execution_device
|
||||
self.storage_device: torch.device=storage_device
|
||||
self.sha_chunksize=sha_chunksize
|
||||
|
||||
def get_submodel(
|
||||
self,
|
||||
repo_id_or_path: Union[str,Path],
|
||||
submodel: SDModelType=SDModelType.vae,
|
||||
subfolder: Path=None,
|
||||
revision: str=None,
|
||||
legacy_info: LegacyInfo=None,
|
||||
)->ModelClass:
|
||||
'''
|
||||
Load and return a HuggingFace model, with RAM caching.
|
||||
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
|
||||
:param submodel: an SDModelType enum indicating the model part to return, e.g. SDModelType.vae
|
||||
:param subfolder: name of a subfolder in which the model can be found, e.g. "vae"
|
||||
:param revision: model revision name
|
||||
:param legacy_info: a LegacyInfo object containing additional info needed to load a legacy ckpt
|
||||
'''
|
||||
parent_model = self.get_model(
|
||||
repo_id_or_path=repo_id_or_path,
|
||||
subfolder=subfolder,
|
||||
revision=revision,
|
||||
)
|
||||
return getattr(parent_model, submodel.name)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def get_model(
|
||||
self,
|
||||
repo_id_or_path: Union[str,Path],
|
||||
model_type: SDModelType=SDModelType.diffusion_pipeline,
|
||||
subfolder: Path=None,
|
||||
submodel: SDModelType=None,
|
||||
revision: str=None,
|
||||
legacy_info: LegacyInfo=None,
|
||||
)->ModelClass:
|
||||
gpu_load: bool=True,
|
||||
)->Generator[ModelClass, None, None]:
|
||||
'''
|
||||
Load and return a HuggingFace model, with RAM caching.
|
||||
Load and return a HuggingFace model wrapped in a context manager generator, with RAM caching.
|
||||
Use like this:
|
||||
|
||||
cache = ModelCache()
|
||||
with cache.get_model('stabilityai/stable-diffusion-2') as SD2:
|
||||
do_something_with_the_model(SD2)
|
||||
|
||||
The model will be locked into GPU VRAM for the duration of the context.
|
||||
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
|
||||
:param subfolder: name of a subfolder in which the model can be found, e.g. "vae"
|
||||
:param submodel: an SDModelType enum indicating the model part to return, e.g. SDModelType.vae
|
||||
:param revision: model revision
|
||||
:param model_class: class of model to return
|
||||
:param gpu_load: load the model into GPU [default True]
|
||||
:param legacy_info: a LegacyInfo object containing additional info needed to load a legacy ckpt
|
||||
'''
|
||||
key = self._model_key( # internal unique identifier for the model
|
||||
@ -141,19 +131,12 @@ class ModelCache(object):
|
||||
subfolder
|
||||
)
|
||||
if key in self.models: # cached - move to bottom of stack
|
||||
previous_key = self._current_model_key
|
||||
with contextlib.suppress(ValueError):
|
||||
self.stack.remove(key)
|
||||
self.stack.append(key)
|
||||
if previous_key != key:
|
||||
if hasattr(self.current_model,'to'):
|
||||
print(f' | loading {key} into GPU')
|
||||
self.model_group.offload_current()
|
||||
self.model_group.load(self.models[key])
|
||||
model = self.models[key]
|
||||
else: # not cached -load
|
||||
self._make_cache_room()
|
||||
self.model_group.offload_current()
|
||||
print(f' | loading model {key} from disk/net')
|
||||
model = self._load_model_from_storage(
|
||||
repo_id_or_path=repo_id_or_path,
|
||||
model_class=model_type.value,
|
||||
@ -161,14 +144,29 @@ class ModelCache(object):
|
||||
revision=revision,
|
||||
legacy_info=legacy_info,
|
||||
)
|
||||
if hasattr(model,'to'):
|
||||
self.model_group.install(model) # register with the model group
|
||||
self.stack.append(key) # add to LRU cache
|
||||
self.models[key]=model # keep copy of model in dict
|
||||
return self.models[key]
|
||||
|
||||
if submodel:
|
||||
model = getattr(model, submodel.name)
|
||||
debugging_name = f'{submodel.name} submodel of {repo_id_or_path}'
|
||||
else:
|
||||
debugging_name = repo_id_or_path
|
||||
|
||||
try:
|
||||
if gpu_load and hasattr(model,'to'):
|
||||
print(f' | Loading {debugging_name} into GPU')
|
||||
model.to(self.execution_device) # move into GPU
|
||||
self._print_cuda_stats()
|
||||
yield model
|
||||
finally:
|
||||
if gpu_load and hasattr(model,'to'):
|
||||
print(f' | Unloading {debugging_name} from GPU')
|
||||
model.to(self.storage_device)
|
||||
self._print_cuda_stats()
|
||||
|
||||
@staticmethod
|
||||
def model_hash(repo_id_or_path: Union[str,Path],
|
||||
def model_hash(self,
|
||||
repo_id_or_path: Union[str,Path],
|
||||
revision: str=None)->str:
|
||||
'''
|
||||
Given the HF repo id or path to a model on disk, returns a unique
|
||||
@ -183,16 +181,28 @@ class ModelCache(object):
|
||||
else:
|
||||
return self._hf_commit_hash(repo_id_or_path,revision)
|
||||
|
||||
def cache_size(self)->int:
|
||||
"Return the current number of models cached."
|
||||
return len(self.models)
|
||||
|
||||
@staticmethod
|
||||
def _model_key(path,model_class,revision,subfolder)->str:
|
||||
return ':'.join([str(path),model_class.__name__,str(revision or ''),str(subfolder or '')])
|
||||
|
||||
def _has_cuda(self)->bool:
|
||||
return self.execution_device.type == 'cuda'
|
||||
|
||||
def _print_cuda_stats(self):
|
||||
print(
|
||||
" | Current VRAM usage:",
|
||||
"%4.2fG" % (torch.cuda.memory_allocated() / 1e9),
|
||||
)
|
||||
|
||||
def _make_cache_room(self):
|
||||
models_in_ram = len(self.models)
|
||||
while models_in_ram >= self.max_models_cached:
|
||||
while models_in_ram >= self.max_models:
|
||||
if least_recently_used_key := self.stack.pop(0):
|
||||
print(f' | maximum cache size reached: cache_size={models_in_ram}; unloading model {least_recently_used_key}')
|
||||
self.model_group.uninstall(self.models[least_recently_used_key])
|
||||
print(f' | Maximum cache size reached: cache_size={models_in_ram}; unloading model {least_recently_used_key}')
|
||||
del self.models[least_recently_used_key]
|
||||
models_in_ram = len(self.models)
|
||||
gc.collect()
|
||||
@ -239,9 +249,9 @@ class ModelCache(object):
|
||||
model_class,
|
||||
)
|
||||
if self.sequential_offload and isinstance(model,StableDiffusionGeneratorPipeline):
|
||||
model.enable_offload_submodels(self.device)
|
||||
model.enable_offload_submodels(self.execution_device)
|
||||
elif hasattr(model,'to'):
|
||||
model.to(self.device)
|
||||
model.to(self.execution_device)
|
||||
return model
|
||||
|
||||
def _load_diffusers_from_storage(
|
||||
@ -274,6 +284,23 @@ class ModelCache(object):
|
||||
path = Path(repo_id_or_path)
|
||||
return path.is_file() and path.suffix in [".ckpt",".safetensors"]
|
||||
|
||||
@classmethod
|
||||
def scan_model(cls, model_name, checkpoint):
|
||||
"""
|
||||
Apply picklescanner to the indicated checkpoint and issue a warning
|
||||
and option to exit if an infected file is identified.
|
||||
"""
|
||||
# scan model
|
||||
print(f" | Scanning Model: {model_name}")
|
||||
scan_result = scan_file_path(checkpoint)
|
||||
if scan_result.infected_files != 0:
|
||||
if scan_result.infected_files == 1:
|
||||
raise UnsafeModelException("The legacy model you are trying to load may contain malware. Aborting.")
|
||||
else:
|
||||
raise UnscannableModelException("InvokeAI was unable to scan the legacy model you requested. Aborting")
|
||||
else:
|
||||
print(" | Model scanned ok")
|
||||
|
||||
def _load_ckpt_from_storage(self,
|
||||
ckpt_path: Union[str,Path],
|
||||
legacy_info:LegacyInfo)->StableDiffusionGeneratorPipeline:
|
||||
|
@ -1,53 +1,84 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_management.model_cache import ModelCache, SDModelType
|
||||
from invokeai.backend.stable_diffusion import StableDiffusionGeneratorPipeline
|
||||
from enum import Enum
|
||||
from invokeai.backend.model_management.model_cache import ModelCache
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
UNet2DConditionModel,
|
||||
SchedulerMixin,
|
||||
)
|
||||
from transformers import (
|
||||
CLIPTokenizer,
|
||||
CLIPFeatureExtractor,
|
||||
CLIPTextModel,
|
||||
)
|
||||
class DummyModelBase(object):
|
||||
'''Base class for dummy component of a diffusers model'''
|
||||
def __init__(self, repo_id):
|
||||
self.repo_id = repo_id
|
||||
self.device = torch.device('cpu')
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls,
|
||||
repo_id:str,
|
||||
revision:str=None,
|
||||
subfolder:str=None,
|
||||
cache_dir:str=None,
|
||||
):
|
||||
return cls(repo_id)
|
||||
|
||||
def to(self, device):
|
||||
self.device = device
|
||||
|
||||
cache = ModelCache()
|
||||
class DummyModelType1(DummyModelBase):
|
||||
pass
|
||||
|
||||
class DummyModelType2(DummyModelBase):
|
||||
pass
|
||||
|
||||
class DummyPipeline(DummyModelBase):
|
||||
'''Dummy pipeline object is a composite of several types'''
|
||||
def __init__(self,repo_id):
|
||||
super().__init__(repo_id)
|
||||
self.type1 = DummyModelType1('dummy/type1')
|
||||
self.type2 = DummyModelType2('dummy/type2')
|
||||
|
||||
class DMType(Enum):
|
||||
dummy_pipeline = DummyPipeline
|
||||
type1 = DummyModelType1
|
||||
type2 = DummyModelType2
|
||||
|
||||
cache = ModelCache(max_models=4)
|
||||
|
||||
def test_pipeline_fetch():
|
||||
model0 = cache.get_model('stabilityai/sd-vae-ft-mse',SDModelType.vae)
|
||||
model1 = cache.get_model('stabilityai/stable-diffusion-2-1',SDModelType.diffusion_pipeline)
|
||||
model1_2 = cache.get_model('stabilityai/stable-diffusion-2-1')
|
||||
assert model1==model1_2
|
||||
assert model1.device==torch.device('cuda')
|
||||
model2 = cache.get_model('runwayml/stable-diffusion-v1-5')
|
||||
assert model2.device==torch.device('cuda')
|
||||
assert model1.device==torch.device('cpu')
|
||||
model1 = cache.get_model('stabilityai/stable-diffusion-2-1')
|
||||
assert model1.device==torch.device('cuda')
|
||||
assert cache.cache_size()==0
|
||||
with cache.get_model('dummy/pipeline1',DMType.dummy_pipeline) as pipeline1,\
|
||||
cache.get_model('dummy/pipeline1',DMType.dummy_pipeline) as pipeline1a,\
|
||||
cache.get_model('dummy/pipeline2',DMType.dummy_pipeline) as pipeline2:
|
||||
assert pipeline1 is not None, 'get_model() should not return None'
|
||||
assert pipeline1a is not None, 'get_model() should not return None'
|
||||
assert pipeline2 is not None, 'get_model() should not return None'
|
||||
assert type(pipeline1)==DMType.dummy_pipeline.value,'get_model() did not return model of expected type'
|
||||
assert pipeline1==pipeline1a,'pipelines with the same repo_id should be the same'
|
||||
assert pipeline1!=pipeline2,'pipelines with different repo_ids should not be the same'
|
||||
assert cache.cache_size()==2,'cache should uniquely cache models with same identity'
|
||||
with cache.get_model('dummy/pipeline3',DMType.dummy_pipeline) as pipeline3,\
|
||||
cache.get_model('dummy/pipeline4',DMType.dummy_pipeline) as pipeline4:
|
||||
assert cache.cache_size()==4,'cache did not grow as expected'
|
||||
with cache.get_model('dummy/pipeline5',DMType.dummy_pipeline) as pipeline5:
|
||||
assert cache.cache_size()==4,'cache did not free space as expected'
|
||||
|
||||
def test_signatures():
|
||||
with cache.get_model('dummy/pipeline',DMType.dummy_pipeline,revision='main') as pipeline1,\
|
||||
cache.get_model('dummy/pipeline',DMType.dummy_pipeline,revision='fp16') as pipeline2,\
|
||||
cache.get_model('dummy/pipeline',DMType.dummy_pipeline,revision='main',subfolder='foo') as pipeline3:
|
||||
assert pipeline1 != pipeline2,'models are distinguished by their revision'
|
||||
assert pipeline1 != pipeline3,'models are distinguished by their subfolder'
|
||||
|
||||
def test_pipeline_device():
|
||||
with cache.get_model('dummy/pipeline1',DMType.type1) as model1:
|
||||
assert model1.device==torch.device('cuda'),'when in context, model device should be in GPU'
|
||||
with cache.get_model('dummy/pipeline1',DMType.type1, gpu_load=False) as model1:
|
||||
assert model1.device==torch.device('cpu'),'when gpu_load=False, model device should be CPU'
|
||||
|
||||
def test_submodel_fetch():
|
||||
model1_vae = cache.get_submodel('stabilityai/stable-diffusion-2-1',SDModelType.vae)
|
||||
assert isinstance(model1_vae,AutoencoderKL)
|
||||
model1 = cache.get_model('stabilityai/stable-diffusion-2-1',SDModelType.diffusion_pipeline)
|
||||
assert model1_vae == model1.vae
|
||||
model1_vae_2 = cache.get_submodel('stabilityai/stable-diffusion-2-1')
|
||||
assert model1_vae == model1_vae_2
|
||||
with cache.get_model(repo_id_or_path='dummy/pipeline1',model_type=DMType.dummy_pipeline) as pipeline,\
|
||||
cache.get_model(repo_id_or_path='dummy/pipeline1',model_type=DMType.dummy_pipeline,submodel=DMType.type1) as part1,\
|
||||
cache.get_model(repo_id_or_path='dummy/pipeline2',model_type=DMType.dummy_pipeline,submodel=DMType.type1) as part2:
|
||||
assert type(part1)==DummyModelType1,'returned submodel is not of expected type'
|
||||
assert part1.device==torch.device('cuda'),'returned submodel should be in the GPU when in context'
|
||||
assert pipeline.type1==part1,'returned submodel should match the corresponding subpart of parent model'
|
||||
assert pipeline.type1!=part2,'returned submodel should not match the subpart of a different parent'
|
||||
|
||||
def test_transformer_fetch():
|
||||
model4 = cache.get_model('openai/clip-vit-large-patch14',SDModelType.tokenizer)
|
||||
assert isinstance(model4,CLIPTokenizer)
|
||||
|
||||
model5 = cache.get_model('openai/clip-vit-large-patch14',SDModelType.text_encoder)
|
||||
assert isinstance(model5,CLIPTextModel)
|
||||
|
||||
def test_subfolder_fetch():
|
||||
model6 = cache.get_model('stabilityai/stable-diffusion-2',SDModelType.tokenizer,subfolder="tokenizer")
|
||||
assert isinstance(model6,CLIPTokenizer)
|
||||
|
||||
model7 = cache.get_model('stabilityai/stable-diffusion-2',SDModelType.text_encoder,subfolder="text_encoder")
|
||||
assert isinstance(model7,CLIPTextModel)
|
||||
|
Loading…
Reference in New Issue
Block a user