generalize model loading support, include loras/embeds

This commit is contained in:
Lincoln Stein 2023-05-06 15:58:44 -04:00
parent a8cfa3565c
commit 05a27bda5e
2 changed files with 182 additions and 43 deletions

View File

@ -20,13 +20,13 @@ import contextlib
import gc
import hashlib
import warnings
from collections.abc import Generator
from collections import Counter
from enum import Enum
from pathlib import Path
from typing import Sequence, Union, Tuple, types
import torch
import safetensors.torch
from diffusers import StableDiffusionPipeline, AutoencoderKL, SchedulerMixin, UNet2DConditionModel
from diffusers import logging as diffusers_logging
from diffusers.pipelines.stable_diffusion.safety_checker import \
@ -46,13 +46,18 @@ MAX_MODELS = 4
# This is the mapping from the stable diffusion submodel dict key to the class
class SDModelType(Enum):
diffusion_pipeline=StableDiffusionGeneratorPipeline # whole thing
vae=AutoencoderKL # parts
diffusers=StableDiffusionGeneratorPipeline # same thing
vae=AutoencoderKL # diffusers parts
text_encoder=CLIPTextModel
tokenizer=CLIPTokenizer
unet=UNet2DConditionModel
scheduler=SchedulerMixin
safety_checker=StableDiffusionSafetyChecker
feature_extractor=CLIPFeatureExtractor
# These are all loaded as dicts of tensors
lora=dict
textual_inversion=dict
ckpt=dict
class ModelStatus(Enum):
unknown='unknown'
@ -192,6 +197,15 @@ class ModelCache(object):
return self.ModelLocker(self, key, model, gpu_load)
def uncache_model(self, key: str):
'''Remove corresponding model from the cache'''
if key is not None and key in self.models:
with contextlib.suppress(ValueError):
del self.models[key]
del self.locked_models[key]
self.stack.remove(key)
self.loaded_models.remove(key)
class ModelLocker(object):
def __init__(self, cache, key, model, gpu_load):
self.gpu_load = gpu_load
@ -301,7 +315,7 @@ class ModelCache(object):
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
'''
path = Path(repo_id_or_path)
return path.is_file() and path.suffix in [".ckpt",".safetensors"]
return path.suffix in [".ckpt",".safetensors",".pt"]
@classmethod
def scan_model(cls, model_name, checkpoint):
@ -368,15 +382,17 @@ class ModelCache(object):
:param model_class: class of model to return, defaults to StableDiffusionGeneratorPIpeline
:param legacy_info: a LegacyInfo object containing additional info needed to load a legacy ckpt
'''
if self.is_legacy_ckpt(repo_id_or_path):
model = self._load_ckpt_from_storage(repo_id_or_path, legacy_info)
else:
model = self._load_diffusers_from_storage(
repo_id_or_path,
subfolder,
revision,
model_class,
)
# silence transformer and diffuser warnings
with SilenceWarnings():
if self.is_legacy_ckpt(repo_id_or_path):
model = self._load_ckpt_from_storage(repo_id_or_path, legacy_info)
else:
model = self._load_diffusers_from_storage(
repo_id_or_path,
subfolder,
revision,
model_class,
)
if self.sequential_offload and isinstance(model,StableDiffusionGeneratorPipeline):
model.enable_offload_submodels(self.execution_device)
return model
@ -404,21 +420,19 @@ class ModelCache(object):
if model_class in DiffusionClasses\
else {}
# silence transformer and diffuser warnings
with SilenceWarnings():
for rev in revisions:
try:
model = model_class.from_pretrained(
repo_id_or_path,
revision=rev,
subfolder=subfolder or '.',
cache_dir=global_cache_dir('hub'),
**extra_args,
)
self.logger.debug(f'Found revision {rev}')
break
except OSError:
pass
for rev in revisions:
try:
model = model_class.from_pretrained(
repo_id_or_path,
revision=rev,
subfolder=subfolder or '.',
cache_dir=global_cache_dir('hub'),
**extra_args,
)
self.logger.debug(f'Found revision {rev}')
break
except OSError:
pass
return model
def _load_ckpt_from_storage(self,
@ -429,24 +443,27 @@ class ModelCache(object):
:param ckpt_path: string or Path pointing to the weights file (.ckpt or .safetensors)
:param legacy_info: LegacyInfo object containing paths to legacy config file and alternate vae if required
'''
assert legacy_info is not None
# deferred loading to avoid circular import errors
from .convert_ckpt_to_diffusers import load_pipeline_from_original_stable_diffusion_ckpt
with SilenceWarnings():
if legacy_info is None or legacy_info.config_file is None:
if Path(ckpt_path).suffix == '.safetensors':
return safetensors.torch.load_file(ckpt_path)
else:
return torch.load(ckpt_path)
else:
# deferred loading to avoid circular import errors
from .convert_ckpt_to_diffusers import load_pipeline_from_original_stable_diffusion_ckpt
pipeline = load_pipeline_from_original_stable_diffusion_ckpt(
checkpoint_path=ckpt_path,
original_config_file=legacy_info.config_file,
vae_path=legacy_info.vae_file,
vae_path=legacy_info.vae_file,
return_generator_pipeline=True,
precision=self.precision,
)
return pipeline
return pipeline
def _legacy_model_hash(self, checkpoint_path: Union[str,Path])->str:
sha = hashlib.sha256()
path = Path(checkpoint_path)
assert path.is_file()
assert path.is_file(),f"File {checkpoint_path} not found"
hashpath = path.parent / f"{path.name}.sha256"
if hashpath.exists() and path.stat().st_mtime <= hashpath.stat().st_mtime:

View File

@ -31,7 +31,56 @@ generation operations. The ModelCache object can be accessed using
the manager's "cache" attribute.
Other methods provided by ModelManager support importing, editing,
converting and deleting models.
converting and deleting models.
The general format of a models.yaml section is:
name-of-model:
format: diffusers|ckpt|vae|text_encoder|tokenizer...
repo_id: owner/repo
path: /path/to/local/file/or/directory
subfolder: subfolder-name
submodel: vae|text_encoder|tokenizer...
The format is one of {diffusers, ckpt, vae, text_encoder, tokenizer,
unet, scheduler, safety_checker, feature_extractor}, and correspond to
items in the SDModelType enum defined in model_cache.py
One, but not both, of repo_id and path are provided. repo_id is the
HuggingFace repository ID of the model, and path points to the file or
directory on disk.
If subfolder is provided, then the model exists in a subdirectory of
the main model. These are usually named after the model type, such as
"unet".
Finally, if submodel is provided, then the path/repo_id is treated as
a diffusers model, the whole thing is ready into memory, and then the
requested part (e.g. "unet") is retrieved.
This summarizes the three ways of getting a non-diffuser model:
clip-test-1:
format: text_encoder
repo_id: openai/clip-vit-large-patch14
description: Returns standalone CLIPTextModel
clip-test-2:
format: diffusers
repo_id: stabilityai/stable-diffusion-2
submodel: text_encoder
description: Returns the text_encoder part of whole diffusers model (whole thing in RAM)
clip-test-3:
format: text_encoder
repo_id: stabilityai/stable-diffusion-2
subfolder: text_encoder
description: Returns the text_encoder in the subfolder of the diffusers model (just the encoder in RAM)
clip-token:
format: tokenizer
repo_id: openai/clip-vit-large-patch14
description: Returns standalone tokenizer
"""
from __future__ import annotations
@ -53,14 +102,14 @@ from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
from invokeai.backend.globals import Globals, global_cache_dir, global_resolve_path
from .model_cache import Generator, ModelClass, ModelCache, SDModelType, ModelStatus, LegacyInfo
from .model_cache import ModelClass, ModelCache, ModelLocker, SDModelType, ModelStatus, LegacyInfo
from ..util import CUDA_DEVICE
# wanted to use pydantic here, but Generator objects not supported
@dataclass
class SDModelInfo():
context: Generator[ModelClass, None, None]
context: ModelLocker
name: str
hash: str
location: Union[Path,str]
@ -125,6 +174,7 @@ class ModelManager(object):
sequential_offload = sequential_offload,
logger = logger,
)
self.cache_keys = dict()
self.logger = logger
def valid_model(self, model_name: str) -> bool:
@ -154,8 +204,12 @@ class ModelManager(object):
# get the required loading info out of the config file
mconfig = self.config[model_name]
format = mconfig.get('format','diffusers')
model_type = SDModelType.diffusion_pipeline
model_parts = dict([(x.name,x) for x in SDModelType])
legacy = None
if format=='ckpt':
location = global_resolve_path(mconfig.weights)
legacy = LegacyInfo(
@ -165,14 +219,22 @@ class ModelManager(object):
legacy.vae_file = global_resolve_path(mconfig.vae)
elif format=='diffusers':
location = mconfig.get('repo_id') or mconfig.get('path')
revision = mconfig.get('revision')
if sm := mconfig.get('submodel'):
submodel = model_parts[sm]
elif format in model_parts:
location = mconfig.get('repo_id') or mconfig.get('path') or mconfig.get('weights')
model_type = model_parts[format]
else:
raise InvalidModelError(
f'"{model_name}" has an unknown format {format}'
)
subfolder = mconfig.get('subfolder')
revision = mconfig.get('revision')
hash = self.cache.model_hash(location,revision)
# to support the traditional way of attaching a VAE
# to a model, we hacked in `attach_model_part`
vae = (None,None)
try:
vae_id = mconfig.vae.repo_id
@ -181,12 +243,19 @@ class ModelManager(object):
pass
model_context = self.cache.get_model(
location,
model_type = model_type,
revision = revision,
subfolder = subfolder,
legacy_info = legacy,
submodel = submodel,
attach_model_part=vae,
)
# in case we need to communicate information about this
# model to the cache manager, then we need to remember
# the cache key
self.cache_keys[model_name] = model_context.key
return SDModelInfo(
context = model_context,
name = model_name,
@ -365,6 +434,7 @@ class ModelManager(object):
attributes are incorrect or the model name is missing.
"""
omega = self.config
assert "format" in model_attributes, 'missing required field "format"'
if model_attributes["format"] == "diffusers":
assert (
@ -373,9 +443,11 @@ class ModelManager(object):
assert (
"path" in model_attributes or "repo_id" in model_attributes
), 'model must have either the "path" or "repo_id" fields defined'
else:
elif model_attributes["format"] == "ckpt":
for field in ("description", "weights", "height", "width", "config"):
assert field in model_attributes, f"required field {field} is missing"
else:
assert "weights" in model_attributes and "description" in model_attributes
assert (
clobber or model_name not in omega
@ -385,9 +457,10 @@ class ModelManager(object):
if "weights" in omega[model_name]:
omega[model_name]["weights"].replace("\\", "/")
if clobber:
self._invalidate_cached_model(model_name)
if clobber and model_name in self.cache_keys:
self.cache.uncache_model(self.cache_keys[model_name])
del self.cache_keys[model_name]
def import_diffuser_model(
self,
@ -425,6 +498,55 @@ class ModelManager(object):
self.commit(commit_to_conf)
return model_name
def import_lora(
self,
path: Path,
model_name: str=None,
description: str=None,
):
"""
Creates an entry for the indicated lora file. Call
mgr.commit() to write out the configuration to models.yaml
"""
path = Path(path)
model_name = model_name or path.stem
model_description = description or f"LoRA model {model_name}"
self.add_model(model_name,
dict(
format="lora",
weights=str(path),
description=model_description,
),
True
)
def import_embedding(
self,
path: Path,
model_name: str=None,
description: str=None,
):
"""
Creates an entry for the indicated lora file. Call
mgr.commit() to write out the configuration to models.yaml
"""
path = Path(path)
if path.is_directory() and (path / "learned_embeds.bin").exists():
weights = path / "learned_embeds.bin"
else:
weights = path
model_name = model_name or path.stem
model_description = description or f"Textual embedding model {model_name}"
self.add_model(model_name,
dict(
format="textual_inversion",
weights=str(path),
description=model_description,
),
True
)
@classmethod
def probe_model_type(self, checkpoint: dict) -> SDLegacyType:
"""