mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
automatically convert models.yaml to new format
This commit is contained in:
1
invokeai/backend/model_management/README
Normal file
1
invokeai/backend/model_management/README
Normal file
@ -0,0 +1 @@
|
||||
The contents of this directory are deprecated. model_manager.py is here only for reference.
|
@ -1,605 +0,0 @@
|
||||
import json
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Literal, Optional, Union
|
||||
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from diffusers import ConfigMixin, ModelMixin
|
||||
from picklescan.scanner import scan_file_path
|
||||
|
||||
from invokeai.backend.model_management.models.ip_adapter import IPAdapterModelFormat
|
||||
|
||||
from .models import (
|
||||
BaseModelType,
|
||||
InvalidModelException,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
SilenceWarnings,
|
||||
)
|
||||
from .models.base import read_checkpoint_meta
|
||||
from .util import lora_token_vector_length
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelProbeInfo(object):
|
||||
model_type: ModelType
|
||||
base_type: BaseModelType
|
||||
variant_type: ModelVariantType
|
||||
prediction_type: SchedulerPredictionType
|
||||
upcast_attention: bool
|
||||
format: Literal["diffusers", "checkpoint", "lycoris", "olive", "onnx"]
|
||||
image_size: int
|
||||
|
||||
|
||||
class ProbeBase(object):
|
||||
"""forward declaration"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ModelProbe(object):
|
||||
PROBES = {
|
||||
"diffusers": {},
|
||||
"checkpoint": {},
|
||||
"onnx": {},
|
||||
}
|
||||
|
||||
CLASS2TYPE = {
|
||||
"StableDiffusionPipeline": ModelType.Main,
|
||||
"StableDiffusionInpaintPipeline": ModelType.Main,
|
||||
"StableDiffusionXLPipeline": ModelType.Main,
|
||||
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
|
||||
"StableDiffusionXLInpaintPipeline": ModelType.Main,
|
||||
"AutoencoderKL": ModelType.Vae,
|
||||
"AutoencoderTiny": ModelType.Vae,
|
||||
"ControlNetModel": ModelType.ControlNet,
|
||||
"CLIPVisionModelWithProjection": ModelType.CLIPVision,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def register_probe(
|
||||
cls, format: Literal["diffusers", "checkpoint", "onnx"], model_type: ModelType, probe_class: ProbeBase
|
||||
):
|
||||
cls.PROBES[format][model_type] = probe_class
|
||||
|
||||
@classmethod
|
||||
def heuristic_probe(
|
||||
cls,
|
||||
model: Union[Dict, ModelMixin, Path],
|
||||
prediction_type_helper: Callable[[Path], SchedulerPredictionType] = None,
|
||||
) -> ModelProbeInfo:
|
||||
if isinstance(model, Path):
|
||||
return cls.probe(model_path=model, prediction_type_helper=prediction_type_helper)
|
||||
elif isinstance(model, (dict, ModelMixin, ConfigMixin)):
|
||||
return cls.probe(model_path=None, model=model, prediction_type_helper=prediction_type_helper)
|
||||
else:
|
||||
raise InvalidModelException("model parameter {model} is neither a Path, nor a model")
|
||||
|
||||
@classmethod
|
||||
def probe(
|
||||
cls,
|
||||
model_path: Path,
|
||||
model: Optional[Union[Dict, ModelMixin]] = None,
|
||||
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
||||
) -> Optional[ModelProbeInfo]:
|
||||
"""
|
||||
Probe the model at model_path and return sufficient information about it
|
||||
to place it somewhere in the models directory hierarchy. If the model is
|
||||
already loaded into memory, you may provide it as model in order to avoid
|
||||
opening it a second time. The prediction_type_helper callable is a function that receives
|
||||
the path to the model and returns the BaseModelType. It is called to distinguish
|
||||
between V2-Base and V2-768 SD models.
|
||||
"""
|
||||
if model_path:
|
||||
format_type = "diffusers" if model_path.is_dir() else "checkpoint"
|
||||
else:
|
||||
format_type = "diffusers" if isinstance(model, (ConfigMixin, ModelMixin)) else "checkpoint"
|
||||
model_info = None
|
||||
try:
|
||||
model_type = (
|
||||
cls.get_model_type_from_folder(model_path, model)
|
||||
if format_type == "diffusers"
|
||||
else cls.get_model_type_from_checkpoint(model_path, model)
|
||||
)
|
||||
format_type = "onnx" if model_type == ModelType.ONNX else format_type
|
||||
probe_class = cls.PROBES[format_type].get(model_type)
|
||||
if not probe_class:
|
||||
return None
|
||||
probe = probe_class(model_path, model, prediction_type_helper)
|
||||
base_type = probe.get_base_type()
|
||||
variant_type = probe.get_variant_type()
|
||||
prediction_type = probe.get_scheduler_prediction_type()
|
||||
format = probe.get_format()
|
||||
model_info = ModelProbeInfo(
|
||||
model_type=model_type,
|
||||
base_type=base_type,
|
||||
variant_type=variant_type,
|
||||
prediction_type=prediction_type,
|
||||
upcast_attention=(
|
||||
base_type == BaseModelType.StableDiffusion2
|
||||
and prediction_type == SchedulerPredictionType.VPrediction
|
||||
),
|
||||
format=format,
|
||||
image_size=(
|
||||
1024
|
||||
if (base_type in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner})
|
||||
else (
|
||||
768
|
||||
if (
|
||||
base_type == BaseModelType.StableDiffusion2
|
||||
and prediction_type == SchedulerPredictionType.VPrediction
|
||||
)
|
||||
else 512
|
||||
)
|
||||
),
|
||||
)
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
return model_info
|
||||
|
||||
@classmethod
|
||||
def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: dict) -> ModelType:
|
||||
if model_path.suffix not in (".bin", ".pt", ".ckpt", ".safetensors", ".pth"):
|
||||
return None
|
||||
|
||||
if model_path.name == "learned_embeds.bin":
|
||||
return ModelType.TextualInversion
|
||||
|
||||
ckpt = checkpoint if checkpoint else read_checkpoint_meta(model_path, scan=True)
|
||||
ckpt = ckpt.get("state_dict", ckpt)
|
||||
|
||||
for key in ckpt.keys():
|
||||
if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}):
|
||||
return ModelType.Main
|
||||
elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}):
|
||||
return ModelType.Vae
|
||||
elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
|
||||
return ModelType.Lora
|
||||
elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}):
|
||||
return ModelType.Lora
|
||||
elif any(key.startswith(v) for v in {"control_model", "input_blocks"}):
|
||||
return ModelType.ControlNet
|
||||
elif key in {"emb_params", "string_to_param"}:
|
||||
return ModelType.TextualInversion
|
||||
|
||||
else:
|
||||
# diffusers-ti
|
||||
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
|
||||
return ModelType.TextualInversion
|
||||
|
||||
raise InvalidModelException(f"Unable to determine model type for {model_path}")
|
||||
|
||||
@classmethod
|
||||
def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin) -> ModelType:
|
||||
"""
|
||||
Get the model type of a hugging-face style folder.
|
||||
"""
|
||||
class_name = None
|
||||
error_hint = None
|
||||
if model:
|
||||
class_name = model.__class__.__name__
|
||||
else:
|
||||
if (folder_path / "unet/model.onnx").exists():
|
||||
return ModelType.ONNX
|
||||
if (folder_path / "learned_embeds.bin").exists():
|
||||
return ModelType.TextualInversion
|
||||
if (folder_path / "pytorch_lora_weights.bin").exists():
|
||||
return ModelType.Lora
|
||||
if (folder_path / "image_encoder.txt").exists():
|
||||
return ModelType.IPAdapter
|
||||
|
||||
i = folder_path / "model_index.json"
|
||||
c = folder_path / "config.json"
|
||||
config_path = i if i.exists() else c if c.exists() else None
|
||||
|
||||
if config_path:
|
||||
with open(config_path, "r") as file:
|
||||
conf = json.load(file)
|
||||
if "_class_name" in conf:
|
||||
class_name = conf["_class_name"]
|
||||
elif "architectures" in conf:
|
||||
class_name = conf["architectures"][0]
|
||||
else:
|
||||
class_name = None
|
||||
else:
|
||||
error_hint = f"No model_index.json or config.json found in {folder_path}."
|
||||
|
||||
if class_name and (type := cls.CLASS2TYPE.get(class_name)):
|
||||
return type
|
||||
else:
|
||||
error_hint = f"class {class_name} is not one of the supported classes [{', '.join(cls.CLASS2TYPE.keys())}]"
|
||||
|
||||
# give up
|
||||
raise InvalidModelException(
|
||||
f"Unable to determine model type for {folder_path}" + (f"; {error_hint}" if error_hint else "")
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _scan_and_load_checkpoint(cls, model_path: Path) -> dict:
|
||||
with SilenceWarnings():
|
||||
if model_path.suffix.endswith((".ckpt", ".pt", ".bin")):
|
||||
cls._scan_model(model_path, model_path)
|
||||
return torch.load(model_path)
|
||||
else:
|
||||
return safetensors.torch.load_file(model_path)
|
||||
|
||||
@classmethod
|
||||
def _scan_model(cls, model_name, checkpoint):
|
||||
"""
|
||||
Apply picklescanner to the indicated checkpoint and issue a warning
|
||||
and option to exit if an infected file is identified.
|
||||
"""
|
||||
# scan model
|
||||
scan_result = scan_file_path(checkpoint)
|
||||
if scan_result.infected_files != 0:
|
||||
raise "The model {model_name} is potentially infected by malware. Aborting import."
|
||||
|
||||
|
||||
# ##################################################3
|
||||
# Checkpoint probing
|
||||
# ##################################################3
|
||||
class ProbeBase(object):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
pass
|
||||
|
||||
def get_variant_type(self) -> ModelVariantType:
|
||||
pass
|
||||
|
||||
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
|
||||
pass
|
||||
|
||||
def get_format(self) -> str:
|
||||
pass
|
||||
|
||||
|
||||
class CheckpointProbeBase(ProbeBase):
|
||||
def __init__(
|
||||
self, checkpoint_path: Path, checkpoint: dict, helper: Callable[[Path], SchedulerPredictionType] = None
|
||||
) -> BaseModelType:
|
||||
self.checkpoint = checkpoint or ModelProbe._scan_and_load_checkpoint(checkpoint_path)
|
||||
self.checkpoint_path = checkpoint_path
|
||||
self.helper = helper
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
pass
|
||||
|
||||
def get_format(self) -> str:
|
||||
return "checkpoint"
|
||||
|
||||
def get_variant_type(self) -> ModelVariantType:
|
||||
model_type = ModelProbe.get_model_type_from_checkpoint(self.checkpoint_path, self.checkpoint)
|
||||
if model_type != ModelType.Main:
|
||||
return ModelVariantType.Normal
|
||||
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
|
||||
in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
||||
if in_channels == 9:
|
||||
return ModelVariantType.Inpaint
|
||||
elif in_channels == 5:
|
||||
return ModelVariantType.Depth
|
||||
elif in_channels == 4:
|
||||
return ModelVariantType.Normal
|
||||
else:
|
||||
raise InvalidModelException(
|
||||
f"Cannot determine variant type (in_channels={in_channels}) at {self.checkpoint_path}"
|
||||
)
|
||||
|
||||
|
||||
class PipelineCheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
checkpoint = self.checkpoint
|
||||
state_dict = self.checkpoint.get("state_dict") or checkpoint
|
||||
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
key_name = "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight"
|
||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 2048:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280:
|
||||
return BaseModelType.StableDiffusionXLRefiner
|
||||
else:
|
||||
raise InvalidModelException("Cannot determine base type")
|
||||
|
||||
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
|
||||
type = self.get_base_type()
|
||||
if type == BaseModelType.StableDiffusion1:
|
||||
return SchedulerPredictionType.Epsilon
|
||||
checkpoint = self.checkpoint
|
||||
state_dict = self.checkpoint.get("state_dict") or checkpoint
|
||||
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
|
||||
if "global_step" in checkpoint:
|
||||
if checkpoint["global_step"] == 220000:
|
||||
return SchedulerPredictionType.Epsilon
|
||||
elif checkpoint["global_step"] == 110000:
|
||||
return SchedulerPredictionType.VPrediction
|
||||
if (
|
||||
self.checkpoint_path and self.helper and not self.checkpoint_path.with_suffix(".yaml").exists()
|
||||
): # if a .yaml config file exists, then this step not needed
|
||||
return self.helper(self.checkpoint_path)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class VaeCheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
# I can't find any standalone 2.X VAEs to test with!
|
||||
return BaseModelType.StableDiffusion1
|
||||
|
||||
|
||||
class LoRACheckpointProbe(CheckpointProbeBase):
|
||||
def get_format(self) -> str:
|
||||
return "lycoris"
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
checkpoint = self.checkpoint
|
||||
token_vector_length = lora_token_vector_length(checkpoint)
|
||||
|
||||
if token_vector_length == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif token_vector_length == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
elif token_vector_length == 2048:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
else:
|
||||
raise InvalidModelException(f"Unknown LoRA type: {self.checkpoint_path}")
|
||||
|
||||
|
||||
class TextualInversionCheckpointProbe(CheckpointProbeBase):
|
||||
def get_format(self) -> str:
|
||||
return None
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
checkpoint = self.checkpoint
|
||||
if "string_to_token" in checkpoint:
|
||||
token_dim = list(checkpoint["string_to_param"].values())[0].shape[-1]
|
||||
elif "emb_params" in checkpoint:
|
||||
token_dim = checkpoint["emb_params"].shape[-1]
|
||||
else:
|
||||
token_dim = list(checkpoint.values())[0].shape[0]
|
||||
if token_dim == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif token_dim == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class ControlNetCheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
checkpoint = self.checkpoint
|
||||
for key_name in (
|
||||
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
|
||||
"input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
|
||||
):
|
||||
if key_name not in checkpoint:
|
||||
continue
|
||||
if checkpoint[key_name].shape[-1] == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif checkpoint[key_name].shape[-1] == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
elif self.checkpoint_path and self.helper:
|
||||
return self.helper(self.checkpoint_path)
|
||||
raise InvalidModelException("Unable to determine base type for {self.checkpoint_path}")
|
||||
|
||||
|
||||
class IPAdapterCheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class CLIPVisionCheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
########################################################
|
||||
# classes for probing folders
|
||||
#######################################################
|
||||
class FolderProbeBase(ProbeBase):
|
||||
def __init__(self, folder_path: Path, model: ModelMixin = None, helper: Callable = None): # not used
|
||||
self.model = model
|
||||
self.folder_path = folder_path
|
||||
|
||||
def get_variant_type(self) -> ModelVariantType:
|
||||
return ModelVariantType.Normal
|
||||
|
||||
def get_format(self) -> str:
|
||||
return "diffusers"
|
||||
|
||||
|
||||
class PipelineFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
if self.model:
|
||||
unet_conf = self.model.unet.config
|
||||
else:
|
||||
with open(self.folder_path / "unet" / "config.json", "r") as file:
|
||||
unet_conf = json.load(file)
|
||||
if unet_conf["cross_attention_dim"] == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif unet_conf["cross_attention_dim"] == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
elif unet_conf["cross_attention_dim"] == 1280:
|
||||
return BaseModelType.StableDiffusionXLRefiner
|
||||
elif unet_conf["cross_attention_dim"] == 2048:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
else:
|
||||
raise InvalidModelException(f"Unknown base model for {self.folder_path}")
|
||||
|
||||
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
|
||||
if self.model:
|
||||
scheduler_conf = self.model.scheduler.config
|
||||
else:
|
||||
with open(self.folder_path / "scheduler" / "scheduler_config.json", "r") as file:
|
||||
scheduler_conf = json.load(file)
|
||||
if scheduler_conf["prediction_type"] == "v_prediction":
|
||||
return SchedulerPredictionType.VPrediction
|
||||
elif scheduler_conf["prediction_type"] == "epsilon":
|
||||
return SchedulerPredictionType.Epsilon
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_variant_type(self) -> ModelVariantType:
|
||||
# This only works for pipelines! Any kind of
|
||||
# exception results in our returning the
|
||||
# "normal" variant type
|
||||
try:
|
||||
if self.model:
|
||||
conf = self.model.unet.config
|
||||
else:
|
||||
config_file = self.folder_path / "unet" / "config.json"
|
||||
with open(config_file, "r") as file:
|
||||
conf = json.load(file)
|
||||
|
||||
in_channels = conf["in_channels"]
|
||||
if in_channels == 9:
|
||||
return ModelVariantType.Inpaint
|
||||
elif in_channels == 5:
|
||||
return ModelVariantType.Depth
|
||||
elif in_channels == 4:
|
||||
return ModelVariantType.Normal
|
||||
except Exception:
|
||||
pass
|
||||
return ModelVariantType.Normal
|
||||
|
||||
|
||||
class VaeFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
if self._config_looks_like_sdxl():
|
||||
return BaseModelType.StableDiffusionXL
|
||||
elif self._name_looks_like_sdxl():
|
||||
# but SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down
|
||||
# by a factor of 8), we can't necessarily tell them apart by config hyperparameters.
|
||||
return BaseModelType.StableDiffusionXL
|
||||
else:
|
||||
return BaseModelType.StableDiffusion1
|
||||
|
||||
def _config_looks_like_sdxl(self) -> bool:
|
||||
# config values that distinguish Stability's SD 1.x VAE from their SDXL VAE.
|
||||
config_file = self.folder_path / "config.json"
|
||||
if not config_file.exists():
|
||||
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
|
||||
with open(config_file, "r") as file:
|
||||
config = json.load(file)
|
||||
return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024]
|
||||
|
||||
def _name_looks_like_sdxl(self) -> bool:
|
||||
return bool(re.search(r"xl\b", self._guess_name(), re.IGNORECASE))
|
||||
|
||||
def _guess_name(self) -> str:
|
||||
name = self.folder_path.name
|
||||
if name == "vae":
|
||||
name = self.folder_path.parent.name
|
||||
return name
|
||||
|
||||
|
||||
class TextualInversionFolderProbe(FolderProbeBase):
|
||||
def get_format(self) -> str:
|
||||
return None
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
path = self.folder_path / "learned_embeds.bin"
|
||||
if not path.exists():
|
||||
return None
|
||||
checkpoint = ModelProbe._scan_and_load_checkpoint(path)
|
||||
return TextualInversionCheckpointProbe(None, checkpoint=checkpoint).get_base_type()
|
||||
|
||||
|
||||
class ONNXFolderProbe(FolderProbeBase):
|
||||
def get_format(self) -> str:
|
||||
return "onnx"
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
return BaseModelType.StableDiffusion1
|
||||
|
||||
def get_variant_type(self) -> ModelVariantType:
|
||||
return ModelVariantType.Normal
|
||||
|
||||
|
||||
class ControlNetFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
config_file = self.folder_path / "config.json"
|
||||
if not config_file.exists():
|
||||
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
|
||||
with open(config_file, "r") as file:
|
||||
config = json.load(file)
|
||||
# no obvious way to distinguish between sd2-base and sd2-768
|
||||
dimension = config["cross_attention_dim"]
|
||||
base_model = (
|
||||
BaseModelType.StableDiffusion1
|
||||
if dimension == 768
|
||||
else (
|
||||
BaseModelType.StableDiffusion2
|
||||
if dimension == 1024
|
||||
else BaseModelType.StableDiffusionXL
|
||||
if dimension == 2048
|
||||
else None
|
||||
)
|
||||
)
|
||||
if not base_model:
|
||||
raise InvalidModelException(f"Unable to determine model base for {self.folder_path}")
|
||||
return base_model
|
||||
|
||||
|
||||
class LoRAFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
model_file = None
|
||||
for suffix in ["safetensors", "bin"]:
|
||||
base_file = self.folder_path / f"pytorch_lora_weights.{suffix}"
|
||||
if base_file.exists():
|
||||
model_file = base_file
|
||||
break
|
||||
if not model_file:
|
||||
raise InvalidModelException("Unknown LoRA format encountered")
|
||||
return LoRACheckpointProbe(model_file, None).get_base_type()
|
||||
|
||||
|
||||
class IPAdapterFolderProbe(FolderProbeBase):
|
||||
def get_format(self) -> str:
|
||||
return IPAdapterModelFormat.InvokeAI.value
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
model_file = self.folder_path / "ip_adapter.bin"
|
||||
if not model_file.exists():
|
||||
raise InvalidModelException("Unknown IP-Adapter model format.")
|
||||
|
||||
state_dict = torch.load(model_file, map_location="cpu")
|
||||
cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1]
|
||||
if cross_attention_dim == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif cross_attention_dim == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
elif cross_attention_dim == 2048:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
else:
|
||||
raise InvalidModelException(f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}.")
|
||||
|
||||
|
||||
class CLIPVisionFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
return BaseModelType.Any
|
||||
|
||||
|
||||
############## register probe classes ######
|
||||
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.Lora, LoRAFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
|
||||
|
||||
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe)
|
||||
|
||||
ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)
|
@ -1,108 +0,0 @@
|
||||
# Copyright 2023, Lincoln D. Stein and the InvokeAI Team
|
||||
"""
|
||||
Abstract base class for recursive directory search for models.
|
||||
"""
|
||||
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import List, Set, types
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
|
||||
class ModelSearch(ABC):
|
||||
def __init__(self, directories: List[Path], logger: types.ModuleType = logger):
|
||||
"""
|
||||
Initialize a recursive model directory search.
|
||||
:param directories: List of directory Paths to recurse through
|
||||
:param logger: Logger to use
|
||||
"""
|
||||
self.directories = directories
|
||||
self.logger = logger
|
||||
self._items_scanned = 0
|
||||
self._models_found = 0
|
||||
self._scanned_dirs = set()
|
||||
self._scanned_paths = set()
|
||||
self._pruned_paths = set()
|
||||
|
||||
@abstractmethod
|
||||
def on_search_started(self):
|
||||
"""
|
||||
Called before the scan starts.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def on_model_found(self, model: Path):
|
||||
"""
|
||||
Process a found model. Raise an exception if something goes wrong.
|
||||
:param model: Model to process - could be a directory or checkpoint.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def on_search_completed(self):
|
||||
"""
|
||||
Perform some activity when the scan is completed. May use instance
|
||||
variables, items_scanned and models_found
|
||||
"""
|
||||
pass
|
||||
|
||||
def search(self):
|
||||
self.on_search_started()
|
||||
for dir in self.directories:
|
||||
self.walk_directory(dir)
|
||||
self.on_search_completed()
|
||||
|
||||
def walk_directory(self, path: Path):
|
||||
for root, dirs, files in os.walk(path, followlinks=True):
|
||||
if str(Path(root).name).startswith("."):
|
||||
self._pruned_paths.add(root)
|
||||
if any([Path(root).is_relative_to(x) for x in self._pruned_paths]):
|
||||
continue
|
||||
|
||||
self._items_scanned += len(dirs) + len(files)
|
||||
for d in dirs:
|
||||
path = Path(root) / d
|
||||
if path in self._scanned_paths or path.parent in self._scanned_dirs:
|
||||
self._scanned_dirs.add(path)
|
||||
continue
|
||||
if any(
|
||||
[
|
||||
(path / x).exists()
|
||||
for x in ["config.json", "model_index.json", "learned_embeds.bin", "pytorch_lora_weights.bin"]
|
||||
]
|
||||
):
|
||||
try:
|
||||
self.on_model_found(path)
|
||||
self._models_found += 1
|
||||
self._scanned_dirs.add(path)
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to process '{path}': {e}")
|
||||
|
||||
for f in files:
|
||||
path = Path(root) / f
|
||||
if path.parent in self._scanned_dirs:
|
||||
continue
|
||||
if path.suffix in {".ckpt", ".bin", ".pth", ".safetensors", ".pt"}:
|
||||
try:
|
||||
self.on_model_found(path)
|
||||
self._models_found += 1
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to process '{path}': {e}")
|
||||
|
||||
|
||||
class FindModels(ModelSearch):
|
||||
def on_search_started(self):
|
||||
self.models_found: Set[Path] = set()
|
||||
|
||||
def on_model_found(self, model: Path):
|
||||
self.models_found.add(model)
|
||||
|
||||
def on_search_completed(self):
|
||||
pass
|
||||
|
||||
def list_models(self) -> List[Path]:
|
||||
self.search()
|
||||
return list(self.models_found)
|
@ -1,75 +0,0 @@
|
||||
# Copyright (c) 2023 The InvokeAI Development Team
|
||||
"""Utilities used by the Model Manager"""
|
||||
|
||||
|
||||
def lora_token_vector_length(checkpoint: dict) -> int:
|
||||
"""
|
||||
Given a checkpoint in memory, return the lora token vector length
|
||||
|
||||
:param checkpoint: The checkpoint
|
||||
"""
|
||||
|
||||
def _get_shape_1(key, tensor, checkpoint):
|
||||
lora_token_vector_length = None
|
||||
|
||||
if "." not in key:
|
||||
return lora_token_vector_length # wrong key format
|
||||
model_key, lora_key = key.split(".", 1)
|
||||
|
||||
# check lora/locon
|
||||
if lora_key == "lora_down.weight":
|
||||
lora_token_vector_length = tensor.shape[1]
|
||||
|
||||
# check loha (don't worry about hada_t1/hada_t2 as it used only in 4d shapes)
|
||||
elif lora_key in ["hada_w1_b", "hada_w2_b"]:
|
||||
lora_token_vector_length = tensor.shape[1]
|
||||
|
||||
# check lokr (don't worry about lokr_t2 as it used only in 4d shapes)
|
||||
elif "lokr_" in lora_key:
|
||||
if model_key + ".lokr_w1" in checkpoint:
|
||||
_lokr_w1 = checkpoint[model_key + ".lokr_w1"]
|
||||
elif model_key + "lokr_w1_b" in checkpoint:
|
||||
_lokr_w1 = checkpoint[model_key + ".lokr_w1_b"]
|
||||
else:
|
||||
return lora_token_vector_length # unknown format
|
||||
|
||||
if model_key + ".lokr_w2" in checkpoint:
|
||||
_lokr_w2 = checkpoint[model_key + ".lokr_w2"]
|
||||
elif model_key + "lokr_w2_b" in checkpoint:
|
||||
_lokr_w2 = checkpoint[model_key + ".lokr_w2_b"]
|
||||
else:
|
||||
return lora_token_vector_length # unknown format
|
||||
|
||||
lora_token_vector_length = _lokr_w1.shape[1] * _lokr_w2.shape[1]
|
||||
|
||||
elif lora_key == "diff":
|
||||
lora_token_vector_length = tensor.shape[1]
|
||||
|
||||
# ia3 can be detected only by shape[0] in text encoder
|
||||
elif lora_key == "weight" and "lora_unet_" not in model_key:
|
||||
lora_token_vector_length = tensor.shape[0]
|
||||
|
||||
return lora_token_vector_length
|
||||
|
||||
lora_token_vector_length = None
|
||||
lora_te1_length = None
|
||||
lora_te2_length = None
|
||||
for key, tensor in checkpoint.items():
|
||||
if key.startswith("lora_unet_") and ("_attn2_to_k." in key or "_attn2_to_v." in key):
|
||||
lora_token_vector_length = _get_shape_1(key, tensor, checkpoint)
|
||||
elif key.startswith("lora_te") and "_self_attn_" in key:
|
||||
tmp_length = _get_shape_1(key, tensor, checkpoint)
|
||||
if key.startswith("lora_te_"):
|
||||
lora_token_vector_length = tmp_length
|
||||
elif key.startswith("lora_te1_"):
|
||||
lora_te1_length = tmp_length
|
||||
elif key.startswith("lora_te2_"):
|
||||
lora_te2_length = tmp_length
|
||||
|
||||
if lora_te1_length is not None and lora_te2_length is not None:
|
||||
lora_token_vector_length = lora_te1_length + lora_te2_length
|
||||
|
||||
if lora_token_vector_length is not None:
|
||||
break
|
||||
|
||||
return lora_token_vector_length
|
@ -17,7 +17,7 @@ from .config import BaseModelType, ModelConfigBase, ModelType, SubModelType
|
||||
from .download import DownloadEventHandler
|
||||
from .install import ModelInstall, ModelInstallBase
|
||||
from .models import MODEL_CLASSES, InvalidModelException, ModelBase
|
||||
from .storage import ModelConfigStore, get_config_store
|
||||
from .storage import ConfigFileVersionMismatchException, ModelConfigStore, get_config_store, migrate_models_store
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -138,7 +138,12 @@ class ModelLoad(ModelLoadBase):
|
||||
models_file = config.model_conf_path
|
||||
else:
|
||||
models_file = config.root_path / "configs/models3.yaml"
|
||||
store = get_config_store(models_file)
|
||||
try:
|
||||
store = get_config_store(models_file)
|
||||
except ConfigFileVersionMismatchException:
|
||||
migrate_models_store(config)
|
||||
store = get_config_store(models_file)
|
||||
|
||||
if not store:
|
||||
raise ValueError(f"Invalid model configuration file: {models_file}")
|
||||
|
||||
|
@ -7,6 +7,7 @@ its base type, model type, format and variant.
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
@ -493,20 +494,33 @@ class PipelineFolderProbe(FolderProbeBase):
|
||||
|
||||
|
||||
class VaeFolderProbe(FolderProbeBase):
|
||||
"""Probe a diffusers-style VAE model."""
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
"""Return the BaseModelType for a diffusers-style VAE."""
|
||||
if self._config_looks_like_sdxl():
|
||||
return BaseModelType.StableDiffusionXL
|
||||
elif self._name_looks_like_sdxl():
|
||||
# but SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down
|
||||
# by a factor of 8), we can't necessarily tell them apart by config hyperparameters.
|
||||
return BaseModelType.StableDiffusionXL
|
||||
else:
|
||||
return BaseModelType.StableDiffusion1
|
||||
|
||||
def _config_looks_like_sdxl(self) -> bool:
|
||||
# config values that distinguish Stability's SD 1.x VAE from their SDXL VAE.
|
||||
config_file = self.folder_path / "config.json"
|
||||
if not config_file.exists():
|
||||
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
|
||||
with open(config_file, "r") as file:
|
||||
config = json.load(file)
|
||||
return (
|
||||
BaseModelType.StableDiffusionXL
|
||||
if config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024]
|
||||
else BaseModelType.StableDiffusion1
|
||||
)
|
||||
return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024]
|
||||
|
||||
def _name_looks_like_sdxl(self) -> bool:
|
||||
return bool(re.search(r"xl\b", self._guess_name(), re.IGNORECASE))
|
||||
|
||||
def _guess_name(self) -> str:
|
||||
name = self.folder_path.name
|
||||
if name == "vae":
|
||||
name = self.folder_path.parent.name
|
||||
return name
|
||||
|
||||
|
||||
class TextualInversionFolderProbe(FolderProbeBase):
|
||||
|
@ -168,7 +168,13 @@ class ModelSearch(ModelSearchBase):
|
||||
if any(
|
||||
[
|
||||
(path / x).exists()
|
||||
for x in ["config.json", "model_index.json", "learned_embeds.bin", "pytorch_lora_weights.bin"]
|
||||
for x in [
|
||||
"config.json",
|
||||
"model_index.json",
|
||||
"learned_embeds.bin",
|
||||
"pytorch_lora_weights.bin",
|
||||
"image_encoder.txt",
|
||||
]
|
||||
]
|
||||
):
|
||||
self._scanned_dirs.add(path)
|
||||
|
@ -3,7 +3,13 @@ Initialization file for invokeai.backend.model_manager.storage
|
||||
"""
|
||||
import pathlib
|
||||
|
||||
from .base import DuplicateModelException, ModelConfigStore, UnknownModelException # noqa F401
|
||||
from .base import ( # noqa F401
|
||||
ConfigFileVersionMismatchException,
|
||||
DuplicateModelException,
|
||||
ModelConfigStore,
|
||||
UnknownModelException,
|
||||
)
|
||||
from .migrate import migrate_models_store # noqa F401
|
||||
from .sql import ModelConfigStoreSQL # noqa F401
|
||||
from .yaml import ModelConfigStoreYAML # noqa F401
|
||||
|
||||
|
@ -4,6 +4,7 @@ Abstract base class for storing and retrieving model configuration records.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Set, Union
|
||||
|
||||
from ..config import BaseModelType, ModelConfigBase, ModelType
|
||||
@ -24,6 +25,10 @@ class UnknownModelException(Exception):
|
||||
"""Raised on an attempt to fetch or delete a model with a nonexistent key."""
|
||||
|
||||
|
||||
class ConfigFileVersionMismatchException(Exception):
|
||||
"""Raised on an attempt to open a config with an incompatible version."""
|
||||
|
||||
|
||||
class ModelConfigStore(ABC):
|
||||
"""Abstract base class for storage and retrieval of model configs."""
|
||||
|
||||
@ -99,6 +104,16 @@ class ModelConfigStore(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search_by_path(
|
||||
self,
|
||||
path: Union[str, Path],
|
||||
) -> Optional[ModelConfigBase]:
|
||||
"""
|
||||
Return the model having the indicated path.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search_by_name(
|
||||
self,
|
||||
|
61
invokeai/backend/model_manager/storage/migrate.py
Normal file
61
invokeai/backend/model_manager/storage/migrate.py
Normal file
@ -0,0 +1,61 @@
|
||||
# Copyright (c) 2023 The InvokeAI Development Team
|
||||
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
from .base import CONFIG_FILE_VERSION
|
||||
|
||||
|
||||
def migrate_models_store(config: InvokeAIAppConfig):
|
||||
# avoid circular import
|
||||
from invokeai.backend.model_manager import DuplicateModelException, InvalidModelException, ModelInstall
|
||||
from invokeai.backend.model_manager.storage import get_config_store
|
||||
|
||||
app_config = InvokeAIAppConfig.get_config()
|
||||
logger = InvokeAILogger.getLogger()
|
||||
old_file: Path = app_config.model_conf_path
|
||||
new_file: Path = old_file.with_name("models3_2.yaml")
|
||||
|
||||
old_conf = OmegaConf.load(old_file)
|
||||
store = get_config_store(new_file)
|
||||
installer = ModelInstall(store=store)
|
||||
logger.info(f"Migrating old models file at {old_file} to new {CONFIG_FILE_VERSION} format")
|
||||
|
||||
for model_key, stanza in old_conf.items():
|
||||
if model_key == "__metadata__":
|
||||
assert (
|
||||
stanza["version"] == "3.0.0"
|
||||
), f"This script works on version 3.0.0 yaml files, but your configuration points to a {stanza['version']} version"
|
||||
continue
|
||||
|
||||
base_type, model_type, model_name = model_key.split("/")
|
||||
|
||||
try:
|
||||
path = app_config.models_path / stanza["path"]
|
||||
new_key = installer.register_path(path)
|
||||
except DuplicateModelException:
|
||||
# if model already installed, then we just update its info
|
||||
models = store.search_by_name(model_name=model_name, base_model=base_type, model_type=model_type)
|
||||
if len(models) != 1:
|
||||
continue
|
||||
new_key = models[0].key
|
||||
except Exception as excp:
|
||||
print(str(excp))
|
||||
|
||||
model_info = store.get_model(new_key)
|
||||
if vae := stanza.get("vae"):
|
||||
model_info.vae = (app_config.models_path / vae).as_posix()
|
||||
if model_config := stanza.get("config"):
|
||||
model_info.config = (app_config.root_path / model_config).as_posix()
|
||||
model_info.description = stanza.get("description")
|
||||
store.update_model(new_key, model_info)
|
||||
store.update_model(new_key, model_info)
|
||||
|
||||
logger.info(f"Original version of models config file saved as {str(old_file) + '.orig'}")
|
||||
shutil.move(old_file, str(old_file) + ".orig")
|
||||
shutil.move(new_file, old_file)
|
@ -477,3 +477,9 @@ class ModelConfigStoreSQL(ModelConfigStore):
|
||||
finally:
|
||||
self._lock.release()
|
||||
return results
|
||||
|
||||
def search_by_path(self, path: Union[str, Path]) -> Optional[ModelConfigBase]:
|
||||
"""
|
||||
Return the model with the indicated path, or None..
|
||||
"""
|
||||
raise NotImplementedError("search_by_path not implemented in storage.sql")
|
||||
|
@ -50,7 +50,13 @@ from omegaconf import OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
|
||||
from ..config import BaseModelType, ModelConfigBase, ModelConfigFactory, ModelType
|
||||
from .base import CONFIG_FILE_VERSION, DuplicateModelException, ModelConfigStore, UnknownModelException
|
||||
from .base import (
|
||||
CONFIG_FILE_VERSION,
|
||||
ConfigFileVersionMismatchException,
|
||||
DuplicateModelException,
|
||||
ModelConfigStore,
|
||||
UnknownModelException,
|
||||
)
|
||||
|
||||
|
||||
class ModelConfigStoreYAML(ModelConfigStore):
|
||||
@ -68,9 +74,8 @@ class ModelConfigStoreYAML(ModelConfigStore):
|
||||
if not self._filename.exists():
|
||||
self._initialize_yaml()
|
||||
self._config = OmegaConf.load(self._filename)
|
||||
assert (
|
||||
str(self.version) == CONFIG_FILE_VERSION
|
||||
), f"Model config version {self.version} does not match expected version {CONFIG_FILE_VERSION}"
|
||||
if str(self.version) != CONFIG_FILE_VERSION:
|
||||
raise ConfigFileVersionMismatchException
|
||||
|
||||
def _initialize_yaml(self):
|
||||
try:
|
||||
@ -239,3 +244,67 @@ class ModelConfigStoreYAML(ModelConfigStore):
|
||||
finally:
|
||||
self._lock.release()
|
||||
return results
|
||||
|
||||
def search_by_path(self, path: Union[str, Path]) -> Optional[ModelConfigBase]:
|
||||
"""
|
||||
Return the model with the indicated path, or None..
|
||||
"""
|
||||
try:
|
||||
self._lock.acquire()
|
||||
for key, record in self._config.items():
|
||||
if key == "__metadata__":
|
||||
continue
|
||||
model = ModelConfigFactory.make_config(record, key)
|
||||
if model.path == path:
|
||||
return model
|
||||
finally:
|
||||
self._lock.release()
|
||||
return None
|
||||
|
||||
def _load_and_maybe_upgrade(self, config_path: Path) -> DictConfig:
|
||||
config = OmegaConf.load(config_path)
|
||||
version = config["__metadata__"].get("version")
|
||||
if version == CONFIG_FILE_VERSION:
|
||||
return config
|
||||
|
||||
# if we get here we need to upgrade
|
||||
if version == "3.0.0":
|
||||
return self._migrate_format_to_3_2(config, config_path)
|
||||
else:
|
||||
raise Exception(f"{config_path} has unknown version: {version}")
|
||||
|
||||
def _migrate_format_to_3_2(self, old_config: DictConfig, config_path: Path) -> DictConfig:
|
||||
print(
|
||||
f"** Doing one-time conversion of {config_path.as_posix()} to new format. Original will be named {config_path.as_posix() + '.orig'}"
|
||||
)
|
||||
|
||||
# avoid circular dependencies
|
||||
from shutil import move
|
||||
|
||||
from ..install import InvalidModelException, ModelInstall
|
||||
|
||||
move(config_path, config_path.as_posix() + ".orig")
|
||||
|
||||
new_store = self.__class__(config_path)
|
||||
installer = ModelInstall(store=new_store)
|
||||
|
||||
for model_key, stanza in old_config.items():
|
||||
if model_key == "__metadata__":
|
||||
assert (
|
||||
stanza["version"] == "3.0.0"
|
||||
), f"This script works on version 3.0.0 yaml files, but your configuration points to a {stanza['version']} version"
|
||||
continue
|
||||
|
||||
try:
|
||||
path = stanza["path"]
|
||||
new_key = installer.register_path(path)
|
||||
model_info = new_store.get_model(new_key)
|
||||
if vae := stanza.get("vae"):
|
||||
model_info.vae = vae
|
||||
if model_config := stanza.get("config"):
|
||||
model_info.config = model_config.as_posix()
|
||||
model_info.description = stanza.get("description")
|
||||
new_store.update_model(new_key, model_info)
|
||||
return OmegaConf.load(config_path)
|
||||
except (DuplicateModelException, InvalidModelException) as e:
|
||||
print(str(e))
|
||||
|
@ -14,11 +14,8 @@ when new models are downloaded from HuggingFace or Civitae.
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_manager import DuplicateModelException, InvalidModelException, ModelInstall
|
||||
from invokeai.backend.model_manager.storage import get_config_store
|
||||
from invokeai.backend.model_manager.storage import migrate_models_store
|
||||
|
||||
|
||||
def main():
|
||||
@ -35,34 +32,7 @@ def main():
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
config.parse_args(config_args)
|
||||
old_yaml_file = OmegaConf.load(config.model_conf_path)
|
||||
|
||||
store = get_config_store(args.outfile)
|
||||
installer = ModelInstall(store=store)
|
||||
|
||||
print(f"Writing 3.2 models configuration into {args.outfile}.")
|
||||
|
||||
for model_key, stanza in old_yaml_file.items():
|
||||
if model_key == "__metadata__":
|
||||
assert (
|
||||
stanza["version"] == "3.0.0"
|
||||
), f"This script works on version 3.0.0 yaml files, but your configuration points to a {stanza['version']} version"
|
||||
continue
|
||||
|
||||
try:
|
||||
path = config.models_path / stanza["path"]
|
||||
new_key = installer.register_path(path)
|
||||
model_info = store.get_model(new_key)
|
||||
if vae := stanza.get("vae"):
|
||||
model_info.vae = (config.models_path / vae).as_posix()
|
||||
if model_config := stanza.get("config"):
|
||||
model_info.config = (config.root_path / model_config).as_posix()
|
||||
model_info.description = stanza.get("description")
|
||||
store.update_model(new_key, model_info)
|
||||
|
||||
print(f"{model_key} => {new_key}")
|
||||
except (DuplicateModelException, InvalidModelException) as e:
|
||||
print(str(e))
|
||||
migrate_models_store(config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Reference in New Issue
Block a user