diff --git a/invokeai/backend/model_management/__init__.py b/invokeai/backend/model_management/__init__.py index 69f449dd45..1d290050d4 100644 --- a/invokeai/backend/model_management/__init__.py +++ b/invokeai/backend/model_management/__init__.py @@ -7,3 +7,4 @@ from .convert_ckpt_to_diffusers import ( ) from .model_manager import ModelManager + diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index 4a2bb56270..a51a2fec22 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -1,4 +1,4 @@ -""" +"""enum Manage a cache of Stable Diffusion model files for fast switching. They are moved between GPU and CPU as necessary. If CPU memory falls below a preset minimum, the least recently used model will be @@ -15,7 +15,7 @@ import sys import textwrap import time import warnings -from enum import Enum +from enum import Enum, auto from pathlib import Path from shutil import move, rmtree from typing import Any, Optional, Union, Callable @@ -24,8 +24,12 @@ import safetensors import safetensors.torch import torch import transformers -from diffusers import AutoencoderKL -from diffusers import logging as dlogging +from diffusers import ( + AutoencoderKL, + UNet2DConditionModel, + SchedulerMixin, + logging as dlogging, +) from huggingface_hub import scan_cache_dir from omegaconf import OmegaConf from omegaconf.dictconfig import DictConfig @@ -33,37 +37,58 @@ from picklescan.scanner import scan_file_path from invokeai.backend.globals import Globals, global_cache_dir -from ..stable_diffusion import StableDiffusionGeneratorPipeline +from transformers import ( + CLIPTextModel, + CLIPTokenizer, + CLIPFeatureExtractor, +) +from diffusers.pipelines.stable_diffusion.safety_checker import ( + StableDiffusionSafetyChecker, + ) +from ..stable_diffusion import ( + StableDiffusionGeneratorPipeline, +) from ..util import CUDA_DEVICE, ask_user, download_with_resume -class SDLegacyType(Enum): - V1 = 1 - V1_INPAINT = 2 - V2 = 3 - V2_e = 4 - V2_v = 5 - UNKNOWN = 99 +class SDLegacyType(Enum): + V1 = auto() + V1_INPAINT = auto() + V2 = auto() + V2_e = auto() + V2_v = auto() + UNKNOWN = auto() + +class SDModelComponent(Enum): + vae="vae" + text_encoder="text_encoder" + tokenizer="tokenizer" + unet="unet" + scheduler="scheduler" + safety_checker="safety_checker" + feature_extractor="feature_extractor" + DEFAULT_MAX_MODELS = 2 class ModelManager(object): - ''' + """ Model manager handles loading, caching, importing, deleting, converting, and editing models. - ''' + """ + def __init__( - self, - config: OmegaConf|Path, - device_type: torch.device = CUDA_DEVICE, - precision: str = "float16", - max_loaded_models=DEFAULT_MAX_MODELS, - sequential_offload=False, - embedding_path: Path=None, + self, + config: OmegaConf | Path, + device_type: torch.device = CUDA_DEVICE, + precision: str = "float16", + max_loaded_models=DEFAULT_MAX_MODELS, + sequential_offload=False, + embedding_path: Path = None, ): """ Initialize with the path to the models.yaml config file or an initialized OmegaConf dictionary. Optional parameters are the torch device type, precision, max_loaded_models, - and sequential_offload boolean. Note that the default device + and sequential_offload boolean. Note that the default device type and precision are set up for a CUDA system running at half precision. """ # prevent nasty-looking CLIP log message @@ -87,15 +112,25 @@ class ModelManager(object): """ return model_name in self.config - def get_model(self, model_name: str=None)->dict: - """ - Given a model named identified in models.yaml, return - the model object. If in RAM will load into GPU VRAM. - If on disk, will load from there. + def get_model(self, model_name: str = None) -> dict: + """Given a model named identified in models.yaml, return a dict + containing the model object and some of its key features. If + in RAM will load into GPU VRAM. If on disk, will load from + there. + The dict has the following keys: + 'model': The StableDiffusionGeneratorPipeline object + 'model_name': The name of the model in models.yaml + 'width': The width of images trained by this model + 'height': The height of images trained by this model + 'hash': A unique hash of this model's files on disk. """ if not model_name: - return self.get_model(self.current_model) if self.current_model else self.get_model(self.default_model()) - + return ( + self.get_model(self.current_model) + if self.current_model + else self.get_model(self.default_model()) + ) + if not self.valid_model(model_name): print( f'** "{model_name}" is not a known model name. Please check your models.yaml file' @@ -135,6 +170,81 @@ class ModelManager(object): "hash": hash, } + def get_model_vae(self, model_name: str=None)->AutoencoderKL: + """Given a model name identified in models.yaml, load the model into + GPU if necessary and return its assigned VAE as an + AutoencoderKL object. If no model name is provided, return the + vae from the model currently in the GPU. + """ + return self._get_sub_model(model_name, SDModelComponent.vae) + + def get_model_tokenizer(self, model_name: str=None)->CLIPTokenizer: + """Given a model name identified in models.yaml, load the model into + GPU if necessary and return its assigned CLIPTokenizer. If no + model name is provided, return the tokenizer from the model + currently in the GPU. + """ + return self._get_sub_model(model_name, SDModelComponent.tokenizer) + + def get_model_unet(self, model_name: str=None)->UNet2DConditionModel: + """Given a model name identified in models.yaml, load the model into + GPU if necessary and return its assigned UNet2DConditionModel. If no model + name is provided, return the UNet from the model + currently in the GPU. + """ + return self._get_sub_model(model_name, SDModelComponent.unet) + + def get_model_text_encoder(self, model_name: str=None)->CLIPTextModel: + """Given a model name identified in models.yaml, load the model into + GPU if necessary and return its assigned CLIPTextModel. If no + model name is provided, return the text encoder from the model + currently in the GPU. + """ + return self._get_sub_model(model_name, SDModelComponent.text_encoder) + + def get_model_feature_extractor(self, model_name: str=None)->CLIPFeatureExtractor: + """Given a model name identified in models.yaml, load the model into + GPU if necessary and return its assigned CLIPFeatureExtractor. If no + model name is provided, return the text encoder from the model + currently in the GPU. + """ + return self._get_sub_model(model_name, SDModelComponent.feature_extractor) + + def get_model_scheduler(self, model_name: str=None)->SchedulerMixin: + """Given a model name identified in models.yaml, load the model into + GPU if necessary and return its assigned scheduler. If no + model name is provided, return the text encoder from the model + currently in the GPU. + """ + return self._get_sub_model(model_name, SDModelComponent.scheduler) + + def _get_sub_model( + self, + model_name: str=None, + model_part: SDModelComponent=SDModelComponent.vae, + ) -> Union[ + AutoencoderKL, + CLIPTokenizer, + CLIPFeatureExtractor, + UNet2DConditionModel, + CLIPTextModel, + StableDiffusionSafetyChecker, + ]: + """Given a model name identified in models.yaml, and the part of the + model you wish to retrieve, return that part. Parts are in an Enum + class named SDModelComponent, and consist of: + SDModelComponent.vae + SDModelComponent.text_encoder + SDModelComponent.tokenizer + SDModelComponent.unet + SDModelComponent.scheduler + SDModelComponent.safety_checker + SDModelComponent.feature_extractor + """ + model_dict = self.get_model(model_name) + model = model_dict["model"] + return getattr(model, model_part.value) + def default_model(self) -> str | None: """ Returns the name of the default model, or None @@ -360,7 +470,7 @@ class ModelManager(object): f"Unknown model format {model_name}: {model_format}" ) self._add_embeddings_to_model(model) - + # usage statistics toc = time.time() print(">> Model loaded in", "%4.2fs" % (toc - tic)) @@ -433,7 +543,7 @@ class ModelManager(object): width = pipeline.unet.config.sample_size * pipeline.vae_scale_factor height = width print(f" | Default image dimensions = {width} x {height}") - + return pipeline, width, height, model_hash def _load_ckpt_model(self, model_name, mconfig): @@ -454,14 +564,18 @@ class ModelManager(object): from . import load_pipeline_from_original_stable_diffusion_ckpt try: - if self.list_models()[self.current_model]['status'] == 'active': + if self.list_models()[self.current_model]["status"] == "active": self.offload_model(self.current_model) except Exception as e: pass - + vae_path = None if vae: - vae_path = vae if os.path.isabs(vae) else os.path.normpath(os.path.join(Globals.root, vae)) + vae_path = ( + vae + if os.path.isabs(vae) + else os.path.normpath(os.path.join(Globals.root, vae)) + ) if self._has_cuda(): torch.cuda.empty_cache() pipeline = load_pipeline_from_original_stable_diffusion_ckpt( @@ -571,9 +685,7 @@ class ModelManager(object): models.yaml file. """ model_name = model_name or Path(repo_or_path).stem - model_description = ( - description or f"Imported diffusers model {model_name}" - ) + model_description = description or f"Imported diffusers model {model_name}" new_config = dict( description=model_description, vae=vae, @@ -602,7 +714,7 @@ class ModelManager(object): SDLegacyType.V2_v (V2 using 'v_prediction' prediction type) SDLegacyType.UNKNOWN """ - global_step = checkpoint.get('global_step') + global_step = checkpoint.get("global_step") state_dict = checkpoint.get("state_dict") or checkpoint try: @@ -628,13 +740,13 @@ class ModelManager(object): return SDLegacyType.UNKNOWN def heuristic_import( - self, - path_url_or_repo: str, - model_name: str = None, - description: str = None, - model_config_file: Path = None, - commit_to_conf: Path = None, - config_file_callback: Callable[[Path], Path] = None, + self, + path_url_or_repo: str, + model_name: str = None, + description: str = None, + model_config_file: Path = None, + commit_to_conf: Path = None, + config_file_callback: Callable[[Path], Path] = None, ) -> str: """Accept a string which could be: - a HF diffusers repo_id @@ -738,8 +850,8 @@ class ModelManager(object): # another round of heuristics to guess the correct config file. checkpoint = None - if model_path.suffix in [".ckpt",".pt"]: - self.scan_model(model_path,model_path) + if model_path.suffix in [".ckpt", ".pt"]: + self.scan_model(model_path, model_path) checkpoint = torch.load(model_path) else: checkpoint = safetensors.torch.load_file(model_path) @@ -761,19 +873,16 @@ class ModelManager(object): elif model_type == SDLegacyType.V1_INPAINT: print(" | SD-v1 inpainting model detected") model_config_file = Path( - Globals.root, "configs/stable-diffusion/v1-inpainting-inference.yaml" + Globals.root, + "configs/stable-diffusion/v1-inpainting-inference.yaml", ) elif model_type == SDLegacyType.V2_v: - print( - " | SD-v2-v model detected" - ) + print(" | SD-v2-v model detected") model_config_file = Path( Globals.root, "configs/stable-diffusion/v2-inference-v.yaml" ) elif model_type == SDLegacyType.V2_e: - print( - " | SD-v2-e model detected" - ) + print(" | SD-v2-e model detected") model_config_file = Path( Globals.root, "configs/stable-diffusion/v2-inference.yaml" ) @@ -820,16 +929,16 @@ class ModelManager(object): return model_name def convert_and_import( - self, - ckpt_path: Path, - diffusers_path: Path, - model_name=None, - model_description=None, - vae:dict=None, - vae_path:Path=None, - original_config_file: Path = None, - commit_to_conf: Path = None, - scan_needed: bool=True, + self, + ckpt_path: Path, + diffusers_path: Path, + model_name=None, + model_description=None, + vae: dict = None, + vae_path: Path = None, + original_config_file: Path = None, + commit_to_conf: Path = None, + scan_needed: bool = True, ) -> str: """ Convert a legacy ckpt weights file to diffuser model and import @@ -857,10 +966,10 @@ class ModelManager(object): try: # By passing the specified VAE to the conversion function, the autoencoder # will be built into the model rather than tacked on afterward via the config file - vae_model=None + vae_model = None if vae: - vae_model=self._load_vae(vae) - vae_path=None + vae_model = self._load_vae(vae) + vae_path = None convert_ckpt_to_diffusers( ckpt_path, diffusers_path, @@ -976,16 +1085,16 @@ class ModelManager(object): legacy_locations = [ Path( models_dir, - "CompVis/stable-diffusion-safety-checker/models--CompVis--stable-diffusion-safety-checker" + "CompVis/stable-diffusion-safety-checker/models--CompVis--stable-diffusion-safety-checker", ), Path(models_dir, "bert-base-uncased/models--bert-base-uncased"), Path( models_dir, - "openai/clip-vit-large-patch14/models--openai--clip-vit-large-patch14" + "openai/clip-vit-large-patch14/models--openai--clip-vit-large-patch14", ), ] - legacy_locations.extend(list(global_cache_dir("diffusers").glob('*'))) - + legacy_locations.extend(list(global_cache_dir("diffusers").glob("*"))) + legacy_layout = False for model in legacy_locations: legacy_layout = legacy_layout or model.exists() @@ -1003,7 +1112,7 @@ class ModelManager(object): >> make adjustments, please press ctrl-C now to abort and relaunch InvokeAI when you are ready. >> Otherwise press to continue.""" ) - input('continue> ') + input("continue> ") # transformer files get moved into the hub directory if cls._is_huggingface_hub_directory_present(): @@ -1090,7 +1199,7 @@ class ModelManager(object): print( f'>> Textual inversion triggers: {", ".join(sorted(model.textual_inversion_manager.get_all_trigger_strings()))}' ) - + def _has_cuda(self) -> bool: return self.device.type == "cuda"