Probe LoRAs that do not have the text encoder (#4181)

## What type of PR is this? (check all applicable)

- [X] Bug Fix

## Have you discussed this change with the InvokeAI team?
- [X] No - minor fix

      
## Have you updated all relevant documentation?
- [X] Yes

## Description

It turns out that some LoRAs do not have the text encoder model, and
this was causing the code that distinguishes the model base type during
model import to reject them as having an unknown base model. This PR
enables detection of these cases.
This commit is contained in:
Kent Keirsey 2023-08-10 17:50:20 -04:00 committed by GitHub
commit 8e7eae6cc7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 147 additions and 28 deletions

View File

@ -17,6 +17,7 @@ from .models import (
SilenceWarnings, SilenceWarnings,
InvalidModelException, InvalidModelException,
) )
from .util import lora_token_vector_length
from .models.base import read_checkpoint_meta from .models.base import read_checkpoint_meta
@ -315,38 +316,16 @@ class LoRACheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType: def get_base_type(self) -> BaseModelType:
checkpoint = self.checkpoint checkpoint = self.checkpoint
token_vector_length = lora_token_vector_length(checkpoint)
# SD-2 models are very hard to probe. These probes are brittle and likely to fail in the future if token_vector_length == 768:
# There are also some "SD-2 LoRAs" that have identical keys and shapes to SD-1 and will be
# misclassified as SD-1
key = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
if key in checkpoint and checkpoint[key].shape[0] == 320:
return BaseModelType.StableDiffusion2
key = "lora_unet_output_blocks_5_1_transformer_blocks_1_ff_net_2.lora_up.weight"
if key in checkpoint:
return BaseModelType.StableDiffusionXL
key1 = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
key2 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
key3 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.hada_w1_a"
lora_token_vector_length = (
checkpoint[key1].shape[1]
if key1 in checkpoint
else checkpoint[key2].shape[1]
if key2 in checkpoint
else checkpoint[key3].shape[0]
if key3 in checkpoint
else None
)
if lora_token_vector_length == 768:
return BaseModelType.StableDiffusion1 return BaseModelType.StableDiffusion1
elif lora_token_vector_length == 1024: elif token_vector_length == 1024:
return BaseModelType.StableDiffusion2 return BaseModelType.StableDiffusion2
elif token_vector_length == 2048:
return BaseModelType.StableDiffusionXL
else: else:
raise InvalidModelException(f"Unknown LoRA type") raise InvalidModelException(f"Unknown LoRA type: {self.checkpoint_path}")
class TextualInversionCheckpointProbe(CheckpointProbeBase): class TextualInversionCheckpointProbe(CheckpointProbeBase):

View File

@ -0,0 +1,69 @@
# Copyright (c) 2023 The InvokeAI Development Team
"""Utilities used by the Model Manager"""
def lora_token_vector_length(checkpoint: dict) -> int:
"""
Given a checkpoint in memory, return the lora token vector length
:param checkpoint: The checkpoint
"""
def _get_shape_1(key, tensor, checkpoint):
lora_token_vector_length = None
# check lora/locon
if ".lora_down.weight" in key:
lora_token_vector_length = tensor.shape[1]
# check loha (don't worry about hada_t1/hada_t2 as it used only in 4d shapes)
elif ".hada_w1_b" in key or ".hada_w2_b" in key:
lora_token_vector_length = tensor.shape[1]
# check lokr (don't worry about lokr_t2 as it used only in 4d shapes)
elif ".lokr_" in key:
_lokr_key = key.split(".")[0]
if _lokr_key + ".lokr_w1" in checkpoint:
_lokr_w1 = checkpoint[_lokr_key + ".lokr_w1"]
elif _lokr_key + "lokr_w1_b" in checkpoint:
_lokr_w1 = checkpoint[_lokr_key + ".lokr_w1_b"]
else:
return lora_token_vector_length # unknown format
if _lokr_key + ".lokr_w2" in checkpoint:
_lokr_w2 = checkpoint[_lokr_key + ".lokr_w2"]
elif _lokr_key + "lokr_w2_b" in checkpoint:
_lokr_w2 = checkpoint[_lokr_key + ".lokr_w2_b"]
else:
return lora_token_vector_length # unknown format
lora_token_vector_length = _lokr_w1.shape[1] * _lokr_w2.shape[1]
elif ".diff" in key:
lora_token_vector_length = tensor.shape[1]
return lora_token_vector_length
lora_token_vector_length = None
lora_te1_length = None
lora_te2_length = None
for key, tensor in checkpoint.items():
if key.startswith("lora_unet_") and ("_attn2_to_k." in key or "_attn2_to_v." in key):
lora_token_vector_length = _get_shape_1(key, tensor, checkpoint)
elif key.startswith("lora_te") and "_self_attn_" in key:
tmp_length = _get_shape_1(key, tensor, checkpoint)
if key.startswith("lora_te_"):
lora_token_vector_length = tmp_length
elif key.startswith("lora_te1_"):
lora_te1_length = tmp_length
elif key.startswith("lora_te2_"):
lora_te2_length = tmp_length
if lora_te1_length is not None and lora_te2_length is not None:
lora_token_vector_length = lora_te1_length + lora_te2_length
if lora_token_vector_length is not None:
break
return lora_token_vector_length

View File

@ -0,0 +1,34 @@
#!/usr/bin/env python
"""
Read a checkpoint/safetensors file and write out a template .json file containing
its metadata for use in fast model probing.
"""
import sys
import argparse
import json
from pathlib import Path
from invokeai.backend.model_management.models.base import read_checkpoint_meta
parser = argparse.ArgumentParser(description="Create a .json template from checkpoint/safetensors model")
parser.add_argument("--checkpoint", "--in", type=Path, help="Path to the input checkpoint/safetensors file")
parser.add_argument("--template", "--out", type=Path, help="Path to the output .json file")
opt = parser.parse_args()
ckpt = read_checkpoint_meta(opt.checkpoint)
while "state_dict" in ckpt:
ckpt = ckpt["state_dict"]
tmpl = {}
for key, tensor in ckpt.items():
tmpl[key] = list(tensor.shape)
try:
with open(opt.template, "w") as f:
json.dump(tmpl, f)
print(f"Template written out as {opt.template}")
except Exception as e:
print(f"An exception occurred while writing template: {str(e)}")

View File

@ -0,0 +1,37 @@
#!/usr/bin/env python
"""
Read a checkpoint/safetensors file and compare it to a template .json.
Returns True if their metadata match.
"""
import sys
import argparse
import json
from pathlib import Path
from invokeai.backend.model_management.models.base import read_checkpoint_meta
parser = argparse.ArgumentParser(description="Compare a checkpoint/safetensors file to a JSON metadata template.")
parser.add_argument("--checkpoint", "--in", type=Path, help="Path to the input checkpoint/safetensors file")
parser.add_argument("--template", "--out", type=Path, help="Path to the template .json file to match against")
opt = parser.parse_args()
ckpt = read_checkpoint_meta(opt.checkpoint)
while "state_dict" in ckpt:
ckpt = ckpt["state_dict"]
checkpoint_metadata = {}
for key, tensor in ckpt.items():
checkpoint_metadata[key] = list(tensor.shape)
with open(opt.template, "r") as f:
template = json.load(f)
if checkpoint_metadata == template:
print("True")
sys.exit(0)
else:
print("False")
sys.exit(-1)