mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
515 lines
20 KiB
Python
515 lines
20 KiB
Python
import json
|
|
import torch
|
|
import safetensors.torch
|
|
|
|
from dataclasses import dataclass
|
|
|
|
from diffusers import ModelMixin, ConfigMixin
|
|
from pathlib import Path
|
|
from typing import Callable, Literal, Union, Dict, Optional
|
|
from picklescan.scanner import scan_file_path
|
|
|
|
from .models import (
|
|
BaseModelType,
|
|
ModelType,
|
|
ModelVariantType,
|
|
SchedulerPredictionType,
|
|
SilenceWarnings,
|
|
InvalidModelException,
|
|
)
|
|
from .models.base import read_checkpoint_meta
|
|
|
|
|
|
@dataclass
|
|
class ModelProbeInfo(object):
|
|
model_type: ModelType
|
|
base_type: BaseModelType
|
|
variant_type: ModelVariantType
|
|
prediction_type: SchedulerPredictionType
|
|
upcast_attention: bool
|
|
format: Literal["diffusers", "checkpoint", "lycoris", "olive", "onnx"]
|
|
image_size: int
|
|
|
|
|
|
class ProbeBase(object):
|
|
"""forward declaration"""
|
|
|
|
pass
|
|
|
|
|
|
class ModelProbe(object):
|
|
PROBES = {
|
|
"diffusers": {},
|
|
"checkpoint": {},
|
|
"onnx": {},
|
|
}
|
|
|
|
CLASS2TYPE = {
|
|
"StableDiffusionPipeline": ModelType.Main,
|
|
"StableDiffusionInpaintPipeline": ModelType.Main,
|
|
"StableDiffusionXLPipeline": ModelType.Main,
|
|
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
|
|
"AutoencoderKL": ModelType.Vae,
|
|
"ControlNetModel": ModelType.ControlNet,
|
|
}
|
|
|
|
@classmethod
|
|
def register_probe(cls, format: Literal["diffusers", "checkpoint", "onnx"], model_type: ModelType, probe_class: ProbeBase):
|
|
cls.PROBES[format][model_type] = probe_class
|
|
|
|
@classmethod
|
|
def heuristic_probe(
|
|
cls,
|
|
model: Union[Dict, ModelMixin, Path],
|
|
prediction_type_helper: Callable[[Path], SchedulerPredictionType] = None,
|
|
) -> ModelProbeInfo:
|
|
if isinstance(model, Path):
|
|
return cls.probe(model_path=model, prediction_type_helper=prediction_type_helper)
|
|
elif isinstance(model, (dict, ModelMixin, ConfigMixin)):
|
|
return cls.probe(model_path=None, model=model, prediction_type_helper=prediction_type_helper)
|
|
else:
|
|
raise InvalidModelException("model parameter {model} is neither a Path, nor a model")
|
|
|
|
@classmethod
|
|
def probe(
|
|
cls,
|
|
model_path: Path,
|
|
model: Optional[Union[Dict, ModelMixin]] = None,
|
|
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
|
) -> ModelProbeInfo:
|
|
"""
|
|
Probe the model at model_path and return sufficient information about it
|
|
to place it somewhere in the models directory hierarchy. If the model is
|
|
already loaded into memory, you may provide it as model in order to avoid
|
|
opening it a second time. The prediction_type_helper callable is a function that receives
|
|
the path to the model and returns the BaseModelType. It is called to distinguish
|
|
between V2-Base and V2-768 SD models.
|
|
"""
|
|
if model_path:
|
|
format_type = "diffusers" if model_path.is_dir() else "checkpoint"
|
|
else:
|
|
format_type = "diffusers" if isinstance(model, (ConfigMixin, ModelMixin)) else "checkpoint"
|
|
model_info = None
|
|
try:
|
|
model_type = (
|
|
cls.get_model_type_from_folder(model_path, model)
|
|
if format_type == "diffusers"
|
|
else cls.get_model_type_from_checkpoint(model_path, model)
|
|
)
|
|
format_type = 'onnx' if model_type == ModelType.ONNX else format_type
|
|
probe_class = cls.PROBES[format_type].get(model_type)
|
|
if not probe_class:
|
|
return None
|
|
probe = probe_class(model_path, model, prediction_type_helper)
|
|
base_type = probe.get_base_type()
|
|
variant_type = probe.get_variant_type()
|
|
prediction_type = probe.get_scheduler_prediction_type()
|
|
format = probe.get_format()
|
|
model_info = ModelProbeInfo(
|
|
model_type=model_type,
|
|
base_type=base_type,
|
|
variant_type=variant_type,
|
|
prediction_type=prediction_type,
|
|
upcast_attention=(
|
|
base_type == BaseModelType.StableDiffusion2
|
|
and prediction_type == SchedulerPredictionType.VPrediction
|
|
),
|
|
format=format,
|
|
image_size=1024
|
|
if (base_type in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner})
|
|
else 768
|
|
if (
|
|
base_type == BaseModelType.StableDiffusion2
|
|
and prediction_type == SchedulerPredictionType.VPrediction
|
|
)
|
|
else 512,
|
|
)
|
|
except Exception:
|
|
raise
|
|
|
|
return model_info
|
|
|
|
@classmethod
|
|
def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: dict) -> ModelType:
|
|
if model_path.suffix not in (".bin", ".pt", ".ckpt", ".safetensors", ".pth"):
|
|
return None
|
|
|
|
if model_path.name == "learned_embeds.bin":
|
|
return ModelType.TextualInversion
|
|
|
|
ckpt = checkpoint if checkpoint else read_checkpoint_meta(model_path, scan=True)
|
|
ckpt = ckpt.get("state_dict", ckpt)
|
|
|
|
for key in ckpt.keys():
|
|
if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}):
|
|
return ModelType.Main
|
|
elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}):
|
|
return ModelType.Vae
|
|
elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
|
|
return ModelType.Lora
|
|
elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}):
|
|
return ModelType.Lora
|
|
elif any(key.startswith(v) for v in {"control_model", "input_blocks"}):
|
|
return ModelType.ControlNet
|
|
elif key in {"emb_params", "string_to_param"}:
|
|
return ModelType.TextualInversion
|
|
|
|
else:
|
|
# diffusers-ti
|
|
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
|
|
return ModelType.TextualInversion
|
|
|
|
raise InvalidModelException(f"Unable to determine model type for {model_path}")
|
|
|
|
@classmethod
|
|
def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin) -> ModelType:
|
|
"""
|
|
Get the model type of a hugging-face style folder.
|
|
"""
|
|
class_name = None
|
|
if model:
|
|
class_name = model.__class__.__name__
|
|
else:
|
|
if (folder_path / 'unet/model.onnx').exists():
|
|
return ModelType.ONNX
|
|
if (folder_path / "learned_embeds.bin").exists():
|
|
return ModelType.TextualInversion
|
|
|
|
if (folder_path / "pytorch_lora_weights.bin").exists():
|
|
return ModelType.Lora
|
|
|
|
i = folder_path / "model_index.json"
|
|
c = folder_path / "config.json"
|
|
config_path = i if i.exists() else c if c.exists() else None
|
|
|
|
if config_path:
|
|
with open(config_path, "r") as file:
|
|
conf = json.load(file)
|
|
class_name = conf["_class_name"]
|
|
|
|
if class_name and (type := cls.CLASS2TYPE.get(class_name)):
|
|
return type
|
|
|
|
# give up
|
|
raise InvalidModelException(f"Unable to determine model type for {folder_path}")
|
|
|
|
@classmethod
|
|
def _scan_and_load_checkpoint(cls, model_path: Path) -> dict:
|
|
with SilenceWarnings():
|
|
if model_path.suffix.endswith((".ckpt", ".pt", ".bin")):
|
|
cls._scan_model(model_path, model_path)
|
|
return torch.load(model_path)
|
|
else:
|
|
return safetensors.torch.load_file(model_path)
|
|
|
|
@classmethod
|
|
def _scan_model(cls, model_name, checkpoint):
|
|
"""
|
|
Apply picklescanner to the indicated checkpoint and issue a warning
|
|
and option to exit if an infected file is identified.
|
|
"""
|
|
# scan model
|
|
scan_result = scan_file_path(checkpoint)
|
|
if scan_result.infected_files != 0:
|
|
raise "The model {model_name} is potentially infected by malware. Aborting import."
|
|
|
|
|
|
###################################################3
|
|
# Checkpoint probing
|
|
###################################################3
|
|
class ProbeBase(object):
|
|
def get_base_type(self) -> BaseModelType:
|
|
pass
|
|
|
|
def get_variant_type(self) -> ModelVariantType:
|
|
pass
|
|
|
|
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
|
|
pass
|
|
|
|
def get_format(self) -> str:
|
|
pass
|
|
|
|
|
|
class CheckpointProbeBase(ProbeBase):
|
|
def __init__(
|
|
self, checkpoint_path: Path, checkpoint: dict, helper: Callable[[Path], SchedulerPredictionType] = None
|
|
) -> BaseModelType:
|
|
self.checkpoint = checkpoint or ModelProbe._scan_and_load_checkpoint(checkpoint_path)
|
|
self.checkpoint_path = checkpoint_path
|
|
self.helper = helper
|
|
|
|
def get_base_type(self) -> BaseModelType:
|
|
pass
|
|
|
|
def get_format(self) -> str:
|
|
return "checkpoint"
|
|
|
|
def get_variant_type(self) -> ModelVariantType:
|
|
model_type = ModelProbe.get_model_type_from_checkpoint(self.checkpoint_path, self.checkpoint)
|
|
if model_type != ModelType.Main:
|
|
return ModelVariantType.Normal
|
|
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
|
|
in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
|
if in_channels == 9:
|
|
return ModelVariantType.Inpaint
|
|
elif in_channels == 5:
|
|
return ModelVariantType.Depth
|
|
elif in_channels == 4:
|
|
return ModelVariantType.Normal
|
|
else:
|
|
raise InvalidModelException(
|
|
f"Cannot determine variant type (in_channels={in_channels}) at {self.checkpoint_path}"
|
|
)
|
|
|
|
|
|
class PipelineCheckpointProbe(CheckpointProbeBase):
|
|
def get_base_type(self) -> BaseModelType:
|
|
checkpoint = self.checkpoint
|
|
state_dict = self.checkpoint.get("state_dict") or checkpoint
|
|
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
|
if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
|
|
return BaseModelType.StableDiffusion1
|
|
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
|
|
return BaseModelType.StableDiffusion2
|
|
key_name = "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight"
|
|
if key_name in state_dict and state_dict[key_name].shape[-1] == 2048:
|
|
return BaseModelType.StableDiffusionXL
|
|
elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280:
|
|
return BaseModelType.StableDiffusionXLRefiner
|
|
else:
|
|
raise InvalidModelException("Cannot determine base type")
|
|
|
|
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
|
|
type = self.get_base_type()
|
|
if type == BaseModelType.StableDiffusion1:
|
|
return SchedulerPredictionType.Epsilon
|
|
checkpoint = self.checkpoint
|
|
state_dict = self.checkpoint.get("state_dict") or checkpoint
|
|
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
|
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
|
|
if "global_step" in checkpoint:
|
|
if checkpoint["global_step"] == 220000:
|
|
return SchedulerPredictionType.Epsilon
|
|
elif checkpoint["global_step"] == 110000:
|
|
return SchedulerPredictionType.VPrediction
|
|
if (
|
|
self.checkpoint_path and self.helper and not self.checkpoint_path.with_suffix(".yaml").exists()
|
|
): # if a .yaml config file exists, then this step not needed
|
|
return self.helper(self.checkpoint_path)
|
|
else:
|
|
return None
|
|
|
|
|
|
class VaeCheckpointProbe(CheckpointProbeBase):
|
|
def get_base_type(self) -> BaseModelType:
|
|
# I can't find any standalone 2.X VAEs to test with!
|
|
return BaseModelType.StableDiffusion1
|
|
|
|
|
|
class LoRACheckpointProbe(CheckpointProbeBase):
|
|
def get_format(self) -> str:
|
|
return "lycoris"
|
|
|
|
def get_base_type(self) -> BaseModelType:
|
|
checkpoint = self.checkpoint
|
|
key1 = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
|
|
key2 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.hada_w1_a"
|
|
lora_token_vector_length = (
|
|
checkpoint[key1].shape[1]
|
|
if key1 in checkpoint
|
|
else checkpoint[key2].shape[0]
|
|
if key2 in checkpoint
|
|
else 768
|
|
)
|
|
if lora_token_vector_length == 768:
|
|
return BaseModelType.StableDiffusion1
|
|
elif lora_token_vector_length == 1024:
|
|
return BaseModelType.StableDiffusion2
|
|
else:
|
|
return None
|
|
|
|
|
|
class TextualInversionCheckpointProbe(CheckpointProbeBase):
|
|
def get_format(self) -> str:
|
|
return None
|
|
|
|
def get_base_type(self) -> BaseModelType:
|
|
checkpoint = self.checkpoint
|
|
if "string_to_token" in checkpoint:
|
|
token_dim = list(checkpoint["string_to_param"].values())[0].shape[-1]
|
|
elif "emb_params" in checkpoint:
|
|
token_dim = checkpoint["emb_params"].shape[-1]
|
|
else:
|
|
token_dim = list(checkpoint.values())[0].shape[0]
|
|
if token_dim == 768:
|
|
return BaseModelType.StableDiffusion1
|
|
elif token_dim == 1024:
|
|
return BaseModelType.StableDiffusion2
|
|
else:
|
|
return None
|
|
|
|
|
|
class ControlNetCheckpointProbe(CheckpointProbeBase):
|
|
def get_base_type(self) -> BaseModelType:
|
|
checkpoint = self.checkpoint
|
|
for key_name in (
|
|
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
|
|
"input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
|
|
):
|
|
if key_name not in checkpoint:
|
|
continue
|
|
if checkpoint[key_name].shape[-1] == 768:
|
|
return BaseModelType.StableDiffusion1
|
|
elif checkpoint[key_name].shape[-1] == 1024:
|
|
return BaseModelType.StableDiffusion2
|
|
elif self.checkpoint_path and self.helper:
|
|
return self.helper(self.checkpoint_path)
|
|
raise InvalidModelException("Unable to determine base type for {self.checkpoint_path}")
|
|
|
|
|
|
########################################################
|
|
# classes for probing folders
|
|
#######################################################
|
|
class FolderProbeBase(ProbeBase):
|
|
def __init__(self, folder_path: Path, model: ModelMixin = None, helper: Callable = None): # not used
|
|
self.model = model
|
|
self.folder_path = folder_path
|
|
|
|
def get_variant_type(self) -> ModelVariantType:
|
|
return ModelVariantType.Normal
|
|
|
|
def get_format(self) -> str:
|
|
return "diffusers"
|
|
|
|
|
|
class PipelineFolderProbe(FolderProbeBase):
|
|
def get_base_type(self) -> BaseModelType:
|
|
if self.model:
|
|
unet_conf = self.model.unet.config
|
|
else:
|
|
with open(self.folder_path / "unet" / "config.json", "r") as file:
|
|
unet_conf = json.load(file)
|
|
if unet_conf["cross_attention_dim"] == 768:
|
|
return BaseModelType.StableDiffusion1
|
|
elif unet_conf["cross_attention_dim"] == 1024:
|
|
return BaseModelType.StableDiffusion2
|
|
elif unet_conf["cross_attention_dim"] == 1280:
|
|
return BaseModelType.StableDiffusionXLRefiner
|
|
elif unet_conf["cross_attention_dim"] == 2048:
|
|
return BaseModelType.StableDiffusionXL
|
|
else:
|
|
raise InvalidModelException(f"Unknown base model for {self.folder_path}")
|
|
|
|
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
|
|
if self.model:
|
|
scheduler_conf = self.model.scheduler.config
|
|
else:
|
|
with open(self.folder_path / "scheduler" / "scheduler_config.json", "r") as file:
|
|
scheduler_conf = json.load(file)
|
|
if scheduler_conf["prediction_type"] == "v_prediction":
|
|
return SchedulerPredictionType.VPrediction
|
|
elif scheduler_conf["prediction_type"] == "epsilon":
|
|
return SchedulerPredictionType.Epsilon
|
|
else:
|
|
return None
|
|
|
|
def get_variant_type(self) -> ModelVariantType:
|
|
# This only works for pipelines! Any kind of
|
|
# exception results in our returning the
|
|
# "normal" variant type
|
|
try:
|
|
if self.model:
|
|
conf = self.model.unet.config
|
|
else:
|
|
config_file = self.folder_path / "unet" / "config.json"
|
|
with open(config_file, "r") as file:
|
|
conf = json.load(file)
|
|
|
|
in_channels = conf["in_channels"]
|
|
if in_channels == 9:
|
|
return ModelVariantType.Inpaint
|
|
elif in_channels == 5:
|
|
return ModelVariantType.Depth
|
|
elif in_channels == 4:
|
|
return ModelVariantType.Normal
|
|
except:
|
|
pass
|
|
return ModelVariantType.Normal
|
|
|
|
|
|
class VaeFolderProbe(FolderProbeBase):
|
|
def get_base_type(self) -> BaseModelType:
|
|
config_file = self.folder_path / "config.json"
|
|
if not config_file.exists():
|
|
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
|
|
with open(config_file, "r") as file:
|
|
config = json.load(file)
|
|
return (
|
|
BaseModelType.StableDiffusionXL
|
|
if config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024]
|
|
else BaseModelType.StableDiffusion1
|
|
)
|
|
|
|
|
|
class TextualInversionFolderProbe(FolderProbeBase):
|
|
def get_format(self) -> str:
|
|
return None
|
|
|
|
def get_base_type(self) -> BaseModelType:
|
|
path = self.folder_path / "learned_embeds.bin"
|
|
if not path.exists():
|
|
return None
|
|
checkpoint = ModelProbe._scan_and_load_checkpoint(path)
|
|
return TextualInversionCheckpointProbe(None, checkpoint=checkpoint).get_base_type()
|
|
|
|
|
|
class ONNXFolderProbe(FolderProbeBase):
|
|
def get_format(self) -> str:
|
|
return "onnx"
|
|
|
|
def get_base_type(self) -> BaseModelType:
|
|
return BaseModelType.StableDiffusion1
|
|
|
|
def get_variant_type(self) -> ModelVariantType:
|
|
return ModelVariantType.Normal
|
|
|
|
class ControlNetFolderProbe(FolderProbeBase):
|
|
def get_base_type(self) -> BaseModelType:
|
|
config_file = self.folder_path / "config.json"
|
|
if not config_file.exists():
|
|
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
|
|
with open(config_file, "r") as file:
|
|
config = json.load(file)
|
|
# no obvious way to distinguish between sd2-base and sd2-768
|
|
return (
|
|
BaseModelType.StableDiffusion1 if config["cross_attention_dim"] == 768 else BaseModelType.StableDiffusion2
|
|
)
|
|
|
|
|
|
class LoRAFolderProbe(FolderProbeBase):
|
|
def get_base_type(self) -> BaseModelType:
|
|
model_file = None
|
|
for suffix in ["safetensors", "bin"]:
|
|
base_file = self.folder_path / f"pytorch_lora_weights.{suffix}"
|
|
if base_file.exists():
|
|
model_file = base_file
|
|
break
|
|
if not model_file:
|
|
raise InvalidModelException("Unknown LoRA format encountered")
|
|
return LoRACheckpointProbe(model_file, None).get_base_type()
|
|
|
|
|
|
############## register probe classes ######
|
|
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
|
|
ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe)
|
|
ModelProbe.register_probe("diffusers", ModelType.Lora, LoRAFolderProbe)
|
|
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
|
|
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
|
|
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
|
|
ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe)
|
|
ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe)
|
|
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
|
|
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
|
|
ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)
|