import json import re from pathlib import Path from typing import Any, Dict, Literal, Optional, Union import safetensors.torch import torch from picklescan.scanner import scan_file_path from invokeai.backend.model_management.models.base import read_checkpoint_meta from invokeai.backend.model_management.models.ip_adapter import IPAdapterModelFormat from invokeai.backend.model_management.util import lora_token_vector_length from invokeai.backend.util.util import SilenceWarnings from .config import ( AnyModelConfig, BaseModelType, InvalidModelConfigException, ModelConfigFactory, ModelFormat, ModelType, ModelVariantType, SchedulerPredictionType, ) from .hash import FastModelHash CkptType = Dict[str, Any] LEGACY_CONFIGS: Dict[BaseModelType, Dict[ModelVariantType, Union[str, Dict[SchedulerPredictionType, str]]]] = { BaseModelType.StableDiffusion1: { ModelVariantType.Normal: "v1-inference.yaml", ModelVariantType.Inpaint: "v1-inpainting-inference.yaml", }, BaseModelType.StableDiffusion2: { ModelVariantType.Normal: { SchedulerPredictionType.Epsilon: "v2-inference.yaml", SchedulerPredictionType.VPrediction: "v2-inference-v.yaml", }, ModelVariantType.Inpaint: { SchedulerPredictionType.Epsilon: "v2-inpainting-inference.yaml", SchedulerPredictionType.VPrediction: "v2-inpainting-inference-v.yaml", }, }, BaseModelType.StableDiffusionXL: { ModelVariantType.Normal: "sd_xl_base.yaml", }, BaseModelType.StableDiffusionXLRefiner: { ModelVariantType.Normal: "sd_xl_refiner.yaml", }, } class ProbeBase(object): """Base class for probes.""" def __init__(self, model_path: Path): self.model_path = model_path def get_base_type(self) -> BaseModelType: """Get model base type.""" raise NotImplementedError def get_format(self) -> ModelFormat: """Get model file format.""" raise NotImplementedError def get_variant_type(self) -> Optional[ModelVariantType]: """Get model variant type.""" return None def get_scheduler_prediction_type(self) -> Optional[SchedulerPredictionType]: """Get model scheduler prediction type.""" return None class ModelProbe(object): PROBES: Dict[str, Dict[ModelType, type[ProbeBase]]] = { "diffusers": {}, "checkpoint": {}, "onnx": {}, } CLASS2TYPE = { "StableDiffusionPipeline": ModelType.Main, "StableDiffusionInpaintPipeline": ModelType.Main, "StableDiffusionXLPipeline": ModelType.Main, "StableDiffusionXLImg2ImgPipeline": ModelType.Main, "StableDiffusionXLInpaintPipeline": ModelType.Main, "LatentConsistencyModelPipeline": ModelType.Main, "AutoencoderKL": ModelType.Vae, "AutoencoderTiny": ModelType.Vae, "ControlNetModel": ModelType.ControlNet, "CLIPVisionModelWithProjection": ModelType.CLIPVision, "T2IAdapter": ModelType.T2IAdapter, } @classmethod def register_probe( cls, format: Literal["diffusers", "checkpoint", "onnx"], model_type: ModelType, probe_class: type[ProbeBase] ) -> None: cls.PROBES[format][model_type] = probe_class @classmethod def heuristic_probe( cls, model_path: Path, fields: Optional[Dict[str, Any]] = None, ) -> AnyModelConfig: return cls.probe(model_path, fields) @classmethod def probe( cls, model_path: Path, fields: Optional[Dict[str, Any]] = None, ) -> AnyModelConfig: """ Probe the model at model_path and return sufficient information about it to place it somewhere in the models directory hierarchy. If the model is already loaded into memory, you may provide it as model in order to avoid opening it a second time. The prediction_type_helper callable is a function that receives the path to the model and returns the SchedulerPredictionType. """ if fields is None: fields = {} format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint model_info = None model_type = None if format_type == "diffusers": model_type = cls.get_model_type_from_folder(model_path) else: model_type = cls.get_model_type_from_checkpoint(model_path) format_type = ModelFormat.Onnx if model_type == ModelType.ONNX else format_type probe_class = cls.PROBES[format_type].get(model_type) if not probe_class: raise InvalidModelConfigException(f"Unhandled combination of {format_type} and {model_type}") hash = FastModelHash.hash(model_path) probe = probe_class(model_path) fields["path"] = model_path.as_posix() fields["type"] = fields.get("type") or model_type fields["base"] = fields.get("base") or probe.get_base_type() fields["variant"] = fields.get("variant") or probe.get_variant_type() fields["prediction_type"] = fields.get("prediction_type") or probe.get_scheduler_prediction_type() fields["name"] = fields.get("name") or cls.get_model_name(model_path) fields["description"] = ( fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}" ) fields["format"] = fields.get("format") or probe.get_format() fields["original_hash"] = fields.get("original_hash") or hash fields["current_hash"] = fields.get("current_hash") or hash # additional fields needed for main and controlnet models if fields["type"] in [ModelType.Main, ModelType.ControlNet] and fields["format"] == ModelFormat.Checkpoint: fields["config"] = cls._get_checkpoint_config_path( model_path, model_type=fields["type"], base_type=fields["base"], variant_type=fields["variant"], prediction_type=fields["prediction_type"], ).as_posix() # additional fields needed for main non-checkpoint models elif fields["type"] == ModelType.Main and fields["format"] in [ ModelFormat.Onnx, ModelFormat.Olive, ModelFormat.Diffusers, ]: fields["upcast_attention"] = fields.get("upcast_attention") or ( fields["base"] == BaseModelType.StableDiffusion2 and fields["prediction_type"] == SchedulerPredictionType.VPrediction ) model_info = ModelConfigFactory.make_config(fields) return model_info @classmethod def get_model_name(cls, model_path: Path) -> str: if model_path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}: return model_path.stem else: return model_path.name @classmethod def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: Optional[CkptType] = None) -> ModelType: if model_path.suffix not in (".bin", ".pt", ".ckpt", ".safetensors", ".pth"): raise InvalidModelConfigException(f"{model_path}: unrecognized suffix") if model_path.name == "learned_embeds.bin": return ModelType.TextualInversion ckpt = checkpoint if checkpoint else read_checkpoint_meta(model_path, scan=True) ckpt = ckpt.get("state_dict", ckpt) for key in ckpt.keys(): if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}): return ModelType.Main elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}): return ModelType.Vae elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}): return ModelType.Lora elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}): return ModelType.Lora elif any(key.startswith(v) for v in {"control_model", "input_blocks"}): return ModelType.ControlNet elif key in {"emb_params", "string_to_param"}: return ModelType.TextualInversion else: # diffusers-ti if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()): return ModelType.TextualInversion raise InvalidModelConfigException(f"Unable to determine model type for {model_path}") @classmethod def get_model_type_from_folder(cls, folder_path: Path) -> ModelType: """Get the model type of a hugging-face style folder.""" class_name = None error_hint = None for suffix in ["bin", "safetensors"]: if (folder_path / f"learned_embeds.{suffix}").exists(): return ModelType.TextualInversion if (folder_path / f"pytorch_lora_weights.{suffix}").exists(): return ModelType.Lora if (folder_path / "unet/model.onnx").exists(): return ModelType.ONNX if (folder_path / "image_encoder.txt").exists(): return ModelType.IPAdapter i = folder_path / "model_index.json" c = folder_path / "config.json" config_path = i if i.exists() else c if c.exists() else None if config_path: with open(config_path, "r") as file: conf = json.load(file) if "_class_name" in conf: class_name = conf["_class_name"] elif "architectures" in conf: class_name = conf["architectures"][0] else: class_name = None else: error_hint = f"No model_index.json or config.json found in {folder_path}." if class_name and (type := cls.CLASS2TYPE.get(class_name)): return type else: error_hint = f"class {class_name} is not one of the supported classes [{', '.join(cls.CLASS2TYPE.keys())}]" # give up raise InvalidModelConfigException( f"Unable to determine model type for {folder_path}" + (f"; {error_hint}" if error_hint else "") ) @classmethod def _get_checkpoint_config_path( cls, model_path: Path, model_type: ModelType, base_type: BaseModelType, variant_type: ModelVariantType, prediction_type: SchedulerPredictionType, ) -> Path: # look for a YAML file adjacent to the model file first possible_conf = model_path.with_suffix(".yaml") if possible_conf.exists(): return possible_conf.absolute() if model_type == ModelType.Main: config_file = LEGACY_CONFIGS[base_type][variant_type] if isinstance(config_file, dict): # need another tier for sd-2.x models config_file = config_file[prediction_type] elif model_type == ModelType.ControlNet: config_file = ( "../controlnet/cldm_v15.yaml" if base_type == BaseModelType("sd-1") else "../controlnet/cldm_v21.yaml" ) else: raise InvalidModelConfigException( f"{model_path}: Unrecognized combination of model_type={model_type}, base_type={base_type}" ) assert isinstance(config_file, str) return Path(config_file) @classmethod def _scan_and_load_checkpoint(cls, model_path: Path) -> CkptType: with SilenceWarnings(): if model_path.suffix.endswith((".ckpt", ".pt", ".bin")): cls._scan_model(model_path.name, model_path) model = torch.load(model_path) assert isinstance(model, dict) return model else: return safetensors.torch.load_file(model_path) @classmethod def _scan_model(cls, model_name: str, checkpoint: Path) -> None: """ Apply picklescanner to the indicated checkpoint and issue a warning and option to exit if an infected file is identified. """ # scan model scan_result = scan_file_path(checkpoint) if scan_result.infected_files != 0: raise Exception("The model {model_name} is potentially infected by malware. Aborting import.") # ##################################################3 # Checkpoint probing # ##################################################3 class CheckpointProbeBase(ProbeBase): def __init__(self, model_path: Path): super().__init__(model_path) self.checkpoint = ModelProbe._scan_and_load_checkpoint(model_path) def get_format(self) -> ModelFormat: return ModelFormat("checkpoint") def get_variant_type(self) -> ModelVariantType: model_type = ModelProbe.get_model_type_from_checkpoint(self.model_path, self.checkpoint) if model_type != ModelType.Main: return ModelVariantType.Normal state_dict = self.checkpoint.get("state_dict") or self.checkpoint in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1] if in_channels == 9: return ModelVariantType.Inpaint elif in_channels == 5: return ModelVariantType.Depth elif in_channels == 4: return ModelVariantType.Normal else: raise InvalidModelConfigException( f"Cannot determine variant type (in_channels={in_channels}) at {self.model_path}" ) class PipelineCheckpointProbe(CheckpointProbeBase): def get_base_type(self) -> BaseModelType: checkpoint = self.checkpoint state_dict = self.checkpoint.get("state_dict") or checkpoint key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" if key_name in state_dict and state_dict[key_name].shape[-1] == 768: return BaseModelType.StableDiffusion1 if key_name in state_dict and state_dict[key_name].shape[-1] == 1024: return BaseModelType.StableDiffusion2 key_name = "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight" if key_name in state_dict and state_dict[key_name].shape[-1] == 2048: return BaseModelType.StableDiffusionXL elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280: return BaseModelType.StableDiffusionXLRefiner else: raise InvalidModelConfigException("Cannot determine base type") def get_scheduler_prediction_type(self) -> SchedulerPredictionType: """Return model prediction type.""" type = self.get_base_type() if type == BaseModelType.StableDiffusion2: checkpoint = self.checkpoint state_dict = self.checkpoint.get("state_dict") or checkpoint key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" if key_name in state_dict and state_dict[key_name].shape[-1] == 1024: if "global_step" in checkpoint: if checkpoint["global_step"] == 220000: return SchedulerPredictionType.Epsilon elif checkpoint["global_step"] == 110000: return SchedulerPredictionType.VPrediction return SchedulerPredictionType.VPrediction # a guess for sd2 ckpts elif type == BaseModelType.StableDiffusion1: return SchedulerPredictionType.Epsilon # a reasonable guess for sd1 ckpts else: return SchedulerPredictionType.Epsilon class VaeCheckpointProbe(CheckpointProbeBase): def get_base_type(self) -> BaseModelType: # I can't find any standalone 2.X VAEs to test with! return BaseModelType.StableDiffusion1 class LoRACheckpointProbe(CheckpointProbeBase): """Class for LoRA checkpoints.""" def get_format(self) -> ModelFormat: return ModelFormat("lycoris") def get_base_type(self) -> BaseModelType: checkpoint = self.checkpoint token_vector_length = lora_token_vector_length(checkpoint) if token_vector_length == 768: return BaseModelType.StableDiffusion1 elif token_vector_length == 1024: return BaseModelType.StableDiffusion2 elif token_vector_length == 2048: return BaseModelType.StableDiffusionXL else: raise InvalidModelConfigException(f"Unknown LoRA type: {self.model_path}") class TextualInversionCheckpointProbe(CheckpointProbeBase): """Class for probing embeddings.""" def get_format(self) -> ModelFormat: return ModelFormat.EmbeddingFile def get_base_type(self) -> BaseModelType: checkpoint = self.checkpoint if "string_to_token" in checkpoint: token_dim = list(checkpoint["string_to_param"].values())[0].shape[-1] elif "emb_params" in checkpoint: token_dim = checkpoint["emb_params"].shape[-1] else: token_dim = list(checkpoint.values())[0].shape[0] if token_dim == 768: return BaseModelType.StableDiffusion1 elif token_dim == 1024: return BaseModelType.StableDiffusion2 elif token_dim == 1280: return BaseModelType.StableDiffusionXL else: raise InvalidModelConfigException("Could not determine base type") class ControlNetCheckpointProbe(CheckpointProbeBase): """Class for probing controlnets.""" def get_base_type(self) -> BaseModelType: checkpoint = self.checkpoint for key_name in ( "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight", "input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight", ): if key_name not in checkpoint: continue if checkpoint[key_name].shape[-1] == 768: return BaseModelType.StableDiffusion1 elif checkpoint[key_name].shape[-1] == 1024: return BaseModelType.StableDiffusion2 raise InvalidModelConfigException("Unable to determine base type for {self.checkpoint_path}") class IPAdapterCheckpointProbe(CheckpointProbeBase): def get_base_type(self) -> BaseModelType: raise NotImplementedError() class CLIPVisionCheckpointProbe(CheckpointProbeBase): def get_base_type(self) -> BaseModelType: raise NotImplementedError() class T2IAdapterCheckpointProbe(CheckpointProbeBase): def get_base_type(self) -> BaseModelType: raise NotImplementedError() ######################################################## # classes for probing folders ####################################################### class FolderProbeBase(ProbeBase): def get_variant_type(self) -> ModelVariantType: return ModelVariantType.Normal def get_format(self) -> ModelFormat: return ModelFormat("diffusers") class PipelineFolderProbe(FolderProbeBase): def get_base_type(self) -> BaseModelType: with open(self.model_path / "unet" / "config.json", "r") as file: unet_conf = json.load(file) if unet_conf["cross_attention_dim"] == 768: return BaseModelType.StableDiffusion1 elif unet_conf["cross_attention_dim"] == 1024: return BaseModelType.StableDiffusion2 elif unet_conf["cross_attention_dim"] == 1280: return BaseModelType.StableDiffusionXLRefiner elif unet_conf["cross_attention_dim"] == 2048: return BaseModelType.StableDiffusionXL else: raise InvalidModelConfigException(f"Unknown base model for {self.model_path}") def get_scheduler_prediction_type(self) -> SchedulerPredictionType: with open(self.model_path / "scheduler" / "scheduler_config.json", "r") as file: scheduler_conf = json.load(file) if scheduler_conf["prediction_type"] == "v_prediction": return SchedulerPredictionType.VPrediction elif scheduler_conf["prediction_type"] == "epsilon": return SchedulerPredictionType.Epsilon else: raise InvalidModelConfigException("Unknown scheduler prediction type: {scheduler_conf['prediction_type']}") def get_variant_type(self) -> ModelVariantType: # This only works for pipelines! Any kind of # exception results in our returning the # "normal" variant type try: config_file = self.model_path / "unet" / "config.json" with open(config_file, "r") as file: conf = json.load(file) in_channels = conf["in_channels"] if in_channels == 9: return ModelVariantType.Inpaint elif in_channels == 5: return ModelVariantType.Depth elif in_channels == 4: return ModelVariantType.Normal except Exception: pass return ModelVariantType.Normal class VaeFolderProbe(FolderProbeBase): def get_base_type(self) -> BaseModelType: if self._config_looks_like_sdxl(): return BaseModelType.StableDiffusionXL elif self._name_looks_like_sdxl(): # but SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down # by a factor of 8), we can't necessarily tell them apart by config hyperparameters. return BaseModelType.StableDiffusionXL else: return BaseModelType.StableDiffusion1 def _config_looks_like_sdxl(self) -> bool: # config values that distinguish Stability's SD 1.x VAE from their SDXL VAE. config_file = self.model_path / "config.json" if not config_file.exists(): raise InvalidModelConfigException(f"Cannot determine base type for {self.model_path}") with open(config_file, "r") as file: config = json.load(file) return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024] def _name_looks_like_sdxl(self) -> bool: return bool(re.search(r"xl\b", self._guess_name(), re.IGNORECASE)) def _guess_name(self) -> str: name = self.model_path.name if name == "vae": name = self.model_path.parent.name return name class TextualInversionFolderProbe(FolderProbeBase): def get_format(self) -> ModelFormat: return ModelFormat.EmbeddingFolder def get_base_type(self) -> BaseModelType: path = self.model_path / "learned_embeds.bin" if not path.exists(): raise InvalidModelConfigException( f"{self.model_path.as_posix()} does not contain expected 'learned_embeds.bin' file" ) return TextualInversionCheckpointProbe(path).get_base_type() class ONNXFolderProbe(FolderProbeBase): def get_format(self) -> ModelFormat: return ModelFormat("onnx") def get_base_type(self) -> BaseModelType: return BaseModelType.StableDiffusion1 def get_variant_type(self) -> ModelVariantType: return ModelVariantType.Normal class ControlNetFolderProbe(FolderProbeBase): def get_base_type(self) -> BaseModelType: config_file = self.model_path / "config.json" if not config_file.exists(): raise InvalidModelConfigException(f"Cannot determine base type for {self.model_path}") with open(config_file, "r") as file: config = json.load(file) # no obvious way to distinguish between sd2-base and sd2-768 dimension = config["cross_attention_dim"] base_model = ( BaseModelType.StableDiffusion1 if dimension == 768 else ( BaseModelType.StableDiffusion2 if dimension == 1024 else BaseModelType.StableDiffusionXL if dimension == 2048 else None ) ) if not base_model: raise InvalidModelConfigException(f"Unable to determine model base for {self.model_path}") return base_model class LoRAFolderProbe(FolderProbeBase): def get_base_type(self) -> BaseModelType: model_file = None for suffix in ["safetensors", "bin"]: base_file = self.model_path / f"pytorch_lora_weights.{suffix}" if base_file.exists(): model_file = base_file break if not model_file: raise InvalidModelConfigException("Unknown LoRA format encountered") return LoRACheckpointProbe(model_file).get_base_type() class IPAdapterFolderProbe(FolderProbeBase): def get_format(self) -> IPAdapterModelFormat: return IPAdapterModelFormat.InvokeAI.value def get_base_type(self) -> BaseModelType: model_file = self.model_path / "ip_adapter.bin" if not model_file.exists(): raise InvalidModelConfigException("Unknown IP-Adapter model format.") state_dict = torch.load(model_file, map_location="cpu") cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1] if cross_attention_dim == 768: return BaseModelType.StableDiffusion1 elif cross_attention_dim == 1024: return BaseModelType.StableDiffusion2 elif cross_attention_dim == 2048: return BaseModelType.StableDiffusionXL else: raise InvalidModelConfigException( f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}." ) class CLIPVisionFolderProbe(FolderProbeBase): def get_base_type(self) -> BaseModelType: return BaseModelType.Any class T2IAdapterFolderProbe(FolderProbeBase): def get_base_type(self) -> BaseModelType: config_file = self.model_path / "config.json" if not config_file.exists(): raise InvalidModelConfigException(f"Cannot determine base type for {self.model_path}") with open(config_file, "r") as file: config = json.load(file) adapter_type = config.get("adapter_type", None) if adapter_type == "full_adapter_xl": return BaseModelType.StableDiffusionXL elif adapter_type == "full_adapter" or "light_adapter": # I haven't seen any T2I adapter models for SD2, so assume that this is an SD1 adapter. return BaseModelType.StableDiffusion1 else: raise InvalidModelConfigException( f"Unable to determine base model for '{self.model_path}' (adapter_type = {adapter_type})." ) ############## register probe classes ###### ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe) ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe) ModelProbe.register_probe("diffusers", ModelType.Lora, LoRAFolderProbe) ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe) ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe) ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe) ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe) ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe) ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpointProbe) ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)