mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
generalize model loading support, include loras/embeds
This commit is contained in:
parent
a8cfa3565c
commit
05a27bda5e
@ -20,13 +20,13 @@ import contextlib
|
|||||||
import gc
|
import gc
|
||||||
import hashlib
|
import hashlib
|
||||||
import warnings
|
import warnings
|
||||||
from collections.abc import Generator
|
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Sequence, Union, Tuple, types
|
from typing import Sequence, Union, Tuple, types
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import safetensors.torch
|
||||||
from diffusers import StableDiffusionPipeline, AutoencoderKL, SchedulerMixin, UNet2DConditionModel
|
from diffusers import StableDiffusionPipeline, AutoencoderKL, SchedulerMixin, UNet2DConditionModel
|
||||||
from diffusers import logging as diffusers_logging
|
from diffusers import logging as diffusers_logging
|
||||||
from diffusers.pipelines.stable_diffusion.safety_checker import \
|
from diffusers.pipelines.stable_diffusion.safety_checker import \
|
||||||
@ -46,13 +46,18 @@ MAX_MODELS = 4
|
|||||||
# This is the mapping from the stable diffusion submodel dict key to the class
|
# This is the mapping from the stable diffusion submodel dict key to the class
|
||||||
class SDModelType(Enum):
|
class SDModelType(Enum):
|
||||||
diffusion_pipeline=StableDiffusionGeneratorPipeline # whole thing
|
diffusion_pipeline=StableDiffusionGeneratorPipeline # whole thing
|
||||||
vae=AutoencoderKL # parts
|
diffusers=StableDiffusionGeneratorPipeline # same thing
|
||||||
|
vae=AutoencoderKL # diffusers parts
|
||||||
text_encoder=CLIPTextModel
|
text_encoder=CLIPTextModel
|
||||||
tokenizer=CLIPTokenizer
|
tokenizer=CLIPTokenizer
|
||||||
unet=UNet2DConditionModel
|
unet=UNet2DConditionModel
|
||||||
scheduler=SchedulerMixin
|
scheduler=SchedulerMixin
|
||||||
safety_checker=StableDiffusionSafetyChecker
|
safety_checker=StableDiffusionSafetyChecker
|
||||||
feature_extractor=CLIPFeatureExtractor
|
feature_extractor=CLIPFeatureExtractor
|
||||||
|
# These are all loaded as dicts of tensors
|
||||||
|
lora=dict
|
||||||
|
textual_inversion=dict
|
||||||
|
ckpt=dict
|
||||||
|
|
||||||
class ModelStatus(Enum):
|
class ModelStatus(Enum):
|
||||||
unknown='unknown'
|
unknown='unknown'
|
||||||
@ -192,6 +197,15 @@ class ModelCache(object):
|
|||||||
|
|
||||||
return self.ModelLocker(self, key, model, gpu_load)
|
return self.ModelLocker(self, key, model, gpu_load)
|
||||||
|
|
||||||
|
def uncache_model(self, key: str):
|
||||||
|
'''Remove corresponding model from the cache'''
|
||||||
|
if key is not None and key in self.models:
|
||||||
|
with contextlib.suppress(ValueError):
|
||||||
|
del self.models[key]
|
||||||
|
del self.locked_models[key]
|
||||||
|
self.stack.remove(key)
|
||||||
|
self.loaded_models.remove(key)
|
||||||
|
|
||||||
class ModelLocker(object):
|
class ModelLocker(object):
|
||||||
def __init__(self, cache, key, model, gpu_load):
|
def __init__(self, cache, key, model, gpu_load):
|
||||||
self.gpu_load = gpu_load
|
self.gpu_load = gpu_load
|
||||||
@ -301,7 +315,7 @@ class ModelCache(object):
|
|||||||
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
|
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
|
||||||
'''
|
'''
|
||||||
path = Path(repo_id_or_path)
|
path = Path(repo_id_or_path)
|
||||||
return path.is_file() and path.suffix in [".ckpt",".safetensors"]
|
return path.suffix in [".ckpt",".safetensors",".pt"]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def scan_model(cls, model_name, checkpoint):
|
def scan_model(cls, model_name, checkpoint):
|
||||||
@ -368,6 +382,8 @@ class ModelCache(object):
|
|||||||
:param model_class: class of model to return, defaults to StableDiffusionGeneratorPIpeline
|
:param model_class: class of model to return, defaults to StableDiffusionGeneratorPIpeline
|
||||||
:param legacy_info: a LegacyInfo object containing additional info needed to load a legacy ckpt
|
:param legacy_info: a LegacyInfo object containing additional info needed to load a legacy ckpt
|
||||||
'''
|
'''
|
||||||
|
# silence transformer and diffuser warnings
|
||||||
|
with SilenceWarnings():
|
||||||
if self.is_legacy_ckpt(repo_id_or_path):
|
if self.is_legacy_ckpt(repo_id_or_path):
|
||||||
model = self._load_ckpt_from_storage(repo_id_or_path, legacy_info)
|
model = self._load_ckpt_from_storage(repo_id_or_path, legacy_info)
|
||||||
else:
|
else:
|
||||||
@ -404,8 +420,6 @@ class ModelCache(object):
|
|||||||
if model_class in DiffusionClasses\
|
if model_class in DiffusionClasses\
|
||||||
else {}
|
else {}
|
||||||
|
|
||||||
# silence transformer and diffuser warnings
|
|
||||||
with SilenceWarnings():
|
|
||||||
for rev in revisions:
|
for rev in revisions:
|
||||||
try:
|
try:
|
||||||
model = model_class.from_pretrained(
|
model = model_class.from_pretrained(
|
||||||
@ -429,11 +443,14 @@ class ModelCache(object):
|
|||||||
:param ckpt_path: string or Path pointing to the weights file (.ckpt or .safetensors)
|
:param ckpt_path: string or Path pointing to the weights file (.ckpt or .safetensors)
|
||||||
:param legacy_info: LegacyInfo object containing paths to legacy config file and alternate vae if required
|
:param legacy_info: LegacyInfo object containing paths to legacy config file and alternate vae if required
|
||||||
'''
|
'''
|
||||||
assert legacy_info is not None
|
if legacy_info is None or legacy_info.config_file is None:
|
||||||
|
if Path(ckpt_path).suffix == '.safetensors':
|
||||||
|
return safetensors.torch.load_file(ckpt_path)
|
||||||
|
else:
|
||||||
|
return torch.load(ckpt_path)
|
||||||
|
else:
|
||||||
# deferred loading to avoid circular import errors
|
# deferred loading to avoid circular import errors
|
||||||
from .convert_ckpt_to_diffusers import load_pipeline_from_original_stable_diffusion_ckpt
|
from .convert_ckpt_to_diffusers import load_pipeline_from_original_stable_diffusion_ckpt
|
||||||
with SilenceWarnings():
|
|
||||||
pipeline = load_pipeline_from_original_stable_diffusion_ckpt(
|
pipeline = load_pipeline_from_original_stable_diffusion_ckpt(
|
||||||
checkpoint_path=ckpt_path,
|
checkpoint_path=ckpt_path,
|
||||||
original_config_file=legacy_info.config_file,
|
original_config_file=legacy_info.config_file,
|
||||||
@ -446,7 +463,7 @@ class ModelCache(object):
|
|||||||
def _legacy_model_hash(self, checkpoint_path: Union[str,Path])->str:
|
def _legacy_model_hash(self, checkpoint_path: Union[str,Path])->str:
|
||||||
sha = hashlib.sha256()
|
sha = hashlib.sha256()
|
||||||
path = Path(checkpoint_path)
|
path = Path(checkpoint_path)
|
||||||
assert path.is_file()
|
assert path.is_file(),f"File {checkpoint_path} not found"
|
||||||
|
|
||||||
hashpath = path.parent / f"{path.name}.sha256"
|
hashpath = path.parent / f"{path.name}.sha256"
|
||||||
if hashpath.exists() and path.stat().st_mtime <= hashpath.stat().st_mtime:
|
if hashpath.exists() and path.stat().st_mtime <= hashpath.stat().st_mtime:
|
||||||
|
@ -32,6 +32,55 @@ the manager's "cache" attribute.
|
|||||||
|
|
||||||
Other methods provided by ModelManager support importing, editing,
|
Other methods provided by ModelManager support importing, editing,
|
||||||
converting and deleting models.
|
converting and deleting models.
|
||||||
|
|
||||||
|
The general format of a models.yaml section is:
|
||||||
|
|
||||||
|
name-of-model:
|
||||||
|
format: diffusers|ckpt|vae|text_encoder|tokenizer...
|
||||||
|
repo_id: owner/repo
|
||||||
|
path: /path/to/local/file/or/directory
|
||||||
|
subfolder: subfolder-name
|
||||||
|
submodel: vae|text_encoder|tokenizer...
|
||||||
|
|
||||||
|
The format is one of {diffusers, ckpt, vae, text_encoder, tokenizer,
|
||||||
|
unet, scheduler, safety_checker, feature_extractor}, and correspond to
|
||||||
|
items in the SDModelType enum defined in model_cache.py
|
||||||
|
|
||||||
|
One, but not both, of repo_id and path are provided. repo_id is the
|
||||||
|
HuggingFace repository ID of the model, and path points to the file or
|
||||||
|
directory on disk.
|
||||||
|
|
||||||
|
If subfolder is provided, then the model exists in a subdirectory of
|
||||||
|
the main model. These are usually named after the model type, such as
|
||||||
|
"unet".
|
||||||
|
|
||||||
|
Finally, if submodel is provided, then the path/repo_id is treated as
|
||||||
|
a diffusers model, the whole thing is ready into memory, and then the
|
||||||
|
requested part (e.g. "unet") is retrieved.
|
||||||
|
|
||||||
|
This summarizes the three ways of getting a non-diffuser model:
|
||||||
|
|
||||||
|
clip-test-1:
|
||||||
|
format: text_encoder
|
||||||
|
repo_id: openai/clip-vit-large-patch14
|
||||||
|
description: Returns standalone CLIPTextModel
|
||||||
|
|
||||||
|
clip-test-2:
|
||||||
|
format: diffusers
|
||||||
|
repo_id: stabilityai/stable-diffusion-2
|
||||||
|
submodel: text_encoder
|
||||||
|
description: Returns the text_encoder part of whole diffusers model (whole thing in RAM)
|
||||||
|
|
||||||
|
clip-test-3:
|
||||||
|
format: text_encoder
|
||||||
|
repo_id: stabilityai/stable-diffusion-2
|
||||||
|
subfolder: text_encoder
|
||||||
|
description: Returns the text_encoder in the subfolder of the diffusers model (just the encoder in RAM)
|
||||||
|
|
||||||
|
clip-token:
|
||||||
|
format: tokenizer
|
||||||
|
repo_id: openai/clip-vit-large-patch14
|
||||||
|
description: Returns standalone tokenizer
|
||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
@ -53,14 +102,14 @@ from omegaconf import OmegaConf
|
|||||||
from omegaconf.dictconfig import DictConfig
|
from omegaconf.dictconfig import DictConfig
|
||||||
|
|
||||||
from invokeai.backend.globals import Globals, global_cache_dir, global_resolve_path
|
from invokeai.backend.globals import Globals, global_cache_dir, global_resolve_path
|
||||||
from .model_cache import Generator, ModelClass, ModelCache, SDModelType, ModelStatus, LegacyInfo
|
from .model_cache import ModelClass, ModelCache, ModelLocker, SDModelType, ModelStatus, LegacyInfo
|
||||||
|
|
||||||
from ..util import CUDA_DEVICE
|
from ..util import CUDA_DEVICE
|
||||||
|
|
||||||
# wanted to use pydantic here, but Generator objects not supported
|
# wanted to use pydantic here, but Generator objects not supported
|
||||||
@dataclass
|
@dataclass
|
||||||
class SDModelInfo():
|
class SDModelInfo():
|
||||||
context: Generator[ModelClass, None, None]
|
context: ModelLocker
|
||||||
name: str
|
name: str
|
||||||
hash: str
|
hash: str
|
||||||
location: Union[Path,str]
|
location: Union[Path,str]
|
||||||
@ -125,6 +174,7 @@ class ModelManager(object):
|
|||||||
sequential_offload = sequential_offload,
|
sequential_offload = sequential_offload,
|
||||||
logger = logger,
|
logger = logger,
|
||||||
)
|
)
|
||||||
|
self.cache_keys = dict()
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
|
|
||||||
def valid_model(self, model_name: str) -> bool:
|
def valid_model(self, model_name: str) -> bool:
|
||||||
@ -154,8 +204,12 @@ class ModelManager(object):
|
|||||||
|
|
||||||
# get the required loading info out of the config file
|
# get the required loading info out of the config file
|
||||||
mconfig = self.config[model_name]
|
mconfig = self.config[model_name]
|
||||||
|
|
||||||
format = mconfig.get('format','diffusers')
|
format = mconfig.get('format','diffusers')
|
||||||
|
model_type = SDModelType.diffusion_pipeline
|
||||||
|
model_parts = dict([(x.name,x) for x in SDModelType])
|
||||||
legacy = None
|
legacy = None
|
||||||
|
|
||||||
if format=='ckpt':
|
if format=='ckpt':
|
||||||
location = global_resolve_path(mconfig.weights)
|
location = global_resolve_path(mconfig.weights)
|
||||||
legacy = LegacyInfo(
|
legacy = LegacyInfo(
|
||||||
@ -165,14 +219,22 @@ class ModelManager(object):
|
|||||||
legacy.vae_file = global_resolve_path(mconfig.vae)
|
legacy.vae_file = global_resolve_path(mconfig.vae)
|
||||||
elif format=='diffusers':
|
elif format=='diffusers':
|
||||||
location = mconfig.get('repo_id') or mconfig.get('path')
|
location = mconfig.get('repo_id') or mconfig.get('path')
|
||||||
revision = mconfig.get('revision')
|
if sm := mconfig.get('submodel'):
|
||||||
|
submodel = model_parts[sm]
|
||||||
|
elif format in model_parts:
|
||||||
|
location = mconfig.get('repo_id') or mconfig.get('path') or mconfig.get('weights')
|
||||||
|
model_type = model_parts[format]
|
||||||
else:
|
else:
|
||||||
raise InvalidModelError(
|
raise InvalidModelError(
|
||||||
f'"{model_name}" has an unknown format {format}'
|
f'"{model_name}" has an unknown format {format}'
|
||||||
)
|
)
|
||||||
|
|
||||||
subfolder = mconfig.get('subfolder')
|
subfolder = mconfig.get('subfolder')
|
||||||
|
revision = mconfig.get('revision')
|
||||||
hash = self.cache.model_hash(location,revision)
|
hash = self.cache.model_hash(location,revision)
|
||||||
|
|
||||||
|
# to support the traditional way of attaching a VAE
|
||||||
|
# to a model, we hacked in `attach_model_part`
|
||||||
vae = (None,None)
|
vae = (None,None)
|
||||||
try:
|
try:
|
||||||
vae_id = mconfig.vae.repo_id
|
vae_id = mconfig.vae.repo_id
|
||||||
@ -181,12 +243,19 @@ class ModelManager(object):
|
|||||||
pass
|
pass
|
||||||
model_context = self.cache.get_model(
|
model_context = self.cache.get_model(
|
||||||
location,
|
location,
|
||||||
|
model_type = model_type,
|
||||||
revision = revision,
|
revision = revision,
|
||||||
subfolder = subfolder,
|
subfolder = subfolder,
|
||||||
legacy_info = legacy,
|
legacy_info = legacy,
|
||||||
submodel = submodel,
|
submodel = submodel,
|
||||||
attach_model_part=vae,
|
attach_model_part=vae,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# in case we need to communicate information about this
|
||||||
|
# model to the cache manager, then we need to remember
|
||||||
|
# the cache key
|
||||||
|
self.cache_keys[model_name] = model_context.key
|
||||||
|
|
||||||
return SDModelInfo(
|
return SDModelInfo(
|
||||||
context = model_context,
|
context = model_context,
|
||||||
name = model_name,
|
name = model_name,
|
||||||
@ -365,6 +434,7 @@ class ModelManager(object):
|
|||||||
attributes are incorrect or the model name is missing.
|
attributes are incorrect or the model name is missing.
|
||||||
"""
|
"""
|
||||||
omega = self.config
|
omega = self.config
|
||||||
|
|
||||||
assert "format" in model_attributes, 'missing required field "format"'
|
assert "format" in model_attributes, 'missing required field "format"'
|
||||||
if model_attributes["format"] == "diffusers":
|
if model_attributes["format"] == "diffusers":
|
||||||
assert (
|
assert (
|
||||||
@ -373,9 +443,11 @@ class ModelManager(object):
|
|||||||
assert (
|
assert (
|
||||||
"path" in model_attributes or "repo_id" in model_attributes
|
"path" in model_attributes or "repo_id" in model_attributes
|
||||||
), 'model must have either the "path" or "repo_id" fields defined'
|
), 'model must have either the "path" or "repo_id" fields defined'
|
||||||
else:
|
elif model_attributes["format"] == "ckpt":
|
||||||
for field in ("description", "weights", "height", "width", "config"):
|
for field in ("description", "weights", "height", "width", "config"):
|
||||||
assert field in model_attributes, f"required field {field} is missing"
|
assert field in model_attributes, f"required field {field} is missing"
|
||||||
|
else:
|
||||||
|
assert "weights" in model_attributes and "description" in model_attributes
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
clobber or model_name not in omega
|
clobber or model_name not in omega
|
||||||
@ -386,8 +458,9 @@ class ModelManager(object):
|
|||||||
if "weights" in omega[model_name]:
|
if "weights" in omega[model_name]:
|
||||||
omega[model_name]["weights"].replace("\\", "/")
|
omega[model_name]["weights"].replace("\\", "/")
|
||||||
|
|
||||||
if clobber:
|
if clobber and model_name in self.cache_keys:
|
||||||
self._invalidate_cached_model(model_name)
|
self.cache.uncache_model(self.cache_keys[model_name])
|
||||||
|
del self.cache_keys[model_name]
|
||||||
|
|
||||||
def import_diffuser_model(
|
def import_diffuser_model(
|
||||||
self,
|
self,
|
||||||
@ -425,6 +498,55 @@ class ModelManager(object):
|
|||||||
self.commit(commit_to_conf)
|
self.commit(commit_to_conf)
|
||||||
return model_name
|
return model_name
|
||||||
|
|
||||||
|
def import_lora(
|
||||||
|
self,
|
||||||
|
path: Path,
|
||||||
|
model_name: str=None,
|
||||||
|
description: str=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Creates an entry for the indicated lora file. Call
|
||||||
|
mgr.commit() to write out the configuration to models.yaml
|
||||||
|
"""
|
||||||
|
path = Path(path)
|
||||||
|
model_name = model_name or path.stem
|
||||||
|
model_description = description or f"LoRA model {model_name}"
|
||||||
|
self.add_model(model_name,
|
||||||
|
dict(
|
||||||
|
format="lora",
|
||||||
|
weights=str(path),
|
||||||
|
description=model_description,
|
||||||
|
),
|
||||||
|
True
|
||||||
|
)
|
||||||
|
|
||||||
|
def import_embedding(
|
||||||
|
self,
|
||||||
|
path: Path,
|
||||||
|
model_name: str=None,
|
||||||
|
description: str=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Creates an entry for the indicated lora file. Call
|
||||||
|
mgr.commit() to write out the configuration to models.yaml
|
||||||
|
"""
|
||||||
|
path = Path(path)
|
||||||
|
if path.is_directory() and (path / "learned_embeds.bin").exists():
|
||||||
|
weights = path / "learned_embeds.bin"
|
||||||
|
else:
|
||||||
|
weights = path
|
||||||
|
|
||||||
|
model_name = model_name or path.stem
|
||||||
|
model_description = description or f"Textual embedding model {model_name}"
|
||||||
|
self.add_model(model_name,
|
||||||
|
dict(
|
||||||
|
format="textual_inversion",
|
||||||
|
weights=str(path),
|
||||||
|
description=model_description,
|
||||||
|
),
|
||||||
|
True
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def probe_model_type(self, checkpoint: dict) -> SDLegacyType:
|
def probe_model_type(self, checkpoint: dict) -> SDLegacyType:
|
||||||
"""
|
"""
|
||||||
|
Loading…
Reference in New Issue
Block a user