diff --git a/invokeai/backend/model_management/README b/invokeai/backend/model_management/README new file mode 100644 index 0000000000..c7388df72e --- /dev/null +++ b/invokeai/backend/model_management/README @@ -0,0 +1 @@ +The contents of this directory are deprecated. model_manager.py is here only for reference. diff --git a/invokeai/backend/model_management/model_probe.py b/invokeai/backend/model_management/model_probe.py deleted file mode 100644 index 184a5c7733..0000000000 --- a/invokeai/backend/model_management/model_probe.py +++ /dev/null @@ -1,605 +0,0 @@ -import json -import re -from dataclasses import dataclass -from pathlib import Path -from typing import Callable, Dict, Literal, Optional, Union - -import safetensors.torch -import torch -from diffusers import ConfigMixin, ModelMixin -from picklescan.scanner import scan_file_path - -from invokeai.backend.model_management.models.ip_adapter import IPAdapterModelFormat - -from .models import ( - BaseModelType, - InvalidModelException, - ModelType, - ModelVariantType, - SchedulerPredictionType, - SilenceWarnings, -) -from .models.base import read_checkpoint_meta -from .util import lora_token_vector_length - - -@dataclass -class ModelProbeInfo(object): - model_type: ModelType - base_type: BaseModelType - variant_type: ModelVariantType - prediction_type: SchedulerPredictionType - upcast_attention: bool - format: Literal["diffusers", "checkpoint", "lycoris", "olive", "onnx"] - image_size: int - - -class ProbeBase(object): - """forward declaration""" - - pass - - -class ModelProbe(object): - PROBES = { - "diffusers": {}, - "checkpoint": {}, - "onnx": {}, - } - - CLASS2TYPE = { - "StableDiffusionPipeline": ModelType.Main, - "StableDiffusionInpaintPipeline": ModelType.Main, - "StableDiffusionXLPipeline": ModelType.Main, - "StableDiffusionXLImg2ImgPipeline": ModelType.Main, - "StableDiffusionXLInpaintPipeline": ModelType.Main, - "AutoencoderKL": ModelType.Vae, - "AutoencoderTiny": ModelType.Vae, - "ControlNetModel": ModelType.ControlNet, - "CLIPVisionModelWithProjection": ModelType.CLIPVision, - } - - @classmethod - def register_probe( - cls, format: Literal["diffusers", "checkpoint", "onnx"], model_type: ModelType, probe_class: ProbeBase - ): - cls.PROBES[format][model_type] = probe_class - - @classmethod - def heuristic_probe( - cls, - model: Union[Dict, ModelMixin, Path], - prediction_type_helper: Callable[[Path], SchedulerPredictionType] = None, - ) -> ModelProbeInfo: - if isinstance(model, Path): - return cls.probe(model_path=model, prediction_type_helper=prediction_type_helper) - elif isinstance(model, (dict, ModelMixin, ConfigMixin)): - return cls.probe(model_path=None, model=model, prediction_type_helper=prediction_type_helper) - else: - raise InvalidModelException("model parameter {model} is neither a Path, nor a model") - - @classmethod - def probe( - cls, - model_path: Path, - model: Optional[Union[Dict, ModelMixin]] = None, - prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None, - ) -> Optional[ModelProbeInfo]: - """ - Probe the model at model_path and return sufficient information about it - to place it somewhere in the models directory hierarchy. If the model is - already loaded into memory, you may provide it as model in order to avoid - opening it a second time. The prediction_type_helper callable is a function that receives - the path to the model and returns the BaseModelType. It is called to distinguish - between V2-Base and V2-768 SD models. - """ - if model_path: - format_type = "diffusers" if model_path.is_dir() else "checkpoint" - else: - format_type = "diffusers" if isinstance(model, (ConfigMixin, ModelMixin)) else "checkpoint" - model_info = None - try: - model_type = ( - cls.get_model_type_from_folder(model_path, model) - if format_type == "diffusers" - else cls.get_model_type_from_checkpoint(model_path, model) - ) - format_type = "onnx" if model_type == ModelType.ONNX else format_type - probe_class = cls.PROBES[format_type].get(model_type) - if not probe_class: - return None - probe = probe_class(model_path, model, prediction_type_helper) - base_type = probe.get_base_type() - variant_type = probe.get_variant_type() - prediction_type = probe.get_scheduler_prediction_type() - format = probe.get_format() - model_info = ModelProbeInfo( - model_type=model_type, - base_type=base_type, - variant_type=variant_type, - prediction_type=prediction_type, - upcast_attention=( - base_type == BaseModelType.StableDiffusion2 - and prediction_type == SchedulerPredictionType.VPrediction - ), - format=format, - image_size=( - 1024 - if (base_type in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner}) - else ( - 768 - if ( - base_type == BaseModelType.StableDiffusion2 - and prediction_type == SchedulerPredictionType.VPrediction - ) - else 512 - ) - ), - ) - except Exception: - raise - - return model_info - - @classmethod - def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: dict) -> ModelType: - if model_path.suffix not in (".bin", ".pt", ".ckpt", ".safetensors", ".pth"): - return None - - if model_path.name == "learned_embeds.bin": - return ModelType.TextualInversion - - ckpt = checkpoint if checkpoint else read_checkpoint_meta(model_path, scan=True) - ckpt = ckpt.get("state_dict", ckpt) - - for key in ckpt.keys(): - if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}): - return ModelType.Main - elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}): - return ModelType.Vae - elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}): - return ModelType.Lora - elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}): - return ModelType.Lora - elif any(key.startswith(v) for v in {"control_model", "input_blocks"}): - return ModelType.ControlNet - elif key in {"emb_params", "string_to_param"}: - return ModelType.TextualInversion - - else: - # diffusers-ti - if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()): - return ModelType.TextualInversion - - raise InvalidModelException(f"Unable to determine model type for {model_path}") - - @classmethod - def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin) -> ModelType: - """ - Get the model type of a hugging-face style folder. - """ - class_name = None - error_hint = None - if model: - class_name = model.__class__.__name__ - else: - if (folder_path / "unet/model.onnx").exists(): - return ModelType.ONNX - if (folder_path / "learned_embeds.bin").exists(): - return ModelType.TextualInversion - if (folder_path / "pytorch_lora_weights.bin").exists(): - return ModelType.Lora - if (folder_path / "image_encoder.txt").exists(): - return ModelType.IPAdapter - - i = folder_path / "model_index.json" - c = folder_path / "config.json" - config_path = i if i.exists() else c if c.exists() else None - - if config_path: - with open(config_path, "r") as file: - conf = json.load(file) - if "_class_name" in conf: - class_name = conf["_class_name"] - elif "architectures" in conf: - class_name = conf["architectures"][0] - else: - class_name = None - else: - error_hint = f"No model_index.json or config.json found in {folder_path}." - - if class_name and (type := cls.CLASS2TYPE.get(class_name)): - return type - else: - error_hint = f"class {class_name} is not one of the supported classes [{', '.join(cls.CLASS2TYPE.keys())}]" - - # give up - raise InvalidModelException( - f"Unable to determine model type for {folder_path}" + (f"; {error_hint}" if error_hint else "") - ) - - @classmethod - def _scan_and_load_checkpoint(cls, model_path: Path) -> dict: - with SilenceWarnings(): - if model_path.suffix.endswith((".ckpt", ".pt", ".bin")): - cls._scan_model(model_path, model_path) - return torch.load(model_path) - else: - return safetensors.torch.load_file(model_path) - - @classmethod - def _scan_model(cls, model_name, checkpoint): - """ - Apply picklescanner to the indicated checkpoint and issue a warning - and option to exit if an infected file is identified. - """ - # scan model - scan_result = scan_file_path(checkpoint) - if scan_result.infected_files != 0: - raise "The model {model_name} is potentially infected by malware. Aborting import." - - -# ##################################################3 -# Checkpoint probing -# ##################################################3 -class ProbeBase(object): - def get_base_type(self) -> BaseModelType: - pass - - def get_variant_type(self) -> ModelVariantType: - pass - - def get_scheduler_prediction_type(self) -> SchedulerPredictionType: - pass - - def get_format(self) -> str: - pass - - -class CheckpointProbeBase(ProbeBase): - def __init__( - self, checkpoint_path: Path, checkpoint: dict, helper: Callable[[Path], SchedulerPredictionType] = None - ) -> BaseModelType: - self.checkpoint = checkpoint or ModelProbe._scan_and_load_checkpoint(checkpoint_path) - self.checkpoint_path = checkpoint_path - self.helper = helper - - def get_base_type(self) -> BaseModelType: - pass - - def get_format(self) -> str: - return "checkpoint" - - def get_variant_type(self) -> ModelVariantType: - model_type = ModelProbe.get_model_type_from_checkpoint(self.checkpoint_path, self.checkpoint) - if model_type != ModelType.Main: - return ModelVariantType.Normal - state_dict = self.checkpoint.get("state_dict") or self.checkpoint - in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1] - if in_channels == 9: - return ModelVariantType.Inpaint - elif in_channels == 5: - return ModelVariantType.Depth - elif in_channels == 4: - return ModelVariantType.Normal - else: - raise InvalidModelException( - f"Cannot determine variant type (in_channels={in_channels}) at {self.checkpoint_path}" - ) - - -class PipelineCheckpointProbe(CheckpointProbeBase): - def get_base_type(self) -> BaseModelType: - checkpoint = self.checkpoint - state_dict = self.checkpoint.get("state_dict") or checkpoint - key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" - if key_name in state_dict and state_dict[key_name].shape[-1] == 768: - return BaseModelType.StableDiffusion1 - if key_name in state_dict and state_dict[key_name].shape[-1] == 1024: - return BaseModelType.StableDiffusion2 - key_name = "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight" - if key_name in state_dict and state_dict[key_name].shape[-1] == 2048: - return BaseModelType.StableDiffusionXL - elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280: - return BaseModelType.StableDiffusionXLRefiner - else: - raise InvalidModelException("Cannot determine base type") - - def get_scheduler_prediction_type(self) -> SchedulerPredictionType: - type = self.get_base_type() - if type == BaseModelType.StableDiffusion1: - return SchedulerPredictionType.Epsilon - checkpoint = self.checkpoint - state_dict = self.checkpoint.get("state_dict") or checkpoint - key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" - if key_name in state_dict and state_dict[key_name].shape[-1] == 1024: - if "global_step" in checkpoint: - if checkpoint["global_step"] == 220000: - return SchedulerPredictionType.Epsilon - elif checkpoint["global_step"] == 110000: - return SchedulerPredictionType.VPrediction - if ( - self.checkpoint_path and self.helper and not self.checkpoint_path.with_suffix(".yaml").exists() - ): # if a .yaml config file exists, then this step not needed - return self.helper(self.checkpoint_path) - else: - return None - - -class VaeCheckpointProbe(CheckpointProbeBase): - def get_base_type(self) -> BaseModelType: - # I can't find any standalone 2.X VAEs to test with! - return BaseModelType.StableDiffusion1 - - -class LoRACheckpointProbe(CheckpointProbeBase): - def get_format(self) -> str: - return "lycoris" - - def get_base_type(self) -> BaseModelType: - checkpoint = self.checkpoint - token_vector_length = lora_token_vector_length(checkpoint) - - if token_vector_length == 768: - return BaseModelType.StableDiffusion1 - elif token_vector_length == 1024: - return BaseModelType.StableDiffusion2 - elif token_vector_length == 2048: - return BaseModelType.StableDiffusionXL - else: - raise InvalidModelException(f"Unknown LoRA type: {self.checkpoint_path}") - - -class TextualInversionCheckpointProbe(CheckpointProbeBase): - def get_format(self) -> str: - return None - - def get_base_type(self) -> BaseModelType: - checkpoint = self.checkpoint - if "string_to_token" in checkpoint: - token_dim = list(checkpoint["string_to_param"].values())[0].shape[-1] - elif "emb_params" in checkpoint: - token_dim = checkpoint["emb_params"].shape[-1] - else: - token_dim = list(checkpoint.values())[0].shape[0] - if token_dim == 768: - return BaseModelType.StableDiffusion1 - elif token_dim == 1024: - return BaseModelType.StableDiffusion2 - else: - return None - - -class ControlNetCheckpointProbe(CheckpointProbeBase): - def get_base_type(self) -> BaseModelType: - checkpoint = self.checkpoint - for key_name in ( - "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight", - "input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight", - ): - if key_name not in checkpoint: - continue - if checkpoint[key_name].shape[-1] == 768: - return BaseModelType.StableDiffusion1 - elif checkpoint[key_name].shape[-1] == 1024: - return BaseModelType.StableDiffusion2 - elif self.checkpoint_path and self.helper: - return self.helper(self.checkpoint_path) - raise InvalidModelException("Unable to determine base type for {self.checkpoint_path}") - - -class IPAdapterCheckpointProbe(CheckpointProbeBase): - def get_base_type(self) -> BaseModelType: - raise NotImplementedError() - - -class CLIPVisionCheckpointProbe(CheckpointProbeBase): - def get_base_type(self) -> BaseModelType: - raise NotImplementedError() - - -######################################################## -# classes for probing folders -####################################################### -class FolderProbeBase(ProbeBase): - def __init__(self, folder_path: Path, model: ModelMixin = None, helper: Callable = None): # not used - self.model = model - self.folder_path = folder_path - - def get_variant_type(self) -> ModelVariantType: - return ModelVariantType.Normal - - def get_format(self) -> str: - return "diffusers" - - -class PipelineFolderProbe(FolderProbeBase): - def get_base_type(self) -> BaseModelType: - if self.model: - unet_conf = self.model.unet.config - else: - with open(self.folder_path / "unet" / "config.json", "r") as file: - unet_conf = json.load(file) - if unet_conf["cross_attention_dim"] == 768: - return BaseModelType.StableDiffusion1 - elif unet_conf["cross_attention_dim"] == 1024: - return BaseModelType.StableDiffusion2 - elif unet_conf["cross_attention_dim"] == 1280: - return BaseModelType.StableDiffusionXLRefiner - elif unet_conf["cross_attention_dim"] == 2048: - return BaseModelType.StableDiffusionXL - else: - raise InvalidModelException(f"Unknown base model for {self.folder_path}") - - def get_scheduler_prediction_type(self) -> SchedulerPredictionType: - if self.model: - scheduler_conf = self.model.scheduler.config - else: - with open(self.folder_path / "scheduler" / "scheduler_config.json", "r") as file: - scheduler_conf = json.load(file) - if scheduler_conf["prediction_type"] == "v_prediction": - return SchedulerPredictionType.VPrediction - elif scheduler_conf["prediction_type"] == "epsilon": - return SchedulerPredictionType.Epsilon - else: - return None - - def get_variant_type(self) -> ModelVariantType: - # This only works for pipelines! Any kind of - # exception results in our returning the - # "normal" variant type - try: - if self.model: - conf = self.model.unet.config - else: - config_file = self.folder_path / "unet" / "config.json" - with open(config_file, "r") as file: - conf = json.load(file) - - in_channels = conf["in_channels"] - if in_channels == 9: - return ModelVariantType.Inpaint - elif in_channels == 5: - return ModelVariantType.Depth - elif in_channels == 4: - return ModelVariantType.Normal - except Exception: - pass - return ModelVariantType.Normal - - -class VaeFolderProbe(FolderProbeBase): - def get_base_type(self) -> BaseModelType: - if self._config_looks_like_sdxl(): - return BaseModelType.StableDiffusionXL - elif self._name_looks_like_sdxl(): - # but SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down - # by a factor of 8), we can't necessarily tell them apart by config hyperparameters. - return BaseModelType.StableDiffusionXL - else: - return BaseModelType.StableDiffusion1 - - def _config_looks_like_sdxl(self) -> bool: - # config values that distinguish Stability's SD 1.x VAE from their SDXL VAE. - config_file = self.folder_path / "config.json" - if not config_file.exists(): - raise InvalidModelException(f"Cannot determine base type for {self.folder_path}") - with open(config_file, "r") as file: - config = json.load(file) - return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024] - - def _name_looks_like_sdxl(self) -> bool: - return bool(re.search(r"xl\b", self._guess_name(), re.IGNORECASE)) - - def _guess_name(self) -> str: - name = self.folder_path.name - if name == "vae": - name = self.folder_path.parent.name - return name - - -class TextualInversionFolderProbe(FolderProbeBase): - def get_format(self) -> str: - return None - - def get_base_type(self) -> BaseModelType: - path = self.folder_path / "learned_embeds.bin" - if not path.exists(): - return None - checkpoint = ModelProbe._scan_and_load_checkpoint(path) - return TextualInversionCheckpointProbe(None, checkpoint=checkpoint).get_base_type() - - -class ONNXFolderProbe(FolderProbeBase): - def get_format(self) -> str: - return "onnx" - - def get_base_type(self) -> BaseModelType: - return BaseModelType.StableDiffusion1 - - def get_variant_type(self) -> ModelVariantType: - return ModelVariantType.Normal - - -class ControlNetFolderProbe(FolderProbeBase): - def get_base_type(self) -> BaseModelType: - config_file = self.folder_path / "config.json" - if not config_file.exists(): - raise InvalidModelException(f"Cannot determine base type for {self.folder_path}") - with open(config_file, "r") as file: - config = json.load(file) - # no obvious way to distinguish between sd2-base and sd2-768 - dimension = config["cross_attention_dim"] - base_model = ( - BaseModelType.StableDiffusion1 - if dimension == 768 - else ( - BaseModelType.StableDiffusion2 - if dimension == 1024 - else BaseModelType.StableDiffusionXL - if dimension == 2048 - else None - ) - ) - if not base_model: - raise InvalidModelException(f"Unable to determine model base for {self.folder_path}") - return base_model - - -class LoRAFolderProbe(FolderProbeBase): - def get_base_type(self) -> BaseModelType: - model_file = None - for suffix in ["safetensors", "bin"]: - base_file = self.folder_path / f"pytorch_lora_weights.{suffix}" - if base_file.exists(): - model_file = base_file - break - if not model_file: - raise InvalidModelException("Unknown LoRA format encountered") - return LoRACheckpointProbe(model_file, None).get_base_type() - - -class IPAdapterFolderProbe(FolderProbeBase): - def get_format(self) -> str: - return IPAdapterModelFormat.InvokeAI.value - - def get_base_type(self) -> BaseModelType: - model_file = self.folder_path / "ip_adapter.bin" - if not model_file.exists(): - raise InvalidModelException("Unknown IP-Adapter model format.") - - state_dict = torch.load(model_file, map_location="cpu") - cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1] - if cross_attention_dim == 768: - return BaseModelType.StableDiffusion1 - elif cross_attention_dim == 1024: - return BaseModelType.StableDiffusion2 - elif cross_attention_dim == 2048: - return BaseModelType.StableDiffusionXL - else: - raise InvalidModelException(f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}.") - - -class CLIPVisionFolderProbe(FolderProbeBase): - def get_base_type(self) -> BaseModelType: - return BaseModelType.Any - - -############## register probe classes ###### -ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.Lora, LoRAFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe) - -ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe) - -ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe) diff --git a/invokeai/backend/model_management/model_search.py b/invokeai/backend/model_management/model_search.py deleted file mode 100644 index be969900ac..0000000000 --- a/invokeai/backend/model_management/model_search.py +++ /dev/null @@ -1,108 +0,0 @@ -# Copyright 2023, Lincoln D. Stein and the InvokeAI Team -""" -Abstract base class for recursive directory search for models. -""" - -import os -from abc import ABC, abstractmethod -from pathlib import Path -from typing import List, Set, types - -import invokeai.backend.util.logging as logger - - -class ModelSearch(ABC): - def __init__(self, directories: List[Path], logger: types.ModuleType = logger): - """ - Initialize a recursive model directory search. - :param directories: List of directory Paths to recurse through - :param logger: Logger to use - """ - self.directories = directories - self.logger = logger - self._items_scanned = 0 - self._models_found = 0 - self._scanned_dirs = set() - self._scanned_paths = set() - self._pruned_paths = set() - - @abstractmethod - def on_search_started(self): - """ - Called before the scan starts. - """ - pass - - @abstractmethod - def on_model_found(self, model: Path): - """ - Process a found model. Raise an exception if something goes wrong. - :param model: Model to process - could be a directory or checkpoint. - """ - pass - - @abstractmethod - def on_search_completed(self): - """ - Perform some activity when the scan is completed. May use instance - variables, items_scanned and models_found - """ - pass - - def search(self): - self.on_search_started() - for dir in self.directories: - self.walk_directory(dir) - self.on_search_completed() - - def walk_directory(self, path: Path): - for root, dirs, files in os.walk(path, followlinks=True): - if str(Path(root).name).startswith("."): - self._pruned_paths.add(root) - if any([Path(root).is_relative_to(x) for x in self._pruned_paths]): - continue - - self._items_scanned += len(dirs) + len(files) - for d in dirs: - path = Path(root) / d - if path in self._scanned_paths or path.parent in self._scanned_dirs: - self._scanned_dirs.add(path) - continue - if any( - [ - (path / x).exists() - for x in ["config.json", "model_index.json", "learned_embeds.bin", "pytorch_lora_weights.bin"] - ] - ): - try: - self.on_model_found(path) - self._models_found += 1 - self._scanned_dirs.add(path) - except Exception as e: - self.logger.warning(f"Failed to process '{path}': {e}") - - for f in files: - path = Path(root) / f - if path.parent in self._scanned_dirs: - continue - if path.suffix in {".ckpt", ".bin", ".pth", ".safetensors", ".pt"}: - try: - self.on_model_found(path) - self._models_found += 1 - except Exception as e: - self.logger.warning(f"Failed to process '{path}': {e}") - - -class FindModels(ModelSearch): - def on_search_started(self): - self.models_found: Set[Path] = set() - - def on_model_found(self, model: Path): - self.models_found.add(model) - - def on_search_completed(self): - pass - - def list_models(self) -> List[Path]: - self.search() - return list(self.models_found) diff --git a/invokeai/backend/model_management/util.py b/invokeai/backend/model_management/util.py deleted file mode 100644 index 6d70107c93..0000000000 --- a/invokeai/backend/model_management/util.py +++ /dev/null @@ -1,75 +0,0 @@ -# Copyright (c) 2023 The InvokeAI Development Team -"""Utilities used by the Model Manager""" - - -def lora_token_vector_length(checkpoint: dict) -> int: - """ - Given a checkpoint in memory, return the lora token vector length - - :param checkpoint: The checkpoint - """ - - def _get_shape_1(key, tensor, checkpoint): - lora_token_vector_length = None - - if "." not in key: - return lora_token_vector_length # wrong key format - model_key, lora_key = key.split(".", 1) - - # check lora/locon - if lora_key == "lora_down.weight": - lora_token_vector_length = tensor.shape[1] - - # check loha (don't worry about hada_t1/hada_t2 as it used only in 4d shapes) - elif lora_key in ["hada_w1_b", "hada_w2_b"]: - lora_token_vector_length = tensor.shape[1] - - # check lokr (don't worry about lokr_t2 as it used only in 4d shapes) - elif "lokr_" in lora_key: - if model_key + ".lokr_w1" in checkpoint: - _lokr_w1 = checkpoint[model_key + ".lokr_w1"] - elif model_key + "lokr_w1_b" in checkpoint: - _lokr_w1 = checkpoint[model_key + ".lokr_w1_b"] - else: - return lora_token_vector_length # unknown format - - if model_key + ".lokr_w2" in checkpoint: - _lokr_w2 = checkpoint[model_key + ".lokr_w2"] - elif model_key + "lokr_w2_b" in checkpoint: - _lokr_w2 = checkpoint[model_key + ".lokr_w2_b"] - else: - return lora_token_vector_length # unknown format - - lora_token_vector_length = _lokr_w1.shape[1] * _lokr_w2.shape[1] - - elif lora_key == "diff": - lora_token_vector_length = tensor.shape[1] - - # ia3 can be detected only by shape[0] in text encoder - elif lora_key == "weight" and "lora_unet_" not in model_key: - lora_token_vector_length = tensor.shape[0] - - return lora_token_vector_length - - lora_token_vector_length = None - lora_te1_length = None - lora_te2_length = None - for key, tensor in checkpoint.items(): - if key.startswith("lora_unet_") and ("_attn2_to_k." in key or "_attn2_to_v." in key): - lora_token_vector_length = _get_shape_1(key, tensor, checkpoint) - elif key.startswith("lora_te") and "_self_attn_" in key: - tmp_length = _get_shape_1(key, tensor, checkpoint) - if key.startswith("lora_te_"): - lora_token_vector_length = tmp_length - elif key.startswith("lora_te1_"): - lora_te1_length = tmp_length - elif key.startswith("lora_te2_"): - lora_te2_length = tmp_length - - if lora_te1_length is not None and lora_te2_length is not None: - lora_token_vector_length = lora_te1_length + lora_te2_length - - if lora_token_vector_length is not None: - break - - return lora_token_vector_length diff --git a/invokeai/backend/model_manager/loader.py b/invokeai/backend/model_manager/loader.py index 222db42148..f1813b2b45 100644 --- a/invokeai/backend/model_manager/loader.py +++ b/invokeai/backend/model_manager/loader.py @@ -17,7 +17,7 @@ from .config import BaseModelType, ModelConfigBase, ModelType, SubModelType from .download import DownloadEventHandler from .install import ModelInstall, ModelInstallBase from .models import MODEL_CLASSES, InvalidModelException, ModelBase -from .storage import ModelConfigStore, get_config_store +from .storage import ConfigFileVersionMismatchException, ModelConfigStore, get_config_store, migrate_models_store @dataclass @@ -138,7 +138,12 @@ class ModelLoad(ModelLoadBase): models_file = config.model_conf_path else: models_file = config.root_path / "configs/models3.yaml" - store = get_config_store(models_file) + try: + store = get_config_store(models_file) + except ConfigFileVersionMismatchException: + migrate_models_store(config) + store = get_config_store(models_file) + if not store: raise ValueError(f"Invalid model configuration file: {models_file}") diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index d483184df8..ba73b25789 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -7,6 +7,7 @@ its base type, model type, format and variant. """ import json +import re from abc import ABC, abstractmethod from pathlib import Path from typing import Callable, Optional @@ -493,20 +494,33 @@ class PipelineFolderProbe(FolderProbeBase): class VaeFolderProbe(FolderProbeBase): - """Probe a diffusers-style VAE model.""" - def get_base_type(self) -> BaseModelType: - """Return the BaseModelType for a diffusers-style VAE.""" + if self._config_looks_like_sdxl(): + return BaseModelType.StableDiffusionXL + elif self._name_looks_like_sdxl(): + # but SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down + # by a factor of 8), we can't necessarily tell them apart by config hyperparameters. + return BaseModelType.StableDiffusionXL + else: + return BaseModelType.StableDiffusion1 + + def _config_looks_like_sdxl(self) -> bool: + # config values that distinguish Stability's SD 1.x VAE from their SDXL VAE. config_file = self.folder_path / "config.json" if not config_file.exists(): raise InvalidModelException(f"Cannot determine base type for {self.folder_path}") with open(config_file, "r") as file: config = json.load(file) - return ( - BaseModelType.StableDiffusionXL - if config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024] - else BaseModelType.StableDiffusion1 - ) + return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024] + + def _name_looks_like_sdxl(self) -> bool: + return bool(re.search(r"xl\b", self._guess_name(), re.IGNORECASE)) + + def _guess_name(self) -> str: + name = self.folder_path.name + if name == "vae": + name = self.folder_path.parent.name + return name class TextualInversionFolderProbe(FolderProbeBase): diff --git a/invokeai/backend/model_manager/search.py b/invokeai/backend/model_manager/search.py index a91c38e18a..f978b57dfc 100644 --- a/invokeai/backend/model_manager/search.py +++ b/invokeai/backend/model_manager/search.py @@ -168,7 +168,13 @@ class ModelSearch(ModelSearchBase): if any( [ (path / x).exists() - for x in ["config.json", "model_index.json", "learned_embeds.bin", "pytorch_lora_weights.bin"] + for x in [ + "config.json", + "model_index.json", + "learned_embeds.bin", + "pytorch_lora_weights.bin", + "image_encoder.txt", + ] ] ): self._scanned_dirs.add(path) diff --git a/invokeai/backend/model_manager/storage/__init__.py b/invokeai/backend/model_manager/storage/__init__.py index 3094b73e31..4280721cd6 100644 --- a/invokeai/backend/model_manager/storage/__init__.py +++ b/invokeai/backend/model_manager/storage/__init__.py @@ -3,7 +3,13 @@ Initialization file for invokeai.backend.model_manager.storage """ import pathlib -from .base import DuplicateModelException, ModelConfigStore, UnknownModelException # noqa F401 +from .base import ( # noqa F401 + ConfigFileVersionMismatchException, + DuplicateModelException, + ModelConfigStore, + UnknownModelException, +) +from .migrate import migrate_models_store # noqa F401 from .sql import ModelConfigStoreSQL # noqa F401 from .yaml import ModelConfigStoreYAML # noqa F401 diff --git a/invokeai/backend/model_manager/storage/base.py b/invokeai/backend/model_manager/storage/base.py index 1dc5289ecc..9597b26862 100644 --- a/invokeai/backend/model_manager/storage/base.py +++ b/invokeai/backend/model_manager/storage/base.py @@ -4,6 +4,7 @@ Abstract base class for storing and retrieving model configuration records. """ from abc import ABC, abstractmethod +from pathlib import Path from typing import List, Optional, Set, Union from ..config import BaseModelType, ModelConfigBase, ModelType @@ -24,6 +25,10 @@ class UnknownModelException(Exception): """Raised on an attempt to fetch or delete a model with a nonexistent key.""" +class ConfigFileVersionMismatchException(Exception): + """Raised on an attempt to open a config with an incompatible version.""" + + class ModelConfigStore(ABC): """Abstract base class for storage and retrieval of model configs.""" @@ -99,6 +104,16 @@ class ModelConfigStore(ABC): """ pass + @abstractmethod + def search_by_path( + self, + path: Union[str, Path], + ) -> Optional[ModelConfigBase]: + """ + Return the model having the indicated path. + """ + pass + @abstractmethod def search_by_name( self, diff --git a/invokeai/backend/model_manager/storage/migrate.py b/invokeai/backend/model_manager/storage/migrate.py new file mode 100644 index 0000000000..05ab46596e --- /dev/null +++ b/invokeai/backend/model_manager/storage/migrate.py @@ -0,0 +1,61 @@ +# Copyright (c) 2023 The InvokeAI Development Team + +import shutil +from pathlib import Path + +from omegaconf import OmegaConf + +from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.backend.util.logging import InvokeAILogger + +from .base import CONFIG_FILE_VERSION + + +def migrate_models_store(config: InvokeAIAppConfig): + # avoid circular import + from invokeai.backend.model_manager import DuplicateModelException, InvalidModelException, ModelInstall + from invokeai.backend.model_manager.storage import get_config_store + + app_config = InvokeAIAppConfig.get_config() + logger = InvokeAILogger.getLogger() + old_file: Path = app_config.model_conf_path + new_file: Path = old_file.with_name("models3_2.yaml") + + old_conf = OmegaConf.load(old_file) + store = get_config_store(new_file) + installer = ModelInstall(store=store) + logger.info(f"Migrating old models file at {old_file} to new {CONFIG_FILE_VERSION} format") + + for model_key, stanza in old_conf.items(): + if model_key == "__metadata__": + assert ( + stanza["version"] == "3.0.0" + ), f"This script works on version 3.0.0 yaml files, but your configuration points to a {stanza['version']} version" + continue + + base_type, model_type, model_name = model_key.split("/") + + try: + path = app_config.models_path / stanza["path"] + new_key = installer.register_path(path) + except DuplicateModelException: + # if model already installed, then we just update its info + models = store.search_by_name(model_name=model_name, base_model=base_type, model_type=model_type) + if len(models) != 1: + continue + new_key = models[0].key + except Exception as excp: + print(str(excp)) + + model_info = store.get_model(new_key) + if vae := stanza.get("vae"): + model_info.vae = (app_config.models_path / vae).as_posix() + if model_config := stanza.get("config"): + model_info.config = (app_config.root_path / model_config).as_posix() + model_info.description = stanza.get("description") + store.update_model(new_key, model_info) + store.update_model(new_key, model_info) + + logger.info(f"Original version of models config file saved as {str(old_file) + '.orig'}") + shutil.move(old_file, str(old_file) + ".orig") + shutil.move(new_file, old_file) diff --git a/invokeai/backend/model_manager/storage/sql.py b/invokeai/backend/model_manager/storage/sql.py index 71b8332103..9487b755b5 100644 --- a/invokeai/backend/model_manager/storage/sql.py +++ b/invokeai/backend/model_manager/storage/sql.py @@ -477,3 +477,9 @@ class ModelConfigStoreSQL(ModelConfigStore): finally: self._lock.release() return results + + def search_by_path(self, path: Union[str, Path]) -> Optional[ModelConfigBase]: + """ + Return the model with the indicated path, or None.. + """ + raise NotImplementedError("search_by_path not implemented in storage.sql") diff --git a/invokeai/backend/model_manager/storage/yaml.py b/invokeai/backend/model_manager/storage/yaml.py index 1f56a2dbab..68ce9f14bb 100644 --- a/invokeai/backend/model_manager/storage/yaml.py +++ b/invokeai/backend/model_manager/storage/yaml.py @@ -50,7 +50,13 @@ from omegaconf import OmegaConf from omegaconf.dictconfig import DictConfig from ..config import BaseModelType, ModelConfigBase, ModelConfigFactory, ModelType -from .base import CONFIG_FILE_VERSION, DuplicateModelException, ModelConfigStore, UnknownModelException +from .base import ( + CONFIG_FILE_VERSION, + ConfigFileVersionMismatchException, + DuplicateModelException, + ModelConfigStore, + UnknownModelException, +) class ModelConfigStoreYAML(ModelConfigStore): @@ -68,9 +74,8 @@ class ModelConfigStoreYAML(ModelConfigStore): if not self._filename.exists(): self._initialize_yaml() self._config = OmegaConf.load(self._filename) - assert ( - str(self.version) == CONFIG_FILE_VERSION - ), f"Model config version {self.version} does not match expected version {CONFIG_FILE_VERSION}" + if str(self.version) != CONFIG_FILE_VERSION: + raise ConfigFileVersionMismatchException def _initialize_yaml(self): try: @@ -239,3 +244,67 @@ class ModelConfigStoreYAML(ModelConfigStore): finally: self._lock.release() return results + + def search_by_path(self, path: Union[str, Path]) -> Optional[ModelConfigBase]: + """ + Return the model with the indicated path, or None.. + """ + try: + self._lock.acquire() + for key, record in self._config.items(): + if key == "__metadata__": + continue + model = ModelConfigFactory.make_config(record, key) + if model.path == path: + return model + finally: + self._lock.release() + return None + + def _load_and_maybe_upgrade(self, config_path: Path) -> DictConfig: + config = OmegaConf.load(config_path) + version = config["__metadata__"].get("version") + if version == CONFIG_FILE_VERSION: + return config + + # if we get here we need to upgrade + if version == "3.0.0": + return self._migrate_format_to_3_2(config, config_path) + else: + raise Exception(f"{config_path} has unknown version: {version}") + + def _migrate_format_to_3_2(self, old_config: DictConfig, config_path: Path) -> DictConfig: + print( + f"** Doing one-time conversion of {config_path.as_posix()} to new format. Original will be named {config_path.as_posix() + '.orig'}" + ) + + # avoid circular dependencies + from shutil import move + + from ..install import InvalidModelException, ModelInstall + + move(config_path, config_path.as_posix() + ".orig") + + new_store = self.__class__(config_path) + installer = ModelInstall(store=new_store) + + for model_key, stanza in old_config.items(): + if model_key == "__metadata__": + assert ( + stanza["version"] == "3.0.0" + ), f"This script works on version 3.0.0 yaml files, but your configuration points to a {stanza['version']} version" + continue + + try: + path = stanza["path"] + new_key = installer.register_path(path) + model_info = new_store.get_model(new_key) + if vae := stanza.get("vae"): + model_info.vae = vae + if model_config := stanza.get("config"): + model_info.config = model_config.as_posix() + model_info.description = stanza.get("description") + new_store.update_model(new_key, model_info) + return OmegaConf.load(config_path) + except (DuplicateModelException, InvalidModelException) as e: + print(str(e)) diff --git a/scripts/convert_models_config_to_3.2.py b/scripts/convert_models_config_to_3.2.py index 203f09c804..7ce6f35744 100644 --- a/scripts/convert_models_config_to_3.2.py +++ b/scripts/convert_models_config_to_3.2.py @@ -14,11 +14,8 @@ when new models are downloaded from HuggingFace or Civitae. import argparse from pathlib import Path -from omegaconf import OmegaConf - from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.model_manager import DuplicateModelException, InvalidModelException, ModelInstall -from invokeai.backend.model_manager.storage import get_config_store +from invokeai.backend.model_manager.storage import migrate_models_store def main(): @@ -35,34 +32,7 @@ def main(): config = InvokeAIAppConfig.get_config() config.parse_args(config_args) - old_yaml_file = OmegaConf.load(config.model_conf_path) - - store = get_config_store(args.outfile) - installer = ModelInstall(store=store) - - print(f"Writing 3.2 models configuration into {args.outfile}.") - - for model_key, stanza in old_yaml_file.items(): - if model_key == "__metadata__": - assert ( - stanza["version"] == "3.0.0" - ), f"This script works on version 3.0.0 yaml files, but your configuration points to a {stanza['version']} version" - continue - - try: - path = config.models_path / stanza["path"] - new_key = installer.register_path(path) - model_info = store.get_model(new_key) - if vae := stanza.get("vae"): - model_info.vae = (config.models_path / vae).as_posix() - if model_config := stanza.get("config"): - model_info.config = (config.root_path / model_config).as_posix() - model_info.description = stanza.get("description") - store.update_model(new_key, model_info) - - print(f"{model_key} => {new_key}") - except (DuplicateModelException, InvalidModelException) as e: - print(str(e)) + migrate_models_store(config) if __name__ == "__main__":