mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
implement hashing for local & remote models
This commit is contained in:
parent
2e2abf6ea6
commit
bb959448c1
@ -20,12 +20,14 @@ import warnings
|
||||
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from pydantic import BaseModel
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
UNet2DConditionModel,
|
||||
SchedulerMixin,
|
||||
logging as diffusers_logging,
|
||||
)
|
||||
from huggingface_hub import list_repo_refs,HfApi
|
||||
from transformers import(
|
||||
CLIPTokenizer,
|
||||
CLIPFeatureExtractor,
|
||||
@ -36,10 +38,11 @@ from huggingface_hub import scan_cache_dir
|
||||
from picklescan.scanner import scan_file_path
|
||||
from typing import Sequence, Union
|
||||
|
||||
from invokeai.backend.globals import Globals, global_cache_dir
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import (
|
||||
StableDiffusionSafetyChecker,
|
||||
)
|
||||
from . import load_pipeline_from_original_stable_diffusion_ckpt
|
||||
from ..globals import Globals, global_cache_dir
|
||||
from ..stable_diffusion import (
|
||||
StableDiffusionGeneratorPipeline,
|
||||
)
|
||||
@ -59,9 +62,14 @@ class SDModelType(Enum):
|
||||
safety_checker=StableDiffusionSafetyChecker
|
||||
feature_extractor=CLIPFeatureExtractor
|
||||
|
||||
# List the model classes we know how to fetch
|
||||
# The list of model classes we know how to fetch, for typechecking
|
||||
ModelClass = Union[tuple([x.value for x in SDModelType])]
|
||||
|
||||
# Legacy information needed to load a legacy checkpoint file
|
||||
class LegacyInfo(BaseModel):
|
||||
config_file: Path
|
||||
vae_file: Path
|
||||
|
||||
class ModelCache(object):
|
||||
def __init__(
|
||||
self,
|
||||
@ -69,7 +77,15 @@ class ModelCache(object):
|
||||
execution_device: torch.device=torch.device('cuda'),
|
||||
precision: torch.dtype=torch.float16,
|
||||
sequential_offload: bool=False,
|
||||
sha_chunksize: int = 16777216,
|
||||
):
|
||||
'''
|
||||
:param max_models_cached: Maximum number of models to cache in CPU RAM [4]
|
||||
:param execution_device: Torch device to load active model into [torch.device('cuda')]
|
||||
:param precision: Precision for loaded models [torch.float16]
|
||||
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
|
||||
:param sha_chunksize: Chunksize to use when calculating sha256 model hash
|
||||
'''
|
||||
self.model_group: ModelGroup=FullyLoadedModelGroup(execution_device)
|
||||
self.models: dict = dict()
|
||||
self.stack: Sequence = list()
|
||||
@ -77,6 +93,7 @@ class ModelCache(object):
|
||||
self.precision: torch.dtype=precision
|
||||
self.max_models_cached: int=max_models_cached
|
||||
self.device: torch.device=execution_device
|
||||
self.sha_chunksize=sha_chunksize
|
||||
|
||||
def get_submodel(
|
||||
self,
|
||||
@ -84,7 +101,16 @@ class ModelCache(object):
|
||||
submodel: SDModelType=SDModelType.vae,
|
||||
subfolder: Path=None,
|
||||
revision: str=None,
|
||||
legacy_info: LegacyInfo=None,
|
||||
)->ModelClass:
|
||||
'''
|
||||
Load and return a HuggingFace model, with RAM caching.
|
||||
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
|
||||
:param submodel: an SDModelType enum indicating the model part to return, e.g. SDModelType.vae
|
||||
:param subfolder: name of a subfolder in which the model can be found, e.g. "vae"
|
||||
:param revision: model revision name
|
||||
:param legacy_info: a LegacyInfo object containing additional info needed to load a legacy ckpt
|
||||
'''
|
||||
parent_model = self.get_model(
|
||||
repo_id_or_path=repo_id_or_path,
|
||||
subfolder=subfolder,
|
||||
@ -98,6 +124,7 @@ class ModelCache(object):
|
||||
model_type: SDModelType=SDModelType.diffusion_pipeline,
|
||||
subfolder: Path=None,
|
||||
revision: str=None,
|
||||
legacy_info: LegacyInfo=None,
|
||||
)->ModelClass:
|
||||
'''
|
||||
Load and return a HuggingFace model, with RAM caching.
|
||||
@ -105,13 +132,14 @@ class ModelCache(object):
|
||||
:param subfolder: name of a subfolder in which the model can be found, e.g. "vae"
|
||||
:param revision: model revision
|
||||
:param model_class: class of model to return
|
||||
:param legacy_info: a LegacyInfo object containing additional info needed to load a legacy ckpt
|
||||
'''
|
||||
key = self._model_key(
|
||||
key = self._model_key( # internal unique identifier for the model
|
||||
repo_id_or_path,
|
||||
model_type.value,
|
||||
revision,
|
||||
subfolder
|
||||
) # internal unique identifier for the model
|
||||
)
|
||||
if key in self.models: # cached - move to bottom of stack
|
||||
previous_key = self._current_model_key
|
||||
with contextlib.suppress(ValueError):
|
||||
@ -119,19 +147,19 @@ class ModelCache(object):
|
||||
self.stack.append(key)
|
||||
if previous_key != key:
|
||||
if hasattr(self.current_model,'to'):
|
||||
print(f'DEBUG: loading {key} into GPU')
|
||||
print(f' | loading {key} into GPU')
|
||||
self.model_group.offload_current()
|
||||
self.model_group.load(self.models[key])
|
||||
|
||||
else: # not cached -load
|
||||
self._make_cache_room()
|
||||
self.model_group.offload_current()
|
||||
print(f'DEBUG: loading {key} from disk/net')
|
||||
print(f' | loading model {key} from disk/net')
|
||||
model = self._load_model_from_storage(
|
||||
repo_id_or_path=repo_id_or_path,
|
||||
model_class=model_type.value,
|
||||
subfolder=subfolder,
|
||||
revision=revision,
|
||||
legacy_info=legacy_info,
|
||||
)
|
||||
if hasattr(model,'to'):
|
||||
self.model_group.install(model) # register with the model group
|
||||
@ -139,6 +167,22 @@ class ModelCache(object):
|
||||
self.models[key]=model # keep copy of model in dict
|
||||
return self.models[key]
|
||||
|
||||
@staticmethod
|
||||
def model_hash(repo_id_or_path: Union[str,Path],
|
||||
revision: str=None)->str:
|
||||
'''
|
||||
Given the HF repo id or path to a model on disk, returns a unique
|
||||
hash. Works for legacy checkpoint files, HF models on disk, and HF repo IDs
|
||||
:param repo_id_or_path: repo_id string or Path to model file/directory on disk.
|
||||
:param revision: optional revision string (if fetching a HF repo_id)
|
||||
'''
|
||||
if self.is_legacy_ckpt(repo_id_or_path):
|
||||
return self._legacy_model_hash(repo_id_or_path)
|
||||
elif Path(repo_id_or_path).is_dir():
|
||||
return self._local_model_hash(repo_id_or_path)
|
||||
else:
|
||||
return self._hf_commit_hash(repo_id_or_path,revision)
|
||||
|
||||
@staticmethod
|
||||
def _model_key(path,model_class,revision,subfolder)->str:
|
||||
return ':'.join([str(path),model_class.__name__,str(revision or ''),str(subfolder or '')])
|
||||
@ -147,7 +191,7 @@ class ModelCache(object):
|
||||
models_in_ram = len(self.models)
|
||||
while models_in_ram >= self.max_models_cached:
|
||||
if least_recently_used_key := self.stack.pop(0):
|
||||
print(f'DEBUG: maximum cache size reached: cache_size={models_in_ram}; unloading model {least_recently_used_key}')
|
||||
print(f' | maximum cache size reached: cache_size={models_in_ram}; unloading model {least_recently_used_key}')
|
||||
self.model_group.uninstall(self.models[least_recently_used_key])
|
||||
del self.models[least_recently_used_key]
|
||||
models_in_ram = len(self.models)
|
||||
@ -173,28 +217,135 @@ class ModelCache(object):
|
||||
subfolder: Path=None,
|
||||
revision: str=None,
|
||||
model_class: ModelClass=StableDiffusionGeneratorPipeline,
|
||||
legacy_info: LegacyInfo=None,
|
||||
)->ModelClass:
|
||||
'''
|
||||
Load and return a HuggingFace model.
|
||||
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
|
||||
:param subfolder: name of a subfolder in which the model can be found, e.g. "vae"
|
||||
:param revision: model revision
|
||||
:param model_class: class of model to return
|
||||
:param model_class: class of model to return, defaults to StableDiffusionGeneratorPIpeline
|
||||
:param legacy_info: a LegacyInfo object containing additional info needed to load a legacy ckpt
|
||||
'''
|
||||
# silence transformer and diffuser warnings
|
||||
with SilenceWarnings():
|
||||
model = model_class.from_pretrained(
|
||||
repo_id_or_path,
|
||||
revision=revision,
|
||||
subfolder=subfolder or '.',
|
||||
cache_dir=global_cache_dir('hub'),
|
||||
)
|
||||
if self.is_legacy_ckpt(repo_id_or_path):
|
||||
model = self._load_ckpt_from_storage(repo_id_or_path, legacy_info)
|
||||
else:
|
||||
model = self._load_diffusers_from_storage(
|
||||
repo_id_or_path,
|
||||
subfolder,
|
||||
revision,
|
||||
model_class,
|
||||
)
|
||||
if self.sequential_offload and isinstance(model,StableDiffusionGeneratorPipeline):
|
||||
model.enable_offload_submodels(self.device)
|
||||
elif hasattr(model,'to'):
|
||||
model.to(self.device)
|
||||
return model
|
||||
|
||||
def _load_diffusers_from_storage(
|
||||
self,
|
||||
repo_id_or_path: Union[str,Path],
|
||||
subfolder: Path=None,
|
||||
revision: str=None,
|
||||
model_class: ModelClass=StableDiffusionGeneratorPipeline,
|
||||
)->ModelClass:
|
||||
'''
|
||||
Load and return a HuggingFace model using from_pretrained().
|
||||
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
|
||||
:param subfolder: name of a subfolder in which the model can be found, e.g. "vae"
|
||||
:param revision: model revision
|
||||
:param model_class: class of model to return, defaults to StableDiffusionGeneratorPIpeline
|
||||
'''
|
||||
return model_class.from_pretrained(
|
||||
repo_id_or_path,
|
||||
revision=revision,
|
||||
subfolder=subfolder or '.',
|
||||
cache_dir=global_cache_dir('hub'),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_legacy_ckpt(cls, repo_id_or_path: Union[str,Path])->bool:
|
||||
'''
|
||||
Return true if the indicated path is a legacy checkpoint
|
||||
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
|
||||
'''
|
||||
path = Path(repo_id_or_path)
|
||||
return path.is_file() and path.suffix in [".ckpt",".safetensors"]
|
||||
|
||||
def _load_ckpt_from_storage(self,
|
||||
ckpt_path: Union[str,Path],
|
||||
legacy_info:LegacyInfo)->StableDiffusionGeneratorPipeline:
|
||||
'''
|
||||
Load a legacy checkpoint, convert it, and return a StableDiffusionGeneratorPipeline.
|
||||
:param ckpt_path: string or Path pointing to the weights file (.ckpt or .safetensors)
|
||||
:param legacy_info: LegacyInfo object containing paths to legacy config file and alternate vae if required
|
||||
'''
|
||||
assert legacy_info is not None
|
||||
pipeline = load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
checkpoint_path=ckpt_path,
|
||||
original_config_file=legacy_info.config_file,
|
||||
vae_path=legacy_info.vae_file,
|
||||
return_generator_pipeline=True,
|
||||
precision=self.precision,
|
||||
)
|
||||
return pipeline
|
||||
|
||||
def _legacy_model_hash(self, checkpoint_path: Union[str,Path])->str:
|
||||
sha = hashlib.sha256()
|
||||
path = Path(checkpoint_path)
|
||||
assert path.is_file()
|
||||
|
||||
hashpath = path.parent / f"{path.name}.sha256"
|
||||
if hashpath.exists() and path.stat().st_mtime <= hashpath.stat().st_mtime:
|
||||
with open(hashpath) as f:
|
||||
hash = f.read()
|
||||
return hash
|
||||
|
||||
print(f' | computing hash of model {path.name}')
|
||||
with open(path, "rb") as f:
|
||||
while chunk := f.read(self.sha_chunksize):
|
||||
sha.update(chunk)
|
||||
hash = sha.hexdigest()
|
||||
|
||||
with open(hashpath, "w") as f:
|
||||
f.write(hash)
|
||||
return hash
|
||||
|
||||
def _local_model_hash(self, model_path: Union[str,Path])->str:
|
||||
sha = hashlib.sha256()
|
||||
path = Path(model_path)
|
||||
|
||||
hashpath = path / "checksum.sha256"
|
||||
if hashpath.exists() and path.stat().st_mtime <= hashpath.stat().st_mtime:
|
||||
with open(hashpath) as f:
|
||||
hash = f.read()
|
||||
return hash
|
||||
|
||||
print(f' | computing hash of model {path.name}')
|
||||
for file in list(path.rglob("*.ckpt")) \
|
||||
+ list(path.rglob("*.safetensors")) \
|
||||
+ list(path.rglob("*.pth")):
|
||||
with open(file, "rb") as f:
|
||||
while chunk := f.read(self.sha_chunksize):
|
||||
sha.update(chunk)
|
||||
hash = sha.hexdigest()
|
||||
with open(hashpath, "w") as f:
|
||||
f.write(hash)
|
||||
return hash
|
||||
|
||||
def _hf_commit_hash(self, repo_id: str, revision: str='main')->str:
|
||||
api = HfApi()
|
||||
info = api.list_repo_refs(
|
||||
repo_id=repo_id,
|
||||
repo_type='model',
|
||||
)
|
||||
desired_revisions = [branch for branch in info.branches if branch.name==revision]
|
||||
if not desired_revisions:
|
||||
raise KeyError(f"Revision '{revision}' not found in {repo_id}")
|
||||
return desired_revisions[0].target_commit
|
||||
|
||||
class SilenceWarnings(object):
|
||||
def __init__(self):
|
||||
self.transformers_verbosity = transformers_logging.get_verbosity()
|
||||
|
Loading…
Reference in New Issue
Block a user