From 05a27bda5e4375630d1ff8255bf2373d547f8a5f Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sat, 6 May 2023 15:58:44 -0400 Subject: [PATCH] generalize model loading support, include loras/embeds --- .../backend/model_management/model_cache.py | 87 ++++++----- .../backend/model_management/model_manager.py | 138 +++++++++++++++++- 2 files changed, 182 insertions(+), 43 deletions(-) diff --git a/invokeai/backend/model_management/model_cache.py b/invokeai/backend/model_management/model_cache.py index 265d363475..c316c4292f 100644 --- a/invokeai/backend/model_management/model_cache.py +++ b/invokeai/backend/model_management/model_cache.py @@ -20,13 +20,13 @@ import contextlib import gc import hashlib import warnings -from collections.abc import Generator from collections import Counter from enum import Enum from pathlib import Path from typing import Sequence, Union, Tuple, types import torch +import safetensors.torch from diffusers import StableDiffusionPipeline, AutoencoderKL, SchedulerMixin, UNet2DConditionModel from diffusers import logging as diffusers_logging from diffusers.pipelines.stable_diffusion.safety_checker import \ @@ -46,13 +46,18 @@ MAX_MODELS = 4 # This is the mapping from the stable diffusion submodel dict key to the class class SDModelType(Enum): diffusion_pipeline=StableDiffusionGeneratorPipeline # whole thing - vae=AutoencoderKL # parts + diffusers=StableDiffusionGeneratorPipeline # same thing + vae=AutoencoderKL # diffusers parts text_encoder=CLIPTextModel tokenizer=CLIPTokenizer unet=UNet2DConditionModel scheduler=SchedulerMixin safety_checker=StableDiffusionSafetyChecker feature_extractor=CLIPFeatureExtractor + # These are all loaded as dicts of tensors + lora=dict + textual_inversion=dict + ckpt=dict class ModelStatus(Enum): unknown='unknown' @@ -192,6 +197,15 @@ class ModelCache(object): return self.ModelLocker(self, key, model, gpu_load) + def uncache_model(self, key: str): + '''Remove corresponding model from the cache''' + if key is not None and key in self.models: + with contextlib.suppress(ValueError): + del self.models[key] + del self.locked_models[key] + self.stack.remove(key) + self.loaded_models.remove(key) + class ModelLocker(object): def __init__(self, cache, key, model, gpu_load): self.gpu_load = gpu_load @@ -301,7 +315,7 @@ class ModelCache(object): :param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model ''' path = Path(repo_id_or_path) - return path.is_file() and path.suffix in [".ckpt",".safetensors"] + return path.suffix in [".ckpt",".safetensors",".pt"] @classmethod def scan_model(cls, model_name, checkpoint): @@ -368,15 +382,17 @@ class ModelCache(object): :param model_class: class of model to return, defaults to StableDiffusionGeneratorPIpeline :param legacy_info: a LegacyInfo object containing additional info needed to load a legacy ckpt ''' - if self.is_legacy_ckpt(repo_id_or_path): - model = self._load_ckpt_from_storage(repo_id_or_path, legacy_info) - else: - model = self._load_diffusers_from_storage( - repo_id_or_path, - subfolder, - revision, - model_class, - ) + # silence transformer and diffuser warnings + with SilenceWarnings(): + if self.is_legacy_ckpt(repo_id_or_path): + model = self._load_ckpt_from_storage(repo_id_or_path, legacy_info) + else: + model = self._load_diffusers_from_storage( + repo_id_or_path, + subfolder, + revision, + model_class, + ) if self.sequential_offload and isinstance(model,StableDiffusionGeneratorPipeline): model.enable_offload_submodels(self.execution_device) return model @@ -404,21 +420,19 @@ class ModelCache(object): if model_class in DiffusionClasses\ else {} - # silence transformer and diffuser warnings - with SilenceWarnings(): - for rev in revisions: - try: - model = model_class.from_pretrained( - repo_id_or_path, - revision=rev, - subfolder=subfolder or '.', - cache_dir=global_cache_dir('hub'), - **extra_args, - ) - self.logger.debug(f'Found revision {rev}') - break - except OSError: - pass + for rev in revisions: + try: + model = model_class.from_pretrained( + repo_id_or_path, + revision=rev, + subfolder=subfolder or '.', + cache_dir=global_cache_dir('hub'), + **extra_args, + ) + self.logger.debug(f'Found revision {rev}') + break + except OSError: + pass return model def _load_ckpt_from_storage(self, @@ -429,24 +443,27 @@ class ModelCache(object): :param ckpt_path: string or Path pointing to the weights file (.ckpt or .safetensors) :param legacy_info: LegacyInfo object containing paths to legacy config file and alternate vae if required ''' - assert legacy_info is not None - - # deferred loading to avoid circular import errors - from .convert_ckpt_to_diffusers import load_pipeline_from_original_stable_diffusion_ckpt - with SilenceWarnings(): + if legacy_info is None or legacy_info.config_file is None: + if Path(ckpt_path).suffix == '.safetensors': + return safetensors.torch.load_file(ckpt_path) + else: + return torch.load(ckpt_path) + else: + # deferred loading to avoid circular import errors + from .convert_ckpt_to_diffusers import load_pipeline_from_original_stable_diffusion_ckpt pipeline = load_pipeline_from_original_stable_diffusion_ckpt( checkpoint_path=ckpt_path, original_config_file=legacy_info.config_file, - vae_path=legacy_info.vae_file, + vae_path=legacy_info.vae_file, return_generator_pipeline=True, precision=self.precision, ) - return pipeline + return pipeline def _legacy_model_hash(self, checkpoint_path: Union[str,Path])->str: sha = hashlib.sha256() path = Path(checkpoint_path) - assert path.is_file() + assert path.is_file(),f"File {checkpoint_path} not found" hashpath = path.parent / f"{path.name}.sha256" if hashpath.exists() and path.stat().st_mtime <= hashpath.stat().st_mtime: diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index 8b6704eb8a..8929d8cdd6 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -31,7 +31,56 @@ generation operations. The ModelCache object can be accessed using the manager's "cache" attribute. Other methods provided by ModelManager support importing, editing, -converting and deleting models. +converting and deleting models. + +The general format of a models.yaml section is: + + name-of-model: + format: diffusers|ckpt|vae|text_encoder|tokenizer... + repo_id: owner/repo + path: /path/to/local/file/or/directory + subfolder: subfolder-name + submodel: vae|text_encoder|tokenizer... + +The format is one of {diffusers, ckpt, vae, text_encoder, tokenizer, +unet, scheduler, safety_checker, feature_extractor}, and correspond to +items in the SDModelType enum defined in model_cache.py + +One, but not both, of repo_id and path are provided. repo_id is the +HuggingFace repository ID of the model, and path points to the file or +directory on disk. + +If subfolder is provided, then the model exists in a subdirectory of +the main model. These are usually named after the model type, such as +"unet". + +Finally, if submodel is provided, then the path/repo_id is treated as +a diffusers model, the whole thing is ready into memory, and then the +requested part (e.g. "unet") is retrieved. + +This summarizes the three ways of getting a non-diffuser model: + + clip-test-1: + format: text_encoder + repo_id: openai/clip-vit-large-patch14 + description: Returns standalone CLIPTextModel + + clip-test-2: + format: diffusers + repo_id: stabilityai/stable-diffusion-2 + submodel: text_encoder + description: Returns the text_encoder part of whole diffusers model (whole thing in RAM) + + clip-test-3: + format: text_encoder + repo_id: stabilityai/stable-diffusion-2 + subfolder: text_encoder + description: Returns the text_encoder in the subfolder of the diffusers model (just the encoder in RAM) + + clip-token: + format: tokenizer + repo_id: openai/clip-vit-large-patch14 + description: Returns standalone tokenizer """ from __future__ import annotations @@ -53,14 +102,14 @@ from omegaconf import OmegaConf from omegaconf.dictconfig import DictConfig from invokeai.backend.globals import Globals, global_cache_dir, global_resolve_path -from .model_cache import Generator, ModelClass, ModelCache, SDModelType, ModelStatus, LegacyInfo +from .model_cache import ModelClass, ModelCache, ModelLocker, SDModelType, ModelStatus, LegacyInfo from ..util import CUDA_DEVICE # wanted to use pydantic here, but Generator objects not supported @dataclass class SDModelInfo(): - context: Generator[ModelClass, None, None] + context: ModelLocker name: str hash: str location: Union[Path,str] @@ -125,6 +174,7 @@ class ModelManager(object): sequential_offload = sequential_offload, logger = logger, ) + self.cache_keys = dict() self.logger = logger def valid_model(self, model_name: str) -> bool: @@ -154,8 +204,12 @@ class ModelManager(object): # get the required loading info out of the config file mconfig = self.config[model_name] + format = mconfig.get('format','diffusers') + model_type = SDModelType.diffusion_pipeline + model_parts = dict([(x.name,x) for x in SDModelType]) legacy = None + if format=='ckpt': location = global_resolve_path(mconfig.weights) legacy = LegacyInfo( @@ -165,14 +219,22 @@ class ModelManager(object): legacy.vae_file = global_resolve_path(mconfig.vae) elif format=='diffusers': location = mconfig.get('repo_id') or mconfig.get('path') - revision = mconfig.get('revision') + if sm := mconfig.get('submodel'): + submodel = model_parts[sm] + elif format in model_parts: + location = mconfig.get('repo_id') or mconfig.get('path') or mconfig.get('weights') + model_type = model_parts[format] else: raise InvalidModelError( f'"{model_name}" has an unknown format {format}' ) subfolder = mconfig.get('subfolder') + revision = mconfig.get('revision') hash = self.cache.model_hash(location,revision) + + # to support the traditional way of attaching a VAE + # to a model, we hacked in `attach_model_part` vae = (None,None) try: vae_id = mconfig.vae.repo_id @@ -181,12 +243,19 @@ class ModelManager(object): pass model_context = self.cache.get_model( location, + model_type = model_type, revision = revision, subfolder = subfolder, legacy_info = legacy, submodel = submodel, attach_model_part=vae, ) + + # in case we need to communicate information about this + # model to the cache manager, then we need to remember + # the cache key + self.cache_keys[model_name] = model_context.key + return SDModelInfo( context = model_context, name = model_name, @@ -365,6 +434,7 @@ class ModelManager(object): attributes are incorrect or the model name is missing. """ omega = self.config + assert "format" in model_attributes, 'missing required field "format"' if model_attributes["format"] == "diffusers": assert ( @@ -373,9 +443,11 @@ class ModelManager(object): assert ( "path" in model_attributes or "repo_id" in model_attributes ), 'model must have either the "path" or "repo_id" fields defined' - else: + elif model_attributes["format"] == "ckpt": for field in ("description", "weights", "height", "width", "config"): assert field in model_attributes, f"required field {field} is missing" + else: + assert "weights" in model_attributes and "description" in model_attributes assert ( clobber or model_name not in omega @@ -385,9 +457,10 @@ class ModelManager(object): if "weights" in omega[model_name]: omega[model_name]["weights"].replace("\\", "/") - - if clobber: - self._invalidate_cached_model(model_name) + + if clobber and model_name in self.cache_keys: + self.cache.uncache_model(self.cache_keys[model_name]) + del self.cache_keys[model_name] def import_diffuser_model( self, @@ -425,6 +498,55 @@ class ModelManager(object): self.commit(commit_to_conf) return model_name + def import_lora( + self, + path: Path, + model_name: str=None, + description: str=None, + ): + """ + Creates an entry for the indicated lora file. Call + mgr.commit() to write out the configuration to models.yaml + """ + path = Path(path) + model_name = model_name or path.stem + model_description = description or f"LoRA model {model_name}" + self.add_model(model_name, + dict( + format="lora", + weights=str(path), + description=model_description, + ), + True + ) + + def import_embedding( + self, + path: Path, + model_name: str=None, + description: str=None, + ): + """ + Creates an entry for the indicated lora file. Call + mgr.commit() to write out the configuration to models.yaml + """ + path = Path(path) + if path.is_directory() and (path / "learned_embeds.bin").exists(): + weights = path / "learned_embeds.bin" + else: + weights = path + + model_name = model_name or path.stem + model_description = description or f"Textual embedding model {model_name}" + self.add_model(model_name, + dict( + format="textual_inversion", + weights=str(path), + description=model_description, + ), + True + ) + @classmethod def probe_model_type(self, checkpoint: dict) -> SDLegacyType: """