mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Probe LoRAs that do not have the text encoder (#4181)
## What type of PR is this? (check all applicable) - [X] Bug Fix ## Have you discussed this change with the InvokeAI team? - [X] No - minor fix ## Have you updated all relevant documentation? - [X] Yes ## Description It turns out that some LoRAs do not have the text encoder model, and this was causing the code that distinguishes the model base type during model import to reject them as having an unknown base model. This PR enables detection of these cases.
This commit is contained in:
commit
8e7eae6cc7
@ -17,6 +17,7 @@ from .models import (
|
|||||||
SilenceWarnings,
|
SilenceWarnings,
|
||||||
InvalidModelException,
|
InvalidModelException,
|
||||||
)
|
)
|
||||||
|
from .util import lora_token_vector_length
|
||||||
from .models.base import read_checkpoint_meta
|
from .models.base import read_checkpoint_meta
|
||||||
|
|
||||||
|
|
||||||
@ -315,38 +316,16 @@ class LoRACheckpointProbe(CheckpointProbeBase):
|
|||||||
|
|
||||||
def get_base_type(self) -> BaseModelType:
|
def get_base_type(self) -> BaseModelType:
|
||||||
checkpoint = self.checkpoint
|
checkpoint = self.checkpoint
|
||||||
|
token_vector_length = lora_token_vector_length(checkpoint)
|
||||||
|
|
||||||
# SD-2 models are very hard to probe. These probes are brittle and likely to fail in the future
|
if token_vector_length == 768:
|
||||||
# There are also some "SD-2 LoRAs" that have identical keys and shapes to SD-1 and will be
|
|
||||||
# misclassified as SD-1
|
|
||||||
key = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
|
|
||||||
if key in checkpoint and checkpoint[key].shape[0] == 320:
|
|
||||||
return BaseModelType.StableDiffusion2
|
|
||||||
|
|
||||||
key = "lora_unet_output_blocks_5_1_transformer_blocks_1_ff_net_2.lora_up.weight"
|
|
||||||
if key in checkpoint:
|
|
||||||
return BaseModelType.StableDiffusionXL
|
|
||||||
|
|
||||||
key1 = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
|
|
||||||
key2 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
|
|
||||||
key3 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.hada_w1_a"
|
|
||||||
|
|
||||||
lora_token_vector_length = (
|
|
||||||
checkpoint[key1].shape[1]
|
|
||||||
if key1 in checkpoint
|
|
||||||
else checkpoint[key2].shape[1]
|
|
||||||
if key2 in checkpoint
|
|
||||||
else checkpoint[key3].shape[0]
|
|
||||||
if key3 in checkpoint
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
if lora_token_vector_length == 768:
|
|
||||||
return BaseModelType.StableDiffusion1
|
return BaseModelType.StableDiffusion1
|
||||||
elif lora_token_vector_length == 1024:
|
elif token_vector_length == 1024:
|
||||||
return BaseModelType.StableDiffusion2
|
return BaseModelType.StableDiffusion2
|
||||||
|
elif token_vector_length == 2048:
|
||||||
|
return BaseModelType.StableDiffusionXL
|
||||||
else:
|
else:
|
||||||
raise InvalidModelException(f"Unknown LoRA type")
|
raise InvalidModelException(f"Unknown LoRA type: {self.checkpoint_path}")
|
||||||
|
|
||||||
|
|
||||||
class TextualInversionCheckpointProbe(CheckpointProbeBase):
|
class TextualInversionCheckpointProbe(CheckpointProbeBase):
|
||||||
|
69
invokeai/backend/model_management/util.py
Normal file
69
invokeai/backend/model_management/util.py
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
# Copyright (c) 2023 The InvokeAI Development Team
|
||||||
|
"""Utilities used by the Model Manager"""
|
||||||
|
|
||||||
|
|
||||||
|
def lora_token_vector_length(checkpoint: dict) -> int:
|
||||||
|
"""
|
||||||
|
Given a checkpoint in memory, return the lora token vector length
|
||||||
|
|
||||||
|
:param checkpoint: The checkpoint
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _get_shape_1(key, tensor, checkpoint):
|
||||||
|
lora_token_vector_length = None
|
||||||
|
|
||||||
|
# check lora/locon
|
||||||
|
if ".lora_down.weight" in key:
|
||||||
|
lora_token_vector_length = tensor.shape[1]
|
||||||
|
|
||||||
|
# check loha (don't worry about hada_t1/hada_t2 as it used only in 4d shapes)
|
||||||
|
elif ".hada_w1_b" in key or ".hada_w2_b" in key:
|
||||||
|
lora_token_vector_length = tensor.shape[1]
|
||||||
|
|
||||||
|
# check lokr (don't worry about lokr_t2 as it used only in 4d shapes)
|
||||||
|
elif ".lokr_" in key:
|
||||||
|
_lokr_key = key.split(".")[0]
|
||||||
|
|
||||||
|
if _lokr_key + ".lokr_w1" in checkpoint:
|
||||||
|
_lokr_w1 = checkpoint[_lokr_key + ".lokr_w1"]
|
||||||
|
elif _lokr_key + "lokr_w1_b" in checkpoint:
|
||||||
|
_lokr_w1 = checkpoint[_lokr_key + ".lokr_w1_b"]
|
||||||
|
else:
|
||||||
|
return lora_token_vector_length # unknown format
|
||||||
|
|
||||||
|
if _lokr_key + ".lokr_w2" in checkpoint:
|
||||||
|
_lokr_w2 = checkpoint[_lokr_key + ".lokr_w2"]
|
||||||
|
elif _lokr_key + "lokr_w2_b" in checkpoint:
|
||||||
|
_lokr_w2 = checkpoint[_lokr_key + ".lokr_w2_b"]
|
||||||
|
else:
|
||||||
|
return lora_token_vector_length # unknown format
|
||||||
|
|
||||||
|
lora_token_vector_length = _lokr_w1.shape[1] * _lokr_w2.shape[1]
|
||||||
|
|
||||||
|
elif ".diff" in key:
|
||||||
|
lora_token_vector_length = tensor.shape[1]
|
||||||
|
|
||||||
|
return lora_token_vector_length
|
||||||
|
|
||||||
|
lora_token_vector_length = None
|
||||||
|
lora_te1_length = None
|
||||||
|
lora_te2_length = None
|
||||||
|
for key, tensor in checkpoint.items():
|
||||||
|
if key.startswith("lora_unet_") and ("_attn2_to_k." in key or "_attn2_to_v." in key):
|
||||||
|
lora_token_vector_length = _get_shape_1(key, tensor, checkpoint)
|
||||||
|
elif key.startswith("lora_te") and "_self_attn_" in key:
|
||||||
|
tmp_length = _get_shape_1(key, tensor, checkpoint)
|
||||||
|
if key.startswith("lora_te_"):
|
||||||
|
lora_token_vector_length = tmp_length
|
||||||
|
elif key.startswith("lora_te1_"):
|
||||||
|
lora_te1_length = tmp_length
|
||||||
|
elif key.startswith("lora_te2_"):
|
||||||
|
lora_te2_length = tmp_length
|
||||||
|
|
||||||
|
if lora_te1_length is not None and lora_te2_length is not None:
|
||||||
|
lora_token_vector_length = lora_te1_length + lora_te2_length
|
||||||
|
|
||||||
|
if lora_token_vector_length is not None:
|
||||||
|
break
|
||||||
|
|
||||||
|
return lora_token_vector_length
|
34
scripts/create_checkpoint_template.py
Executable file
34
scripts/create_checkpoint_template.py
Executable file
@ -0,0 +1,34 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
"""
|
||||||
|
Read a checkpoint/safetensors file and write out a template .json file containing
|
||||||
|
its metadata for use in fast model probing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from invokeai.backend.model_management.models.base import read_checkpoint_meta
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="Create a .json template from checkpoint/safetensors model")
|
||||||
|
parser.add_argument("--checkpoint", "--in", type=Path, help="Path to the input checkpoint/safetensors file")
|
||||||
|
parser.add_argument("--template", "--out", type=Path, help="Path to the output .json file")
|
||||||
|
|
||||||
|
opt = parser.parse_args()
|
||||||
|
ckpt = read_checkpoint_meta(opt.checkpoint)
|
||||||
|
while "state_dict" in ckpt:
|
||||||
|
ckpt = ckpt["state_dict"]
|
||||||
|
|
||||||
|
tmpl = {}
|
||||||
|
|
||||||
|
for key, tensor in ckpt.items():
|
||||||
|
tmpl[key] = list(tensor.shape)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(opt.template, "w") as f:
|
||||||
|
json.dump(tmpl, f)
|
||||||
|
print(f"Template written out as {opt.template}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"An exception occurred while writing template: {str(e)}")
|
37
scripts/verify_checkpoint_template.py
Executable file
37
scripts/verify_checkpoint_template.py
Executable file
@ -0,0 +1,37 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
"""
|
||||||
|
Read a checkpoint/safetensors file and compare it to a template .json.
|
||||||
|
Returns True if their metadata match.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from invokeai.backend.model_management.models.base import read_checkpoint_meta
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="Compare a checkpoint/safetensors file to a JSON metadata template.")
|
||||||
|
parser.add_argument("--checkpoint", "--in", type=Path, help="Path to the input checkpoint/safetensors file")
|
||||||
|
parser.add_argument("--template", "--out", type=Path, help="Path to the template .json file to match against")
|
||||||
|
|
||||||
|
opt = parser.parse_args()
|
||||||
|
ckpt = read_checkpoint_meta(opt.checkpoint)
|
||||||
|
while "state_dict" in ckpt:
|
||||||
|
ckpt = ckpt["state_dict"]
|
||||||
|
|
||||||
|
checkpoint_metadata = {}
|
||||||
|
|
||||||
|
for key, tensor in ckpt.items():
|
||||||
|
checkpoint_metadata[key] = list(tensor.shape)
|
||||||
|
|
||||||
|
with open(opt.template, "r") as f:
|
||||||
|
template = json.load(f)
|
||||||
|
|
||||||
|
if checkpoint_metadata == template:
|
||||||
|
print("True")
|
||||||
|
sys.exit(0)
|
||||||
|
else:
|
||||||
|
print("False")
|
||||||
|
sys.exit(-1)
|
Loading…
Reference in New Issue
Block a user