# Copyright (c) 2023 Lincoln Stein and the InvokeAI Team """ Return descriptive information on Stable Diffusion models. Module for probing a Stable Diffusion model and returning its base type, model type, format and variant. """ import json import re from abc import ABC, abstractmethod from pathlib import Path from typing import Callable, Dict, Optional, Type import safetensors.torch import torch from picklescan.scanner import scan_file_path from pydantic import BaseModel from .config import BaseModelType, ModelFormat, ModelType, ModelVariantType, SchedulerPredictionType from .hash import FastModelHash from .util import lora_token_vector_length, read_checkpoint_meta class InvalidModelException(Exception): """Raised when an invalid model is encountered.""" class ModelProbeInfo(BaseModel): """Fields describing a probed model.""" model_type: ModelType base_type: BaseModelType format: ModelFormat hash: str variant_type: ModelVariantType = ModelVariantType("normal") prediction_type: Optional[SchedulerPredictionType] = SchedulerPredictionType("v_prediction") upcast_attention: Optional[bool] = False image_size: Optional[int] = None class ModelProbeBase(ABC): """Class to probe a checkpoint, safetensors or diffusers folder.""" @classmethod @abstractmethod def probe( cls, model: Path, prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None, ) -> Optional[ModelProbeInfo]: """ Probe model located at path and return ModelProbeInfo object. :param model: Path to a model checkpoint or folder. :param prediction_type_helper: An optional Callable that takes the model path and returns the SchedulerPredictionType. """ pass class ProbeBase(ABC): """Base model for probing checkpoint and diffusers-style models.""" @abstractmethod def get_base_type(self) -> Optional[BaseModelType]: """Return the BaseModelType for the model.""" pass def get_variant_type(self) -> ModelVariantType: """Return the ModelVariantType for the model.""" pass def get_scheduler_prediction_type(self) -> Optional[SchedulerPredictionType]: """Return the SchedulerPredictionType for the model.""" pass def get_format(self) -> str: """Return the format for the model.""" pass class ModelProbe(ModelProbeBase): """Class to probe a checkpoint, safetensors or diffusers folder.""" PROBES: Dict[str, dict] = { "diffusers": {}, "checkpoint": {}, "onnx": {}, } CLASS2TYPE = { "StableDiffusionPipeline": ModelType.Main, "StableDiffusionInpaintPipeline": ModelType.Main, "StableDiffusionXLPipeline": ModelType.Main, "StableDiffusionXLImg2ImgPipeline": ModelType.Main, "AutoencoderKL": ModelType.Vae, "AutoencoderTiny": ModelType.Vae, "ControlNetModel": ModelType.ControlNet, "CLIPVisionModelWithProjection": ModelType.CLIPVision, "T2IAdapter": ModelType.T2IAdapter, } @classmethod def register_probe(cls, format: ModelFormat, model_type: ModelType, probe_class: Type[ProbeBase]): """ Register a probe subclass to use when interrogating a model. :param format: The ModelFormat of the model to be probed. :param model_type: The ModelType of the model to be probed. :param probe_class: The class of the prober (inherits from ProbeBase). """ cls.PROBES[format][model_type] = probe_class @classmethod def probe( cls, model_path: Path, prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None, ) -> ModelProbeInfo: """Probe model.""" try: model_type = ( cls.get_model_type_from_folder(model_path) if model_path.is_dir() else cls.get_model_type_from_checkpoint(model_path) ) format_type = ( "onnx" if model_type == ModelType.ONNX else "diffusers" if model_path.is_dir() else "checkpoint" ) probe_class = cls.PROBES[format_type].get(model_type) if not probe_class: raise InvalidModelException(f"Unable to determine model type for {model_path}") probe = probe_class(model_path, prediction_type_helper) base_type = probe.get_base_type() variant_type = probe.get_variant_type() prediction_type = probe.get_scheduler_prediction_type() format = probe.get_format() hash = FastModelHash.hash(model_path) model_info = ModelProbeInfo( model_type=model_type, base_type=base_type, variant_type=variant_type, prediction_type=prediction_type, upcast_attention=( base_type == BaseModelType.StableDiffusion2 and prediction_type == SchedulerPredictionType.VPrediction ), format=format, hash=hash, image_size=1024 if (base_type in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner}) else 768 if ( base_type == BaseModelType.StableDiffusion2 and prediction_type == SchedulerPredictionType.VPrediction ) else 512, ) except Exception: raise InvalidModelException(f"Unable to determine model type for {model_path}") return model_info @classmethod def get_model_type_from_checkpoint(cls, model_path: Path) -> Optional[ModelType]: """ Scan a checkpoint model and return its ModelType. :param model_path: path to the model checkpoint/safetensors file """ if model_path.suffix not in (".bin", ".pt", ".ckpt", ".safetensors", ".pth"): return None if model_path.name == "learned_embeds.bin": return ModelType.TextualInversion ckpt = read_checkpoint_meta(model_path, scan=True) ckpt = ckpt.get("state_dict", ckpt) for key in ckpt.keys(): if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}): return ModelType.Main elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}): return ModelType.Vae elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}): return ModelType.Lora elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}): return ModelType.Lora elif any(key.startswith(v) for v in {"control_model", "input_blocks"}): return ModelType.ControlNet elif key in {"emb_params", "string_to_param"}: return ModelType.TextualInversion else: # diffusers-ti if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()): return ModelType.TextualInversion raise InvalidModelException(f"Unable to determine model type for {model_path}") @classmethod def get_model_type_from_folder(cls, folder_path: Path) -> Optional[ModelType]: """ Get the model type of a hugging-face style folder. :param folder_path: Path to model folder. """ class_name = None if (folder_path / "unet/model.onnx").exists(): return ModelType.ONNX if (folder_path / "learned_embeds.bin").exists(): return ModelType.TextualInversion if (folder_path / "pytorch_lora_weights.bin").exists(): return ModelType.Lora if (folder_path / "image_encoder.txt").exists(): return ModelType.IPAdapter i = folder_path / "model_index.json" c = folder_path / "config.json" config_path = i if i.exists() else c if c.exists() else None if config_path: with open(config_path, "r") as file: conf = json.load(file) if "_class_name" in conf: class_name = conf["_class_name"] elif "architectures" in conf: class_name = conf["architectures"][0] else: class_name = None else: error_hint = f"No model_index.json or config.json found in {folder_path}." if class_name and (type := cls.CLASS2TYPE.get(class_name)): return type else: error_hint = f"class {class_name} is not one of the supported classes [{', '.join(cls.CLASS2TYPE.keys())}]" # give up raise InvalidModelException( f"Unable to determine model type for {folder_path}" + (f"; {error_hint}" if error_hint else "") ) @classmethod def _scan_and_load_checkpoint(cls, model: Path) -> dict: if model.suffix.endswith((".ckpt", ".pt", ".bin")): cls._scan_model(model) return torch.load(model) else: return safetensors.torch.load_file(model) @classmethod def _scan_model(cls, model: Path): """ Scan a model for malicious code. :param model: Path to the model to be scanned Raises an Exception if unsafe code is found. """ # scan model scan_result = scan_file_path(model) if scan_result.infected_files != 0: raise InvalidModelException("The model {model_name} is potentially infected by malware. Aborting import.") # ##################################################3 # Checkpoint probing # ##################################################3 class CheckpointProbeBase(ProbeBase): """Base class for probing checkpoint-style models.""" def __init__(self, checkpoint_path: Path, helper: Optional[Callable[[Path], SchedulerPredictionType]] = None): """Initialize the CheckpointProbeBase object.""" self.checkpoint_path = checkpoint_path self.checkpoint = ModelProbe._scan_and_load_checkpoint(checkpoint_path) self.helper = helper def get_base_type(self) -> Optional[BaseModelType]: """Return the BaseModelType of a checkpoint-style model.""" pass def get_format(self) -> str: """Return the format of a checkpoint-style model.""" return "checkpoint" def get_variant_type(self) -> ModelVariantType: """Return the ModelVariantType of a checkpoint-style model.""" model_type = ModelProbe.get_model_type_from_checkpoint(self.checkpoint_path) if model_type != ModelType.Main: return ModelVariantType.Normal state_dict = self.checkpoint.get("state_dict") or self.checkpoint in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1] if in_channels == 9: return ModelVariantType.Inpaint elif in_channels == 5: return ModelVariantType.Depth elif in_channels == 4: return ModelVariantType.Normal else: raise InvalidModelException( f"Cannot determine variant type (in_channels={in_channels}) at {self.checkpoint_path}" ) class PipelineCheckpointProbe(CheckpointProbeBase): """Probe a checkpoint-style main model.""" def get_base_type(self) -> BaseModelType: """Return the ModelBaseType for the checkpoint-style main model.""" checkpoint = self.checkpoint state_dict = self.checkpoint.get("state_dict") or checkpoint key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" if key_name in state_dict and state_dict[key_name].shape[-1] == 768: return BaseModelType.StableDiffusion1 if key_name in state_dict and state_dict[key_name].shape[-1] == 1024: return BaseModelType.StableDiffusion2 key_name = "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight" if key_name in state_dict and state_dict[key_name].shape[-1] == 2048: return BaseModelType.StableDiffusionXL elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280: return BaseModelType.StableDiffusionXLRefiner else: raise InvalidModelException("Cannot determine base type") def get_scheduler_prediction_type(self) -> Optional[SchedulerPredictionType]: """Return model prediction type.""" # if there is a .yaml associated with this checkpoint, then we do not need # to probe for the prediction type as it will be ignored. if self.checkpoint_path and self.checkpoint_path.with_suffix(".yaml").exists(): return None type = self.get_base_type() if type == BaseModelType.StableDiffusion2: checkpoint = self.checkpoint state_dict = self.checkpoint.get("state_dict") or checkpoint key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" if key_name in state_dict and state_dict[key_name].shape[-1] == 1024: if "global_step" in checkpoint: if checkpoint["global_step"] == 220000: return SchedulerPredictionType.Epsilon elif checkpoint["global_step"] == 110000: return SchedulerPredictionType.VPrediction if self.helper and self.checkpoint_path: if helper_guess := self.helper(self.checkpoint_path): return helper_guess return SchedulerPredictionType.VPrediction # a guess for sd2 ckpts elif type == BaseModelType.StableDiffusion1: if self.helper and self.checkpoint_path: if helper_guess := self.helper(self.checkpoint_path): return helper_guess return SchedulerPredictionType.Epsilon # a reasonable guess for sd1 ckpts else: return None class VaeCheckpointProbe(CheckpointProbeBase): """Probe a Checkpoint-style VAE model.""" def get_base_type(self) -> BaseModelType: """Return the BaseModelType of the VAE model.""" # I can't find any standalone 2.X VAEs to test with! return BaseModelType.StableDiffusion1 class LoRACheckpointProbe(CheckpointProbeBase): """Probe for LoRA Checkpoint Files.""" def get_format(self) -> str: """Return the format of the LoRA.""" return "lycoris" def get_base_type(self) -> BaseModelType: """Return the BaseModelType of the LoRA.""" checkpoint = self.checkpoint token_vector_length = lora_token_vector_length(checkpoint) if token_vector_length == 768: return BaseModelType.StableDiffusion1 elif token_vector_length == 1024: return BaseModelType.StableDiffusion2 elif token_vector_length == 2048: return BaseModelType.StableDiffusionXL else: raise InvalidModelException(f"Unsupported LoRA type: {self.checkpoint_path}") class TextualInversionCheckpointProbe(CheckpointProbeBase): """TextualInversion checkpoint prober.""" def get_format(self) -> str: """Return the format of a TextualInversion emedding.""" return ModelFormat.EmbeddingFile def get_base_type(self) -> BaseModelType: """Return BaseModelType of the checkpoint model.""" checkpoint = self.checkpoint if "string_to_token" in checkpoint: token_dim = list(checkpoint["string_to_param"].values())[0].shape[-1] elif "emb_params" in checkpoint: token_dim = checkpoint["emb_params"].shape[-1] else: token_dim = list(checkpoint.values())[0].shape[0] if token_dim == 768: return BaseModelType.StableDiffusion1 elif token_dim == 1024: return BaseModelType.StableDiffusion2 raise InvalidModelException("Unknown base model for {self.checkpoint_path}") class ControlNetCheckpointProbe(CheckpointProbeBase): """Probe checkpoint-based ControlNet models.""" def get_base_type(self) -> BaseModelType: """Return the BaseModelType of the model.""" checkpoint = self.checkpoint for key_name in ( "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight", "input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight", ): if key_name not in checkpoint: continue if checkpoint[key_name].shape[-1] == 768: return BaseModelType.StableDiffusion1 elif checkpoint[key_name].shape[-1] == 1024: return BaseModelType.StableDiffusion2 raise InvalidModelException("Unable to determine base type for {self.checkpoint_path}") class IPAdapterCheckpointProbe(CheckpointProbeBase): """Probe IP adapter models.""" def get_base_type(self) -> BaseModelType: """Probe base type.""" raise NotImplementedError() class CLIPVisionCheckpointProbe(CheckpointProbeBase): """Probe ClipVision adapter models.""" def get_base_type(self) -> BaseModelType: """Probe base type.""" raise NotImplementedError() class T2IAdapterCheckpointProbe(CheckpointProbeBase): def get_base_type(self) -> BaseModelType: raise NotImplementedError() ######################################################## # classes for probing folders ####################################################### class FolderProbeBase(ProbeBase): """Class for probing folder-based models.""" def __init__(self, folder_path: Path, helper: Optional[Callable] = None): # not used """ Initialize the folder prober. :param model: Path to the model to be probed. :param helper: Callable for returning the SchedulerPredictionType (unused). """ self.folder_path = folder_path def get_variant_type(self) -> ModelVariantType: """Return the model's variant type.""" return ModelVariantType.Normal def get_format(self) -> str: """Return the model's format.""" return "diffusers" class PipelineFolderProbe(FolderProbeBase): """Probe a pipeline (main) folder.""" def get_base_type(self) -> BaseModelType: """Return the BaseModelType of a pipeline folder.""" with open(self.folder_path / "unet" / "config.json", "r") as file: unet_conf = json.load(file) if unet_conf["cross_attention_dim"] == 768: return BaseModelType.StableDiffusion1 elif unet_conf["cross_attention_dim"] == 1024: return BaseModelType.StableDiffusion2 elif unet_conf["cross_attention_dim"] == 1280: return BaseModelType.StableDiffusionXLRefiner elif unet_conf["cross_attention_dim"] == 2048: return BaseModelType.StableDiffusionXL else: raise InvalidModelException(f"Unknown base model for {self.folder_path}") def get_scheduler_prediction_type(self) -> SchedulerPredictionType: """Return the SchedulerPredictionType of a diffusers-style sd-2 model.""" with open(self.folder_path / "scheduler" / "scheduler_config.json", "r") as file: scheduler_conf = json.load(file) prediction_type = scheduler_conf.get("prediction_type", "epsilon") return SchedulerPredictionType(prediction_type) def get_variant_type(self) -> ModelVariantType: """Return the ModelVariantType for diffusers-style main models.""" # This only works for pipelines! Any kind of # exception results in our returning the # "normal" variant type try: config_file = self.folder_path / "unet" / "config.json" with open(config_file, "r") as file: conf = json.load(file) in_channels = conf["in_channels"] if in_channels == 9: return ModelVariantType.Inpaint elif in_channels == 5: return ModelVariantType.Depth elif in_channels == 4: return ModelVariantType.Normal except Exception: pass return ModelVariantType.Normal class VaeFolderProbe(FolderProbeBase): """Class for probing folder-style models.""" def get_base_type(self) -> BaseModelType: """Get base type of model.""" if self._config_looks_like_sdxl(): return BaseModelType.StableDiffusionXL elif self._name_looks_like_sdxl(): # but SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down # by a factor of 8), we can't necessarily tell them apart by config hyperparameters. return BaseModelType.StableDiffusionXL else: return BaseModelType.StableDiffusion1 def _config_looks_like_sdxl(self) -> bool: # config values that distinguish Stability's SD 1.x VAE from their SDXL VAE. config_file = self.folder_path / "config.json" if not config_file.exists(): raise InvalidModelException(f"Cannot determine base type for {self.folder_path}") with open(config_file, "r") as file: config = json.load(file) return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024] def _name_looks_like_sdxl(self) -> bool: return bool(re.search(r"xl\b", self._guess_name(), re.IGNORECASE)) def _guess_name(self) -> str: name = self.folder_path.name if name == "vae": name = self.folder_path.parent.name return name class TextualInversionFolderProbe(FolderProbeBase): """Probe a HuggingFace-style TextualInversion folder.""" def get_format(self) -> str: """Return the format of the TextualInversion.""" return ModelFormat.EmbeddingFolder def get_base_type(self) -> BaseModelType: """Return the ModelBaseType of the HuggingFace-style Textual Inversion Folder.""" path = self.folder_path / "learned_embeds.bin" if not path.exists(): raise InvalidModelException("This textual inversion folder does not contain a learned_embeds.bin file.") return TextualInversionCheckpointProbe(path).get_base_type() class ONNXFolderProbe(FolderProbeBase): """Probe an ONNX-format folder.""" def get_format(self) -> str: """Return the format of the folder (always "onnx").""" return "onnx" def get_base_type(self) -> BaseModelType: """Return the BaseModelType of the ONNX folder.""" return BaseModelType.StableDiffusion1 def get_variant_type(self) -> ModelVariantType: """Return the ModelVariantType of the ONNX folder.""" return ModelVariantType.Normal class ControlNetFolderProbe(FolderProbeBase): """Probe a ControlNet model folder.""" def get_base_type(self) -> BaseModelType: """Return the BaseModelType of a ControlNet model folder.""" config_file = self.folder_path / "config.json" if not config_file.exists(): raise InvalidModelException(f"Cannot determine base type for {self.folder_path}") with open(config_file, "r") as file: config = json.load(file) # no obvious way to distinguish between sd2-base and sd2-768 dimension = config["cross_attention_dim"] base_model = ( BaseModelType.StableDiffusion1 if dimension == 768 else BaseModelType.StableDiffusion2 if dimension == 1024 else BaseModelType.StableDiffusionXL if dimension == 2048 else None ) if not base_model: raise InvalidModelException(f"Unable to determine model base for {self.folder_path}") return base_model class LoRAFolderProbe(FolderProbeBase): """Probe a LoRA model folder.""" def get_base_type(self) -> BaseModelType: """Get the ModelBaseType of a LoRA model folder.""" model_file = None for suffix in ["safetensors", "bin"]: base_file = self.folder_path / f"pytorch_lora_weights.{suffix}" if base_file.exists(): model_file = base_file break if not model_file: raise InvalidModelException("Unknown LoRA format encountered") return LoRACheckpointProbe(model_file).get_base_type() class IPAdapterFolderProbe(FolderProbeBase): """Class for probing IP-Adapter models.""" def get_format(self) -> str: """Get format of ip adapter.""" return ModelFormat.InvokeAI.value def get_base_type(self) -> BaseModelType: """Get base type of ip adapter.""" model_file = self.folder_path / "ip_adapter.bin" if not model_file.exists(): raise InvalidModelException("Unknown IP-Adapter model format.") state_dict = torch.load(model_file, map_location="cpu") cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1] if cross_attention_dim == 768: return BaseModelType.StableDiffusion1 elif cross_attention_dim == 1024: return BaseModelType.StableDiffusion2 elif cross_attention_dim == 2048: return BaseModelType.StableDiffusionXL else: raise InvalidModelException(f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}.") class CLIPVisionFolderProbe(FolderProbeBase): """Probe for folder-based CLIPVision models.""" def get_base_type(self) -> BaseModelType: """Get base type.""" return BaseModelType.Any class T2IAdapterFolderProbe(FolderProbeBase): def get_base_type(self) -> BaseModelType: config_file = self.folder_path / "config.json" if not config_file.exists(): raise InvalidModelException(f"Cannot determine base type for {self.folder_path}") with open(config_file, "r") as file: config = json.load(file) adapter_type = config.get("adapter_type", None) if adapter_type == "full_adapter_xl": return BaseModelType.StableDiffusionXL elif adapter_type == "full_adapter" or "light_adapter": # I haven't seen any T2I adapter models for SD2, so assume that this is an SD1 adapter. return BaseModelType.StableDiffusion1 else: raise InvalidModelException( f"Unable to determine base model for '{self.folder_path}' (adapter_type = {adapter_type})." ) ############## register probe classes ###### diffusers = ModelFormat("diffusers") checkpoint = ModelFormat("checkpoint") ModelProbe.register_probe(diffusers, ModelType.Main, PipelineFolderProbe) ModelProbe.register_probe(diffusers, ModelType.Vae, VaeFolderProbe) ModelProbe.register_probe(diffusers, ModelType.Lora, LoRAFolderProbe) ModelProbe.register_probe(diffusers, ModelType.TextualInversion, TextualInversionFolderProbe) ModelProbe.register_probe(diffusers, ModelType.ControlNet, ControlNetFolderProbe) ModelProbe.register_probe(diffusers, ModelType.IPAdapter, IPAdapterFolderProbe) ModelProbe.register_probe(diffusers, ModelType.CLIPVision, CLIPVisionFolderProbe) ModelProbe.register_probe(diffusers, ModelType.T2IAdapter, T2IAdapterFolderProbe) ModelProbe.register_probe(checkpoint, ModelType.Main, PipelineCheckpointProbe) ModelProbe.register_probe(checkpoint, ModelType.Vae, VaeCheckpointProbe) ModelProbe.register_probe(checkpoint, ModelType.Lora, LoRACheckpointProbe) ModelProbe.register_probe(checkpoint, ModelType.TextualInversion, TextualInversionCheckpointProbe) ModelProbe.register_probe(checkpoint, ModelType.ControlNet, ControlNetCheckpointProbe) ModelProbe.register_probe(checkpoint, ModelType.IPAdapter, IPAdapterCheckpointProbe) ModelProbe.register_probe(checkpoint, ModelType.CLIPVision, CLIPVisionCheckpointProbe) ModelProbe.register_probe(checkpoint, ModelType.T2IAdapter, T2IAdapterCheckpointProbe) ModelProbe.register_probe(ModelFormat("onnx"), ModelType.ONNX, ONNXFolderProbe)