diff --git a/invokeai/backend/model_management/model_probe.py b/invokeai/backend/model_management/model_probe.py index 21462cf6e6..145c56c273 100644 --- a/invokeai/backend/model_management/model_probe.py +++ b/invokeai/backend/model_management/model_probe.py @@ -17,6 +17,7 @@ from .models import ( SilenceWarnings, InvalidModelException, ) +from .util import lora_token_vector_length from .models.base import read_checkpoint_meta @@ -315,38 +316,16 @@ class LoRACheckpointProbe(CheckpointProbeBase): def get_base_type(self) -> BaseModelType: checkpoint = self.checkpoint + token_vector_length = lora_token_vector_length(checkpoint) - # SD-2 models are very hard to probe. These probes are brittle and likely to fail in the future - # There are also some "SD-2 LoRAs" that have identical keys and shapes to SD-1 and will be - # misclassified as SD-1 - key = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight" - if key in checkpoint and checkpoint[key].shape[0] == 320: - return BaseModelType.StableDiffusion2 - - key = "lora_unet_output_blocks_5_1_transformer_blocks_1_ff_net_2.lora_up.weight" - if key in checkpoint: - return BaseModelType.StableDiffusionXL - - key1 = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight" - key2 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight" - key3 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.hada_w1_a" - - lora_token_vector_length = ( - checkpoint[key1].shape[1] - if key1 in checkpoint - else checkpoint[key2].shape[1] - if key2 in checkpoint - else checkpoint[key3].shape[0] - if key3 in checkpoint - else None - ) - - if lora_token_vector_length == 768: + if token_vector_length == 768: return BaseModelType.StableDiffusion1 - elif lora_token_vector_length == 1024: + elif token_vector_length == 1024: return BaseModelType.StableDiffusion2 + elif token_vector_length == 2048: + return BaseModelType.StableDiffusionXL else: - raise InvalidModelException(f"Unknown LoRA type") + raise InvalidModelException(f"Unknown LoRA type: {self.checkpoint_path}") class TextualInversionCheckpointProbe(CheckpointProbeBase): diff --git a/invokeai/backend/model_management/util.py b/invokeai/backend/model_management/util.py new file mode 100644 index 0000000000..f435ab79b6 --- /dev/null +++ b/invokeai/backend/model_management/util.py @@ -0,0 +1,69 @@ +# Copyright (c) 2023 The InvokeAI Development Team +"""Utilities used by the Model Manager""" + + +def lora_token_vector_length(checkpoint: dict) -> int: + """ + Given a checkpoint in memory, return the lora token vector length + + :param checkpoint: The checkpoint + """ + + def _get_shape_1(key, tensor, checkpoint): + lora_token_vector_length = None + + # check lora/locon + if ".lora_down.weight" in key: + lora_token_vector_length = tensor.shape[1] + + # check loha (don't worry about hada_t1/hada_t2 as it used only in 4d shapes) + elif ".hada_w1_b" in key or ".hada_w2_b" in key: + lora_token_vector_length = tensor.shape[1] + + # check lokr (don't worry about lokr_t2 as it used only in 4d shapes) + elif ".lokr_" in key: + _lokr_key = key.split(".")[0] + + if _lokr_key + ".lokr_w1" in checkpoint: + _lokr_w1 = checkpoint[_lokr_key + ".lokr_w1"] + elif _lokr_key + "lokr_w1_b" in checkpoint: + _lokr_w1 = checkpoint[_lokr_key + ".lokr_w1_b"] + else: + return lora_token_vector_length # unknown format + + if _lokr_key + ".lokr_w2" in checkpoint: + _lokr_w2 = checkpoint[_lokr_key + ".lokr_w2"] + elif _lokr_key + "lokr_w2_b" in checkpoint: + _lokr_w2 = checkpoint[_lokr_key + ".lokr_w2_b"] + else: + return lora_token_vector_length # unknown format + + lora_token_vector_length = _lokr_w1.shape[1] * _lokr_w2.shape[1] + + elif ".diff" in key: + lora_token_vector_length = tensor.shape[1] + + return lora_token_vector_length + + lora_token_vector_length = None + lora_te1_length = None + lora_te2_length = None + for key, tensor in checkpoint.items(): + if key.startswith("lora_unet_") and ("_attn2_to_k." in key or "_attn2_to_v." in key): + lora_token_vector_length = _get_shape_1(key, tensor, checkpoint) + elif key.startswith("lora_te") and "_self_attn_" in key: + tmp_length = _get_shape_1(key, tensor, checkpoint) + if key.startswith("lora_te_"): + lora_token_vector_length = tmp_length + elif key.startswith("lora_te1_"): + lora_te1_length = tmp_length + elif key.startswith("lora_te2_"): + lora_te2_length = tmp_length + + if lora_te1_length is not None and lora_te2_length is not None: + lora_token_vector_length = lora_te1_length + lora_te2_length + + if lora_token_vector_length is not None: + break + + return lora_token_vector_length diff --git a/scripts/create_checkpoint_template.py b/scripts/create_checkpoint_template.py new file mode 100755 index 0000000000..7ff201c841 --- /dev/null +++ b/scripts/create_checkpoint_template.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python +""" +Read a checkpoint/safetensors file and write out a template .json file containing +its metadata for use in fast model probing. +""" + +import sys +import argparse +import json + +from pathlib import Path + +from invokeai.backend.model_management.models.base import read_checkpoint_meta + +parser = argparse.ArgumentParser(description="Create a .json template from checkpoint/safetensors model") +parser.add_argument("--checkpoint", "--in", type=Path, help="Path to the input checkpoint/safetensors file") +parser.add_argument("--template", "--out", type=Path, help="Path to the output .json file") + +opt = parser.parse_args() +ckpt = read_checkpoint_meta(opt.checkpoint) +while "state_dict" in ckpt: + ckpt = ckpt["state_dict"] + +tmpl = {} + +for key, tensor in ckpt.items(): + tmpl[key] = list(tensor.shape) + +try: + with open(opt.template, "w") as f: + json.dump(tmpl, f) + print(f"Template written out as {opt.template}") +except Exception as e: + print(f"An exception occurred while writing template: {str(e)}") diff --git a/scripts/verify_checkpoint_template.py b/scripts/verify_checkpoint_template.py new file mode 100755 index 0000000000..15194290f5 --- /dev/null +++ b/scripts/verify_checkpoint_template.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python +""" +Read a checkpoint/safetensors file and compare it to a template .json. +Returns True if their metadata match. +""" + +import sys +import argparse +import json + +from pathlib import Path + +from invokeai.backend.model_management.models.base import read_checkpoint_meta + +parser = argparse.ArgumentParser(description="Compare a checkpoint/safetensors file to a JSON metadata template.") +parser.add_argument("--checkpoint", "--in", type=Path, help="Path to the input checkpoint/safetensors file") +parser.add_argument("--template", "--out", type=Path, help="Path to the template .json file to match against") + +opt = parser.parse_args() +ckpt = read_checkpoint_meta(opt.checkpoint) +while "state_dict" in ckpt: + ckpt = ckpt["state_dict"] + +checkpoint_metadata = {} + +for key, tensor in ckpt.items(): + checkpoint_metadata[key] = list(tensor.shape) + +with open(opt.template, "r") as f: + template = json.load(f) + +if checkpoint_metadata == template: + print("True") + sys.exit(0) +else: + print("False") + sys.exit(-1)