mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
78377469db
* Bump diffusers to 0.21.2. * Add T2IAdapterInvocation boilerplate. * Add T2I-Adapter model to model-management. * (minor) Tidy prepare_control_image(...). * Add logic to run the T2I-Adapter models at the start of the DenoiseLatentsInvocation. * Add logic for applying T2I-Adapter weights and accumulating. * Add T2IAdapter to MODEL_CLASSES map. * yarn typegen * Add model probes for T2I-Adapter models. * Add all of the frontend boilerplate required to use T2I-Adapter in the nodes editor. * Add T2IAdapterModel.convert_if_required(...). * Fix errors in T2I-Adapter input image sizing logic. * Fix bug with handling of multiple T2I-Adapters. * black / flake8 * Fix typo * yarn build * Add num_channels param to prepare_control_image(...). * Link to upstream diffusers bugfix PR that currently requires a workaround. * feat: Add Color Map Preprocessor Needed for the color T2I Adapter * feat: Add Color Map Preprocessor to Linear UI * Revert "feat: Add Color Map Preprocessor" This reverts commita1119a00bf
. * Revert "feat: Add Color Map Preprocessor to Linear UI" This reverts commitbd8a9b82d8
. * Fix T2I-Adapter field rendering in workflow editor. * yarn build, yarn typegen --------- Co-authored-by: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
644 lines
26 KiB
Python
644 lines
26 KiB
Python
import json
|
|
import re
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Callable, Dict, Literal, Optional, Union
|
|
|
|
import safetensors.torch
|
|
import torch
|
|
from diffusers import ConfigMixin, ModelMixin
|
|
from picklescan.scanner import scan_file_path
|
|
|
|
from invokeai.backend.model_management.models.ip_adapter import IPAdapterModelFormat
|
|
|
|
from .models import (
|
|
BaseModelType,
|
|
InvalidModelException,
|
|
ModelType,
|
|
ModelVariantType,
|
|
SchedulerPredictionType,
|
|
SilenceWarnings,
|
|
)
|
|
from .models.base import read_checkpoint_meta
|
|
from .util import lora_token_vector_length
|
|
|
|
|
|
@dataclass
|
|
class ModelProbeInfo(object):
|
|
model_type: ModelType
|
|
base_type: BaseModelType
|
|
variant_type: ModelVariantType
|
|
prediction_type: SchedulerPredictionType
|
|
upcast_attention: bool
|
|
format: Literal["diffusers", "checkpoint", "lycoris", "olive", "onnx"]
|
|
image_size: int
|
|
|
|
|
|
class ProbeBase(object):
|
|
"""forward declaration"""
|
|
|
|
pass
|
|
|
|
|
|
class ModelProbe(object):
|
|
PROBES = {
|
|
"diffusers": {},
|
|
"checkpoint": {},
|
|
"onnx": {},
|
|
}
|
|
|
|
CLASS2TYPE = {
|
|
"StableDiffusionPipeline": ModelType.Main,
|
|
"StableDiffusionInpaintPipeline": ModelType.Main,
|
|
"StableDiffusionXLPipeline": ModelType.Main,
|
|
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
|
|
"StableDiffusionXLInpaintPipeline": ModelType.Main,
|
|
"AutoencoderKL": ModelType.Vae,
|
|
"AutoencoderTiny": ModelType.Vae,
|
|
"ControlNetModel": ModelType.ControlNet,
|
|
"CLIPVisionModelWithProjection": ModelType.CLIPVision,
|
|
"T2IAdapter": ModelType.T2IAdapter,
|
|
}
|
|
|
|
@classmethod
|
|
def register_probe(
|
|
cls, format: Literal["diffusers", "checkpoint", "onnx"], model_type: ModelType, probe_class: ProbeBase
|
|
):
|
|
cls.PROBES[format][model_type] = probe_class
|
|
|
|
@classmethod
|
|
def heuristic_probe(
|
|
cls,
|
|
model: Union[Dict, ModelMixin, Path],
|
|
prediction_type_helper: Callable[[Path], SchedulerPredictionType] = None,
|
|
) -> ModelProbeInfo:
|
|
if isinstance(model, Path):
|
|
return cls.probe(model_path=model, prediction_type_helper=prediction_type_helper)
|
|
elif isinstance(model, (dict, ModelMixin, ConfigMixin)):
|
|
return cls.probe(model_path=None, model=model, prediction_type_helper=prediction_type_helper)
|
|
else:
|
|
raise InvalidModelException("model parameter {model} is neither a Path, nor a model")
|
|
|
|
@classmethod
|
|
def probe(
|
|
cls,
|
|
model_path: Path,
|
|
model: Optional[Union[Dict, ModelMixin]] = None,
|
|
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
|
) -> ModelProbeInfo:
|
|
"""
|
|
Probe the model at model_path and return sufficient information about it
|
|
to place it somewhere in the models directory hierarchy. If the model is
|
|
already loaded into memory, you may provide it as model in order to avoid
|
|
opening it a second time. The prediction_type_helper callable is a function that receives
|
|
the path to the model and returns the SchedulerPredictionType.
|
|
"""
|
|
if model_path:
|
|
format_type = "diffusers" if model_path.is_dir() else "checkpoint"
|
|
else:
|
|
format_type = "diffusers" if isinstance(model, (ConfigMixin, ModelMixin)) else "checkpoint"
|
|
model_info = None
|
|
try:
|
|
model_type = (
|
|
cls.get_model_type_from_folder(model_path, model)
|
|
if format_type == "diffusers"
|
|
else cls.get_model_type_from_checkpoint(model_path, model)
|
|
)
|
|
format_type = "onnx" if model_type == ModelType.ONNX else format_type
|
|
probe_class = cls.PROBES[format_type].get(model_type)
|
|
if not probe_class:
|
|
return None
|
|
probe = probe_class(model_path, model, prediction_type_helper)
|
|
base_type = probe.get_base_type()
|
|
variant_type = probe.get_variant_type()
|
|
prediction_type = probe.get_scheduler_prediction_type()
|
|
format = probe.get_format()
|
|
model_info = ModelProbeInfo(
|
|
model_type=model_type,
|
|
base_type=base_type,
|
|
variant_type=variant_type,
|
|
prediction_type=prediction_type,
|
|
upcast_attention=(
|
|
base_type == BaseModelType.StableDiffusion2
|
|
and prediction_type == SchedulerPredictionType.VPrediction
|
|
),
|
|
format=format,
|
|
image_size=(
|
|
1024
|
|
if (base_type in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner})
|
|
else (
|
|
768
|
|
if (
|
|
base_type == BaseModelType.StableDiffusion2
|
|
and prediction_type == SchedulerPredictionType.VPrediction
|
|
)
|
|
else 512
|
|
)
|
|
),
|
|
)
|
|
except Exception:
|
|
raise
|
|
|
|
return model_info
|
|
|
|
@classmethod
|
|
def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: dict) -> ModelType:
|
|
if model_path.suffix not in (".bin", ".pt", ".ckpt", ".safetensors", ".pth"):
|
|
return None
|
|
|
|
if model_path.name == "learned_embeds.bin":
|
|
return ModelType.TextualInversion
|
|
|
|
ckpt = checkpoint if checkpoint else read_checkpoint_meta(model_path, scan=True)
|
|
ckpt = ckpt.get("state_dict", ckpt)
|
|
|
|
for key in ckpt.keys():
|
|
if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}):
|
|
return ModelType.Main
|
|
elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}):
|
|
return ModelType.Vae
|
|
elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
|
|
return ModelType.Lora
|
|
elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}):
|
|
return ModelType.Lora
|
|
elif any(key.startswith(v) for v in {"control_model", "input_blocks"}):
|
|
return ModelType.ControlNet
|
|
elif key in {"emb_params", "string_to_param"}:
|
|
return ModelType.TextualInversion
|
|
|
|
else:
|
|
# diffusers-ti
|
|
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
|
|
return ModelType.TextualInversion
|
|
|
|
raise InvalidModelException(f"Unable to determine model type for {model_path}")
|
|
|
|
@classmethod
|
|
def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin) -> ModelType:
|
|
"""
|
|
Get the model type of a hugging-face style folder.
|
|
"""
|
|
class_name = None
|
|
error_hint = None
|
|
if model:
|
|
class_name = model.__class__.__name__
|
|
else:
|
|
if (folder_path / "unet/model.onnx").exists():
|
|
return ModelType.ONNX
|
|
if (folder_path / "learned_embeds.bin").exists():
|
|
return ModelType.TextualInversion
|
|
if (folder_path / "pytorch_lora_weights.bin").exists():
|
|
return ModelType.Lora
|
|
if (folder_path / "image_encoder.txt").exists():
|
|
return ModelType.IPAdapter
|
|
|
|
i = folder_path / "model_index.json"
|
|
c = folder_path / "config.json"
|
|
config_path = i if i.exists() else c if c.exists() else None
|
|
|
|
if config_path:
|
|
with open(config_path, "r") as file:
|
|
conf = json.load(file)
|
|
if "_class_name" in conf:
|
|
class_name = conf["_class_name"]
|
|
elif "architectures" in conf:
|
|
class_name = conf["architectures"][0]
|
|
else:
|
|
class_name = None
|
|
else:
|
|
error_hint = f"No model_index.json or config.json found in {folder_path}."
|
|
|
|
if class_name and (type := cls.CLASS2TYPE.get(class_name)):
|
|
return type
|
|
else:
|
|
error_hint = f"class {class_name} is not one of the supported classes [{', '.join(cls.CLASS2TYPE.keys())}]"
|
|
|
|
# give up
|
|
raise InvalidModelException(
|
|
f"Unable to determine model type for {folder_path}" + (f"; {error_hint}" if error_hint else "")
|
|
)
|
|
|
|
@classmethod
|
|
def _scan_and_load_checkpoint(cls, model_path: Path) -> dict:
|
|
with SilenceWarnings():
|
|
if model_path.suffix.endswith((".ckpt", ".pt", ".bin")):
|
|
cls._scan_model(model_path, model_path)
|
|
return torch.load(model_path)
|
|
else:
|
|
return safetensors.torch.load_file(model_path)
|
|
|
|
@classmethod
|
|
def _scan_model(cls, model_name, checkpoint):
|
|
"""
|
|
Apply picklescanner to the indicated checkpoint and issue a warning
|
|
and option to exit if an infected file is identified.
|
|
"""
|
|
# scan model
|
|
scan_result = scan_file_path(checkpoint)
|
|
if scan_result.infected_files != 0:
|
|
raise "The model {model_name} is potentially infected by malware. Aborting import."
|
|
|
|
|
|
# ##################################################3
|
|
# Checkpoint probing
|
|
# ##################################################3
|
|
class ProbeBase(object):
|
|
def get_base_type(self) -> BaseModelType:
|
|
pass
|
|
|
|
def get_variant_type(self) -> ModelVariantType:
|
|
pass
|
|
|
|
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
|
|
pass
|
|
|
|
def get_format(self) -> str:
|
|
pass
|
|
|
|
|
|
class CheckpointProbeBase(ProbeBase):
|
|
def __init__(
|
|
self, checkpoint_path: Path, checkpoint: dict, helper: Callable[[Path], SchedulerPredictionType] = None
|
|
) -> BaseModelType:
|
|
self.checkpoint = checkpoint or ModelProbe._scan_and_load_checkpoint(checkpoint_path)
|
|
self.checkpoint_path = checkpoint_path
|
|
self.helper = helper
|
|
|
|
def get_base_type(self) -> BaseModelType:
|
|
pass
|
|
|
|
def get_format(self) -> str:
|
|
return "checkpoint"
|
|
|
|
def get_variant_type(self) -> ModelVariantType:
|
|
model_type = ModelProbe.get_model_type_from_checkpoint(self.checkpoint_path, self.checkpoint)
|
|
if model_type != ModelType.Main:
|
|
return ModelVariantType.Normal
|
|
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
|
|
in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
|
if in_channels == 9:
|
|
return ModelVariantType.Inpaint
|
|
elif in_channels == 5:
|
|
return ModelVariantType.Depth
|
|
elif in_channels == 4:
|
|
return ModelVariantType.Normal
|
|
else:
|
|
raise InvalidModelException(
|
|
f"Cannot determine variant type (in_channels={in_channels}) at {self.checkpoint_path}"
|
|
)
|
|
|
|
|
|
class PipelineCheckpointProbe(CheckpointProbeBase):
|
|
def get_base_type(self) -> BaseModelType:
|
|
checkpoint = self.checkpoint
|
|
state_dict = self.checkpoint.get("state_dict") or checkpoint
|
|
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
|
if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
|
|
return BaseModelType.StableDiffusion1
|
|
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
|
|
return BaseModelType.StableDiffusion2
|
|
key_name = "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight"
|
|
if key_name in state_dict and state_dict[key_name].shape[-1] == 2048:
|
|
return BaseModelType.StableDiffusionXL
|
|
elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280:
|
|
return BaseModelType.StableDiffusionXLRefiner
|
|
else:
|
|
raise InvalidModelException("Cannot determine base type")
|
|
|
|
def get_scheduler_prediction_type(self) -> Optional[SchedulerPredictionType]:
|
|
"""Return model prediction type."""
|
|
# if there is a .yaml associated with this checkpoint, then we do not need
|
|
# to probe for the prediction type as it will be ignored.
|
|
if self.checkpoint_path and self.checkpoint_path.with_suffix(".yaml").exists():
|
|
return None
|
|
|
|
type = self.get_base_type()
|
|
if type == BaseModelType.StableDiffusion2:
|
|
checkpoint = self.checkpoint
|
|
state_dict = self.checkpoint.get("state_dict") or checkpoint
|
|
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
|
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
|
|
if "global_step" in checkpoint:
|
|
if checkpoint["global_step"] == 220000:
|
|
return SchedulerPredictionType.Epsilon
|
|
elif checkpoint["global_step"] == 110000:
|
|
return SchedulerPredictionType.VPrediction
|
|
if self.helper and self.checkpoint_path:
|
|
if helper_guess := self.helper(self.checkpoint_path):
|
|
return helper_guess
|
|
return SchedulerPredictionType.VPrediction # a guess for sd2 ckpts
|
|
|
|
elif type == BaseModelType.StableDiffusion1:
|
|
if self.helper and self.checkpoint_path:
|
|
if helper_guess := self.helper(self.checkpoint_path):
|
|
return helper_guess
|
|
return SchedulerPredictionType.Epsilon # a reasonable guess for sd1 ckpts
|
|
else:
|
|
return None
|
|
|
|
|
|
class VaeCheckpointProbe(CheckpointProbeBase):
|
|
def get_base_type(self) -> BaseModelType:
|
|
# I can't find any standalone 2.X VAEs to test with!
|
|
return BaseModelType.StableDiffusion1
|
|
|
|
|
|
class LoRACheckpointProbe(CheckpointProbeBase):
|
|
def get_format(self) -> str:
|
|
return "lycoris"
|
|
|
|
def get_base_type(self) -> BaseModelType:
|
|
checkpoint = self.checkpoint
|
|
token_vector_length = lora_token_vector_length(checkpoint)
|
|
|
|
if token_vector_length == 768:
|
|
return BaseModelType.StableDiffusion1
|
|
elif token_vector_length == 1024:
|
|
return BaseModelType.StableDiffusion2
|
|
elif token_vector_length == 2048:
|
|
return BaseModelType.StableDiffusionXL
|
|
else:
|
|
raise InvalidModelException(f"Unknown LoRA type: {self.checkpoint_path}")
|
|
|
|
|
|
class TextualInversionCheckpointProbe(CheckpointProbeBase):
|
|
def get_format(self) -> str:
|
|
return None
|
|
|
|
def get_base_type(self) -> BaseModelType:
|
|
checkpoint = self.checkpoint
|
|
if "string_to_token" in checkpoint:
|
|
token_dim = list(checkpoint["string_to_param"].values())[0].shape[-1]
|
|
elif "emb_params" in checkpoint:
|
|
token_dim = checkpoint["emb_params"].shape[-1]
|
|
else:
|
|
token_dim = list(checkpoint.values())[0].shape[0]
|
|
if token_dim == 768:
|
|
return BaseModelType.StableDiffusion1
|
|
elif token_dim == 1024:
|
|
return BaseModelType.StableDiffusion2
|
|
else:
|
|
return None
|
|
|
|
|
|
class ControlNetCheckpointProbe(CheckpointProbeBase):
|
|
def get_base_type(self) -> BaseModelType:
|
|
checkpoint = self.checkpoint
|
|
for key_name in (
|
|
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
|
|
"input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
|
|
):
|
|
if key_name not in checkpoint:
|
|
continue
|
|
if checkpoint[key_name].shape[-1] == 768:
|
|
return BaseModelType.StableDiffusion1
|
|
elif checkpoint[key_name].shape[-1] == 1024:
|
|
return BaseModelType.StableDiffusion2
|
|
elif self.checkpoint_path and self.helper:
|
|
return self.helper(self.checkpoint_path)
|
|
raise InvalidModelException("Unable to determine base type for {self.checkpoint_path}")
|
|
|
|
|
|
class IPAdapterCheckpointProbe(CheckpointProbeBase):
|
|
def get_base_type(self) -> BaseModelType:
|
|
raise NotImplementedError()
|
|
|
|
|
|
class CLIPVisionCheckpointProbe(CheckpointProbeBase):
|
|
def get_base_type(self) -> BaseModelType:
|
|
raise NotImplementedError()
|
|
|
|
|
|
class T2IAdapterCheckpointProbe(CheckpointProbeBase):
|
|
def get_base_type(self) -> BaseModelType:
|
|
raise NotImplementedError()
|
|
|
|
|
|
########################################################
|
|
# classes for probing folders
|
|
#######################################################
|
|
class FolderProbeBase(ProbeBase):
|
|
def __init__(self, folder_path: Path, model: ModelMixin = None, helper: Callable = None): # not used
|
|
self.model = model
|
|
self.folder_path = folder_path
|
|
|
|
def get_variant_type(self) -> ModelVariantType:
|
|
return ModelVariantType.Normal
|
|
|
|
def get_format(self) -> str:
|
|
return "diffusers"
|
|
|
|
|
|
class PipelineFolderProbe(FolderProbeBase):
|
|
def get_base_type(self) -> BaseModelType:
|
|
if self.model:
|
|
unet_conf = self.model.unet.config
|
|
else:
|
|
with open(self.folder_path / "unet" / "config.json", "r") as file:
|
|
unet_conf = json.load(file)
|
|
if unet_conf["cross_attention_dim"] == 768:
|
|
return BaseModelType.StableDiffusion1
|
|
elif unet_conf["cross_attention_dim"] == 1024:
|
|
return BaseModelType.StableDiffusion2
|
|
elif unet_conf["cross_attention_dim"] == 1280:
|
|
return BaseModelType.StableDiffusionXLRefiner
|
|
elif unet_conf["cross_attention_dim"] == 2048:
|
|
return BaseModelType.StableDiffusionXL
|
|
else:
|
|
raise InvalidModelException(f"Unknown base model for {self.folder_path}")
|
|
|
|
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
|
|
if self.model:
|
|
scheduler_conf = self.model.scheduler.config
|
|
else:
|
|
with open(self.folder_path / "scheduler" / "scheduler_config.json", "r") as file:
|
|
scheduler_conf = json.load(file)
|
|
if scheduler_conf["prediction_type"] == "v_prediction":
|
|
return SchedulerPredictionType.VPrediction
|
|
elif scheduler_conf["prediction_type"] == "epsilon":
|
|
return SchedulerPredictionType.Epsilon
|
|
else:
|
|
return None
|
|
|
|
def get_variant_type(self) -> ModelVariantType:
|
|
# This only works for pipelines! Any kind of
|
|
# exception results in our returning the
|
|
# "normal" variant type
|
|
try:
|
|
if self.model:
|
|
conf = self.model.unet.config
|
|
else:
|
|
config_file = self.folder_path / "unet" / "config.json"
|
|
with open(config_file, "r") as file:
|
|
conf = json.load(file)
|
|
|
|
in_channels = conf["in_channels"]
|
|
if in_channels == 9:
|
|
return ModelVariantType.Inpaint
|
|
elif in_channels == 5:
|
|
return ModelVariantType.Depth
|
|
elif in_channels == 4:
|
|
return ModelVariantType.Normal
|
|
except Exception:
|
|
pass
|
|
return ModelVariantType.Normal
|
|
|
|
|
|
class VaeFolderProbe(FolderProbeBase):
|
|
def get_base_type(self) -> BaseModelType:
|
|
if self._config_looks_like_sdxl():
|
|
return BaseModelType.StableDiffusionXL
|
|
elif self._name_looks_like_sdxl():
|
|
# but SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down
|
|
# by a factor of 8), we can't necessarily tell them apart by config hyperparameters.
|
|
return BaseModelType.StableDiffusionXL
|
|
else:
|
|
return BaseModelType.StableDiffusion1
|
|
|
|
def _config_looks_like_sdxl(self) -> bool:
|
|
# config values that distinguish Stability's SD 1.x VAE from their SDXL VAE.
|
|
config_file = self.folder_path / "config.json"
|
|
if not config_file.exists():
|
|
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
|
|
with open(config_file, "r") as file:
|
|
config = json.load(file)
|
|
return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024]
|
|
|
|
def _name_looks_like_sdxl(self) -> bool:
|
|
return bool(re.search(r"xl\b", self._guess_name(), re.IGNORECASE))
|
|
|
|
def _guess_name(self) -> str:
|
|
name = self.folder_path.name
|
|
if name == "vae":
|
|
name = self.folder_path.parent.name
|
|
return name
|
|
|
|
|
|
class TextualInversionFolderProbe(FolderProbeBase):
|
|
def get_format(self) -> str:
|
|
return None
|
|
|
|
def get_base_type(self) -> BaseModelType:
|
|
path = self.folder_path / "learned_embeds.bin"
|
|
if not path.exists():
|
|
return None
|
|
checkpoint = ModelProbe._scan_and_load_checkpoint(path)
|
|
return TextualInversionCheckpointProbe(None, checkpoint=checkpoint).get_base_type()
|
|
|
|
|
|
class ONNXFolderProbe(FolderProbeBase):
|
|
def get_format(self) -> str:
|
|
return "onnx"
|
|
|
|
def get_base_type(self) -> BaseModelType:
|
|
return BaseModelType.StableDiffusion1
|
|
|
|
def get_variant_type(self) -> ModelVariantType:
|
|
return ModelVariantType.Normal
|
|
|
|
|
|
class ControlNetFolderProbe(FolderProbeBase):
|
|
def get_base_type(self) -> BaseModelType:
|
|
config_file = self.folder_path / "config.json"
|
|
if not config_file.exists():
|
|
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
|
|
with open(config_file, "r") as file:
|
|
config = json.load(file)
|
|
# no obvious way to distinguish between sd2-base and sd2-768
|
|
dimension = config["cross_attention_dim"]
|
|
base_model = (
|
|
BaseModelType.StableDiffusion1
|
|
if dimension == 768
|
|
else (
|
|
BaseModelType.StableDiffusion2
|
|
if dimension == 1024
|
|
else BaseModelType.StableDiffusionXL
|
|
if dimension == 2048
|
|
else None
|
|
)
|
|
)
|
|
if not base_model:
|
|
raise InvalidModelException(f"Unable to determine model base for {self.folder_path}")
|
|
return base_model
|
|
|
|
|
|
class LoRAFolderProbe(FolderProbeBase):
|
|
def get_base_type(self) -> BaseModelType:
|
|
model_file = None
|
|
for suffix in ["safetensors", "bin"]:
|
|
base_file = self.folder_path / f"pytorch_lora_weights.{suffix}"
|
|
if base_file.exists():
|
|
model_file = base_file
|
|
break
|
|
if not model_file:
|
|
raise InvalidModelException("Unknown LoRA format encountered")
|
|
return LoRACheckpointProbe(model_file, None).get_base_type()
|
|
|
|
|
|
class IPAdapterFolderProbe(FolderProbeBase):
|
|
def get_format(self) -> str:
|
|
return IPAdapterModelFormat.InvokeAI.value
|
|
|
|
def get_base_type(self) -> BaseModelType:
|
|
model_file = self.folder_path / "ip_adapter.bin"
|
|
if not model_file.exists():
|
|
raise InvalidModelException("Unknown IP-Adapter model format.")
|
|
|
|
state_dict = torch.load(model_file, map_location="cpu")
|
|
cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1]
|
|
if cross_attention_dim == 768:
|
|
return BaseModelType.StableDiffusion1
|
|
elif cross_attention_dim == 1024:
|
|
return BaseModelType.StableDiffusion2
|
|
elif cross_attention_dim == 2048:
|
|
return BaseModelType.StableDiffusionXL
|
|
else:
|
|
raise InvalidModelException(f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}.")
|
|
|
|
|
|
class CLIPVisionFolderProbe(FolderProbeBase):
|
|
def get_base_type(self) -> BaseModelType:
|
|
return BaseModelType.Any
|
|
|
|
|
|
class T2IAdapterFolderProbe(FolderProbeBase):
|
|
def get_base_type(self) -> BaseModelType:
|
|
config_file = self.folder_path / "config.json"
|
|
if not config_file.exists():
|
|
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
|
|
with open(config_file, "r") as file:
|
|
config = json.load(file)
|
|
|
|
adapter_type = config.get("adapter_type", None)
|
|
if adapter_type == "full_adapter_xl":
|
|
return BaseModelType.StableDiffusionXL
|
|
elif adapter_type == "full_adapter" or "light_adapter":
|
|
# I haven't seen any T2I adapter models for SD2, so assume that this is an SD1 adapter.
|
|
return BaseModelType.StableDiffusion1
|
|
else:
|
|
raise InvalidModelException(
|
|
f"Unable to determine base model for '{self.folder_path}' (adapter_type = {adapter_type})."
|
|
)
|
|
|
|
|
|
############## register probe classes ######
|
|
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
|
|
ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe)
|
|
ModelProbe.register_probe("diffusers", ModelType.Lora, LoRAFolderProbe)
|
|
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
|
|
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
|
|
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
|
|
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
|
|
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
|
|
|
|
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
|
|
ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe)
|
|
ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe)
|
|
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
|
|
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
|
|
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
|
|
ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe)
|
|
ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpointProbe)
|
|
|
|
ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)
|