From f3d3316558d29bed477776876c940fe52b1118ef Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sun, 6 Aug 2023 16:00:53 -0400 Subject: [PATCH 1/8] probe LoRAs that do not have the text encoder --- invokeai/backend/model_management/model_probe.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/invokeai/backend/model_management/model_probe.py b/invokeai/backend/model_management/model_probe.py index 21462cf6e6..1fcb5334df 100644 --- a/invokeai/backend/model_management/model_probe.py +++ b/invokeai/backend/model_management/model_probe.py @@ -345,6 +345,8 @@ class LoRACheckpointProbe(CheckpointProbeBase): return BaseModelType.StableDiffusion1 elif lora_token_vector_length == 1024: return BaseModelType.StableDiffusion2 + elif lora_token_vector_length is None: # variant w/o the text encoder! + return BaseModelType.StableDiffusion1 else: raise InvalidModelException(f"Unknown LoRA type") From 22f7cf063854d4cf252b31ef15d3544ba325680a Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 7 Aug 2023 16:19:57 -0400 Subject: [PATCH 2/8] add stalker's complicated but effective code for finding token vector length in LoRAs --- .../backend/model_management/model_manager.py | 2 +- .../backend/model_management/model_probe.py | 35 +----- invokeai/backend/model_management/util.py | 118 ++++++++++++++++++ 3 files changed, 125 insertions(+), 30 deletions(-) create mode 100644 invokeai/backend/model_management/util.py diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index 0bad714a17..ea8865ede4 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -472,7 +472,7 @@ class ModelManager(object): if submodel_type is not None and hasattr(model_config, submodel_type): override_path = getattr(model_config, submodel_type) if override_path: - model_path = self.resolve_path(override_path) + model_path = self.resolve_model_path(override_path) model_type = submodel_type submodel_type = None model_class = MODEL_CLASSES[base_model][model_type] diff --git a/invokeai/backend/model_management/model_probe.py b/invokeai/backend/model_management/model_probe.py index 1fcb5334df..4c5d48bf9c 100644 --- a/invokeai/backend/model_management/model_probe.py +++ b/invokeai/backend/model_management/model_probe.py @@ -17,6 +17,7 @@ from .models import ( SilenceWarnings, InvalidModelException, ) +from .util import lora_token_vector_length from .models.base import read_checkpoint_meta @@ -315,38 +316,14 @@ class LoRACheckpointProbe(CheckpointProbeBase): def get_base_type(self) -> BaseModelType: checkpoint = self.checkpoint + token_vector_length = lora_token_vector_length(checkpoint) - # SD-2 models are very hard to probe. These probes are brittle and likely to fail in the future - # There are also some "SD-2 LoRAs" that have identical keys and shapes to SD-1 and will be - # misclassified as SD-1 - key = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight" - if key in checkpoint and checkpoint[key].shape[0] == 320: + if token_vector_length == 768: + return BaseModelType.StableDiffusion1 + elif token_vector_length == 1024: return BaseModelType.StableDiffusion2 - - key = "lora_unet_output_blocks_5_1_transformer_blocks_1_ff_net_2.lora_up.weight" - if key in checkpoint: + elif token_vector_length == 2048: # variant w/o the text encoder! return BaseModelType.StableDiffusionXL - - key1 = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight" - key2 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight" - key3 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.hada_w1_a" - - lora_token_vector_length = ( - checkpoint[key1].shape[1] - if key1 in checkpoint - else checkpoint[key2].shape[1] - if key2 in checkpoint - else checkpoint[key3].shape[0] - if key3 in checkpoint - else None - ) - - if lora_token_vector_length == 768: - return BaseModelType.StableDiffusion1 - elif lora_token_vector_length == 1024: - return BaseModelType.StableDiffusion2 - elif lora_token_vector_length is None: # variant w/o the text encoder! - return BaseModelType.StableDiffusion1 else: raise InvalidModelException(f"Unknown LoRA type") diff --git a/invokeai/backend/model_management/util.py b/invokeai/backend/model_management/util.py new file mode 100644 index 0000000000..ece9c96d4c --- /dev/null +++ b/invokeai/backend/model_management/util.py @@ -0,0 +1,118 @@ +# Copyright (c) 2023 The InvokeAI Development Team +"""Utilities used by the Model Manager""" + + +def lora_token_vector_length(checkpoint: dict) -> int: + """ + Given a checkpoint in memory, return the lora token vector length + + :param checkpoint: The checkpoint + """ + + def _handle_unet_key(key, tensor, checkpoint): + lora_token_vector_length = None + if "_attn2_to_k." not in key and "_attn2_to_v." not in key: + return lora_token_vector_length + + # check lora/locon + if ".lora_up.weight" in key: + lora_token_vector_length = tensor.shape[0] + elif ".lora_down.weight" in key: + lora_token_vector_length = tensor.shape[1] + + # check loha (don't worry about hada_t1/hada_t2 as it used only in 4d shapes) + elif ".hada_w1_b" in key or ".hada_w2_b" in key: + lora_token_vector_length = tensor.shape[1] + + # check lokr (don't worry about lokr_t2 as it used only in 4d shapes) + elif ".lokr_" in key: + _lokr_key = key.split(".")[0] + + if _lokr_key + ".lokr_w1" in checkpoint: + _lokr_w1 = checkpoint[_lokr_key + ".lokr_w1"] + elif _lokr_key + "lokr_w1_b" in checkpoint: + _lokr_w1 = checkpoint[_lokr_key + ".lokr_w1_b"] + else: + return lora_token_vector_length # unknown format + + if _lokr_key + ".lokr_w2" in checkpoint: + _lokr_w2 = checkpoint[_lokr_key + ".lokr_w2"] + elif _lokr_key + "lokr_w2_b" in checkpoint: + _lokr_w2 = checkpoint[_lokr_key + ".lokr_w2_b"] + else: + return lora_token_vector_length # unknown format + + lora_token_vector_length = _lokr_w1.shape[1] * _lokr_w2.shape[1] + + elif ".diff" in key: + lora_token_vector_length = tensor.shape[1] + + return lora_token_vector_length + + def _handle_te_key(key, tensor, checkpoint): + lora_token_vector_length = None + if "text_model_encoder_layers_" not in key: + return lora_token_vector_length + + # skip detect by mlp + if "_self_attn_" not in key: + return lora_token_vector_length + + # check lora/locon + if ".lora_up.weight" in key: + lora_token_vector_length = tensor.shape[0] + elif ".lora_down.weight" in key: + lora_token_vector_length = tensor.shape[1] + + # check loha (don't worry about hada_t1/hada_t2 as it used only in 4d shapes) + elif ".hada_w1_a" in key or ".hada_w2_a" in key: + lora_token_vector_length = tensor.shape[0] + elif ".hada_w1_b" in key or ".hada_w2_b" in key: + lora_token_vector_length = tensor.shape[1] + + # check lokr (don't worry about lokr_t2 as it used only in 4d shapes) + elif ".lokr_" in key: + _lokr_key = key.split(".")[0] + + if _lokr_key + ".lokr_w1" in checkpoint: + _lokr_w1 = checkpoint[_lokr_key + ".lokr_w1"] + elif _lokr_key + "lokr_w1_b" in checkpoint: + _lokr_w1 = checkpoint[_lokr_key + ".lokr_w1_b"] + else: + return lora_token_vector_length # unknown format + + if _lokr_key + ".lokr_w2" in checkpoint: + _lokr_w2 = checkpoint[_lokr_key + ".lokr_w2"] + elif _lokr_key + "lokr_w2_b" in checkpoint: + _lokr_w2 = checkpoint[_lokr_key + ".lokr_w2_b"] + else: + return lora_token_vector_length # unknown format + + lora_token_vector_length = _lokr_w1.shape[1] * _lokr_w2.shape[1] + + elif ".diff" in key: + lora_token_vector_length = tensor.shape[1] + + return lora_token_vector_length + + lora_token_vector_length = None + lora_te1_length = None + lora_te2_length = None + for key, tensor in checkpoint.items(): + if key.startswith("lora_unet_"): + lora_token_vector_length = _handle_unet_key(key, tensor, checkpoint) + elif key.startswith("lora_te_"): + lora_token_vector_length = _handle_te_key(key, tensor, checkpoint) + + elif key.startswith("lora_te1_"): + lora_te1_length = _handle_te_key(key, tensor, checkpoint) + elif key.startswith("lora_te2_"): + lora_te2_length = _handle_te_key(key, tensor, checkpoint) + + if lora_te1_length is not None and lora_te2_length is not None: + lora_token_vector_length = lora_te1_length + lora_te2_length + + if lora_token_vector_length is not None: + break + + return lora_token_vector_length From f0918edf98b29066280f8192a53f61874fb4fd0b Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 7 Aug 2023 16:38:58 -0400 Subject: [PATCH 3/8] improve error reporting on unrecognized lora models --- invokeai/backend/model_management/model_probe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/invokeai/backend/model_management/model_probe.py b/invokeai/backend/model_management/model_probe.py index 4c5d48bf9c..145c56c273 100644 --- a/invokeai/backend/model_management/model_probe.py +++ b/invokeai/backend/model_management/model_probe.py @@ -322,10 +322,10 @@ class LoRACheckpointProbe(CheckpointProbeBase): return BaseModelType.StableDiffusion1 elif token_vector_length == 1024: return BaseModelType.StableDiffusion2 - elif token_vector_length == 2048: # variant w/o the text encoder! + elif token_vector_length == 2048: return BaseModelType.StableDiffusionXL else: - raise InvalidModelException(f"Unknown LoRA type") + raise InvalidModelException(f"Unknown LoRA type: {self.checkpoint_path}") class TextualInversionCheckpointProbe(CheckpointProbeBase): From eb70bc2ae4c12f128148147e02b6aefd45874bff Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 7 Aug 2023 21:00:47 -0400 Subject: [PATCH 4/8] add scripts to create model templates and check whether they match --- scripts/create_checkpoint_template.py | 48 +++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100755 scripts/create_checkpoint_template.py diff --git a/scripts/create_checkpoint_template.py b/scripts/create_checkpoint_template.py new file mode 100755 index 0000000000..5b8fca8b58 --- /dev/null +++ b/scripts/create_checkpoint_template.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python +""" +Read a checkpoint/safetensors file and write out a template .json file containing +its metadata for use in fast model probing. +""" + +import sys +import argparse +import json + +from pathlib import Path + +from invokeai.backend.model_management.models.base import read_checkpoint_meta + +parser = argparse.ArgumentParser(description="Create a .json template from checkpoint/safetensors model") +parser.add_argument( + "--checkpoint", + "--in", + type=Path, + help="Path to the input checkpoint/safetensors file" +) +parser.add_argument( + "--template", + "--out", + type=Path, + help="Path to the output .json file" +) + +opt = parser.parse_args() +ckpt = read_checkpoint_meta(opt.checkpoint) +while "state_dict" in ckpt: + ckpt = ckpt["state_dict"] + +tmpl = {} + +for key, tensor in ckpt.items(): + tmpl[key] = list(tensor.shape) + +try: + with open(opt.template,'w') as f: + json.dump(tmpl, f) + print(f"Template written out as {opt.template}") +except Exception as e: + print(f"An exception occurred while writing template: {str(e)}") + + + + From 4df581811ea77ac6081ebe55fc4b59fd494a140c Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 7 Aug 2023 21:01:48 -0400 Subject: [PATCH 5/8] add template verification script --- scripts/verify_checkpoint_template.py | 51 +++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100755 scripts/verify_checkpoint_template.py diff --git a/scripts/verify_checkpoint_template.py b/scripts/verify_checkpoint_template.py new file mode 100755 index 0000000000..42c7acca3a --- /dev/null +++ b/scripts/verify_checkpoint_template.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python +""" +Read a checkpoint/safetensors file and compare it to a template .json. +Returns True if their metadata match. +""" + +import sys +import argparse +import json + +from pathlib import Path + +from invokeai.backend.model_management.models.base import read_checkpoint_meta + +parser = argparse.ArgumentParser(description="Create a .json template from checkpoint/safetensors model") +parser.add_argument( + "--checkpoint", + "--in", + type=Path, + help="Path to the input checkpoint/safetensors file" +) +parser.add_argument( + "--template", + "--out", + type=Path, + help="Path to the template .json file to match against" +) + +opt = parser.parse_args() +ckpt = read_checkpoint_meta(opt.checkpoint) +while "state_dict" in ckpt: + ckpt = ckpt["state_dict"] + +checkpoint_metadata = {} + +for key, tensor in ckpt.items(): + checkpoint_metadata[key] = list(tensor.shape) + +with open(opt.template,'r') as f: + template = json.load(f) + +if checkpoint_metadata == template: + print('True') + sys.exit(0) +else: + print('False') + sys.exit(-1) + + + + From 750f09fbed98a4a8bd663f5be0cdc836d77b787d Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 7 Aug 2023 21:01:59 -0400 Subject: [PATCH 6/8] blackify --- scripts/create_checkpoint_template.py | 20 +++----------------- scripts/verify_checkpoint_template.py | 24 +++++------------------- 2 files changed, 8 insertions(+), 36 deletions(-) diff --git a/scripts/create_checkpoint_template.py b/scripts/create_checkpoint_template.py index 5b8fca8b58..7ff201c841 100755 --- a/scripts/create_checkpoint_template.py +++ b/scripts/create_checkpoint_template.py @@ -13,18 +13,8 @@ from pathlib import Path from invokeai.backend.model_management.models.base import read_checkpoint_meta parser = argparse.ArgumentParser(description="Create a .json template from checkpoint/safetensors model") -parser.add_argument( - "--checkpoint", - "--in", - type=Path, - help="Path to the input checkpoint/safetensors file" -) -parser.add_argument( - "--template", - "--out", - type=Path, - help="Path to the output .json file" -) +parser.add_argument("--checkpoint", "--in", type=Path, help="Path to the input checkpoint/safetensors file") +parser.add_argument("--template", "--out", type=Path, help="Path to the output .json file") opt = parser.parse_args() ckpt = read_checkpoint_meta(opt.checkpoint) @@ -37,12 +27,8 @@ for key, tensor in ckpt.items(): tmpl[key] = list(tensor.shape) try: - with open(opt.template,'w') as f: + with open(opt.template, "w") as f: json.dump(tmpl, f) print(f"Template written out as {opt.template}") except Exception as e: print(f"An exception occurred while writing template: {str(e)}") - - - - diff --git a/scripts/verify_checkpoint_template.py b/scripts/verify_checkpoint_template.py index 42c7acca3a..68ed72037e 100755 --- a/scripts/verify_checkpoint_template.py +++ b/scripts/verify_checkpoint_template.py @@ -13,18 +13,8 @@ from pathlib import Path from invokeai.backend.model_management.models.base import read_checkpoint_meta parser = argparse.ArgumentParser(description="Create a .json template from checkpoint/safetensors model") -parser.add_argument( - "--checkpoint", - "--in", - type=Path, - help="Path to the input checkpoint/safetensors file" -) -parser.add_argument( - "--template", - "--out", - type=Path, - help="Path to the template .json file to match against" -) +parser.add_argument("--checkpoint", "--in", type=Path, help="Path to the input checkpoint/safetensors file") +parser.add_argument("--template", "--out", type=Path, help="Path to the template .json file to match against") opt = parser.parse_args() ckpt = read_checkpoint_meta(opt.checkpoint) @@ -36,16 +26,12 @@ checkpoint_metadata = {} for key, tensor in ckpt.items(): checkpoint_metadata[key] = list(tensor.shape) -with open(opt.template,'r') as f: +with open(opt.template, "r") as f: template = json.load(f) if checkpoint_metadata == template: - print('True') + print("True") sys.exit(0) else: - print('False') + print("False") sys.exit(-1) - - - - From 2f68a1a76cdd8367f75b94749ba4dcafe4c6360d Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Wed, 9 Aug 2023 09:21:29 -0400 Subject: [PATCH 7/8] use Stalker's simplified LoRA vector-length detection code --- invokeai/backend/model_management/util.py | 73 ++++------------------- 1 file changed, 12 insertions(+), 61 deletions(-) diff --git a/invokeai/backend/model_management/util.py b/invokeai/backend/model_management/util.py index ece9c96d4c..f435ab79b6 100644 --- a/invokeai/backend/model_management/util.py +++ b/invokeai/backend/model_management/util.py @@ -9,15 +9,11 @@ def lora_token_vector_length(checkpoint: dict) -> int: :param checkpoint: The checkpoint """ - def _handle_unet_key(key, tensor, checkpoint): + def _get_shape_1(key, tensor, checkpoint): lora_token_vector_length = None - if "_attn2_to_k." not in key and "_attn2_to_v." not in key: - return lora_token_vector_length # check lora/locon - if ".lora_up.weight" in key: - lora_token_vector_length = tensor.shape[0] - elif ".lora_down.weight" in key: + if ".lora_down.weight" in key: lora_token_vector_length = tensor.shape[1] # check loha (don't worry about hada_t1/hada_t2 as it used only in 4d shapes) @@ -49,65 +45,20 @@ def lora_token_vector_length(checkpoint: dict) -> int: return lora_token_vector_length - def _handle_te_key(key, tensor, checkpoint): - lora_token_vector_length = None - if "text_model_encoder_layers_" not in key: - return lora_token_vector_length - - # skip detect by mlp - if "_self_attn_" not in key: - return lora_token_vector_length - - # check lora/locon - if ".lora_up.weight" in key: - lora_token_vector_length = tensor.shape[0] - elif ".lora_down.weight" in key: - lora_token_vector_length = tensor.shape[1] - - # check loha (don't worry about hada_t1/hada_t2 as it used only in 4d shapes) - elif ".hada_w1_a" in key or ".hada_w2_a" in key: - lora_token_vector_length = tensor.shape[0] - elif ".hada_w1_b" in key or ".hada_w2_b" in key: - lora_token_vector_length = tensor.shape[1] - - # check lokr (don't worry about lokr_t2 as it used only in 4d shapes) - elif ".lokr_" in key: - _lokr_key = key.split(".")[0] - - if _lokr_key + ".lokr_w1" in checkpoint: - _lokr_w1 = checkpoint[_lokr_key + ".lokr_w1"] - elif _lokr_key + "lokr_w1_b" in checkpoint: - _lokr_w1 = checkpoint[_lokr_key + ".lokr_w1_b"] - else: - return lora_token_vector_length # unknown format - - if _lokr_key + ".lokr_w2" in checkpoint: - _lokr_w2 = checkpoint[_lokr_key + ".lokr_w2"] - elif _lokr_key + "lokr_w2_b" in checkpoint: - _lokr_w2 = checkpoint[_lokr_key + ".lokr_w2_b"] - else: - return lora_token_vector_length # unknown format - - lora_token_vector_length = _lokr_w1.shape[1] * _lokr_w2.shape[1] - - elif ".diff" in key: - lora_token_vector_length = tensor.shape[1] - - return lora_token_vector_length - lora_token_vector_length = None lora_te1_length = None lora_te2_length = None for key, tensor in checkpoint.items(): - if key.startswith("lora_unet_"): - lora_token_vector_length = _handle_unet_key(key, tensor, checkpoint) - elif key.startswith("lora_te_"): - lora_token_vector_length = _handle_te_key(key, tensor, checkpoint) - - elif key.startswith("lora_te1_"): - lora_te1_length = _handle_te_key(key, tensor, checkpoint) - elif key.startswith("lora_te2_"): - lora_te2_length = _handle_te_key(key, tensor, checkpoint) + if key.startswith("lora_unet_") and ("_attn2_to_k." in key or "_attn2_to_v." in key): + lora_token_vector_length = _get_shape_1(key, tensor, checkpoint) + elif key.startswith("lora_te") and "_self_attn_" in key: + tmp_length = _get_shape_1(key, tensor, checkpoint) + if key.startswith("lora_te_"): + lora_token_vector_length = tmp_length + elif key.startswith("lora_te1_"): + lora_te1_length = tmp_length + elif key.startswith("lora_te2_"): + lora_te2_length = tmp_length if lora_te1_length is not None and lora_te2_length is not None: lora_token_vector_length = lora_te1_length + lora_te2_length From 6c8e898f099c8bd95b15b97ab4badbffddbb8463 Mon Sep 17 00:00:00 2001 From: Kent Keirsey <31807370+hipsterusername@users.noreply.github.com> Date: Thu, 10 Aug 2023 16:00:33 -0400 Subject: [PATCH 8/8] Update scripts/verify_checkpoint_template.py Co-authored-by: Eugene Brodsky --- scripts/verify_checkpoint_template.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/verify_checkpoint_template.py b/scripts/verify_checkpoint_template.py index 68ed72037e..15194290f5 100755 --- a/scripts/verify_checkpoint_template.py +++ b/scripts/verify_checkpoint_template.py @@ -12,7 +12,7 @@ from pathlib import Path from invokeai.backend.model_management.models.base import read_checkpoint_meta -parser = argparse.ArgumentParser(description="Create a .json template from checkpoint/safetensors model") +parser = argparse.ArgumentParser(description="Compare a checkpoint/safetensors file to a JSON metadata template.") parser.add_argument("--checkpoint", "--in", type=Path, help="Path to the input checkpoint/safetensors file") parser.add_argument("--template", "--out", type=Path, help="Path to the template .json file to match against")