generalize model loading support, include loras/embeds

This commit is contained in:
Lincoln Stein 2023-05-06 15:58:44 -04:00
parent a8cfa3565c
commit 05a27bda5e
2 changed files with 182 additions and 43 deletions

View File

@ -20,13 +20,13 @@ import contextlib
import gc import gc
import hashlib import hashlib
import warnings import warnings
from collections.abc import Generator
from collections import Counter from collections import Counter
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Sequence, Union, Tuple, types from typing import Sequence, Union, Tuple, types
import torch import torch
import safetensors.torch
from diffusers import StableDiffusionPipeline, AutoencoderKL, SchedulerMixin, UNet2DConditionModel from diffusers import StableDiffusionPipeline, AutoencoderKL, SchedulerMixin, UNet2DConditionModel
from diffusers import logging as diffusers_logging from diffusers import logging as diffusers_logging
from diffusers.pipelines.stable_diffusion.safety_checker import \ from diffusers.pipelines.stable_diffusion.safety_checker import \
@ -46,13 +46,18 @@ MAX_MODELS = 4
# This is the mapping from the stable diffusion submodel dict key to the class # This is the mapping from the stable diffusion submodel dict key to the class
class SDModelType(Enum): class SDModelType(Enum):
diffusion_pipeline=StableDiffusionGeneratorPipeline # whole thing diffusion_pipeline=StableDiffusionGeneratorPipeline # whole thing
vae=AutoencoderKL # parts diffusers=StableDiffusionGeneratorPipeline # same thing
vae=AutoencoderKL # diffusers parts
text_encoder=CLIPTextModel text_encoder=CLIPTextModel
tokenizer=CLIPTokenizer tokenizer=CLIPTokenizer
unet=UNet2DConditionModel unet=UNet2DConditionModel
scheduler=SchedulerMixin scheduler=SchedulerMixin
safety_checker=StableDiffusionSafetyChecker safety_checker=StableDiffusionSafetyChecker
feature_extractor=CLIPFeatureExtractor feature_extractor=CLIPFeatureExtractor
# These are all loaded as dicts of tensors
lora=dict
textual_inversion=dict
ckpt=dict
class ModelStatus(Enum): class ModelStatus(Enum):
unknown='unknown' unknown='unknown'
@ -192,6 +197,15 @@ class ModelCache(object):
return self.ModelLocker(self, key, model, gpu_load) return self.ModelLocker(self, key, model, gpu_load)
def uncache_model(self, key: str):
'''Remove corresponding model from the cache'''
if key is not None and key in self.models:
with contextlib.suppress(ValueError):
del self.models[key]
del self.locked_models[key]
self.stack.remove(key)
self.loaded_models.remove(key)
class ModelLocker(object): class ModelLocker(object):
def __init__(self, cache, key, model, gpu_load): def __init__(self, cache, key, model, gpu_load):
self.gpu_load = gpu_load self.gpu_load = gpu_load
@ -301,7 +315,7 @@ class ModelCache(object):
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model :param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
''' '''
path = Path(repo_id_or_path) path = Path(repo_id_or_path)
return path.is_file() and path.suffix in [".ckpt",".safetensors"] return path.suffix in [".ckpt",".safetensors",".pt"]
@classmethod @classmethod
def scan_model(cls, model_name, checkpoint): def scan_model(cls, model_name, checkpoint):
@ -368,6 +382,8 @@ class ModelCache(object):
:param model_class: class of model to return, defaults to StableDiffusionGeneratorPIpeline :param model_class: class of model to return, defaults to StableDiffusionGeneratorPIpeline
:param legacy_info: a LegacyInfo object containing additional info needed to load a legacy ckpt :param legacy_info: a LegacyInfo object containing additional info needed to load a legacy ckpt
''' '''
# silence transformer and diffuser warnings
with SilenceWarnings():
if self.is_legacy_ckpt(repo_id_or_path): if self.is_legacy_ckpt(repo_id_or_path):
model = self._load_ckpt_from_storage(repo_id_or_path, legacy_info) model = self._load_ckpt_from_storage(repo_id_or_path, legacy_info)
else: else:
@ -404,8 +420,6 @@ class ModelCache(object):
if model_class in DiffusionClasses\ if model_class in DiffusionClasses\
else {} else {}
# silence transformer and diffuser warnings
with SilenceWarnings():
for rev in revisions: for rev in revisions:
try: try:
model = model_class.from_pretrained( model = model_class.from_pretrained(
@ -429,11 +443,14 @@ class ModelCache(object):
:param ckpt_path: string or Path pointing to the weights file (.ckpt or .safetensors) :param ckpt_path: string or Path pointing to the weights file (.ckpt or .safetensors)
:param legacy_info: LegacyInfo object containing paths to legacy config file and alternate vae if required :param legacy_info: LegacyInfo object containing paths to legacy config file and alternate vae if required
''' '''
assert legacy_info is not None if legacy_info is None or legacy_info.config_file is None:
if Path(ckpt_path).suffix == '.safetensors':
return safetensors.torch.load_file(ckpt_path)
else:
return torch.load(ckpt_path)
else:
# deferred loading to avoid circular import errors # deferred loading to avoid circular import errors
from .convert_ckpt_to_diffusers import load_pipeline_from_original_stable_diffusion_ckpt from .convert_ckpt_to_diffusers import load_pipeline_from_original_stable_diffusion_ckpt
with SilenceWarnings():
pipeline = load_pipeline_from_original_stable_diffusion_ckpt( pipeline = load_pipeline_from_original_stable_diffusion_ckpt(
checkpoint_path=ckpt_path, checkpoint_path=ckpt_path,
original_config_file=legacy_info.config_file, original_config_file=legacy_info.config_file,
@ -446,7 +463,7 @@ class ModelCache(object):
def _legacy_model_hash(self, checkpoint_path: Union[str,Path])->str: def _legacy_model_hash(self, checkpoint_path: Union[str,Path])->str:
sha = hashlib.sha256() sha = hashlib.sha256()
path = Path(checkpoint_path) path = Path(checkpoint_path)
assert path.is_file() assert path.is_file(),f"File {checkpoint_path} not found"
hashpath = path.parent / f"{path.name}.sha256" hashpath = path.parent / f"{path.name}.sha256"
if hashpath.exists() and path.stat().st_mtime <= hashpath.stat().st_mtime: if hashpath.exists() and path.stat().st_mtime <= hashpath.stat().st_mtime:

View File

@ -32,6 +32,55 @@ the manager's "cache" attribute.
Other methods provided by ModelManager support importing, editing, Other methods provided by ModelManager support importing, editing,
converting and deleting models. converting and deleting models.
The general format of a models.yaml section is:
name-of-model:
format: diffusers|ckpt|vae|text_encoder|tokenizer...
repo_id: owner/repo
path: /path/to/local/file/or/directory
subfolder: subfolder-name
submodel: vae|text_encoder|tokenizer...
The format is one of {diffusers, ckpt, vae, text_encoder, tokenizer,
unet, scheduler, safety_checker, feature_extractor}, and correspond to
items in the SDModelType enum defined in model_cache.py
One, but not both, of repo_id and path are provided. repo_id is the
HuggingFace repository ID of the model, and path points to the file or
directory on disk.
If subfolder is provided, then the model exists in a subdirectory of
the main model. These are usually named after the model type, such as
"unet".
Finally, if submodel is provided, then the path/repo_id is treated as
a diffusers model, the whole thing is ready into memory, and then the
requested part (e.g. "unet") is retrieved.
This summarizes the three ways of getting a non-diffuser model:
clip-test-1:
format: text_encoder
repo_id: openai/clip-vit-large-patch14
description: Returns standalone CLIPTextModel
clip-test-2:
format: diffusers
repo_id: stabilityai/stable-diffusion-2
submodel: text_encoder
description: Returns the text_encoder part of whole diffusers model (whole thing in RAM)
clip-test-3:
format: text_encoder
repo_id: stabilityai/stable-diffusion-2
subfolder: text_encoder
description: Returns the text_encoder in the subfolder of the diffusers model (just the encoder in RAM)
clip-token:
format: tokenizer
repo_id: openai/clip-vit-large-patch14
description: Returns standalone tokenizer
""" """
from __future__ import annotations from __future__ import annotations
@ -53,14 +102,14 @@ from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig from omegaconf.dictconfig import DictConfig
from invokeai.backend.globals import Globals, global_cache_dir, global_resolve_path from invokeai.backend.globals import Globals, global_cache_dir, global_resolve_path
from .model_cache import Generator, ModelClass, ModelCache, SDModelType, ModelStatus, LegacyInfo from .model_cache import ModelClass, ModelCache, ModelLocker, SDModelType, ModelStatus, LegacyInfo
from ..util import CUDA_DEVICE from ..util import CUDA_DEVICE
# wanted to use pydantic here, but Generator objects not supported # wanted to use pydantic here, but Generator objects not supported
@dataclass @dataclass
class SDModelInfo(): class SDModelInfo():
context: Generator[ModelClass, None, None] context: ModelLocker
name: str name: str
hash: str hash: str
location: Union[Path,str] location: Union[Path,str]
@ -125,6 +174,7 @@ class ModelManager(object):
sequential_offload = sequential_offload, sequential_offload = sequential_offload,
logger = logger, logger = logger,
) )
self.cache_keys = dict()
self.logger = logger self.logger = logger
def valid_model(self, model_name: str) -> bool: def valid_model(self, model_name: str) -> bool:
@ -154,8 +204,12 @@ class ModelManager(object):
# get the required loading info out of the config file # get the required loading info out of the config file
mconfig = self.config[model_name] mconfig = self.config[model_name]
format = mconfig.get('format','diffusers') format = mconfig.get('format','diffusers')
model_type = SDModelType.diffusion_pipeline
model_parts = dict([(x.name,x) for x in SDModelType])
legacy = None legacy = None
if format=='ckpt': if format=='ckpt':
location = global_resolve_path(mconfig.weights) location = global_resolve_path(mconfig.weights)
legacy = LegacyInfo( legacy = LegacyInfo(
@ -165,14 +219,22 @@ class ModelManager(object):
legacy.vae_file = global_resolve_path(mconfig.vae) legacy.vae_file = global_resolve_path(mconfig.vae)
elif format=='diffusers': elif format=='diffusers':
location = mconfig.get('repo_id') or mconfig.get('path') location = mconfig.get('repo_id') or mconfig.get('path')
revision = mconfig.get('revision') if sm := mconfig.get('submodel'):
submodel = model_parts[sm]
elif format in model_parts:
location = mconfig.get('repo_id') or mconfig.get('path') or mconfig.get('weights')
model_type = model_parts[format]
else: else:
raise InvalidModelError( raise InvalidModelError(
f'"{model_name}" has an unknown format {format}' f'"{model_name}" has an unknown format {format}'
) )
subfolder = mconfig.get('subfolder') subfolder = mconfig.get('subfolder')
revision = mconfig.get('revision')
hash = self.cache.model_hash(location,revision) hash = self.cache.model_hash(location,revision)
# to support the traditional way of attaching a VAE
# to a model, we hacked in `attach_model_part`
vae = (None,None) vae = (None,None)
try: try:
vae_id = mconfig.vae.repo_id vae_id = mconfig.vae.repo_id
@ -181,12 +243,19 @@ class ModelManager(object):
pass pass
model_context = self.cache.get_model( model_context = self.cache.get_model(
location, location,
model_type = model_type,
revision = revision, revision = revision,
subfolder = subfolder, subfolder = subfolder,
legacy_info = legacy, legacy_info = legacy,
submodel = submodel, submodel = submodel,
attach_model_part=vae, attach_model_part=vae,
) )
# in case we need to communicate information about this
# model to the cache manager, then we need to remember
# the cache key
self.cache_keys[model_name] = model_context.key
return SDModelInfo( return SDModelInfo(
context = model_context, context = model_context,
name = model_name, name = model_name,
@ -365,6 +434,7 @@ class ModelManager(object):
attributes are incorrect or the model name is missing. attributes are incorrect or the model name is missing.
""" """
omega = self.config omega = self.config
assert "format" in model_attributes, 'missing required field "format"' assert "format" in model_attributes, 'missing required field "format"'
if model_attributes["format"] == "diffusers": if model_attributes["format"] == "diffusers":
assert ( assert (
@ -373,9 +443,11 @@ class ModelManager(object):
assert ( assert (
"path" in model_attributes or "repo_id" in model_attributes "path" in model_attributes or "repo_id" in model_attributes
), 'model must have either the "path" or "repo_id" fields defined' ), 'model must have either the "path" or "repo_id" fields defined'
else: elif model_attributes["format"] == "ckpt":
for field in ("description", "weights", "height", "width", "config"): for field in ("description", "weights", "height", "width", "config"):
assert field in model_attributes, f"required field {field} is missing" assert field in model_attributes, f"required field {field} is missing"
else:
assert "weights" in model_attributes and "description" in model_attributes
assert ( assert (
clobber or model_name not in omega clobber or model_name not in omega
@ -386,8 +458,9 @@ class ModelManager(object):
if "weights" in omega[model_name]: if "weights" in omega[model_name]:
omega[model_name]["weights"].replace("\\", "/") omega[model_name]["weights"].replace("\\", "/")
if clobber: if clobber and model_name in self.cache_keys:
self._invalidate_cached_model(model_name) self.cache.uncache_model(self.cache_keys[model_name])
del self.cache_keys[model_name]
def import_diffuser_model( def import_diffuser_model(
self, self,
@ -425,6 +498,55 @@ class ModelManager(object):
self.commit(commit_to_conf) self.commit(commit_to_conf)
return model_name return model_name
def import_lora(
self,
path: Path,
model_name: str=None,
description: str=None,
):
"""
Creates an entry for the indicated lora file. Call
mgr.commit() to write out the configuration to models.yaml
"""
path = Path(path)
model_name = model_name or path.stem
model_description = description or f"LoRA model {model_name}"
self.add_model(model_name,
dict(
format="lora",
weights=str(path),
description=model_description,
),
True
)
def import_embedding(
self,
path: Path,
model_name: str=None,
description: str=None,
):
"""
Creates an entry for the indicated lora file. Call
mgr.commit() to write out the configuration to models.yaml
"""
path = Path(path)
if path.is_directory() and (path / "learned_embeds.bin").exists():
weights = path / "learned_embeds.bin"
else:
weights = path
model_name = model_name or path.stem
model_description = description or f"Textual embedding model {model_name}"
self.add_model(model_name,
dict(
format="textual_inversion",
weights=str(path),
description=model_description,
),
True
)
@classmethod @classmethod
def probe_model_type(self, checkpoint: dict) -> SDLegacyType: def probe_model_type(self, checkpoint: dict) -> SDLegacyType:
""" """