automatically convert models.yaml to new format

This commit is contained in:
Lincoln Stein
2023-09-23 17:00:53 -04:00
parent ab58eb29c5
commit 6edee2d22b
13 changed files with 201 additions and 836 deletions

View File

@ -0,0 +1 @@
The contents of this directory are deprecated. model_manager.py is here only for reference.

View File

@ -1,605 +0,0 @@
import json
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Dict, Literal, Optional, Union
import safetensors.torch
import torch
from diffusers import ConfigMixin, ModelMixin
from picklescan.scanner import scan_file_path
from invokeai.backend.model_management.models.ip_adapter import IPAdapterModelFormat
from .models import (
BaseModelType,
InvalidModelException,
ModelType,
ModelVariantType,
SchedulerPredictionType,
SilenceWarnings,
)
from .models.base import read_checkpoint_meta
from .util import lora_token_vector_length
@dataclass
class ModelProbeInfo(object):
model_type: ModelType
base_type: BaseModelType
variant_type: ModelVariantType
prediction_type: SchedulerPredictionType
upcast_attention: bool
format: Literal["diffusers", "checkpoint", "lycoris", "olive", "onnx"]
image_size: int
class ProbeBase(object):
"""forward declaration"""
pass
class ModelProbe(object):
PROBES = {
"diffusers": {},
"checkpoint": {},
"onnx": {},
}
CLASS2TYPE = {
"StableDiffusionPipeline": ModelType.Main,
"StableDiffusionInpaintPipeline": ModelType.Main,
"StableDiffusionXLPipeline": ModelType.Main,
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
"StableDiffusionXLInpaintPipeline": ModelType.Main,
"AutoencoderKL": ModelType.Vae,
"AutoencoderTiny": ModelType.Vae,
"ControlNetModel": ModelType.ControlNet,
"CLIPVisionModelWithProjection": ModelType.CLIPVision,
}
@classmethod
def register_probe(
cls, format: Literal["diffusers", "checkpoint", "onnx"], model_type: ModelType, probe_class: ProbeBase
):
cls.PROBES[format][model_type] = probe_class
@classmethod
def heuristic_probe(
cls,
model: Union[Dict, ModelMixin, Path],
prediction_type_helper: Callable[[Path], SchedulerPredictionType] = None,
) -> ModelProbeInfo:
if isinstance(model, Path):
return cls.probe(model_path=model, prediction_type_helper=prediction_type_helper)
elif isinstance(model, (dict, ModelMixin, ConfigMixin)):
return cls.probe(model_path=None, model=model, prediction_type_helper=prediction_type_helper)
else:
raise InvalidModelException("model parameter {model} is neither a Path, nor a model")
@classmethod
def probe(
cls,
model_path: Path,
model: Optional[Union[Dict, ModelMixin]] = None,
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
) -> Optional[ModelProbeInfo]:
"""
Probe the model at model_path and return sufficient information about it
to place it somewhere in the models directory hierarchy. If the model is
already loaded into memory, you may provide it as model in order to avoid
opening it a second time. The prediction_type_helper callable is a function that receives
the path to the model and returns the BaseModelType. It is called to distinguish
between V2-Base and V2-768 SD models.
"""
if model_path:
format_type = "diffusers" if model_path.is_dir() else "checkpoint"
else:
format_type = "diffusers" if isinstance(model, (ConfigMixin, ModelMixin)) else "checkpoint"
model_info = None
try:
model_type = (
cls.get_model_type_from_folder(model_path, model)
if format_type == "diffusers"
else cls.get_model_type_from_checkpoint(model_path, model)
)
format_type = "onnx" if model_type == ModelType.ONNX else format_type
probe_class = cls.PROBES[format_type].get(model_type)
if not probe_class:
return None
probe = probe_class(model_path, model, prediction_type_helper)
base_type = probe.get_base_type()
variant_type = probe.get_variant_type()
prediction_type = probe.get_scheduler_prediction_type()
format = probe.get_format()
model_info = ModelProbeInfo(
model_type=model_type,
base_type=base_type,
variant_type=variant_type,
prediction_type=prediction_type,
upcast_attention=(
base_type == BaseModelType.StableDiffusion2
and prediction_type == SchedulerPredictionType.VPrediction
),
format=format,
image_size=(
1024
if (base_type in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner})
else (
768
if (
base_type == BaseModelType.StableDiffusion2
and prediction_type == SchedulerPredictionType.VPrediction
)
else 512
)
),
)
except Exception:
raise
return model_info
@classmethod
def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: dict) -> ModelType:
if model_path.suffix not in (".bin", ".pt", ".ckpt", ".safetensors", ".pth"):
return None
if model_path.name == "learned_embeds.bin":
return ModelType.TextualInversion
ckpt = checkpoint if checkpoint else read_checkpoint_meta(model_path, scan=True)
ckpt = ckpt.get("state_dict", ckpt)
for key in ckpt.keys():
if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}):
return ModelType.Main
elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}):
return ModelType.Vae
elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
return ModelType.Lora
elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}):
return ModelType.Lora
elif any(key.startswith(v) for v in {"control_model", "input_blocks"}):
return ModelType.ControlNet
elif key in {"emb_params", "string_to_param"}:
return ModelType.TextualInversion
else:
# diffusers-ti
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
return ModelType.TextualInversion
raise InvalidModelException(f"Unable to determine model type for {model_path}")
@classmethod
def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin) -> ModelType:
"""
Get the model type of a hugging-face style folder.
"""
class_name = None
error_hint = None
if model:
class_name = model.__class__.__name__
else:
if (folder_path / "unet/model.onnx").exists():
return ModelType.ONNX
if (folder_path / "learned_embeds.bin").exists():
return ModelType.TextualInversion
if (folder_path / "pytorch_lora_weights.bin").exists():
return ModelType.Lora
if (folder_path / "image_encoder.txt").exists():
return ModelType.IPAdapter
i = folder_path / "model_index.json"
c = folder_path / "config.json"
config_path = i if i.exists() else c if c.exists() else None
if config_path:
with open(config_path, "r") as file:
conf = json.load(file)
if "_class_name" in conf:
class_name = conf["_class_name"]
elif "architectures" in conf:
class_name = conf["architectures"][0]
else:
class_name = None
else:
error_hint = f"No model_index.json or config.json found in {folder_path}."
if class_name and (type := cls.CLASS2TYPE.get(class_name)):
return type
else:
error_hint = f"class {class_name} is not one of the supported classes [{', '.join(cls.CLASS2TYPE.keys())}]"
# give up
raise InvalidModelException(
f"Unable to determine model type for {folder_path}" + (f"; {error_hint}" if error_hint else "")
)
@classmethod
def _scan_and_load_checkpoint(cls, model_path: Path) -> dict:
with SilenceWarnings():
if model_path.suffix.endswith((".ckpt", ".pt", ".bin")):
cls._scan_model(model_path, model_path)
return torch.load(model_path)
else:
return safetensors.torch.load_file(model_path)
@classmethod
def _scan_model(cls, model_name, checkpoint):
"""
Apply picklescanner to the indicated checkpoint and issue a warning
and option to exit if an infected file is identified.
"""
# scan model
scan_result = scan_file_path(checkpoint)
if scan_result.infected_files != 0:
raise "The model {model_name} is potentially infected by malware. Aborting import."
# ##################################################3
# Checkpoint probing
# ##################################################3
class ProbeBase(object):
def get_base_type(self) -> BaseModelType:
pass
def get_variant_type(self) -> ModelVariantType:
pass
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
pass
def get_format(self) -> str:
pass
class CheckpointProbeBase(ProbeBase):
def __init__(
self, checkpoint_path: Path, checkpoint: dict, helper: Callable[[Path], SchedulerPredictionType] = None
) -> BaseModelType:
self.checkpoint = checkpoint or ModelProbe._scan_and_load_checkpoint(checkpoint_path)
self.checkpoint_path = checkpoint_path
self.helper = helper
def get_base_type(self) -> BaseModelType:
pass
def get_format(self) -> str:
return "checkpoint"
def get_variant_type(self) -> ModelVariantType:
model_type = ModelProbe.get_model_type_from_checkpoint(self.checkpoint_path, self.checkpoint)
if model_type != ModelType.Main:
return ModelVariantType.Normal
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
if in_channels == 9:
return ModelVariantType.Inpaint
elif in_channels == 5:
return ModelVariantType.Depth
elif in_channels == 4:
return ModelVariantType.Normal
else:
raise InvalidModelException(
f"Cannot determine variant type (in_channels={in_channels}) at {self.checkpoint_path}"
)
class PipelineCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
checkpoint = self.checkpoint
state_dict = self.checkpoint.get("state_dict") or checkpoint
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
return BaseModelType.StableDiffusion1
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
return BaseModelType.StableDiffusion2
key_name = "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in state_dict and state_dict[key_name].shape[-1] == 2048:
return BaseModelType.StableDiffusionXL
elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280:
return BaseModelType.StableDiffusionXLRefiner
else:
raise InvalidModelException("Cannot determine base type")
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
type = self.get_base_type()
if type == BaseModelType.StableDiffusion1:
return SchedulerPredictionType.Epsilon
checkpoint = self.checkpoint
state_dict = self.checkpoint.get("state_dict") or checkpoint
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
if "global_step" in checkpoint:
if checkpoint["global_step"] == 220000:
return SchedulerPredictionType.Epsilon
elif checkpoint["global_step"] == 110000:
return SchedulerPredictionType.VPrediction
if (
self.checkpoint_path and self.helper and not self.checkpoint_path.with_suffix(".yaml").exists()
): # if a .yaml config file exists, then this step not needed
return self.helper(self.checkpoint_path)
else:
return None
class VaeCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
# I can't find any standalone 2.X VAEs to test with!
return BaseModelType.StableDiffusion1
class LoRACheckpointProbe(CheckpointProbeBase):
def get_format(self) -> str:
return "lycoris"
def get_base_type(self) -> BaseModelType:
checkpoint = self.checkpoint
token_vector_length = lora_token_vector_length(checkpoint)
if token_vector_length == 768:
return BaseModelType.StableDiffusion1
elif token_vector_length == 1024:
return BaseModelType.StableDiffusion2
elif token_vector_length == 2048:
return BaseModelType.StableDiffusionXL
else:
raise InvalidModelException(f"Unknown LoRA type: {self.checkpoint_path}")
class TextualInversionCheckpointProbe(CheckpointProbeBase):
def get_format(self) -> str:
return None
def get_base_type(self) -> BaseModelType:
checkpoint = self.checkpoint
if "string_to_token" in checkpoint:
token_dim = list(checkpoint["string_to_param"].values())[0].shape[-1]
elif "emb_params" in checkpoint:
token_dim = checkpoint["emb_params"].shape[-1]
else:
token_dim = list(checkpoint.values())[0].shape[0]
if token_dim == 768:
return BaseModelType.StableDiffusion1
elif token_dim == 1024:
return BaseModelType.StableDiffusion2
else:
return None
class ControlNetCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
checkpoint = self.checkpoint
for key_name in (
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
"input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
):
if key_name not in checkpoint:
continue
if checkpoint[key_name].shape[-1] == 768:
return BaseModelType.StableDiffusion1
elif checkpoint[key_name].shape[-1] == 1024:
return BaseModelType.StableDiffusion2
elif self.checkpoint_path and self.helper:
return self.helper(self.checkpoint_path)
raise InvalidModelException("Unable to determine base type for {self.checkpoint_path}")
class IPAdapterCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
raise NotImplementedError()
class CLIPVisionCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
raise NotImplementedError()
########################################################
# classes for probing folders
#######################################################
class FolderProbeBase(ProbeBase):
def __init__(self, folder_path: Path, model: ModelMixin = None, helper: Callable = None): # not used
self.model = model
self.folder_path = folder_path
def get_variant_type(self) -> ModelVariantType:
return ModelVariantType.Normal
def get_format(self) -> str:
return "diffusers"
class PipelineFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
if self.model:
unet_conf = self.model.unet.config
else:
with open(self.folder_path / "unet" / "config.json", "r") as file:
unet_conf = json.load(file)
if unet_conf["cross_attention_dim"] == 768:
return BaseModelType.StableDiffusion1
elif unet_conf["cross_attention_dim"] == 1024:
return BaseModelType.StableDiffusion2
elif unet_conf["cross_attention_dim"] == 1280:
return BaseModelType.StableDiffusionXLRefiner
elif unet_conf["cross_attention_dim"] == 2048:
return BaseModelType.StableDiffusionXL
else:
raise InvalidModelException(f"Unknown base model for {self.folder_path}")
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
if self.model:
scheduler_conf = self.model.scheduler.config
else:
with open(self.folder_path / "scheduler" / "scheduler_config.json", "r") as file:
scheduler_conf = json.load(file)
if scheduler_conf["prediction_type"] == "v_prediction":
return SchedulerPredictionType.VPrediction
elif scheduler_conf["prediction_type"] == "epsilon":
return SchedulerPredictionType.Epsilon
else:
return None
def get_variant_type(self) -> ModelVariantType:
# This only works for pipelines! Any kind of
# exception results in our returning the
# "normal" variant type
try:
if self.model:
conf = self.model.unet.config
else:
config_file = self.folder_path / "unet" / "config.json"
with open(config_file, "r") as file:
conf = json.load(file)
in_channels = conf["in_channels"]
if in_channels == 9:
return ModelVariantType.Inpaint
elif in_channels == 5:
return ModelVariantType.Depth
elif in_channels == 4:
return ModelVariantType.Normal
except Exception:
pass
return ModelVariantType.Normal
class VaeFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
if self._config_looks_like_sdxl():
return BaseModelType.StableDiffusionXL
elif self._name_looks_like_sdxl():
# but SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down
# by a factor of 8), we can't necessarily tell them apart by config hyperparameters.
return BaseModelType.StableDiffusionXL
else:
return BaseModelType.StableDiffusion1
def _config_looks_like_sdxl(self) -> bool:
# config values that distinguish Stability's SD 1.x VAE from their SDXL VAE.
config_file = self.folder_path / "config.json"
if not config_file.exists():
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
with open(config_file, "r") as file:
config = json.load(file)
return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024]
def _name_looks_like_sdxl(self) -> bool:
return bool(re.search(r"xl\b", self._guess_name(), re.IGNORECASE))
def _guess_name(self) -> str:
name = self.folder_path.name
if name == "vae":
name = self.folder_path.parent.name
return name
class TextualInversionFolderProbe(FolderProbeBase):
def get_format(self) -> str:
return None
def get_base_type(self) -> BaseModelType:
path = self.folder_path / "learned_embeds.bin"
if not path.exists():
return None
checkpoint = ModelProbe._scan_and_load_checkpoint(path)
return TextualInversionCheckpointProbe(None, checkpoint=checkpoint).get_base_type()
class ONNXFolderProbe(FolderProbeBase):
def get_format(self) -> str:
return "onnx"
def get_base_type(self) -> BaseModelType:
return BaseModelType.StableDiffusion1
def get_variant_type(self) -> ModelVariantType:
return ModelVariantType.Normal
class ControlNetFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
config_file = self.folder_path / "config.json"
if not config_file.exists():
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
with open(config_file, "r") as file:
config = json.load(file)
# no obvious way to distinguish between sd2-base and sd2-768
dimension = config["cross_attention_dim"]
base_model = (
BaseModelType.StableDiffusion1
if dimension == 768
else (
BaseModelType.StableDiffusion2
if dimension == 1024
else BaseModelType.StableDiffusionXL
if dimension == 2048
else None
)
)
if not base_model:
raise InvalidModelException(f"Unable to determine model base for {self.folder_path}")
return base_model
class LoRAFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
model_file = None
for suffix in ["safetensors", "bin"]:
base_file = self.folder_path / f"pytorch_lora_weights.{suffix}"
if base_file.exists():
model_file = base_file
break
if not model_file:
raise InvalidModelException("Unknown LoRA format encountered")
return LoRACheckpointProbe(model_file, None).get_base_type()
class IPAdapterFolderProbe(FolderProbeBase):
def get_format(self) -> str:
return IPAdapterModelFormat.InvokeAI.value
def get_base_type(self) -> BaseModelType:
model_file = self.folder_path / "ip_adapter.bin"
if not model_file.exists():
raise InvalidModelException("Unknown IP-Adapter model format.")
state_dict = torch.load(model_file, map_location="cpu")
cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1]
if cross_attention_dim == 768:
return BaseModelType.StableDiffusion1
elif cross_attention_dim == 1024:
return BaseModelType.StableDiffusion2
elif cross_attention_dim == 2048:
return BaseModelType.StableDiffusionXL
else:
raise InvalidModelException(f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}.")
class CLIPVisionFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
return BaseModelType.Any
############## register probe classes ######
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.Lora, LoRAFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe)
ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)

View File

@ -1,108 +0,0 @@
# Copyright 2023, Lincoln D. Stein and the InvokeAI Team
"""
Abstract base class for recursive directory search for models.
"""
import os
from abc import ABC, abstractmethod
from pathlib import Path
from typing import List, Set, types
import invokeai.backend.util.logging as logger
class ModelSearch(ABC):
def __init__(self, directories: List[Path], logger: types.ModuleType = logger):
"""
Initialize a recursive model directory search.
:param directories: List of directory Paths to recurse through
:param logger: Logger to use
"""
self.directories = directories
self.logger = logger
self._items_scanned = 0
self._models_found = 0
self._scanned_dirs = set()
self._scanned_paths = set()
self._pruned_paths = set()
@abstractmethod
def on_search_started(self):
"""
Called before the scan starts.
"""
pass
@abstractmethod
def on_model_found(self, model: Path):
"""
Process a found model. Raise an exception if something goes wrong.
:param model: Model to process - could be a directory or checkpoint.
"""
pass
@abstractmethod
def on_search_completed(self):
"""
Perform some activity when the scan is completed. May use instance
variables, items_scanned and models_found
"""
pass
def search(self):
self.on_search_started()
for dir in self.directories:
self.walk_directory(dir)
self.on_search_completed()
def walk_directory(self, path: Path):
for root, dirs, files in os.walk(path, followlinks=True):
if str(Path(root).name).startswith("."):
self._pruned_paths.add(root)
if any([Path(root).is_relative_to(x) for x in self._pruned_paths]):
continue
self._items_scanned += len(dirs) + len(files)
for d in dirs:
path = Path(root) / d
if path in self._scanned_paths or path.parent in self._scanned_dirs:
self._scanned_dirs.add(path)
continue
if any(
[
(path / x).exists()
for x in ["config.json", "model_index.json", "learned_embeds.bin", "pytorch_lora_weights.bin"]
]
):
try:
self.on_model_found(path)
self._models_found += 1
self._scanned_dirs.add(path)
except Exception as e:
self.logger.warning(f"Failed to process '{path}': {e}")
for f in files:
path = Path(root) / f
if path.parent in self._scanned_dirs:
continue
if path.suffix in {".ckpt", ".bin", ".pth", ".safetensors", ".pt"}:
try:
self.on_model_found(path)
self._models_found += 1
except Exception as e:
self.logger.warning(f"Failed to process '{path}': {e}")
class FindModels(ModelSearch):
def on_search_started(self):
self.models_found: Set[Path] = set()
def on_model_found(self, model: Path):
self.models_found.add(model)
def on_search_completed(self):
pass
def list_models(self) -> List[Path]:
self.search()
return list(self.models_found)

View File

@ -1,75 +0,0 @@
# Copyright (c) 2023 The InvokeAI Development Team
"""Utilities used by the Model Manager"""
def lora_token_vector_length(checkpoint: dict) -> int:
"""
Given a checkpoint in memory, return the lora token vector length
:param checkpoint: The checkpoint
"""
def _get_shape_1(key, tensor, checkpoint):
lora_token_vector_length = None
if "." not in key:
return lora_token_vector_length # wrong key format
model_key, lora_key = key.split(".", 1)
# check lora/locon
if lora_key == "lora_down.weight":
lora_token_vector_length = tensor.shape[1]
# check loha (don't worry about hada_t1/hada_t2 as it used only in 4d shapes)
elif lora_key in ["hada_w1_b", "hada_w2_b"]:
lora_token_vector_length = tensor.shape[1]
# check lokr (don't worry about lokr_t2 as it used only in 4d shapes)
elif "lokr_" in lora_key:
if model_key + ".lokr_w1" in checkpoint:
_lokr_w1 = checkpoint[model_key + ".lokr_w1"]
elif model_key + "lokr_w1_b" in checkpoint:
_lokr_w1 = checkpoint[model_key + ".lokr_w1_b"]
else:
return lora_token_vector_length # unknown format
if model_key + ".lokr_w2" in checkpoint:
_lokr_w2 = checkpoint[model_key + ".lokr_w2"]
elif model_key + "lokr_w2_b" in checkpoint:
_lokr_w2 = checkpoint[model_key + ".lokr_w2_b"]
else:
return lora_token_vector_length # unknown format
lora_token_vector_length = _lokr_w1.shape[1] * _lokr_w2.shape[1]
elif lora_key == "diff":
lora_token_vector_length = tensor.shape[1]
# ia3 can be detected only by shape[0] in text encoder
elif lora_key == "weight" and "lora_unet_" not in model_key:
lora_token_vector_length = tensor.shape[0]
return lora_token_vector_length
lora_token_vector_length = None
lora_te1_length = None
lora_te2_length = None
for key, tensor in checkpoint.items():
if key.startswith("lora_unet_") and ("_attn2_to_k." in key or "_attn2_to_v." in key):
lora_token_vector_length = _get_shape_1(key, tensor, checkpoint)
elif key.startswith("lora_te") and "_self_attn_" in key:
tmp_length = _get_shape_1(key, tensor, checkpoint)
if key.startswith("lora_te_"):
lora_token_vector_length = tmp_length
elif key.startswith("lora_te1_"):
lora_te1_length = tmp_length
elif key.startswith("lora_te2_"):
lora_te2_length = tmp_length
if lora_te1_length is not None and lora_te2_length is not None:
lora_token_vector_length = lora_te1_length + lora_te2_length
if lora_token_vector_length is not None:
break
return lora_token_vector_length

View File

@ -17,7 +17,7 @@ from .config import BaseModelType, ModelConfigBase, ModelType, SubModelType
from .download import DownloadEventHandler from .download import DownloadEventHandler
from .install import ModelInstall, ModelInstallBase from .install import ModelInstall, ModelInstallBase
from .models import MODEL_CLASSES, InvalidModelException, ModelBase from .models import MODEL_CLASSES, InvalidModelException, ModelBase
from .storage import ModelConfigStore, get_config_store from .storage import ConfigFileVersionMismatchException, ModelConfigStore, get_config_store, migrate_models_store
@dataclass @dataclass
@ -138,7 +138,12 @@ class ModelLoad(ModelLoadBase):
models_file = config.model_conf_path models_file = config.model_conf_path
else: else:
models_file = config.root_path / "configs/models3.yaml" models_file = config.root_path / "configs/models3.yaml"
store = get_config_store(models_file) try:
store = get_config_store(models_file)
except ConfigFileVersionMismatchException:
migrate_models_store(config)
store = get_config_store(models_file)
if not store: if not store:
raise ValueError(f"Invalid model configuration file: {models_file}") raise ValueError(f"Invalid model configuration file: {models_file}")

View File

@ -7,6 +7,7 @@ its base type, model type, format and variant.
""" """
import json import json
import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import Callable, Optional from typing import Callable, Optional
@ -493,20 +494,33 @@ class PipelineFolderProbe(FolderProbeBase):
class VaeFolderProbe(FolderProbeBase): class VaeFolderProbe(FolderProbeBase):
"""Probe a diffusers-style VAE model."""
def get_base_type(self) -> BaseModelType: def get_base_type(self) -> BaseModelType:
"""Return the BaseModelType for a diffusers-style VAE.""" if self._config_looks_like_sdxl():
return BaseModelType.StableDiffusionXL
elif self._name_looks_like_sdxl():
# but SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down
# by a factor of 8), we can't necessarily tell them apart by config hyperparameters.
return BaseModelType.StableDiffusionXL
else:
return BaseModelType.StableDiffusion1
def _config_looks_like_sdxl(self) -> bool:
# config values that distinguish Stability's SD 1.x VAE from their SDXL VAE.
config_file = self.folder_path / "config.json" config_file = self.folder_path / "config.json"
if not config_file.exists(): if not config_file.exists():
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}") raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
with open(config_file, "r") as file: with open(config_file, "r") as file:
config = json.load(file) config = json.load(file)
return ( return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024]
BaseModelType.StableDiffusionXL
if config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024] def _name_looks_like_sdxl(self) -> bool:
else BaseModelType.StableDiffusion1 return bool(re.search(r"xl\b", self._guess_name(), re.IGNORECASE))
)
def _guess_name(self) -> str:
name = self.folder_path.name
if name == "vae":
name = self.folder_path.parent.name
return name
class TextualInversionFolderProbe(FolderProbeBase): class TextualInversionFolderProbe(FolderProbeBase):

View File

@ -168,7 +168,13 @@ class ModelSearch(ModelSearchBase):
if any( if any(
[ [
(path / x).exists() (path / x).exists()
for x in ["config.json", "model_index.json", "learned_embeds.bin", "pytorch_lora_weights.bin"] for x in [
"config.json",
"model_index.json",
"learned_embeds.bin",
"pytorch_lora_weights.bin",
"image_encoder.txt",
]
] ]
): ):
self._scanned_dirs.add(path) self._scanned_dirs.add(path)

View File

@ -3,7 +3,13 @@ Initialization file for invokeai.backend.model_manager.storage
""" """
import pathlib import pathlib
from .base import DuplicateModelException, ModelConfigStore, UnknownModelException # noqa F401 from .base import ( # noqa F401
ConfigFileVersionMismatchException,
DuplicateModelException,
ModelConfigStore,
UnknownModelException,
)
from .migrate import migrate_models_store # noqa F401
from .sql import ModelConfigStoreSQL # noqa F401 from .sql import ModelConfigStoreSQL # noqa F401
from .yaml import ModelConfigStoreYAML # noqa F401 from .yaml import ModelConfigStoreYAML # noqa F401

View File

@ -4,6 +4,7 @@ Abstract base class for storing and retrieving model configuration records.
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path
from typing import List, Optional, Set, Union from typing import List, Optional, Set, Union
from ..config import BaseModelType, ModelConfigBase, ModelType from ..config import BaseModelType, ModelConfigBase, ModelType
@ -24,6 +25,10 @@ class UnknownModelException(Exception):
"""Raised on an attempt to fetch or delete a model with a nonexistent key.""" """Raised on an attempt to fetch or delete a model with a nonexistent key."""
class ConfigFileVersionMismatchException(Exception):
"""Raised on an attempt to open a config with an incompatible version."""
class ModelConfigStore(ABC): class ModelConfigStore(ABC):
"""Abstract base class for storage and retrieval of model configs.""" """Abstract base class for storage and retrieval of model configs."""
@ -99,6 +104,16 @@ class ModelConfigStore(ABC):
""" """
pass pass
@abstractmethod
def search_by_path(
self,
path: Union[str, Path],
) -> Optional[ModelConfigBase]:
"""
Return the model having the indicated path.
"""
pass
@abstractmethod @abstractmethod
def search_by_name( def search_by_name(
self, self,

View File

@ -0,0 +1,61 @@
# Copyright (c) 2023 The InvokeAI Development Team
import shutil
from pathlib import Path
from omegaconf import OmegaConf
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.util.logging import InvokeAILogger
from .base import CONFIG_FILE_VERSION
def migrate_models_store(config: InvokeAIAppConfig):
# avoid circular import
from invokeai.backend.model_manager import DuplicateModelException, InvalidModelException, ModelInstall
from invokeai.backend.model_manager.storage import get_config_store
app_config = InvokeAIAppConfig.get_config()
logger = InvokeAILogger.getLogger()
old_file: Path = app_config.model_conf_path
new_file: Path = old_file.with_name("models3_2.yaml")
old_conf = OmegaConf.load(old_file)
store = get_config_store(new_file)
installer = ModelInstall(store=store)
logger.info(f"Migrating old models file at {old_file} to new {CONFIG_FILE_VERSION} format")
for model_key, stanza in old_conf.items():
if model_key == "__metadata__":
assert (
stanza["version"] == "3.0.0"
), f"This script works on version 3.0.0 yaml files, but your configuration points to a {stanza['version']} version"
continue
base_type, model_type, model_name = model_key.split("/")
try:
path = app_config.models_path / stanza["path"]
new_key = installer.register_path(path)
except DuplicateModelException:
# if model already installed, then we just update its info
models = store.search_by_name(model_name=model_name, base_model=base_type, model_type=model_type)
if len(models) != 1:
continue
new_key = models[0].key
except Exception as excp:
print(str(excp))
model_info = store.get_model(new_key)
if vae := stanza.get("vae"):
model_info.vae = (app_config.models_path / vae).as_posix()
if model_config := stanza.get("config"):
model_info.config = (app_config.root_path / model_config).as_posix()
model_info.description = stanza.get("description")
store.update_model(new_key, model_info)
store.update_model(new_key, model_info)
logger.info(f"Original version of models config file saved as {str(old_file) + '.orig'}")
shutil.move(old_file, str(old_file) + ".orig")
shutil.move(new_file, old_file)

View File

@ -477,3 +477,9 @@ class ModelConfigStoreSQL(ModelConfigStore):
finally: finally:
self._lock.release() self._lock.release()
return results return results
def search_by_path(self, path: Union[str, Path]) -> Optional[ModelConfigBase]:
"""
Return the model with the indicated path, or None..
"""
raise NotImplementedError("search_by_path not implemented in storage.sql")

View File

@ -50,7 +50,13 @@ from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig from omegaconf.dictconfig import DictConfig
from ..config import BaseModelType, ModelConfigBase, ModelConfigFactory, ModelType from ..config import BaseModelType, ModelConfigBase, ModelConfigFactory, ModelType
from .base import CONFIG_FILE_VERSION, DuplicateModelException, ModelConfigStore, UnknownModelException from .base import (
CONFIG_FILE_VERSION,
ConfigFileVersionMismatchException,
DuplicateModelException,
ModelConfigStore,
UnknownModelException,
)
class ModelConfigStoreYAML(ModelConfigStore): class ModelConfigStoreYAML(ModelConfigStore):
@ -68,9 +74,8 @@ class ModelConfigStoreYAML(ModelConfigStore):
if not self._filename.exists(): if not self._filename.exists():
self._initialize_yaml() self._initialize_yaml()
self._config = OmegaConf.load(self._filename) self._config = OmegaConf.load(self._filename)
assert ( if str(self.version) != CONFIG_FILE_VERSION:
str(self.version) == CONFIG_FILE_VERSION raise ConfigFileVersionMismatchException
), f"Model config version {self.version} does not match expected version {CONFIG_FILE_VERSION}"
def _initialize_yaml(self): def _initialize_yaml(self):
try: try:
@ -239,3 +244,67 @@ class ModelConfigStoreYAML(ModelConfigStore):
finally: finally:
self._lock.release() self._lock.release()
return results return results
def search_by_path(self, path: Union[str, Path]) -> Optional[ModelConfigBase]:
"""
Return the model with the indicated path, or None..
"""
try:
self._lock.acquire()
for key, record in self._config.items():
if key == "__metadata__":
continue
model = ModelConfigFactory.make_config(record, key)
if model.path == path:
return model
finally:
self._lock.release()
return None
def _load_and_maybe_upgrade(self, config_path: Path) -> DictConfig:
config = OmegaConf.load(config_path)
version = config["__metadata__"].get("version")
if version == CONFIG_FILE_VERSION:
return config
# if we get here we need to upgrade
if version == "3.0.0":
return self._migrate_format_to_3_2(config, config_path)
else:
raise Exception(f"{config_path} has unknown version: {version}")
def _migrate_format_to_3_2(self, old_config: DictConfig, config_path: Path) -> DictConfig:
print(
f"** Doing one-time conversion of {config_path.as_posix()} to new format. Original will be named {config_path.as_posix() + '.orig'}"
)
# avoid circular dependencies
from shutil import move
from ..install import InvalidModelException, ModelInstall
move(config_path, config_path.as_posix() + ".orig")
new_store = self.__class__(config_path)
installer = ModelInstall(store=new_store)
for model_key, stanza in old_config.items():
if model_key == "__metadata__":
assert (
stanza["version"] == "3.0.0"
), f"This script works on version 3.0.0 yaml files, but your configuration points to a {stanza['version']} version"
continue
try:
path = stanza["path"]
new_key = installer.register_path(path)
model_info = new_store.get_model(new_key)
if vae := stanza.get("vae"):
model_info.vae = vae
if model_config := stanza.get("config"):
model_info.config = model_config.as_posix()
model_info.description = stanza.get("description")
new_store.update_model(new_key, model_info)
return OmegaConf.load(config_path)
except (DuplicateModelException, InvalidModelException) as e:
print(str(e))

View File

@ -14,11 +14,8 @@ when new models are downloaded from HuggingFace or Civitae.
import argparse import argparse
from pathlib import Path from pathlib import Path
from omegaconf import OmegaConf
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_manager import DuplicateModelException, InvalidModelException, ModelInstall from invokeai.backend.model_manager.storage import migrate_models_store
from invokeai.backend.model_manager.storage import get_config_store
def main(): def main():
@ -35,34 +32,7 @@ def main():
config = InvokeAIAppConfig.get_config() config = InvokeAIAppConfig.get_config()
config.parse_args(config_args) config.parse_args(config_args)
old_yaml_file = OmegaConf.load(config.model_conf_path) migrate_models_store(config)
store = get_config_store(args.outfile)
installer = ModelInstall(store=store)
print(f"Writing 3.2 models configuration into {args.outfile}.")
for model_key, stanza in old_yaml_file.items():
if model_key == "__metadata__":
assert (
stanza["version"] == "3.0.0"
), f"This script works on version 3.0.0 yaml files, but your configuration points to a {stanza['version']} version"
continue
try:
path = config.models_path / stanza["path"]
new_key = installer.register_path(path)
model_info = store.get_model(new_key)
if vae := stanza.get("vae"):
model_info.vae = (config.models_path / vae).as_posix()
if model_config := stanza.get("config"):
model_info.config = (config.root_path / model_config).as_posix()
model_info.description = stanza.get("description")
store.update_model(new_key, model_info)
print(f"{model_key} => {new_key}")
except (DuplicateModelException, InvalidModelException) as e:
print(str(e))
if __name__ == "__main__": if __name__ == "__main__":