work on model cache and its regression test finished

This commit is contained in:
Lincoln Stein 2023-05-03 12:38:18 -04:00
parent bb959448c1
commit e1fed52c66
2 changed files with 188 additions and 130 deletions

View File

@ -1,55 +1,46 @@
""" """
Manage a cache of Stable Diffusion model files for fast switching. Manage a RAM cache of diffusion/transformer models for fast switching.
They are moved between GPU and CPU as necessary. If the cache They are moved between GPU VRAM and CPU RAM as necessary. If the cache
grows larger than a preset maximum, then the least recently used grows larger than a preset maximum, then the least recently used
model will be cleared and (re)loaded from disk when next needed. model will be cleared and (re)loaded from disk when next needed.
The cache returns context manager generators designed to load the
model into the GPU within the context, and unload outside the
context. Use like this:
cache = ModelCache(max_models_cached=6)
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1,
cache.get_model('stabilityai/stable-diffusion-2') as SD2:
do_something_in_GPU(SD1,SD2)
""" """
import contextlib import contextlib
import hashlib
import gc import gc
import time import hashlib
import os
import psutil
import safetensors
import safetensors.torch
import torch
import transformers
import warnings import warnings
from collections.abc import Generator
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from pydantic import BaseModel
from diffusers import (
AutoencoderKL,
UNet2DConditionModel,
SchedulerMixin,
logging as diffusers_logging,
)
from huggingface_hub import list_repo_refs,HfApi
from transformers import(
CLIPTokenizer,
CLIPFeatureExtractor,
CLIPTextModel,
logging as transformers_logging,
)
from huggingface_hub import scan_cache_dir
from picklescan.scanner import scan_file_path
from typing import Sequence, Union from typing import Sequence, Union
from diffusers.pipelines.stable_diffusion.safety_checker import ( import torch
StableDiffusionSafetyChecker, from diffusers import AutoencoderKL, SchedulerMixin, UNet2DConditionModel
) from diffusers import logging as diffusers_logging
from . import load_pipeline_from_original_stable_diffusion_ckpt from diffusers.pipelines.stable_diffusion.safety_checker import \
from ..globals import Globals, global_cache_dir StableDiffusionSafetyChecker
from ..stable_diffusion import ( from huggingface_hub import HfApi
StableDiffusionGeneratorPipeline, from picklescan.scanner import scan_file_path
) from pydantic import BaseModel
from ..stable_diffusion.offloading import ModelGroup, FullyLoadedModelGroup from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ..util import CUDA_DEVICE, ask_user, download_with_resume from transformers import logging as transformers_logging
MAX_MODELS_CACHED = 4 from ..globals import global_cache_dir
from ..stable_diffusion import StableDiffusionGeneratorPipeline
from . import load_pipeline_from_original_stable_diffusion_ckpt
MAX_MODELS = 4
# This is the mapping from the stable diffusion submodel dict key to the class # This is the mapping from the stable diffusion submodel dict key to the class
class SDModelType(Enum): class SDModelType(Enum):
@ -68,70 +59,69 @@ ModelClass = Union[tuple([x.value for x in SDModelType])]
# Legacy information needed to load a legacy checkpoint file # Legacy information needed to load a legacy checkpoint file
class LegacyInfo(BaseModel): class LegacyInfo(BaseModel):
config_file: Path config_file: Path
vae_file: Path vae_file: Path = None
class UnsafeModelException(Exception):
"Raised when a legacy model file fails the picklescan test"
pass
class UnscannableModelException(Exception):
"Raised when picklescan is unable to scan a legacy model file"
pass
class ModelCache(object): class ModelCache(object):
def __init__( def __init__(
self, self,
max_models_cached: int=MAX_MODELS_CACHED, max_models: int=MAX_MODELS,
execution_device: torch.device=torch.device('cuda'), execution_device: torch.device=torch.device('cuda'),
storage_device: torch.device=torch.device('cpu'),
precision: torch.dtype=torch.float16, precision: torch.dtype=torch.float16,
sequential_offload: bool=False, sequential_offload: bool=False,
sha_chunksize: int = 16777216, sha_chunksize: int = 16777216,
): ):
''' '''
:param max_models_cached: Maximum number of models to cache in CPU RAM [4] :param max_models: Maximum number of models to cache in CPU RAM [4]
:param execution_device: Torch device to load active model into [torch.device('cuda')] :param execution_device: Torch device to load active model into [torch.device('cuda')]
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
:param precision: Precision for loaded models [torch.float16] :param precision: Precision for loaded models [torch.float16]
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially :param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
:param sha_chunksize: Chunksize to use when calculating sha256 model hash :param sha_chunksize: Chunksize to use when calculating sha256 model hash
''' '''
self.model_group: ModelGroup=FullyLoadedModelGroup(execution_device)
self.models: dict = dict() self.models: dict = dict()
self.stack: Sequence = list() self.stack: Sequence = list()
self.sequential_offload: bool=sequential_offload self.sequential_offload: bool=sequential_offload
self.precision: torch.dtype=precision self.precision: torch.dtype=precision
self.max_models_cached: int=max_models_cached self.max_models: int=max_models
self.device: torch.device=execution_device self.execution_device: torch.device=execution_device
self.storage_device: torch.device=storage_device
self.sha_chunksize=sha_chunksize self.sha_chunksize=sha_chunksize
def get_submodel( @contextlib.contextmanager
self,
repo_id_or_path: Union[str,Path],
submodel: SDModelType=SDModelType.vae,
subfolder: Path=None,
revision: str=None,
legacy_info: LegacyInfo=None,
)->ModelClass:
'''
Load and return a HuggingFace model, with RAM caching.
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
:param submodel: an SDModelType enum indicating the model part to return, e.g. SDModelType.vae
:param subfolder: name of a subfolder in which the model can be found, e.g. "vae"
:param revision: model revision name
:param legacy_info: a LegacyInfo object containing additional info needed to load a legacy ckpt
'''
parent_model = self.get_model(
repo_id_or_path=repo_id_or_path,
subfolder=subfolder,
revision=revision,
)
return getattr(parent_model, submodel.name)
def get_model( def get_model(
self, self,
repo_id_or_path: Union[str,Path], repo_id_or_path: Union[str,Path],
model_type: SDModelType=SDModelType.diffusion_pipeline, model_type: SDModelType=SDModelType.diffusion_pipeline,
subfolder: Path=None, subfolder: Path=None,
submodel: SDModelType=None,
revision: str=None, revision: str=None,
legacy_info: LegacyInfo=None, legacy_info: LegacyInfo=None,
)->ModelClass: gpu_load: bool=True,
)->Generator[ModelClass, None, None]:
''' '''
Load and return a HuggingFace model, with RAM caching. Load and return a HuggingFace model wrapped in a context manager generator, with RAM caching.
Use like this:
cache = ModelCache()
with cache.get_model('stabilityai/stable-diffusion-2') as SD2:
do_something_with_the_model(SD2)
The model will be locked into GPU VRAM for the duration of the context.
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model :param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
:param subfolder: name of a subfolder in which the model can be found, e.g. "vae" :param subfolder: name of a subfolder in which the model can be found, e.g. "vae"
:param submodel: an SDModelType enum indicating the model part to return, e.g. SDModelType.vae
:param revision: model revision :param revision: model revision
:param model_class: class of model to return :param model_class: class of model to return
:param gpu_load: load the model into GPU [default True]
:param legacy_info: a LegacyInfo object containing additional info needed to load a legacy ckpt :param legacy_info: a LegacyInfo object containing additional info needed to load a legacy ckpt
''' '''
key = self._model_key( # internal unique identifier for the model key = self._model_key( # internal unique identifier for the model
@ -141,19 +131,12 @@ class ModelCache(object):
subfolder subfolder
) )
if key in self.models: # cached - move to bottom of stack if key in self.models: # cached - move to bottom of stack
previous_key = self._current_model_key
with contextlib.suppress(ValueError): with contextlib.suppress(ValueError):
self.stack.remove(key) self.stack.remove(key)
self.stack.append(key) self.stack.append(key)
if previous_key != key: model = self.models[key]
if hasattr(self.current_model,'to'):
print(f' | loading {key} into GPU')
self.model_group.offload_current()
self.model_group.load(self.models[key])
else: # not cached -load else: # not cached -load
self._make_cache_room() self._make_cache_room()
self.model_group.offload_current()
print(f' | loading model {key} from disk/net')
model = self._load_model_from_storage( model = self._load_model_from_storage(
repo_id_or_path=repo_id_or_path, repo_id_or_path=repo_id_or_path,
model_class=model_type.value, model_class=model_type.value,
@ -161,14 +144,29 @@ class ModelCache(object):
revision=revision, revision=revision,
legacy_info=legacy_info, legacy_info=legacy_info,
) )
if hasattr(model,'to'):
self.model_group.install(model) # register with the model group
self.stack.append(key) # add to LRU cache self.stack.append(key) # add to LRU cache
self.models[key]=model # keep copy of model in dict self.models[key]=model # keep copy of model in dict
return self.models[key]
@staticmethod if submodel:
def model_hash(repo_id_or_path: Union[str,Path], model = getattr(model, submodel.name)
debugging_name = f'{submodel.name} submodel of {repo_id_or_path}'
else:
debugging_name = repo_id_or_path
try:
if gpu_load and hasattr(model,'to'):
print(f' | Loading {debugging_name} into GPU')
model.to(self.execution_device) # move into GPU
self._print_cuda_stats()
yield model
finally:
if gpu_load and hasattr(model,'to'):
print(f' | Unloading {debugging_name} from GPU')
model.to(self.storage_device)
self._print_cuda_stats()
def model_hash(self,
repo_id_or_path: Union[str,Path],
revision: str=None)->str: revision: str=None)->str:
''' '''
Given the HF repo id or path to a model on disk, returns a unique Given the HF repo id or path to a model on disk, returns a unique
@ -183,16 +181,28 @@ class ModelCache(object):
else: else:
return self._hf_commit_hash(repo_id_or_path,revision) return self._hf_commit_hash(repo_id_or_path,revision)
def cache_size(self)->int:
"Return the current number of models cached."
return len(self.models)
@staticmethod @staticmethod
def _model_key(path,model_class,revision,subfolder)->str: def _model_key(path,model_class,revision,subfolder)->str:
return ':'.join([str(path),model_class.__name__,str(revision or ''),str(subfolder or '')]) return ':'.join([str(path),model_class.__name__,str(revision or ''),str(subfolder or '')])
def _has_cuda(self)->bool:
return self.execution_device.type == 'cuda'
def _print_cuda_stats(self):
print(
" | Current VRAM usage:",
"%4.2fG" % (torch.cuda.memory_allocated() / 1e9),
)
def _make_cache_room(self): def _make_cache_room(self):
models_in_ram = len(self.models) models_in_ram = len(self.models)
while models_in_ram >= self.max_models_cached: while models_in_ram >= self.max_models:
if least_recently_used_key := self.stack.pop(0): if least_recently_used_key := self.stack.pop(0):
print(f' | maximum cache size reached: cache_size={models_in_ram}; unloading model {least_recently_used_key}') print(f' | Maximum cache size reached: cache_size={models_in_ram}; unloading model {least_recently_used_key}')
self.model_group.uninstall(self.models[least_recently_used_key])
del self.models[least_recently_used_key] del self.models[least_recently_used_key]
models_in_ram = len(self.models) models_in_ram = len(self.models)
gc.collect() gc.collect()
@ -239,9 +249,9 @@ class ModelCache(object):
model_class, model_class,
) )
if self.sequential_offload and isinstance(model,StableDiffusionGeneratorPipeline): if self.sequential_offload and isinstance(model,StableDiffusionGeneratorPipeline):
model.enable_offload_submodels(self.device) model.enable_offload_submodels(self.execution_device)
elif hasattr(model,'to'): elif hasattr(model,'to'):
model.to(self.device) model.to(self.execution_device)
return model return model
def _load_diffusers_from_storage( def _load_diffusers_from_storage(
@ -274,6 +284,23 @@ class ModelCache(object):
path = Path(repo_id_or_path) path = Path(repo_id_or_path)
return path.is_file() and path.suffix in [".ckpt",".safetensors"] return path.is_file() and path.suffix in [".ckpt",".safetensors"]
@classmethod
def scan_model(cls, model_name, checkpoint):
"""
Apply picklescanner to the indicated checkpoint and issue a warning
and option to exit if an infected file is identified.
"""
# scan model
print(f" | Scanning Model: {model_name}")
scan_result = scan_file_path(checkpoint)
if scan_result.infected_files != 0:
if scan_result.infected_files == 1:
raise UnsafeModelException("The legacy model you are trying to load may contain malware. Aborting.")
else:
raise UnscannableModelException("InvokeAI was unable to scan the legacy model you requested. Aborting")
else:
print(" | Model scanned ok")
def _load_ckpt_from_storage(self, def _load_ckpt_from_storage(self,
ckpt_path: Union[str,Path], ckpt_path: Union[str,Path],
legacy_info:LegacyInfo)->StableDiffusionGeneratorPipeline: legacy_info:LegacyInfo)->StableDiffusionGeneratorPipeline:

View File

@ -1,53 +1,84 @@
import pytest import pytest
import torch import torch
from invokeai.backend.model_management.model_cache import ModelCache, SDModelType from enum import Enum
from invokeai.backend.stable_diffusion import StableDiffusionGeneratorPipeline from invokeai.backend.model_management.model_cache import ModelCache
from diffusers import ( class DummyModelBase(object):
AutoencoderKL, '''Base class for dummy component of a diffusers model'''
UNet2DConditionModel, def __init__(self, repo_id):
SchedulerMixin, self.repo_id = repo_id
) self.device = torch.device('cpu')
from transformers import (
CLIPTokenizer,
CLIPFeatureExtractor,
CLIPTextModel,
)
@classmethod
def from_pretrained(cls,
repo_id:str,
revision:str=None,
subfolder:str=None,
cache_dir:str=None,
):
return cls(repo_id)
cache = ModelCache() def to(self, device):
self.device = device
class DummyModelType1(DummyModelBase):
pass
class DummyModelType2(DummyModelBase):
pass
class DummyPipeline(DummyModelBase):
'''Dummy pipeline object is a composite of several types'''
def __init__(self,repo_id):
super().__init__(repo_id)
self.type1 = DummyModelType1('dummy/type1')
self.type2 = DummyModelType2('dummy/type2')
class DMType(Enum):
dummy_pipeline = DummyPipeline
type1 = DummyModelType1
type2 = DummyModelType2
cache = ModelCache(max_models=4)
def test_pipeline_fetch(): def test_pipeline_fetch():
model0 = cache.get_model('stabilityai/sd-vae-ft-mse',SDModelType.vae) assert cache.cache_size()==0
model1 = cache.get_model('stabilityai/stable-diffusion-2-1',SDModelType.diffusion_pipeline) with cache.get_model('dummy/pipeline1',DMType.dummy_pipeline) as pipeline1,\
model1_2 = cache.get_model('stabilityai/stable-diffusion-2-1') cache.get_model('dummy/pipeline1',DMType.dummy_pipeline) as pipeline1a,\
assert model1==model1_2 cache.get_model('dummy/pipeline2',DMType.dummy_pipeline) as pipeline2:
assert model1.device==torch.device('cuda') assert pipeline1 is not None, 'get_model() should not return None'
model2 = cache.get_model('runwayml/stable-diffusion-v1-5') assert pipeline1a is not None, 'get_model() should not return None'
assert model2.device==torch.device('cuda') assert pipeline2 is not None, 'get_model() should not return None'
assert model1.device==torch.device('cpu') assert type(pipeline1)==DMType.dummy_pipeline.value,'get_model() did not return model of expected type'
model1 = cache.get_model('stabilityai/stable-diffusion-2-1') assert pipeline1==pipeline1a,'pipelines with the same repo_id should be the same'
assert model1.device==torch.device('cuda') assert pipeline1!=pipeline2,'pipelines with different repo_ids should not be the same'
assert cache.cache_size()==2,'cache should uniquely cache models with same identity'
with cache.get_model('dummy/pipeline3',DMType.dummy_pipeline) as pipeline3,\
cache.get_model('dummy/pipeline4',DMType.dummy_pipeline) as pipeline4:
assert cache.cache_size()==4,'cache did not grow as expected'
with cache.get_model('dummy/pipeline5',DMType.dummy_pipeline) as pipeline5:
assert cache.cache_size()==4,'cache did not free space as expected'
def test_signatures():
with cache.get_model('dummy/pipeline',DMType.dummy_pipeline,revision='main') as pipeline1,\
cache.get_model('dummy/pipeline',DMType.dummy_pipeline,revision='fp16') as pipeline2,\
cache.get_model('dummy/pipeline',DMType.dummy_pipeline,revision='main',subfolder='foo') as pipeline3:
assert pipeline1 != pipeline2,'models are distinguished by their revision'
assert pipeline1 != pipeline3,'models are distinguished by their subfolder'
def test_pipeline_device():
with cache.get_model('dummy/pipeline1',DMType.type1) as model1:
assert model1.device==torch.device('cuda'),'when in context, model device should be in GPU'
with cache.get_model('dummy/pipeline1',DMType.type1, gpu_load=False) as model1:
assert model1.device==torch.device('cpu'),'when gpu_load=False, model device should be CPU'
def test_submodel_fetch(): def test_submodel_fetch():
model1_vae = cache.get_submodel('stabilityai/stable-diffusion-2-1',SDModelType.vae) with cache.get_model(repo_id_or_path='dummy/pipeline1',model_type=DMType.dummy_pipeline) as pipeline,\
assert isinstance(model1_vae,AutoencoderKL) cache.get_model(repo_id_or_path='dummy/pipeline1',model_type=DMType.dummy_pipeline,submodel=DMType.type1) as part1,\
model1 = cache.get_model('stabilityai/stable-diffusion-2-1',SDModelType.diffusion_pipeline) cache.get_model(repo_id_or_path='dummy/pipeline2',model_type=DMType.dummy_pipeline,submodel=DMType.type1) as part2:
assert model1_vae == model1.vae assert type(part1)==DummyModelType1,'returned submodel is not of expected type'
model1_vae_2 = cache.get_submodel('stabilityai/stable-diffusion-2-1') assert part1.device==torch.device('cuda'),'returned submodel should be in the GPU when in context'
assert model1_vae == model1_vae_2 assert pipeline.type1==part1,'returned submodel should match the corresponding subpart of parent model'
assert pipeline.type1!=part2,'returned submodel should not match the subpart of a different parent'
def test_transformer_fetch():
model4 = cache.get_model('openai/clip-vit-large-patch14',SDModelType.tokenizer)
assert isinstance(model4,CLIPTokenizer)
model5 = cache.get_model('openai/clip-vit-large-patch14',SDModelType.text_encoder)
assert isinstance(model5,CLIPTextModel)
def test_subfolder_fetch():
model6 = cache.get_model('stabilityai/stable-diffusion-2',SDModelType.tokenizer,subfolder="tokenizer")
assert isinstance(model6,CLIPTokenizer)
model7 = cache.get_model('stabilityai/stable-diffusion-2',SDModelType.text_encoder,subfolder="text_encoder")
assert isinstance(model7,CLIPTextModel)