InvokeAI/invokeai/backend/model_management/model_probe.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

659 lines
26 KiB
Python
Raw Normal View History

import json
import re
from dataclasses import dataclass
from pathlib import Path
2023-08-18 15:18:46 +00:00
from typing import Callable, Dict, Literal, Optional, Union
import safetensors.torch
import torch
2023-08-18 15:18:46 +00:00
from diffusers import ConfigMixin, ModelMixin
from picklescan.scanner import scan_file_path
from invokeai.backend.model_management.models.ip_adapter import IPAdapterModelFormat
from .models import (
2023-07-28 13:46:44 +00:00
BaseModelType,
2023-08-18 15:18:46 +00:00
InvalidModelException,
2023-07-28 13:46:44 +00:00
ModelType,
ModelVariantType,
SchedulerPredictionType,
SilenceWarnings,
)
from .models.base import read_checkpoint_meta
from .util import lora_token_vector_length
2023-07-28 13:46:44 +00:00
@dataclass
class ModelProbeInfo(object):
model_type: ModelType
base_type: BaseModelType
variant_type: ModelVariantType
prediction_type: SchedulerPredictionType
upcast_attention: bool
2023-07-28 13:46:44 +00:00
format: Literal["diffusers", "checkpoint", "lycoris", "olive", "onnx"]
image_size: int
name: Optional[str] = None
description: Optional[str] = None
2023-07-28 13:46:44 +00:00
class ProbeBase(object):
2023-07-28 13:46:44 +00:00
"""forward declaration"""
pass
2023-07-28 13:46:44 +00:00
class ModelProbe(object):
PROBES = {
2023-07-28 13:46:44 +00:00
"diffusers": {},
"checkpoint": {},
2023-07-28 20:54:03 +00:00
"onnx": {},
}
CLASS2TYPE = {
2023-07-28 13:46:44 +00:00
"StableDiffusionPipeline": ModelType.Main,
"StableDiffusionInpaintPipeline": ModelType.Main,
"StableDiffusionXLPipeline": ModelType.Main,
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
"StableDiffusionXLInpaintPipeline": ModelType.Main,
2023-11-22 19:58:27 +00:00
"LatentConsistencyModelPipeline": ModelType.Main,
2023-07-28 13:46:44 +00:00
"AutoencoderKL": ModelType.Vae,
"AutoencoderTiny": ModelType.Vae,
2023-07-28 13:46:44 +00:00
"ControlNetModel": ModelType.ControlNet,
"CLIPVisionModelWithProjection": ModelType.CLIPVision,
"T2IAdapter": ModelType.T2IAdapter,
}
2023-07-28 13:46:44 +00:00
@classmethod
def register_probe(
cls, format: Literal["diffusers", "checkpoint", "onnx"], model_type: ModelType, probe_class: ProbeBase
):
cls.PROBES[format][model_type] = probe_class
@classmethod
2023-07-28 13:46:44 +00:00
def heuristic_probe(
cls,
model: Union[Dict, ModelMixin, Path],
prediction_type_helper: Callable[[Path], SchedulerPredictionType] = None,
) -> ModelProbeInfo:
if isinstance(model, Path):
return cls.probe(model_path=model, prediction_type_helper=prediction_type_helper)
elif isinstance(model, (dict, ModelMixin, ConfigMixin)):
return cls.probe(model_path=None, model=model, prediction_type_helper=prediction_type_helper)
else:
raise InvalidModelException("model parameter {model} is neither a Path, nor a model")
@classmethod
2023-07-28 13:46:44 +00:00
def probe(
cls,
model_path: Path,
model: Optional[Union[Dict, ModelMixin]] = None,
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
) -> ModelProbeInfo:
"""
Probe the model at model_path and return sufficient information about it
to place it somewhere in the models directory hierarchy. If the model is
already loaded into memory, you may provide it as model in order to avoid
opening it a second time. The prediction_type_helper callable is a function that receives
2023-09-24 16:22:29 +00:00
the path to the model and returns the SchedulerPredictionType.
2023-07-28 13:46:44 +00:00
"""
if model_path:
2023-07-28 13:46:44 +00:00
format_type = "diffusers" if model_path.is_dir() else "checkpoint"
else:
2023-07-28 13:46:44 +00:00
format_type = "diffusers" if isinstance(model, (ConfigMixin, ModelMixin)) else "checkpoint"
model_info = None
try:
2023-07-28 13:46:44 +00:00
model_type = (
cls.get_model_type_from_folder(model_path, model)
if format_type == "diffusers"
else cls.get_model_type_from_checkpoint(model_path, model)
)
format_type = "onnx" if model_type == ModelType.ONNX else format_type
probe_class = cls.PROBES[format_type].get(model_type)
if not probe_class:
return None
probe = probe_class(model_path, model, prediction_type_helper)
base_type = probe.get_base_type()
variant_type = probe.get_variant_type()
prediction_type = probe.get_scheduler_prediction_type()
name = cls.get_model_name(model_path)
description = f"{base_type.value} {model_type.value} model {name}"
format = probe.get_format()
model_info = ModelProbeInfo(
2023-07-28 13:46:44 +00:00
model_type=model_type,
base_type=base_type,
variant_type=variant_type,
prediction_type=prediction_type,
name = name,
description = description,
2023-07-28 13:46:44 +00:00
upcast_attention=(
base_type == BaseModelType.StableDiffusion2
and prediction_type == SchedulerPredictionType.VPrediction
),
format=format,
image_size=(
1024
if (base_type in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner})
else (
768
if (
base_type == BaseModelType.StableDiffusion2
and prediction_type == SchedulerPredictionType.VPrediction
)
else 512
)
),
)
except Exception:
raise
return model_info
@classmethod
def get_model_name(cls, model_path: Path) -> str:
if model_path.suffix in {'.safetensors', '.bin', '.pt', '.ckpt'}:
return model_path.stem
else:
return model_path.name
@classmethod
def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: dict) -> ModelType:
2023-07-28 13:46:44 +00:00
if model_path.suffix not in (".bin", ".pt", ".ckpt", ".safetensors", ".pth"):
return None
if model_path.name == "learned_embeds.bin":
return ModelType.TextualInversion
ckpt = checkpoint if checkpoint else read_checkpoint_meta(model_path, scan=True)
ckpt = ckpt.get("state_dict", ckpt)
for key in ckpt.keys():
if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}):
return ModelType.Main
elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}):
return ModelType.Vae
elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
return ModelType.Lora
elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}):
return ModelType.Lora
elif any(key.startswith(v) for v in {"control_model", "input_blocks"}):
return ModelType.ControlNet
elif key in {"emb_params", "string_to_param"}:
return ModelType.TextualInversion
else:
# diffusers-ti
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
return ModelType.TextualInversion
2023-07-28 13:46:44 +00:00
raise InvalidModelException(f"Unable to determine model type for {model_path}")
@classmethod
2023-07-28 13:46:44 +00:00
def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin) -> ModelType:
"""
Get the model type of a hugging-face style folder.
2023-07-28 13:46:44 +00:00
"""
class_name = None
error_hint = None
if model:
class_name = model.__class__.__name__
else:
for suffix in ["bin", "safetensors"]:
if (folder_path / f"learned_embeds.{suffix}").exists():
return ModelType.TextualInversion
if (folder_path / f"pytorch_lora_weights.{suffix}").exists():
return ModelType.Lora
if (folder_path / "unet/model.onnx").exists():
2023-07-28 20:54:03 +00:00
return ModelType.ONNX
if (folder_path / "image_encoder.txt").exists():
return ModelType.IPAdapter
2023-07-28 13:46:44 +00:00
i = folder_path / "model_index.json"
c = folder_path / "config.json"
config_path = i if i.exists() else c if c.exists() else None
if config_path:
2023-07-28 13:46:44 +00:00
with open(config_path, "r") as file:
conf = json.load(file)
if "_class_name" in conf:
class_name = conf["_class_name"]
elif "architectures" in conf:
class_name = conf["architectures"][0]
else:
class_name = None
else:
error_hint = f"No model_index.json or config.json found in {folder_path}."
if class_name and (type := cls.CLASS2TYPE.get(class_name)):
return type
else:
error_hint = f"class {class_name} is not one of the supported classes [{', '.join(cls.CLASS2TYPE.keys())}]"
# give up
raise InvalidModelException(
f"Unable to determine model type for {folder_path}" + (f"; {error_hint}" if error_hint else "")
)
@classmethod
2023-07-28 13:46:44 +00:00
def _scan_and_load_checkpoint(cls, model_path: Path) -> dict:
with SilenceWarnings():
if model_path.suffix.endswith((".ckpt", ".pt", ".bin")):
cls._scan_model(model_path, model_path)
return torch.load(model_path)
else:
return safetensors.torch.load_file(model_path)
@classmethod
def _scan_model(cls, model_name, checkpoint):
2023-07-28 13:46:44 +00:00
"""
Apply picklescanner to the indicated checkpoint and issue a warning
and option to exit if an infected file is identified.
"""
# scan model
scan_result = scan_file_path(checkpoint)
if scan_result.infected_files != 0:
2023-11-10 23:51:21 +00:00
raise Exception("The model {model_name} is potentially infected by malware. Aborting import.")
2023-07-28 13:46:44 +00:00
2023-08-17 22:45:25 +00:00
# ##################################################3
# Checkpoint probing
2023-08-17 22:45:25 +00:00
# ##################################################3
class ProbeBase(object):
2023-07-28 13:46:44 +00:00
def get_base_type(self) -> BaseModelType:
pass
2023-07-28 13:46:44 +00:00
def get_variant_type(self) -> ModelVariantType:
pass
2023-07-28 13:46:44 +00:00
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
pass
2023-07-28 13:46:44 +00:00
def get_format(self) -> str:
pass
2023-07-28 13:46:44 +00:00
class CheckpointProbeBase(ProbeBase):
2023-07-28 13:46:44 +00:00
def __init__(
self, checkpoint_path: Path, checkpoint: dict, helper: Callable[[Path], SchedulerPredictionType] = None
) -> BaseModelType:
self.checkpoint = checkpoint or ModelProbe._scan_and_load_checkpoint(checkpoint_path)
self.checkpoint_path = checkpoint_path
self.helper = helper
2023-07-28 13:46:44 +00:00
def get_base_type(self) -> BaseModelType:
pass
2023-07-28 13:46:44 +00:00
def get_format(self) -> str:
return "checkpoint"
2023-07-28 13:46:44 +00:00
def get_variant_type(self) -> ModelVariantType:
model_type = ModelProbe.get_model_type_from_checkpoint(self.checkpoint_path, self.checkpoint)
if model_type != ModelType.Main:
return ModelVariantType.Normal
2023-07-28 13:46:44 +00:00
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
if in_channels == 9:
return ModelVariantType.Inpaint
elif in_channels == 5:
return ModelVariantType.Depth
elif in_channels == 4:
return ModelVariantType.Normal
else:
2023-07-28 13:46:44 +00:00
raise InvalidModelException(
f"Cannot determine variant type (in_channels={in_channels}) at {self.checkpoint_path}"
)
class PipelineCheckpointProbe(CheckpointProbeBase):
2023-07-28 13:46:44 +00:00
def get_base_type(self) -> BaseModelType:
checkpoint = self.checkpoint
2023-07-28 13:46:44 +00:00
state_dict = self.checkpoint.get("state_dict") or checkpoint
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
return BaseModelType.StableDiffusion1
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
return BaseModelType.StableDiffusion2
2023-07-28 13:46:44 +00:00
key_name = "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight"
2023-07-09 19:47:06 +00:00
if key_name in state_dict and state_dict[key_name].shape[-1] == 2048:
return BaseModelType.StableDiffusionXL
elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280:
return BaseModelType.StableDiffusionXLRefiner
else:
raise InvalidModelException("Cannot determine base type")
2023-09-24 16:22:29 +00:00
def get_scheduler_prediction_type(self) -> Optional[SchedulerPredictionType]:
"""Return model prediction type."""
# if there is a .yaml associated with this checkpoint, then we do not need
# to probe for the prediction type as it will be ignored.
if self.checkpoint_path and self.checkpoint_path.with_suffix(".yaml").exists():
return None
type = self.get_base_type()
2023-09-24 16:22:29 +00:00
if type == BaseModelType.StableDiffusion2:
checkpoint = self.checkpoint
state_dict = self.checkpoint.get("state_dict") or checkpoint
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
if "global_step" in checkpoint:
if checkpoint["global_step"] == 220000:
return SchedulerPredictionType.Epsilon
elif checkpoint["global_step"] == 110000:
return SchedulerPredictionType.VPrediction
if self.helper and self.checkpoint_path:
if helper_guess := self.helper(self.checkpoint_path):
return helper_guess
return SchedulerPredictionType.VPrediction # a guess for sd2 ckpts
elif type == BaseModelType.StableDiffusion1:
if self.helper and self.checkpoint_path:
if helper_guess := self.helper(self.checkpoint_path):
return helper_guess
return SchedulerPredictionType.Epsilon # a reasonable guess for sd1 ckpts
else:
return None
2023-07-28 13:46:44 +00:00
class VaeCheckpointProbe(CheckpointProbeBase):
2023-07-28 13:46:44 +00:00
def get_base_type(self) -> BaseModelType:
# I can't find any standalone 2.X VAEs to test with!
return BaseModelType.StableDiffusion1
2023-07-28 13:46:44 +00:00
class LoRACheckpointProbe(CheckpointProbeBase):
2023-07-28 13:46:44 +00:00
def get_format(self) -> str:
return "lycoris"
2023-07-28 13:46:44 +00:00
def get_base_type(self) -> BaseModelType:
checkpoint = self.checkpoint
token_vector_length = lora_token_vector_length(checkpoint)
2023-08-03 14:26:52 +00:00
if token_vector_length == 768:
return BaseModelType.StableDiffusion1
elif token_vector_length == 1024:
return BaseModelType.StableDiffusion2
elif token_vector_length == 2048:
return BaseModelType.StableDiffusionXL
else:
raise InvalidModelException(f"Unknown LoRA type: {self.checkpoint_path}")
2023-07-28 13:46:44 +00:00
class TextualInversionCheckpointProbe(CheckpointProbeBase):
2023-07-28 13:46:44 +00:00
def get_format(self) -> str:
return None
2023-07-28 13:46:44 +00:00
def get_base_type(self) -> BaseModelType:
checkpoint = self.checkpoint
2023-07-28 13:46:44 +00:00
if "string_to_token" in checkpoint:
token_dim = list(checkpoint["string_to_param"].values())[0].shape[-1]
elif "emb_params" in checkpoint:
token_dim = checkpoint["emb_params"].shape[-1]
else:
token_dim = list(checkpoint.values())[0].shape[0]
if token_dim == 768:
return BaseModelType.StableDiffusion1
elif token_dim == 1024:
return BaseModelType.StableDiffusion2
else:
return None
2023-07-28 13:46:44 +00:00
class ControlNetCheckpointProbe(CheckpointProbeBase):
2023-07-28 13:46:44 +00:00
def get_base_type(self) -> BaseModelType:
checkpoint = self.checkpoint
2023-07-28 13:46:44 +00:00
for key_name in (
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
"input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
):
if key_name not in checkpoint:
continue
if checkpoint[key_name].shape[-1] == 768:
return BaseModelType.StableDiffusion1
elif checkpoint[key_name].shape[-1] == 1024:
return BaseModelType.StableDiffusion2
elif self.checkpoint_path and self.helper:
return self.helper(self.checkpoint_path)
raise InvalidModelException("Unable to determine base type for {self.checkpoint_path}")
2023-07-28 13:46:44 +00:00
class IPAdapterCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
raise NotImplementedError()
class CLIPVisionCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
raise NotImplementedError()
class T2IAdapterCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
raise NotImplementedError()
########################################################
# classes for probing folders
#######################################################
class FolderProbeBase(ProbeBase):
2023-07-28 13:46:44 +00:00
def __init__(self, folder_path: Path, model: ModelMixin = None, helper: Callable = None): # not used
self.model = model
self.folder_path = folder_path
2023-07-28 13:46:44 +00:00
def get_variant_type(self) -> ModelVariantType:
return ModelVariantType.Normal
2023-07-28 13:46:44 +00:00
def get_format(self) -> str:
return "diffusers"
class PipelineFolderProbe(FolderProbeBase):
2023-07-28 13:46:44 +00:00
def get_base_type(self) -> BaseModelType:
if self.model:
unet_conf = self.model.unet.config
else:
2023-07-28 13:46:44 +00:00
with open(self.folder_path / "unet" / "config.json", "r") as file:
unet_conf = json.load(file)
2023-07-28 13:46:44 +00:00
if unet_conf["cross_attention_dim"] == 768:
return BaseModelType.StableDiffusion1
elif unet_conf["cross_attention_dim"] == 1024:
return BaseModelType.StableDiffusion2
2023-07-28 13:46:44 +00:00
elif unet_conf["cross_attention_dim"] == 1280:
return BaseModelType.StableDiffusionXLRefiner
2023-07-28 13:46:44 +00:00
elif unet_conf["cross_attention_dim"] == 2048:
2023-07-09 19:47:06 +00:00
return BaseModelType.StableDiffusionXL
else:
2023-07-28 13:46:44 +00:00
raise InvalidModelException(f"Unknown base model for {self.folder_path}")
2023-07-28 13:46:44 +00:00
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
if self.model:
scheduler_conf = self.model.scheduler.config
else:
2023-07-28 13:46:44 +00:00
with open(self.folder_path / "scheduler" / "scheduler_config.json", "r") as file:
scheduler_conf = json.load(file)
2023-07-28 13:46:44 +00:00
if scheduler_conf["prediction_type"] == "v_prediction":
return SchedulerPredictionType.VPrediction
2023-07-28 13:46:44 +00:00
elif scheduler_conf["prediction_type"] == "epsilon":
return SchedulerPredictionType.Epsilon
else:
return None
2023-07-28 13:46:44 +00:00
def get_variant_type(self) -> ModelVariantType:
# This only works for pipelines! Any kind of
# exception results in our returning the
# "normal" variant type
try:
if self.model:
conf = self.model.unet.config
else:
2023-07-28 13:46:44 +00:00
config_file = self.folder_path / "unet" / "config.json"
with open(config_file, "r") as file:
conf = json.load(file)
2023-07-28 13:46:44 +00:00
in_channels = conf["in_channels"]
if in_channels == 9:
return ModelVariantType.Inpaint
elif in_channels == 5:
return ModelVariantType.Depth
elif in_channels == 4:
return ModelVariantType.Normal
2023-08-17 22:45:25 +00:00
except Exception:
pass
return ModelVariantType.Normal
2023-07-28 13:46:44 +00:00
class VaeFolderProbe(FolderProbeBase):
2023-07-28 13:46:44 +00:00
def get_base_type(self) -> BaseModelType:
if self._config_looks_like_sdxl():
return BaseModelType.StableDiffusionXL
elif self._name_looks_like_sdxl():
# but SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down
# by a factor of 8), we can't necessarily tell them apart by config hyperparameters.
return BaseModelType.StableDiffusionXL
else:
return BaseModelType.StableDiffusion1
def _config_looks_like_sdxl(self) -> bool:
# config values that distinguish Stability's SD 1.x VAE from their SDXL VAE.
2023-07-28 13:46:44 +00:00
config_file = self.folder_path / "config.json"
2023-07-27 03:06:27 +00:00
if not config_file.exists():
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
2023-07-28 13:46:44 +00:00
with open(config_file, "r") as file:
2023-07-27 03:06:27 +00:00
config = json.load(file)
return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024]
def _name_looks_like_sdxl(self) -> bool:
return bool(re.search(r"xl\b", self._guess_name(), re.IGNORECASE))
def _guess_name(self) -> str:
name = self.folder_path.name
if name == "vae":
name = self.folder_path.parent.name
return name
2023-07-28 13:46:44 +00:00
class TextualInversionFolderProbe(FolderProbeBase):
2023-07-28 13:46:44 +00:00
def get_format(self) -> str:
return None
2023-07-28 13:46:44 +00:00
def get_base_type(self) -> BaseModelType:
path = self.folder_path / "learned_embeds.bin"
if not path.exists():
return None
checkpoint = ModelProbe._scan_and_load_checkpoint(path)
2023-07-28 13:46:44 +00:00
return TextualInversionCheckpointProbe(None, checkpoint=checkpoint).get_base_type()
2023-07-28 20:54:03 +00:00
class ONNXFolderProbe(FolderProbeBase):
def get_format(self) -> str:
return "onnx"
def get_base_type(self) -> BaseModelType:
return BaseModelType.StableDiffusion1
def get_variant_type(self) -> ModelVariantType:
return ModelVariantType.Normal
class ControlNetFolderProbe(FolderProbeBase):
2023-07-28 13:46:44 +00:00
def get_base_type(self) -> BaseModelType:
config_file = self.folder_path / "config.json"
if not config_file.exists():
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
2023-07-28 13:46:44 +00:00
with open(config_file, "r") as file:
config = json.load(file)
# no obvious way to distinguish between sd2-base and sd2-768
2023-08-15 15:46:37 +00:00
dimension = config["cross_attention_dim"]
base_model = (
BaseModelType.StableDiffusion1
if dimension == 768
else (
BaseModelType.StableDiffusion2
if dimension == 1024
2023-09-15 03:49:02 +00:00
else BaseModelType.StableDiffusionXL
if dimension == 2048
else None
)
2023-07-28 13:46:44 +00:00
)
2023-08-15 15:46:37 +00:00
if not base_model:
raise InvalidModelException(f"Unable to determine model base for {self.folder_path}")
return base_model
2023-07-28 13:46:44 +00:00
class LoRAFolderProbe(FolderProbeBase):
2023-07-28 13:46:44 +00:00
def get_base_type(self) -> BaseModelType:
2023-06-20 15:08:27 +00:00
model_file = None
2023-07-28 13:46:44 +00:00
for suffix in ["safetensors", "bin"]:
base_file = self.folder_path / f"pytorch_lora_weights.{suffix}"
2023-06-20 15:08:27 +00:00
if base_file.exists():
model_file = base_file
break
if not model_file:
2023-07-28 13:46:44 +00:00
raise InvalidModelException("Unknown LoRA format encountered")
return LoRACheckpointProbe(model_file, None).get_base_type()
class IPAdapterFolderProbe(FolderProbeBase):
def get_format(self) -> str:
return IPAdapterModelFormat.InvokeAI.value
def get_base_type(self) -> BaseModelType:
model_file = self.folder_path / "ip_adapter.bin"
if not model_file.exists():
raise InvalidModelException("Unknown IP-Adapter model format.")
state_dict = torch.load(model_file, map_location="cpu")
cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1]
if cross_attention_dim == 768:
return BaseModelType.StableDiffusion1
elif cross_attention_dim == 1024:
return BaseModelType.StableDiffusion2
elif cross_attention_dim == 2048:
return BaseModelType.StableDiffusionXL
else:
raise InvalidModelException(f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}.")
class CLIPVisionFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
return BaseModelType.Any
class T2IAdapterFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
config_file = self.folder_path / "config.json"
if not config_file.exists():
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
with open(config_file, "r") as file:
config = json.load(file)
adapter_type = config.get("adapter_type", None)
if adapter_type == "full_adapter_xl":
return BaseModelType.StableDiffusionXL
elif adapter_type == "full_adapter" or "light_adapter":
# I haven't seen any T2I adapter models for SD2, so assume that this is an SD1 adapter.
return BaseModelType.StableDiffusion1
else:
raise InvalidModelException(
f"Unable to determine base model for '{self.folder_path}' (adapter_type = {adapter_type})."
)
############## register probe classes ######
2023-07-28 13:46:44 +00:00
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.Lora, LoRAFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
2023-07-28 13:46:44 +00:00
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpointProbe)
2023-07-28 20:54:03 +00:00
ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)