mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
generalize model loading support, include loras/embeds
This commit is contained in:
parent
a8cfa3565c
commit
05a27bda5e
@ -20,13 +20,13 @@ import contextlib
|
||||
import gc
|
||||
import hashlib
|
||||
import warnings
|
||||
from collections.abc import Generator
|
||||
from collections import Counter
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Sequence, Union, Tuple, types
|
||||
|
||||
import torch
|
||||
import safetensors.torch
|
||||
from diffusers import StableDiffusionPipeline, AutoencoderKL, SchedulerMixin, UNet2DConditionModel
|
||||
from diffusers import logging as diffusers_logging
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import \
|
||||
@ -46,13 +46,18 @@ MAX_MODELS = 4
|
||||
# This is the mapping from the stable diffusion submodel dict key to the class
|
||||
class SDModelType(Enum):
|
||||
diffusion_pipeline=StableDiffusionGeneratorPipeline # whole thing
|
||||
vae=AutoencoderKL # parts
|
||||
diffusers=StableDiffusionGeneratorPipeline # same thing
|
||||
vae=AutoencoderKL # diffusers parts
|
||||
text_encoder=CLIPTextModel
|
||||
tokenizer=CLIPTokenizer
|
||||
unet=UNet2DConditionModel
|
||||
scheduler=SchedulerMixin
|
||||
safety_checker=StableDiffusionSafetyChecker
|
||||
feature_extractor=CLIPFeatureExtractor
|
||||
# These are all loaded as dicts of tensors
|
||||
lora=dict
|
||||
textual_inversion=dict
|
||||
ckpt=dict
|
||||
|
||||
class ModelStatus(Enum):
|
||||
unknown='unknown'
|
||||
@ -192,6 +197,15 @@ class ModelCache(object):
|
||||
|
||||
return self.ModelLocker(self, key, model, gpu_load)
|
||||
|
||||
def uncache_model(self, key: str):
|
||||
'''Remove corresponding model from the cache'''
|
||||
if key is not None and key in self.models:
|
||||
with contextlib.suppress(ValueError):
|
||||
del self.models[key]
|
||||
del self.locked_models[key]
|
||||
self.stack.remove(key)
|
||||
self.loaded_models.remove(key)
|
||||
|
||||
class ModelLocker(object):
|
||||
def __init__(self, cache, key, model, gpu_load):
|
||||
self.gpu_load = gpu_load
|
||||
@ -301,7 +315,7 @@ class ModelCache(object):
|
||||
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
|
||||
'''
|
||||
path = Path(repo_id_or_path)
|
||||
return path.is_file() and path.suffix in [".ckpt",".safetensors"]
|
||||
return path.suffix in [".ckpt",".safetensors",".pt"]
|
||||
|
||||
@classmethod
|
||||
def scan_model(cls, model_name, checkpoint):
|
||||
@ -368,15 +382,17 @@ class ModelCache(object):
|
||||
:param model_class: class of model to return, defaults to StableDiffusionGeneratorPIpeline
|
||||
:param legacy_info: a LegacyInfo object containing additional info needed to load a legacy ckpt
|
||||
'''
|
||||
if self.is_legacy_ckpt(repo_id_or_path):
|
||||
model = self._load_ckpt_from_storage(repo_id_or_path, legacy_info)
|
||||
else:
|
||||
model = self._load_diffusers_from_storage(
|
||||
repo_id_or_path,
|
||||
subfolder,
|
||||
revision,
|
||||
model_class,
|
||||
)
|
||||
# silence transformer and diffuser warnings
|
||||
with SilenceWarnings():
|
||||
if self.is_legacy_ckpt(repo_id_or_path):
|
||||
model = self._load_ckpt_from_storage(repo_id_or_path, legacy_info)
|
||||
else:
|
||||
model = self._load_diffusers_from_storage(
|
||||
repo_id_or_path,
|
||||
subfolder,
|
||||
revision,
|
||||
model_class,
|
||||
)
|
||||
if self.sequential_offload and isinstance(model,StableDiffusionGeneratorPipeline):
|
||||
model.enable_offload_submodels(self.execution_device)
|
||||
return model
|
||||
@ -404,21 +420,19 @@ class ModelCache(object):
|
||||
if model_class in DiffusionClasses\
|
||||
else {}
|
||||
|
||||
# silence transformer and diffuser warnings
|
||||
with SilenceWarnings():
|
||||
for rev in revisions:
|
||||
try:
|
||||
model = model_class.from_pretrained(
|
||||
repo_id_or_path,
|
||||
revision=rev,
|
||||
subfolder=subfolder or '.',
|
||||
cache_dir=global_cache_dir('hub'),
|
||||
**extra_args,
|
||||
)
|
||||
self.logger.debug(f'Found revision {rev}')
|
||||
break
|
||||
except OSError:
|
||||
pass
|
||||
for rev in revisions:
|
||||
try:
|
||||
model = model_class.from_pretrained(
|
||||
repo_id_or_path,
|
||||
revision=rev,
|
||||
subfolder=subfolder or '.',
|
||||
cache_dir=global_cache_dir('hub'),
|
||||
**extra_args,
|
||||
)
|
||||
self.logger.debug(f'Found revision {rev}')
|
||||
break
|
||||
except OSError:
|
||||
pass
|
||||
return model
|
||||
|
||||
def _load_ckpt_from_storage(self,
|
||||
@ -429,24 +443,27 @@ class ModelCache(object):
|
||||
:param ckpt_path: string or Path pointing to the weights file (.ckpt or .safetensors)
|
||||
:param legacy_info: LegacyInfo object containing paths to legacy config file and alternate vae if required
|
||||
'''
|
||||
assert legacy_info is not None
|
||||
|
||||
# deferred loading to avoid circular import errors
|
||||
from .convert_ckpt_to_diffusers import load_pipeline_from_original_stable_diffusion_ckpt
|
||||
with SilenceWarnings():
|
||||
if legacy_info is None or legacy_info.config_file is None:
|
||||
if Path(ckpt_path).suffix == '.safetensors':
|
||||
return safetensors.torch.load_file(ckpt_path)
|
||||
else:
|
||||
return torch.load(ckpt_path)
|
||||
else:
|
||||
# deferred loading to avoid circular import errors
|
||||
from .convert_ckpt_to_diffusers import load_pipeline_from_original_stable_diffusion_ckpt
|
||||
pipeline = load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
checkpoint_path=ckpt_path,
|
||||
original_config_file=legacy_info.config_file,
|
||||
vae_path=legacy_info.vae_file,
|
||||
vae_path=legacy_info.vae_file,
|
||||
return_generator_pipeline=True,
|
||||
precision=self.precision,
|
||||
)
|
||||
return pipeline
|
||||
return pipeline
|
||||
|
||||
def _legacy_model_hash(self, checkpoint_path: Union[str,Path])->str:
|
||||
sha = hashlib.sha256()
|
||||
path = Path(checkpoint_path)
|
||||
assert path.is_file()
|
||||
assert path.is_file(),f"File {checkpoint_path} not found"
|
||||
|
||||
hashpath = path.parent / f"{path.name}.sha256"
|
||||
if hashpath.exists() and path.stat().st_mtime <= hashpath.stat().st_mtime:
|
||||
|
@ -31,7 +31,56 @@ generation operations. The ModelCache object can be accessed using
|
||||
the manager's "cache" attribute.
|
||||
|
||||
Other methods provided by ModelManager support importing, editing,
|
||||
converting and deleting models.
|
||||
converting and deleting models.
|
||||
|
||||
The general format of a models.yaml section is:
|
||||
|
||||
name-of-model:
|
||||
format: diffusers|ckpt|vae|text_encoder|tokenizer...
|
||||
repo_id: owner/repo
|
||||
path: /path/to/local/file/or/directory
|
||||
subfolder: subfolder-name
|
||||
submodel: vae|text_encoder|tokenizer...
|
||||
|
||||
The format is one of {diffusers, ckpt, vae, text_encoder, tokenizer,
|
||||
unet, scheduler, safety_checker, feature_extractor}, and correspond to
|
||||
items in the SDModelType enum defined in model_cache.py
|
||||
|
||||
One, but not both, of repo_id and path are provided. repo_id is the
|
||||
HuggingFace repository ID of the model, and path points to the file or
|
||||
directory on disk.
|
||||
|
||||
If subfolder is provided, then the model exists in a subdirectory of
|
||||
the main model. These are usually named after the model type, such as
|
||||
"unet".
|
||||
|
||||
Finally, if submodel is provided, then the path/repo_id is treated as
|
||||
a diffusers model, the whole thing is ready into memory, and then the
|
||||
requested part (e.g. "unet") is retrieved.
|
||||
|
||||
This summarizes the three ways of getting a non-diffuser model:
|
||||
|
||||
clip-test-1:
|
||||
format: text_encoder
|
||||
repo_id: openai/clip-vit-large-patch14
|
||||
description: Returns standalone CLIPTextModel
|
||||
|
||||
clip-test-2:
|
||||
format: diffusers
|
||||
repo_id: stabilityai/stable-diffusion-2
|
||||
submodel: text_encoder
|
||||
description: Returns the text_encoder part of whole diffusers model (whole thing in RAM)
|
||||
|
||||
clip-test-3:
|
||||
format: text_encoder
|
||||
repo_id: stabilityai/stable-diffusion-2
|
||||
subfolder: text_encoder
|
||||
description: Returns the text_encoder in the subfolder of the diffusers model (just the encoder in RAM)
|
||||
|
||||
clip-token:
|
||||
format: tokenizer
|
||||
repo_id: openai/clip-vit-large-patch14
|
||||
description: Returns standalone tokenizer
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
@ -53,14 +102,14 @@ from omegaconf import OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
|
||||
from invokeai.backend.globals import Globals, global_cache_dir, global_resolve_path
|
||||
from .model_cache import Generator, ModelClass, ModelCache, SDModelType, ModelStatus, LegacyInfo
|
||||
from .model_cache import ModelClass, ModelCache, ModelLocker, SDModelType, ModelStatus, LegacyInfo
|
||||
|
||||
from ..util import CUDA_DEVICE
|
||||
|
||||
# wanted to use pydantic here, but Generator objects not supported
|
||||
@dataclass
|
||||
class SDModelInfo():
|
||||
context: Generator[ModelClass, None, None]
|
||||
context: ModelLocker
|
||||
name: str
|
||||
hash: str
|
||||
location: Union[Path,str]
|
||||
@ -125,6 +174,7 @@ class ModelManager(object):
|
||||
sequential_offload = sequential_offload,
|
||||
logger = logger,
|
||||
)
|
||||
self.cache_keys = dict()
|
||||
self.logger = logger
|
||||
|
||||
def valid_model(self, model_name: str) -> bool:
|
||||
@ -154,8 +204,12 @@ class ModelManager(object):
|
||||
|
||||
# get the required loading info out of the config file
|
||||
mconfig = self.config[model_name]
|
||||
|
||||
format = mconfig.get('format','diffusers')
|
||||
model_type = SDModelType.diffusion_pipeline
|
||||
model_parts = dict([(x.name,x) for x in SDModelType])
|
||||
legacy = None
|
||||
|
||||
if format=='ckpt':
|
||||
location = global_resolve_path(mconfig.weights)
|
||||
legacy = LegacyInfo(
|
||||
@ -165,14 +219,22 @@ class ModelManager(object):
|
||||
legacy.vae_file = global_resolve_path(mconfig.vae)
|
||||
elif format=='diffusers':
|
||||
location = mconfig.get('repo_id') or mconfig.get('path')
|
||||
revision = mconfig.get('revision')
|
||||
if sm := mconfig.get('submodel'):
|
||||
submodel = model_parts[sm]
|
||||
elif format in model_parts:
|
||||
location = mconfig.get('repo_id') or mconfig.get('path') or mconfig.get('weights')
|
||||
model_type = model_parts[format]
|
||||
else:
|
||||
raise InvalidModelError(
|
||||
f'"{model_name}" has an unknown format {format}'
|
||||
)
|
||||
|
||||
subfolder = mconfig.get('subfolder')
|
||||
revision = mconfig.get('revision')
|
||||
hash = self.cache.model_hash(location,revision)
|
||||
|
||||
# to support the traditional way of attaching a VAE
|
||||
# to a model, we hacked in `attach_model_part`
|
||||
vae = (None,None)
|
||||
try:
|
||||
vae_id = mconfig.vae.repo_id
|
||||
@ -181,12 +243,19 @@ class ModelManager(object):
|
||||
pass
|
||||
model_context = self.cache.get_model(
|
||||
location,
|
||||
model_type = model_type,
|
||||
revision = revision,
|
||||
subfolder = subfolder,
|
||||
legacy_info = legacy,
|
||||
submodel = submodel,
|
||||
attach_model_part=vae,
|
||||
)
|
||||
|
||||
# in case we need to communicate information about this
|
||||
# model to the cache manager, then we need to remember
|
||||
# the cache key
|
||||
self.cache_keys[model_name] = model_context.key
|
||||
|
||||
return SDModelInfo(
|
||||
context = model_context,
|
||||
name = model_name,
|
||||
@ -365,6 +434,7 @@ class ModelManager(object):
|
||||
attributes are incorrect or the model name is missing.
|
||||
"""
|
||||
omega = self.config
|
||||
|
||||
assert "format" in model_attributes, 'missing required field "format"'
|
||||
if model_attributes["format"] == "diffusers":
|
||||
assert (
|
||||
@ -373,9 +443,11 @@ class ModelManager(object):
|
||||
assert (
|
||||
"path" in model_attributes or "repo_id" in model_attributes
|
||||
), 'model must have either the "path" or "repo_id" fields defined'
|
||||
else:
|
||||
elif model_attributes["format"] == "ckpt":
|
||||
for field in ("description", "weights", "height", "width", "config"):
|
||||
assert field in model_attributes, f"required field {field} is missing"
|
||||
else:
|
||||
assert "weights" in model_attributes and "description" in model_attributes
|
||||
|
||||
assert (
|
||||
clobber or model_name not in omega
|
||||
@ -385,9 +457,10 @@ class ModelManager(object):
|
||||
|
||||
if "weights" in omega[model_name]:
|
||||
omega[model_name]["weights"].replace("\\", "/")
|
||||
|
||||
if clobber:
|
||||
self._invalidate_cached_model(model_name)
|
||||
|
||||
if clobber and model_name in self.cache_keys:
|
||||
self.cache.uncache_model(self.cache_keys[model_name])
|
||||
del self.cache_keys[model_name]
|
||||
|
||||
def import_diffuser_model(
|
||||
self,
|
||||
@ -425,6 +498,55 @@ class ModelManager(object):
|
||||
self.commit(commit_to_conf)
|
||||
return model_name
|
||||
|
||||
def import_lora(
|
||||
self,
|
||||
path: Path,
|
||||
model_name: str=None,
|
||||
description: str=None,
|
||||
):
|
||||
"""
|
||||
Creates an entry for the indicated lora file. Call
|
||||
mgr.commit() to write out the configuration to models.yaml
|
||||
"""
|
||||
path = Path(path)
|
||||
model_name = model_name or path.stem
|
||||
model_description = description or f"LoRA model {model_name}"
|
||||
self.add_model(model_name,
|
||||
dict(
|
||||
format="lora",
|
||||
weights=str(path),
|
||||
description=model_description,
|
||||
),
|
||||
True
|
||||
)
|
||||
|
||||
def import_embedding(
|
||||
self,
|
||||
path: Path,
|
||||
model_name: str=None,
|
||||
description: str=None,
|
||||
):
|
||||
"""
|
||||
Creates an entry for the indicated lora file. Call
|
||||
mgr.commit() to write out the configuration to models.yaml
|
||||
"""
|
||||
path = Path(path)
|
||||
if path.is_directory() and (path / "learned_embeds.bin").exists():
|
||||
weights = path / "learned_embeds.bin"
|
||||
else:
|
||||
weights = path
|
||||
|
||||
model_name = model_name or path.stem
|
||||
model_description = description or f"Textual embedding model {model_name}"
|
||||
self.add_model(model_name,
|
||||
dict(
|
||||
format="textual_inversion",
|
||||
weights=str(path),
|
||||
description=model_description,
|
||||
),
|
||||
True
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def probe_model_type(self, checkpoint: dict) -> SDLegacyType:
|
||||
"""
|
||||
|
Loading…
Reference in New Issue
Block a user