mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
automatically convert models.yaml to new format
This commit is contained in:
1
invokeai/backend/model_management/README
Normal file
1
invokeai/backend/model_management/README
Normal file
@ -0,0 +1 @@
|
|||||||
|
The contents of this directory are deprecated. model_manager.py is here only for reference.
|
@ -1,605 +0,0 @@
|
|||||||
import json
|
|
||||||
import re
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Callable, Dict, Literal, Optional, Union
|
|
||||||
|
|
||||||
import safetensors.torch
|
|
||||||
import torch
|
|
||||||
from diffusers import ConfigMixin, ModelMixin
|
|
||||||
from picklescan.scanner import scan_file_path
|
|
||||||
|
|
||||||
from invokeai.backend.model_management.models.ip_adapter import IPAdapterModelFormat
|
|
||||||
|
|
||||||
from .models import (
|
|
||||||
BaseModelType,
|
|
||||||
InvalidModelException,
|
|
||||||
ModelType,
|
|
||||||
ModelVariantType,
|
|
||||||
SchedulerPredictionType,
|
|
||||||
SilenceWarnings,
|
|
||||||
)
|
|
||||||
from .models.base import read_checkpoint_meta
|
|
||||||
from .util import lora_token_vector_length
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ModelProbeInfo(object):
|
|
||||||
model_type: ModelType
|
|
||||||
base_type: BaseModelType
|
|
||||||
variant_type: ModelVariantType
|
|
||||||
prediction_type: SchedulerPredictionType
|
|
||||||
upcast_attention: bool
|
|
||||||
format: Literal["diffusers", "checkpoint", "lycoris", "olive", "onnx"]
|
|
||||||
image_size: int
|
|
||||||
|
|
||||||
|
|
||||||
class ProbeBase(object):
|
|
||||||
"""forward declaration"""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ModelProbe(object):
|
|
||||||
PROBES = {
|
|
||||||
"diffusers": {},
|
|
||||||
"checkpoint": {},
|
|
||||||
"onnx": {},
|
|
||||||
}
|
|
||||||
|
|
||||||
CLASS2TYPE = {
|
|
||||||
"StableDiffusionPipeline": ModelType.Main,
|
|
||||||
"StableDiffusionInpaintPipeline": ModelType.Main,
|
|
||||||
"StableDiffusionXLPipeline": ModelType.Main,
|
|
||||||
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
|
|
||||||
"StableDiffusionXLInpaintPipeline": ModelType.Main,
|
|
||||||
"AutoencoderKL": ModelType.Vae,
|
|
||||||
"AutoencoderTiny": ModelType.Vae,
|
|
||||||
"ControlNetModel": ModelType.ControlNet,
|
|
||||||
"CLIPVisionModelWithProjection": ModelType.CLIPVision,
|
|
||||||
}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def register_probe(
|
|
||||||
cls, format: Literal["diffusers", "checkpoint", "onnx"], model_type: ModelType, probe_class: ProbeBase
|
|
||||||
):
|
|
||||||
cls.PROBES[format][model_type] = probe_class
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def heuristic_probe(
|
|
||||||
cls,
|
|
||||||
model: Union[Dict, ModelMixin, Path],
|
|
||||||
prediction_type_helper: Callable[[Path], SchedulerPredictionType] = None,
|
|
||||||
) -> ModelProbeInfo:
|
|
||||||
if isinstance(model, Path):
|
|
||||||
return cls.probe(model_path=model, prediction_type_helper=prediction_type_helper)
|
|
||||||
elif isinstance(model, (dict, ModelMixin, ConfigMixin)):
|
|
||||||
return cls.probe(model_path=None, model=model, prediction_type_helper=prediction_type_helper)
|
|
||||||
else:
|
|
||||||
raise InvalidModelException("model parameter {model} is neither a Path, nor a model")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def probe(
|
|
||||||
cls,
|
|
||||||
model_path: Path,
|
|
||||||
model: Optional[Union[Dict, ModelMixin]] = None,
|
|
||||||
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
|
||||||
) -> Optional[ModelProbeInfo]:
|
|
||||||
"""
|
|
||||||
Probe the model at model_path and return sufficient information about it
|
|
||||||
to place it somewhere in the models directory hierarchy. If the model is
|
|
||||||
already loaded into memory, you may provide it as model in order to avoid
|
|
||||||
opening it a second time. The prediction_type_helper callable is a function that receives
|
|
||||||
the path to the model and returns the BaseModelType. It is called to distinguish
|
|
||||||
between V2-Base and V2-768 SD models.
|
|
||||||
"""
|
|
||||||
if model_path:
|
|
||||||
format_type = "diffusers" if model_path.is_dir() else "checkpoint"
|
|
||||||
else:
|
|
||||||
format_type = "diffusers" if isinstance(model, (ConfigMixin, ModelMixin)) else "checkpoint"
|
|
||||||
model_info = None
|
|
||||||
try:
|
|
||||||
model_type = (
|
|
||||||
cls.get_model_type_from_folder(model_path, model)
|
|
||||||
if format_type == "diffusers"
|
|
||||||
else cls.get_model_type_from_checkpoint(model_path, model)
|
|
||||||
)
|
|
||||||
format_type = "onnx" if model_type == ModelType.ONNX else format_type
|
|
||||||
probe_class = cls.PROBES[format_type].get(model_type)
|
|
||||||
if not probe_class:
|
|
||||||
return None
|
|
||||||
probe = probe_class(model_path, model, prediction_type_helper)
|
|
||||||
base_type = probe.get_base_type()
|
|
||||||
variant_type = probe.get_variant_type()
|
|
||||||
prediction_type = probe.get_scheduler_prediction_type()
|
|
||||||
format = probe.get_format()
|
|
||||||
model_info = ModelProbeInfo(
|
|
||||||
model_type=model_type,
|
|
||||||
base_type=base_type,
|
|
||||||
variant_type=variant_type,
|
|
||||||
prediction_type=prediction_type,
|
|
||||||
upcast_attention=(
|
|
||||||
base_type == BaseModelType.StableDiffusion2
|
|
||||||
and prediction_type == SchedulerPredictionType.VPrediction
|
|
||||||
),
|
|
||||||
format=format,
|
|
||||||
image_size=(
|
|
||||||
1024
|
|
||||||
if (base_type in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner})
|
|
||||||
else (
|
|
||||||
768
|
|
||||||
if (
|
|
||||||
base_type == BaseModelType.StableDiffusion2
|
|
||||||
and prediction_type == SchedulerPredictionType.VPrediction
|
|
||||||
)
|
|
||||||
else 512
|
|
||||||
)
|
|
||||||
),
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
raise
|
|
||||||
|
|
||||||
return model_info
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: dict) -> ModelType:
|
|
||||||
if model_path.suffix not in (".bin", ".pt", ".ckpt", ".safetensors", ".pth"):
|
|
||||||
return None
|
|
||||||
|
|
||||||
if model_path.name == "learned_embeds.bin":
|
|
||||||
return ModelType.TextualInversion
|
|
||||||
|
|
||||||
ckpt = checkpoint if checkpoint else read_checkpoint_meta(model_path, scan=True)
|
|
||||||
ckpt = ckpt.get("state_dict", ckpt)
|
|
||||||
|
|
||||||
for key in ckpt.keys():
|
|
||||||
if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}):
|
|
||||||
return ModelType.Main
|
|
||||||
elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}):
|
|
||||||
return ModelType.Vae
|
|
||||||
elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
|
|
||||||
return ModelType.Lora
|
|
||||||
elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}):
|
|
||||||
return ModelType.Lora
|
|
||||||
elif any(key.startswith(v) for v in {"control_model", "input_blocks"}):
|
|
||||||
return ModelType.ControlNet
|
|
||||||
elif key in {"emb_params", "string_to_param"}:
|
|
||||||
return ModelType.TextualInversion
|
|
||||||
|
|
||||||
else:
|
|
||||||
# diffusers-ti
|
|
||||||
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
|
|
||||||
return ModelType.TextualInversion
|
|
||||||
|
|
||||||
raise InvalidModelException(f"Unable to determine model type for {model_path}")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin) -> ModelType:
|
|
||||||
"""
|
|
||||||
Get the model type of a hugging-face style folder.
|
|
||||||
"""
|
|
||||||
class_name = None
|
|
||||||
error_hint = None
|
|
||||||
if model:
|
|
||||||
class_name = model.__class__.__name__
|
|
||||||
else:
|
|
||||||
if (folder_path / "unet/model.onnx").exists():
|
|
||||||
return ModelType.ONNX
|
|
||||||
if (folder_path / "learned_embeds.bin").exists():
|
|
||||||
return ModelType.TextualInversion
|
|
||||||
if (folder_path / "pytorch_lora_weights.bin").exists():
|
|
||||||
return ModelType.Lora
|
|
||||||
if (folder_path / "image_encoder.txt").exists():
|
|
||||||
return ModelType.IPAdapter
|
|
||||||
|
|
||||||
i = folder_path / "model_index.json"
|
|
||||||
c = folder_path / "config.json"
|
|
||||||
config_path = i if i.exists() else c if c.exists() else None
|
|
||||||
|
|
||||||
if config_path:
|
|
||||||
with open(config_path, "r") as file:
|
|
||||||
conf = json.load(file)
|
|
||||||
if "_class_name" in conf:
|
|
||||||
class_name = conf["_class_name"]
|
|
||||||
elif "architectures" in conf:
|
|
||||||
class_name = conf["architectures"][0]
|
|
||||||
else:
|
|
||||||
class_name = None
|
|
||||||
else:
|
|
||||||
error_hint = f"No model_index.json or config.json found in {folder_path}."
|
|
||||||
|
|
||||||
if class_name and (type := cls.CLASS2TYPE.get(class_name)):
|
|
||||||
return type
|
|
||||||
else:
|
|
||||||
error_hint = f"class {class_name} is not one of the supported classes [{', '.join(cls.CLASS2TYPE.keys())}]"
|
|
||||||
|
|
||||||
# give up
|
|
||||||
raise InvalidModelException(
|
|
||||||
f"Unable to determine model type for {folder_path}" + (f"; {error_hint}" if error_hint else "")
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _scan_and_load_checkpoint(cls, model_path: Path) -> dict:
|
|
||||||
with SilenceWarnings():
|
|
||||||
if model_path.suffix.endswith((".ckpt", ".pt", ".bin")):
|
|
||||||
cls._scan_model(model_path, model_path)
|
|
||||||
return torch.load(model_path)
|
|
||||||
else:
|
|
||||||
return safetensors.torch.load_file(model_path)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _scan_model(cls, model_name, checkpoint):
|
|
||||||
"""
|
|
||||||
Apply picklescanner to the indicated checkpoint and issue a warning
|
|
||||||
and option to exit if an infected file is identified.
|
|
||||||
"""
|
|
||||||
# scan model
|
|
||||||
scan_result = scan_file_path(checkpoint)
|
|
||||||
if scan_result.infected_files != 0:
|
|
||||||
raise "The model {model_name} is potentially infected by malware. Aborting import."
|
|
||||||
|
|
||||||
|
|
||||||
# ##################################################3
|
|
||||||
# Checkpoint probing
|
|
||||||
# ##################################################3
|
|
||||||
class ProbeBase(object):
|
|
||||||
def get_base_type(self) -> BaseModelType:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def get_variant_type(self) -> ModelVariantType:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def get_format(self) -> str:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class CheckpointProbeBase(ProbeBase):
|
|
||||||
def __init__(
|
|
||||||
self, checkpoint_path: Path, checkpoint: dict, helper: Callable[[Path], SchedulerPredictionType] = None
|
|
||||||
) -> BaseModelType:
|
|
||||||
self.checkpoint = checkpoint or ModelProbe._scan_and_load_checkpoint(checkpoint_path)
|
|
||||||
self.checkpoint_path = checkpoint_path
|
|
||||||
self.helper = helper
|
|
||||||
|
|
||||||
def get_base_type(self) -> BaseModelType:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def get_format(self) -> str:
|
|
||||||
return "checkpoint"
|
|
||||||
|
|
||||||
def get_variant_type(self) -> ModelVariantType:
|
|
||||||
model_type = ModelProbe.get_model_type_from_checkpoint(self.checkpoint_path, self.checkpoint)
|
|
||||||
if model_type != ModelType.Main:
|
|
||||||
return ModelVariantType.Normal
|
|
||||||
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
|
|
||||||
in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
|
||||||
if in_channels == 9:
|
|
||||||
return ModelVariantType.Inpaint
|
|
||||||
elif in_channels == 5:
|
|
||||||
return ModelVariantType.Depth
|
|
||||||
elif in_channels == 4:
|
|
||||||
return ModelVariantType.Normal
|
|
||||||
else:
|
|
||||||
raise InvalidModelException(
|
|
||||||
f"Cannot determine variant type (in_channels={in_channels}) at {self.checkpoint_path}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class PipelineCheckpointProbe(CheckpointProbeBase):
|
|
||||||
def get_base_type(self) -> BaseModelType:
|
|
||||||
checkpoint = self.checkpoint
|
|
||||||
state_dict = self.checkpoint.get("state_dict") or checkpoint
|
|
||||||
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
|
||||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
|
|
||||||
return BaseModelType.StableDiffusion1
|
|
||||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
|
|
||||||
return BaseModelType.StableDiffusion2
|
|
||||||
key_name = "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight"
|
|
||||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 2048:
|
|
||||||
return BaseModelType.StableDiffusionXL
|
|
||||||
elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280:
|
|
||||||
return BaseModelType.StableDiffusionXLRefiner
|
|
||||||
else:
|
|
||||||
raise InvalidModelException("Cannot determine base type")
|
|
||||||
|
|
||||||
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
|
|
||||||
type = self.get_base_type()
|
|
||||||
if type == BaseModelType.StableDiffusion1:
|
|
||||||
return SchedulerPredictionType.Epsilon
|
|
||||||
checkpoint = self.checkpoint
|
|
||||||
state_dict = self.checkpoint.get("state_dict") or checkpoint
|
|
||||||
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
|
||||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
|
|
||||||
if "global_step" in checkpoint:
|
|
||||||
if checkpoint["global_step"] == 220000:
|
|
||||||
return SchedulerPredictionType.Epsilon
|
|
||||||
elif checkpoint["global_step"] == 110000:
|
|
||||||
return SchedulerPredictionType.VPrediction
|
|
||||||
if (
|
|
||||||
self.checkpoint_path and self.helper and not self.checkpoint_path.with_suffix(".yaml").exists()
|
|
||||||
): # if a .yaml config file exists, then this step not needed
|
|
||||||
return self.helper(self.checkpoint_path)
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class VaeCheckpointProbe(CheckpointProbeBase):
|
|
||||||
def get_base_type(self) -> BaseModelType:
|
|
||||||
# I can't find any standalone 2.X VAEs to test with!
|
|
||||||
return BaseModelType.StableDiffusion1
|
|
||||||
|
|
||||||
|
|
||||||
class LoRACheckpointProbe(CheckpointProbeBase):
|
|
||||||
def get_format(self) -> str:
|
|
||||||
return "lycoris"
|
|
||||||
|
|
||||||
def get_base_type(self) -> BaseModelType:
|
|
||||||
checkpoint = self.checkpoint
|
|
||||||
token_vector_length = lora_token_vector_length(checkpoint)
|
|
||||||
|
|
||||||
if token_vector_length == 768:
|
|
||||||
return BaseModelType.StableDiffusion1
|
|
||||||
elif token_vector_length == 1024:
|
|
||||||
return BaseModelType.StableDiffusion2
|
|
||||||
elif token_vector_length == 2048:
|
|
||||||
return BaseModelType.StableDiffusionXL
|
|
||||||
else:
|
|
||||||
raise InvalidModelException(f"Unknown LoRA type: {self.checkpoint_path}")
|
|
||||||
|
|
||||||
|
|
||||||
class TextualInversionCheckpointProbe(CheckpointProbeBase):
|
|
||||||
def get_format(self) -> str:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_base_type(self) -> BaseModelType:
|
|
||||||
checkpoint = self.checkpoint
|
|
||||||
if "string_to_token" in checkpoint:
|
|
||||||
token_dim = list(checkpoint["string_to_param"].values())[0].shape[-1]
|
|
||||||
elif "emb_params" in checkpoint:
|
|
||||||
token_dim = checkpoint["emb_params"].shape[-1]
|
|
||||||
else:
|
|
||||||
token_dim = list(checkpoint.values())[0].shape[0]
|
|
||||||
if token_dim == 768:
|
|
||||||
return BaseModelType.StableDiffusion1
|
|
||||||
elif token_dim == 1024:
|
|
||||||
return BaseModelType.StableDiffusion2
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class ControlNetCheckpointProbe(CheckpointProbeBase):
|
|
||||||
def get_base_type(self) -> BaseModelType:
|
|
||||||
checkpoint = self.checkpoint
|
|
||||||
for key_name in (
|
|
||||||
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
|
|
||||||
"input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
|
|
||||||
):
|
|
||||||
if key_name not in checkpoint:
|
|
||||||
continue
|
|
||||||
if checkpoint[key_name].shape[-1] == 768:
|
|
||||||
return BaseModelType.StableDiffusion1
|
|
||||||
elif checkpoint[key_name].shape[-1] == 1024:
|
|
||||||
return BaseModelType.StableDiffusion2
|
|
||||||
elif self.checkpoint_path and self.helper:
|
|
||||||
return self.helper(self.checkpoint_path)
|
|
||||||
raise InvalidModelException("Unable to determine base type for {self.checkpoint_path}")
|
|
||||||
|
|
||||||
|
|
||||||
class IPAdapterCheckpointProbe(CheckpointProbeBase):
|
|
||||||
def get_base_type(self) -> BaseModelType:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
|
|
||||||
class CLIPVisionCheckpointProbe(CheckpointProbeBase):
|
|
||||||
def get_base_type(self) -> BaseModelType:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
|
|
||||||
########################################################
|
|
||||||
# classes for probing folders
|
|
||||||
#######################################################
|
|
||||||
class FolderProbeBase(ProbeBase):
|
|
||||||
def __init__(self, folder_path: Path, model: ModelMixin = None, helper: Callable = None): # not used
|
|
||||||
self.model = model
|
|
||||||
self.folder_path = folder_path
|
|
||||||
|
|
||||||
def get_variant_type(self) -> ModelVariantType:
|
|
||||||
return ModelVariantType.Normal
|
|
||||||
|
|
||||||
def get_format(self) -> str:
|
|
||||||
return "diffusers"
|
|
||||||
|
|
||||||
|
|
||||||
class PipelineFolderProbe(FolderProbeBase):
|
|
||||||
def get_base_type(self) -> BaseModelType:
|
|
||||||
if self.model:
|
|
||||||
unet_conf = self.model.unet.config
|
|
||||||
else:
|
|
||||||
with open(self.folder_path / "unet" / "config.json", "r") as file:
|
|
||||||
unet_conf = json.load(file)
|
|
||||||
if unet_conf["cross_attention_dim"] == 768:
|
|
||||||
return BaseModelType.StableDiffusion1
|
|
||||||
elif unet_conf["cross_attention_dim"] == 1024:
|
|
||||||
return BaseModelType.StableDiffusion2
|
|
||||||
elif unet_conf["cross_attention_dim"] == 1280:
|
|
||||||
return BaseModelType.StableDiffusionXLRefiner
|
|
||||||
elif unet_conf["cross_attention_dim"] == 2048:
|
|
||||||
return BaseModelType.StableDiffusionXL
|
|
||||||
else:
|
|
||||||
raise InvalidModelException(f"Unknown base model for {self.folder_path}")
|
|
||||||
|
|
||||||
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
|
|
||||||
if self.model:
|
|
||||||
scheduler_conf = self.model.scheduler.config
|
|
||||||
else:
|
|
||||||
with open(self.folder_path / "scheduler" / "scheduler_config.json", "r") as file:
|
|
||||||
scheduler_conf = json.load(file)
|
|
||||||
if scheduler_conf["prediction_type"] == "v_prediction":
|
|
||||||
return SchedulerPredictionType.VPrediction
|
|
||||||
elif scheduler_conf["prediction_type"] == "epsilon":
|
|
||||||
return SchedulerPredictionType.Epsilon
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_variant_type(self) -> ModelVariantType:
|
|
||||||
# This only works for pipelines! Any kind of
|
|
||||||
# exception results in our returning the
|
|
||||||
# "normal" variant type
|
|
||||||
try:
|
|
||||||
if self.model:
|
|
||||||
conf = self.model.unet.config
|
|
||||||
else:
|
|
||||||
config_file = self.folder_path / "unet" / "config.json"
|
|
||||||
with open(config_file, "r") as file:
|
|
||||||
conf = json.load(file)
|
|
||||||
|
|
||||||
in_channels = conf["in_channels"]
|
|
||||||
if in_channels == 9:
|
|
||||||
return ModelVariantType.Inpaint
|
|
||||||
elif in_channels == 5:
|
|
||||||
return ModelVariantType.Depth
|
|
||||||
elif in_channels == 4:
|
|
||||||
return ModelVariantType.Normal
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return ModelVariantType.Normal
|
|
||||||
|
|
||||||
|
|
||||||
class VaeFolderProbe(FolderProbeBase):
|
|
||||||
def get_base_type(self) -> BaseModelType:
|
|
||||||
if self._config_looks_like_sdxl():
|
|
||||||
return BaseModelType.StableDiffusionXL
|
|
||||||
elif self._name_looks_like_sdxl():
|
|
||||||
# but SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down
|
|
||||||
# by a factor of 8), we can't necessarily tell them apart by config hyperparameters.
|
|
||||||
return BaseModelType.StableDiffusionXL
|
|
||||||
else:
|
|
||||||
return BaseModelType.StableDiffusion1
|
|
||||||
|
|
||||||
def _config_looks_like_sdxl(self) -> bool:
|
|
||||||
# config values that distinguish Stability's SD 1.x VAE from their SDXL VAE.
|
|
||||||
config_file = self.folder_path / "config.json"
|
|
||||||
if not config_file.exists():
|
|
||||||
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
|
|
||||||
with open(config_file, "r") as file:
|
|
||||||
config = json.load(file)
|
|
||||||
return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024]
|
|
||||||
|
|
||||||
def _name_looks_like_sdxl(self) -> bool:
|
|
||||||
return bool(re.search(r"xl\b", self._guess_name(), re.IGNORECASE))
|
|
||||||
|
|
||||||
def _guess_name(self) -> str:
|
|
||||||
name = self.folder_path.name
|
|
||||||
if name == "vae":
|
|
||||||
name = self.folder_path.parent.name
|
|
||||||
return name
|
|
||||||
|
|
||||||
|
|
||||||
class TextualInversionFolderProbe(FolderProbeBase):
|
|
||||||
def get_format(self) -> str:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_base_type(self) -> BaseModelType:
|
|
||||||
path = self.folder_path / "learned_embeds.bin"
|
|
||||||
if not path.exists():
|
|
||||||
return None
|
|
||||||
checkpoint = ModelProbe._scan_and_load_checkpoint(path)
|
|
||||||
return TextualInversionCheckpointProbe(None, checkpoint=checkpoint).get_base_type()
|
|
||||||
|
|
||||||
|
|
||||||
class ONNXFolderProbe(FolderProbeBase):
|
|
||||||
def get_format(self) -> str:
|
|
||||||
return "onnx"
|
|
||||||
|
|
||||||
def get_base_type(self) -> BaseModelType:
|
|
||||||
return BaseModelType.StableDiffusion1
|
|
||||||
|
|
||||||
def get_variant_type(self) -> ModelVariantType:
|
|
||||||
return ModelVariantType.Normal
|
|
||||||
|
|
||||||
|
|
||||||
class ControlNetFolderProbe(FolderProbeBase):
|
|
||||||
def get_base_type(self) -> BaseModelType:
|
|
||||||
config_file = self.folder_path / "config.json"
|
|
||||||
if not config_file.exists():
|
|
||||||
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
|
|
||||||
with open(config_file, "r") as file:
|
|
||||||
config = json.load(file)
|
|
||||||
# no obvious way to distinguish between sd2-base and sd2-768
|
|
||||||
dimension = config["cross_attention_dim"]
|
|
||||||
base_model = (
|
|
||||||
BaseModelType.StableDiffusion1
|
|
||||||
if dimension == 768
|
|
||||||
else (
|
|
||||||
BaseModelType.StableDiffusion2
|
|
||||||
if dimension == 1024
|
|
||||||
else BaseModelType.StableDiffusionXL
|
|
||||||
if dimension == 2048
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if not base_model:
|
|
||||||
raise InvalidModelException(f"Unable to determine model base for {self.folder_path}")
|
|
||||||
return base_model
|
|
||||||
|
|
||||||
|
|
||||||
class LoRAFolderProbe(FolderProbeBase):
|
|
||||||
def get_base_type(self) -> BaseModelType:
|
|
||||||
model_file = None
|
|
||||||
for suffix in ["safetensors", "bin"]:
|
|
||||||
base_file = self.folder_path / f"pytorch_lora_weights.{suffix}"
|
|
||||||
if base_file.exists():
|
|
||||||
model_file = base_file
|
|
||||||
break
|
|
||||||
if not model_file:
|
|
||||||
raise InvalidModelException("Unknown LoRA format encountered")
|
|
||||||
return LoRACheckpointProbe(model_file, None).get_base_type()
|
|
||||||
|
|
||||||
|
|
||||||
class IPAdapterFolderProbe(FolderProbeBase):
|
|
||||||
def get_format(self) -> str:
|
|
||||||
return IPAdapterModelFormat.InvokeAI.value
|
|
||||||
|
|
||||||
def get_base_type(self) -> BaseModelType:
|
|
||||||
model_file = self.folder_path / "ip_adapter.bin"
|
|
||||||
if not model_file.exists():
|
|
||||||
raise InvalidModelException("Unknown IP-Adapter model format.")
|
|
||||||
|
|
||||||
state_dict = torch.load(model_file, map_location="cpu")
|
|
||||||
cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1]
|
|
||||||
if cross_attention_dim == 768:
|
|
||||||
return BaseModelType.StableDiffusion1
|
|
||||||
elif cross_attention_dim == 1024:
|
|
||||||
return BaseModelType.StableDiffusion2
|
|
||||||
elif cross_attention_dim == 2048:
|
|
||||||
return BaseModelType.StableDiffusionXL
|
|
||||||
else:
|
|
||||||
raise InvalidModelException(f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}.")
|
|
||||||
|
|
||||||
|
|
||||||
class CLIPVisionFolderProbe(FolderProbeBase):
|
|
||||||
def get_base_type(self) -> BaseModelType:
|
|
||||||
return BaseModelType.Any
|
|
||||||
|
|
||||||
|
|
||||||
############## register probe classes ######
|
|
||||||
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
|
|
||||||
ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe)
|
|
||||||
ModelProbe.register_probe("diffusers", ModelType.Lora, LoRAFolderProbe)
|
|
||||||
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
|
|
||||||
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
|
|
||||||
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
|
|
||||||
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
|
|
||||||
|
|
||||||
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
|
|
||||||
ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe)
|
|
||||||
ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe)
|
|
||||||
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
|
|
||||||
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
|
|
||||||
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
|
|
||||||
ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe)
|
|
||||||
|
|
||||||
ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)
|
|
@ -1,108 +0,0 @@
|
|||||||
# Copyright 2023, Lincoln D. Stein and the InvokeAI Team
|
|
||||||
"""
|
|
||||||
Abstract base class for recursive directory search for models.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import List, Set, types
|
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
|
|
||||||
|
|
||||||
class ModelSearch(ABC):
|
|
||||||
def __init__(self, directories: List[Path], logger: types.ModuleType = logger):
|
|
||||||
"""
|
|
||||||
Initialize a recursive model directory search.
|
|
||||||
:param directories: List of directory Paths to recurse through
|
|
||||||
:param logger: Logger to use
|
|
||||||
"""
|
|
||||||
self.directories = directories
|
|
||||||
self.logger = logger
|
|
||||||
self._items_scanned = 0
|
|
||||||
self._models_found = 0
|
|
||||||
self._scanned_dirs = set()
|
|
||||||
self._scanned_paths = set()
|
|
||||||
self._pruned_paths = set()
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def on_search_started(self):
|
|
||||||
"""
|
|
||||||
Called before the scan starts.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def on_model_found(self, model: Path):
|
|
||||||
"""
|
|
||||||
Process a found model. Raise an exception if something goes wrong.
|
|
||||||
:param model: Model to process - could be a directory or checkpoint.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def on_search_completed(self):
|
|
||||||
"""
|
|
||||||
Perform some activity when the scan is completed. May use instance
|
|
||||||
variables, items_scanned and models_found
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def search(self):
|
|
||||||
self.on_search_started()
|
|
||||||
for dir in self.directories:
|
|
||||||
self.walk_directory(dir)
|
|
||||||
self.on_search_completed()
|
|
||||||
|
|
||||||
def walk_directory(self, path: Path):
|
|
||||||
for root, dirs, files in os.walk(path, followlinks=True):
|
|
||||||
if str(Path(root).name).startswith("."):
|
|
||||||
self._pruned_paths.add(root)
|
|
||||||
if any([Path(root).is_relative_to(x) for x in self._pruned_paths]):
|
|
||||||
continue
|
|
||||||
|
|
||||||
self._items_scanned += len(dirs) + len(files)
|
|
||||||
for d in dirs:
|
|
||||||
path = Path(root) / d
|
|
||||||
if path in self._scanned_paths or path.parent in self._scanned_dirs:
|
|
||||||
self._scanned_dirs.add(path)
|
|
||||||
continue
|
|
||||||
if any(
|
|
||||||
[
|
|
||||||
(path / x).exists()
|
|
||||||
for x in ["config.json", "model_index.json", "learned_embeds.bin", "pytorch_lora_weights.bin"]
|
|
||||||
]
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
self.on_model_found(path)
|
|
||||||
self._models_found += 1
|
|
||||||
self._scanned_dirs.add(path)
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.warning(f"Failed to process '{path}': {e}")
|
|
||||||
|
|
||||||
for f in files:
|
|
||||||
path = Path(root) / f
|
|
||||||
if path.parent in self._scanned_dirs:
|
|
||||||
continue
|
|
||||||
if path.suffix in {".ckpt", ".bin", ".pth", ".safetensors", ".pt"}:
|
|
||||||
try:
|
|
||||||
self.on_model_found(path)
|
|
||||||
self._models_found += 1
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.warning(f"Failed to process '{path}': {e}")
|
|
||||||
|
|
||||||
|
|
||||||
class FindModels(ModelSearch):
|
|
||||||
def on_search_started(self):
|
|
||||||
self.models_found: Set[Path] = set()
|
|
||||||
|
|
||||||
def on_model_found(self, model: Path):
|
|
||||||
self.models_found.add(model)
|
|
||||||
|
|
||||||
def on_search_completed(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def list_models(self) -> List[Path]:
|
|
||||||
self.search()
|
|
||||||
return list(self.models_found)
|
|
@ -1,75 +0,0 @@
|
|||||||
# Copyright (c) 2023 The InvokeAI Development Team
|
|
||||||
"""Utilities used by the Model Manager"""
|
|
||||||
|
|
||||||
|
|
||||||
def lora_token_vector_length(checkpoint: dict) -> int:
|
|
||||||
"""
|
|
||||||
Given a checkpoint in memory, return the lora token vector length
|
|
||||||
|
|
||||||
:param checkpoint: The checkpoint
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _get_shape_1(key, tensor, checkpoint):
|
|
||||||
lora_token_vector_length = None
|
|
||||||
|
|
||||||
if "." not in key:
|
|
||||||
return lora_token_vector_length # wrong key format
|
|
||||||
model_key, lora_key = key.split(".", 1)
|
|
||||||
|
|
||||||
# check lora/locon
|
|
||||||
if lora_key == "lora_down.weight":
|
|
||||||
lora_token_vector_length = tensor.shape[1]
|
|
||||||
|
|
||||||
# check loha (don't worry about hada_t1/hada_t2 as it used only in 4d shapes)
|
|
||||||
elif lora_key in ["hada_w1_b", "hada_w2_b"]:
|
|
||||||
lora_token_vector_length = tensor.shape[1]
|
|
||||||
|
|
||||||
# check lokr (don't worry about lokr_t2 as it used only in 4d shapes)
|
|
||||||
elif "lokr_" in lora_key:
|
|
||||||
if model_key + ".lokr_w1" in checkpoint:
|
|
||||||
_lokr_w1 = checkpoint[model_key + ".lokr_w1"]
|
|
||||||
elif model_key + "lokr_w1_b" in checkpoint:
|
|
||||||
_lokr_w1 = checkpoint[model_key + ".lokr_w1_b"]
|
|
||||||
else:
|
|
||||||
return lora_token_vector_length # unknown format
|
|
||||||
|
|
||||||
if model_key + ".lokr_w2" in checkpoint:
|
|
||||||
_lokr_w2 = checkpoint[model_key + ".lokr_w2"]
|
|
||||||
elif model_key + "lokr_w2_b" in checkpoint:
|
|
||||||
_lokr_w2 = checkpoint[model_key + ".lokr_w2_b"]
|
|
||||||
else:
|
|
||||||
return lora_token_vector_length # unknown format
|
|
||||||
|
|
||||||
lora_token_vector_length = _lokr_w1.shape[1] * _lokr_w2.shape[1]
|
|
||||||
|
|
||||||
elif lora_key == "diff":
|
|
||||||
lora_token_vector_length = tensor.shape[1]
|
|
||||||
|
|
||||||
# ia3 can be detected only by shape[0] in text encoder
|
|
||||||
elif lora_key == "weight" and "lora_unet_" not in model_key:
|
|
||||||
lora_token_vector_length = tensor.shape[0]
|
|
||||||
|
|
||||||
return lora_token_vector_length
|
|
||||||
|
|
||||||
lora_token_vector_length = None
|
|
||||||
lora_te1_length = None
|
|
||||||
lora_te2_length = None
|
|
||||||
for key, tensor in checkpoint.items():
|
|
||||||
if key.startswith("lora_unet_") and ("_attn2_to_k." in key or "_attn2_to_v." in key):
|
|
||||||
lora_token_vector_length = _get_shape_1(key, tensor, checkpoint)
|
|
||||||
elif key.startswith("lora_te") and "_self_attn_" in key:
|
|
||||||
tmp_length = _get_shape_1(key, tensor, checkpoint)
|
|
||||||
if key.startswith("lora_te_"):
|
|
||||||
lora_token_vector_length = tmp_length
|
|
||||||
elif key.startswith("lora_te1_"):
|
|
||||||
lora_te1_length = tmp_length
|
|
||||||
elif key.startswith("lora_te2_"):
|
|
||||||
lora_te2_length = tmp_length
|
|
||||||
|
|
||||||
if lora_te1_length is not None and lora_te2_length is not None:
|
|
||||||
lora_token_vector_length = lora_te1_length + lora_te2_length
|
|
||||||
|
|
||||||
if lora_token_vector_length is not None:
|
|
||||||
break
|
|
||||||
|
|
||||||
return lora_token_vector_length
|
|
@ -17,7 +17,7 @@ from .config import BaseModelType, ModelConfigBase, ModelType, SubModelType
|
|||||||
from .download import DownloadEventHandler
|
from .download import DownloadEventHandler
|
||||||
from .install import ModelInstall, ModelInstallBase
|
from .install import ModelInstall, ModelInstallBase
|
||||||
from .models import MODEL_CLASSES, InvalidModelException, ModelBase
|
from .models import MODEL_CLASSES, InvalidModelException, ModelBase
|
||||||
from .storage import ModelConfigStore, get_config_store
|
from .storage import ConfigFileVersionMismatchException, ModelConfigStore, get_config_store, migrate_models_store
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -138,7 +138,12 @@ class ModelLoad(ModelLoadBase):
|
|||||||
models_file = config.model_conf_path
|
models_file = config.model_conf_path
|
||||||
else:
|
else:
|
||||||
models_file = config.root_path / "configs/models3.yaml"
|
models_file = config.root_path / "configs/models3.yaml"
|
||||||
store = get_config_store(models_file)
|
try:
|
||||||
|
store = get_config_store(models_file)
|
||||||
|
except ConfigFileVersionMismatchException:
|
||||||
|
migrate_models_store(config)
|
||||||
|
store = get_config_store(models_file)
|
||||||
|
|
||||||
if not store:
|
if not store:
|
||||||
raise ValueError(f"Invalid model configuration file: {models_file}")
|
raise ValueError(f"Invalid model configuration file: {models_file}")
|
||||||
|
|
||||||
|
@ -7,6 +7,7 @@ its base type, model type, format and variant.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import re
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
@ -493,20 +494,33 @@ class PipelineFolderProbe(FolderProbeBase):
|
|||||||
|
|
||||||
|
|
||||||
class VaeFolderProbe(FolderProbeBase):
|
class VaeFolderProbe(FolderProbeBase):
|
||||||
"""Probe a diffusers-style VAE model."""
|
|
||||||
|
|
||||||
def get_base_type(self) -> BaseModelType:
|
def get_base_type(self) -> BaseModelType:
|
||||||
"""Return the BaseModelType for a diffusers-style VAE."""
|
if self._config_looks_like_sdxl():
|
||||||
|
return BaseModelType.StableDiffusionXL
|
||||||
|
elif self._name_looks_like_sdxl():
|
||||||
|
# but SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down
|
||||||
|
# by a factor of 8), we can't necessarily tell them apart by config hyperparameters.
|
||||||
|
return BaseModelType.StableDiffusionXL
|
||||||
|
else:
|
||||||
|
return BaseModelType.StableDiffusion1
|
||||||
|
|
||||||
|
def _config_looks_like_sdxl(self) -> bool:
|
||||||
|
# config values that distinguish Stability's SD 1.x VAE from their SDXL VAE.
|
||||||
config_file = self.folder_path / "config.json"
|
config_file = self.folder_path / "config.json"
|
||||||
if not config_file.exists():
|
if not config_file.exists():
|
||||||
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
|
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
|
||||||
with open(config_file, "r") as file:
|
with open(config_file, "r") as file:
|
||||||
config = json.load(file)
|
config = json.load(file)
|
||||||
return (
|
return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024]
|
||||||
BaseModelType.StableDiffusionXL
|
|
||||||
if config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024]
|
def _name_looks_like_sdxl(self) -> bool:
|
||||||
else BaseModelType.StableDiffusion1
|
return bool(re.search(r"xl\b", self._guess_name(), re.IGNORECASE))
|
||||||
)
|
|
||||||
|
def _guess_name(self) -> str:
|
||||||
|
name = self.folder_path.name
|
||||||
|
if name == "vae":
|
||||||
|
name = self.folder_path.parent.name
|
||||||
|
return name
|
||||||
|
|
||||||
|
|
||||||
class TextualInversionFolderProbe(FolderProbeBase):
|
class TextualInversionFolderProbe(FolderProbeBase):
|
||||||
|
@ -168,7 +168,13 @@ class ModelSearch(ModelSearchBase):
|
|||||||
if any(
|
if any(
|
||||||
[
|
[
|
||||||
(path / x).exists()
|
(path / x).exists()
|
||||||
for x in ["config.json", "model_index.json", "learned_embeds.bin", "pytorch_lora_weights.bin"]
|
for x in [
|
||||||
|
"config.json",
|
||||||
|
"model_index.json",
|
||||||
|
"learned_embeds.bin",
|
||||||
|
"pytorch_lora_weights.bin",
|
||||||
|
"image_encoder.txt",
|
||||||
|
]
|
||||||
]
|
]
|
||||||
):
|
):
|
||||||
self._scanned_dirs.add(path)
|
self._scanned_dirs.add(path)
|
||||||
|
@ -3,7 +3,13 @@ Initialization file for invokeai.backend.model_manager.storage
|
|||||||
"""
|
"""
|
||||||
import pathlib
|
import pathlib
|
||||||
|
|
||||||
from .base import DuplicateModelException, ModelConfigStore, UnknownModelException # noqa F401
|
from .base import ( # noqa F401
|
||||||
|
ConfigFileVersionMismatchException,
|
||||||
|
DuplicateModelException,
|
||||||
|
ModelConfigStore,
|
||||||
|
UnknownModelException,
|
||||||
|
)
|
||||||
|
from .migrate import migrate_models_store # noqa F401
|
||||||
from .sql import ModelConfigStoreSQL # noqa F401
|
from .sql import ModelConfigStoreSQL # noqa F401
|
||||||
from .yaml import ModelConfigStoreYAML # noqa F401
|
from .yaml import ModelConfigStoreYAML # noqa F401
|
||||||
|
|
||||||
|
@ -4,6 +4,7 @@ Abstract base class for storing and retrieving model configuration records.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from pathlib import Path
|
||||||
from typing import List, Optional, Set, Union
|
from typing import List, Optional, Set, Union
|
||||||
|
|
||||||
from ..config import BaseModelType, ModelConfigBase, ModelType
|
from ..config import BaseModelType, ModelConfigBase, ModelType
|
||||||
@ -24,6 +25,10 @@ class UnknownModelException(Exception):
|
|||||||
"""Raised on an attempt to fetch or delete a model with a nonexistent key."""
|
"""Raised on an attempt to fetch or delete a model with a nonexistent key."""
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigFileVersionMismatchException(Exception):
|
||||||
|
"""Raised on an attempt to open a config with an incompatible version."""
|
||||||
|
|
||||||
|
|
||||||
class ModelConfigStore(ABC):
|
class ModelConfigStore(ABC):
|
||||||
"""Abstract base class for storage and retrieval of model configs."""
|
"""Abstract base class for storage and retrieval of model configs."""
|
||||||
|
|
||||||
@ -99,6 +104,16 @@ class ModelConfigStore(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def search_by_path(
|
||||||
|
self,
|
||||||
|
path: Union[str, Path],
|
||||||
|
) -> Optional[ModelConfigBase]:
|
||||||
|
"""
|
||||||
|
Return the model having the indicated path.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def search_by_name(
|
def search_by_name(
|
||||||
self,
|
self,
|
||||||
|
61
invokeai/backend/model_manager/storage/migrate.py
Normal file
61
invokeai/backend/model_manager/storage/migrate.py
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
# Copyright (c) 2023 The InvokeAI Development Team
|
||||||
|
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
|
from .base import CONFIG_FILE_VERSION
|
||||||
|
|
||||||
|
|
||||||
|
def migrate_models_store(config: InvokeAIAppConfig):
|
||||||
|
# avoid circular import
|
||||||
|
from invokeai.backend.model_manager import DuplicateModelException, InvalidModelException, ModelInstall
|
||||||
|
from invokeai.backend.model_manager.storage import get_config_store
|
||||||
|
|
||||||
|
app_config = InvokeAIAppConfig.get_config()
|
||||||
|
logger = InvokeAILogger.getLogger()
|
||||||
|
old_file: Path = app_config.model_conf_path
|
||||||
|
new_file: Path = old_file.with_name("models3_2.yaml")
|
||||||
|
|
||||||
|
old_conf = OmegaConf.load(old_file)
|
||||||
|
store = get_config_store(new_file)
|
||||||
|
installer = ModelInstall(store=store)
|
||||||
|
logger.info(f"Migrating old models file at {old_file} to new {CONFIG_FILE_VERSION} format")
|
||||||
|
|
||||||
|
for model_key, stanza in old_conf.items():
|
||||||
|
if model_key == "__metadata__":
|
||||||
|
assert (
|
||||||
|
stanza["version"] == "3.0.0"
|
||||||
|
), f"This script works on version 3.0.0 yaml files, but your configuration points to a {stanza['version']} version"
|
||||||
|
continue
|
||||||
|
|
||||||
|
base_type, model_type, model_name = model_key.split("/")
|
||||||
|
|
||||||
|
try:
|
||||||
|
path = app_config.models_path / stanza["path"]
|
||||||
|
new_key = installer.register_path(path)
|
||||||
|
except DuplicateModelException:
|
||||||
|
# if model already installed, then we just update its info
|
||||||
|
models = store.search_by_name(model_name=model_name, base_model=base_type, model_type=model_type)
|
||||||
|
if len(models) != 1:
|
||||||
|
continue
|
||||||
|
new_key = models[0].key
|
||||||
|
except Exception as excp:
|
||||||
|
print(str(excp))
|
||||||
|
|
||||||
|
model_info = store.get_model(new_key)
|
||||||
|
if vae := stanza.get("vae"):
|
||||||
|
model_info.vae = (app_config.models_path / vae).as_posix()
|
||||||
|
if model_config := stanza.get("config"):
|
||||||
|
model_info.config = (app_config.root_path / model_config).as_posix()
|
||||||
|
model_info.description = stanza.get("description")
|
||||||
|
store.update_model(new_key, model_info)
|
||||||
|
store.update_model(new_key, model_info)
|
||||||
|
|
||||||
|
logger.info(f"Original version of models config file saved as {str(old_file) + '.orig'}")
|
||||||
|
shutil.move(old_file, str(old_file) + ".orig")
|
||||||
|
shutil.move(new_file, old_file)
|
@ -477,3 +477,9 @@ class ModelConfigStoreSQL(ModelConfigStore):
|
|||||||
finally:
|
finally:
|
||||||
self._lock.release()
|
self._lock.release()
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
def search_by_path(self, path: Union[str, Path]) -> Optional[ModelConfigBase]:
|
||||||
|
"""
|
||||||
|
Return the model with the indicated path, or None..
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("search_by_path not implemented in storage.sql")
|
||||||
|
@ -50,7 +50,13 @@ from omegaconf import OmegaConf
|
|||||||
from omegaconf.dictconfig import DictConfig
|
from omegaconf.dictconfig import DictConfig
|
||||||
|
|
||||||
from ..config import BaseModelType, ModelConfigBase, ModelConfigFactory, ModelType
|
from ..config import BaseModelType, ModelConfigBase, ModelConfigFactory, ModelType
|
||||||
from .base import CONFIG_FILE_VERSION, DuplicateModelException, ModelConfigStore, UnknownModelException
|
from .base import (
|
||||||
|
CONFIG_FILE_VERSION,
|
||||||
|
ConfigFileVersionMismatchException,
|
||||||
|
DuplicateModelException,
|
||||||
|
ModelConfigStore,
|
||||||
|
UnknownModelException,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ModelConfigStoreYAML(ModelConfigStore):
|
class ModelConfigStoreYAML(ModelConfigStore):
|
||||||
@ -68,9 +74,8 @@ class ModelConfigStoreYAML(ModelConfigStore):
|
|||||||
if not self._filename.exists():
|
if not self._filename.exists():
|
||||||
self._initialize_yaml()
|
self._initialize_yaml()
|
||||||
self._config = OmegaConf.load(self._filename)
|
self._config = OmegaConf.load(self._filename)
|
||||||
assert (
|
if str(self.version) != CONFIG_FILE_VERSION:
|
||||||
str(self.version) == CONFIG_FILE_VERSION
|
raise ConfigFileVersionMismatchException
|
||||||
), f"Model config version {self.version} does not match expected version {CONFIG_FILE_VERSION}"
|
|
||||||
|
|
||||||
def _initialize_yaml(self):
|
def _initialize_yaml(self):
|
||||||
try:
|
try:
|
||||||
@ -239,3 +244,67 @@ class ModelConfigStoreYAML(ModelConfigStore):
|
|||||||
finally:
|
finally:
|
||||||
self._lock.release()
|
self._lock.release()
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
def search_by_path(self, path: Union[str, Path]) -> Optional[ModelConfigBase]:
|
||||||
|
"""
|
||||||
|
Return the model with the indicated path, or None..
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
for key, record in self._config.items():
|
||||||
|
if key == "__metadata__":
|
||||||
|
continue
|
||||||
|
model = ModelConfigFactory.make_config(record, key)
|
||||||
|
if model.path == path:
|
||||||
|
return model
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _load_and_maybe_upgrade(self, config_path: Path) -> DictConfig:
|
||||||
|
config = OmegaConf.load(config_path)
|
||||||
|
version = config["__metadata__"].get("version")
|
||||||
|
if version == CONFIG_FILE_VERSION:
|
||||||
|
return config
|
||||||
|
|
||||||
|
# if we get here we need to upgrade
|
||||||
|
if version == "3.0.0":
|
||||||
|
return self._migrate_format_to_3_2(config, config_path)
|
||||||
|
else:
|
||||||
|
raise Exception(f"{config_path} has unknown version: {version}")
|
||||||
|
|
||||||
|
def _migrate_format_to_3_2(self, old_config: DictConfig, config_path: Path) -> DictConfig:
|
||||||
|
print(
|
||||||
|
f"** Doing one-time conversion of {config_path.as_posix()} to new format. Original will be named {config_path.as_posix() + '.orig'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# avoid circular dependencies
|
||||||
|
from shutil import move
|
||||||
|
|
||||||
|
from ..install import InvalidModelException, ModelInstall
|
||||||
|
|
||||||
|
move(config_path, config_path.as_posix() + ".orig")
|
||||||
|
|
||||||
|
new_store = self.__class__(config_path)
|
||||||
|
installer = ModelInstall(store=new_store)
|
||||||
|
|
||||||
|
for model_key, stanza in old_config.items():
|
||||||
|
if model_key == "__metadata__":
|
||||||
|
assert (
|
||||||
|
stanza["version"] == "3.0.0"
|
||||||
|
), f"This script works on version 3.0.0 yaml files, but your configuration points to a {stanza['version']} version"
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
path = stanza["path"]
|
||||||
|
new_key = installer.register_path(path)
|
||||||
|
model_info = new_store.get_model(new_key)
|
||||||
|
if vae := stanza.get("vae"):
|
||||||
|
model_info.vae = vae
|
||||||
|
if model_config := stanza.get("config"):
|
||||||
|
model_info.config = model_config.as_posix()
|
||||||
|
model_info.description = stanza.get("description")
|
||||||
|
new_store.update_model(new_key, model_info)
|
||||||
|
return OmegaConf.load(config_path)
|
||||||
|
except (DuplicateModelException, InvalidModelException) as e:
|
||||||
|
print(str(e))
|
||||||
|
@ -14,11 +14,8 @@ when new models are downloaded from HuggingFace or Civitae.
|
|||||||
import argparse
|
import argparse
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from omegaconf import OmegaConf
|
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.backend.model_manager import DuplicateModelException, InvalidModelException, ModelInstall
|
from invokeai.backend.model_manager.storage import migrate_models_store
|
||||||
from invokeai.backend.model_manager.storage import get_config_store
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@ -35,34 +32,7 @@ def main():
|
|||||||
|
|
||||||
config = InvokeAIAppConfig.get_config()
|
config = InvokeAIAppConfig.get_config()
|
||||||
config.parse_args(config_args)
|
config.parse_args(config_args)
|
||||||
old_yaml_file = OmegaConf.load(config.model_conf_path)
|
migrate_models_store(config)
|
||||||
|
|
||||||
store = get_config_store(args.outfile)
|
|
||||||
installer = ModelInstall(store=store)
|
|
||||||
|
|
||||||
print(f"Writing 3.2 models configuration into {args.outfile}.")
|
|
||||||
|
|
||||||
for model_key, stanza in old_yaml_file.items():
|
|
||||||
if model_key == "__metadata__":
|
|
||||||
assert (
|
|
||||||
stanza["version"] == "3.0.0"
|
|
||||||
), f"This script works on version 3.0.0 yaml files, but your configuration points to a {stanza['version']} version"
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
path = config.models_path / stanza["path"]
|
|
||||||
new_key = installer.register_path(path)
|
|
||||||
model_info = store.get_model(new_key)
|
|
||||||
if vae := stanza.get("vae"):
|
|
||||||
model_info.vae = (config.models_path / vae).as_posix()
|
|
||||||
if model_config := stanza.get("config"):
|
|
||||||
model_info.config = (config.root_path / model_config).as_posix()
|
|
||||||
model_info.description = stanza.get("description")
|
|
||||||
store.update_model(new_key, model_info)
|
|
||||||
|
|
||||||
print(f"{model_key} => {new_key}")
|
|
||||||
except (DuplicateModelException, InvalidModelException) as e:
|
|
||||||
print(str(e))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Reference in New Issue
Block a user