mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add stalker's complicated but effective code for finding token vector length in LoRAs
This commit is contained in:
parent
f3d3316558
commit
22f7cf0638
@ -472,7 +472,7 @@ class ModelManager(object):
|
||||
if submodel_type is not None and hasattr(model_config, submodel_type):
|
||||
override_path = getattr(model_config, submodel_type)
|
||||
if override_path:
|
||||
model_path = self.resolve_path(override_path)
|
||||
model_path = self.resolve_model_path(override_path)
|
||||
model_type = submodel_type
|
||||
submodel_type = None
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
|
@ -17,6 +17,7 @@ from .models import (
|
||||
SilenceWarnings,
|
||||
InvalidModelException,
|
||||
)
|
||||
from .util import lora_token_vector_length
|
||||
from .models.base import read_checkpoint_meta
|
||||
|
||||
|
||||
@ -315,38 +316,14 @@ class LoRACheckpointProbe(CheckpointProbeBase):
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
checkpoint = self.checkpoint
|
||||
token_vector_length = lora_token_vector_length(checkpoint)
|
||||
|
||||
# SD-2 models are very hard to probe. These probes are brittle and likely to fail in the future
|
||||
# There are also some "SD-2 LoRAs" that have identical keys and shapes to SD-1 and will be
|
||||
# misclassified as SD-1
|
||||
key = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
|
||||
if key in checkpoint and checkpoint[key].shape[0] == 320:
|
||||
if token_vector_length == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif token_vector_length == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
|
||||
key = "lora_unet_output_blocks_5_1_transformer_blocks_1_ff_net_2.lora_up.weight"
|
||||
if key in checkpoint:
|
||||
elif token_vector_length == 2048: # variant w/o the text encoder!
|
||||
return BaseModelType.StableDiffusionXL
|
||||
|
||||
key1 = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
|
||||
key2 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
|
||||
key3 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.hada_w1_a"
|
||||
|
||||
lora_token_vector_length = (
|
||||
checkpoint[key1].shape[1]
|
||||
if key1 in checkpoint
|
||||
else checkpoint[key2].shape[1]
|
||||
if key2 in checkpoint
|
||||
else checkpoint[key3].shape[0]
|
||||
if key3 in checkpoint
|
||||
else None
|
||||
)
|
||||
|
||||
if lora_token_vector_length == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif lora_token_vector_length == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
elif lora_token_vector_length is None: # variant w/o the text encoder!
|
||||
return BaseModelType.StableDiffusion1
|
||||
else:
|
||||
raise InvalidModelException(f"Unknown LoRA type")
|
||||
|
||||
|
118
invokeai/backend/model_management/util.py
Normal file
118
invokeai/backend/model_management/util.py
Normal file
@ -0,0 +1,118 @@
|
||||
# Copyright (c) 2023 The InvokeAI Development Team
|
||||
"""Utilities used by the Model Manager"""
|
||||
|
||||
|
||||
def lora_token_vector_length(checkpoint: dict) -> int:
|
||||
"""
|
||||
Given a checkpoint in memory, return the lora token vector length
|
||||
|
||||
:param checkpoint: The checkpoint
|
||||
"""
|
||||
|
||||
def _handle_unet_key(key, tensor, checkpoint):
|
||||
lora_token_vector_length = None
|
||||
if "_attn2_to_k." not in key and "_attn2_to_v." not in key:
|
||||
return lora_token_vector_length
|
||||
|
||||
# check lora/locon
|
||||
if ".lora_up.weight" in key:
|
||||
lora_token_vector_length = tensor.shape[0]
|
||||
elif ".lora_down.weight" in key:
|
||||
lora_token_vector_length = tensor.shape[1]
|
||||
|
||||
# check loha (don't worry about hada_t1/hada_t2 as it used only in 4d shapes)
|
||||
elif ".hada_w1_b" in key or ".hada_w2_b" in key:
|
||||
lora_token_vector_length = tensor.shape[1]
|
||||
|
||||
# check lokr (don't worry about lokr_t2 as it used only in 4d shapes)
|
||||
elif ".lokr_" in key:
|
||||
_lokr_key = key.split(".")[0]
|
||||
|
||||
if _lokr_key + ".lokr_w1" in checkpoint:
|
||||
_lokr_w1 = checkpoint[_lokr_key + ".lokr_w1"]
|
||||
elif _lokr_key + "lokr_w1_b" in checkpoint:
|
||||
_lokr_w1 = checkpoint[_lokr_key + ".lokr_w1_b"]
|
||||
else:
|
||||
return lora_token_vector_length # unknown format
|
||||
|
||||
if _lokr_key + ".lokr_w2" in checkpoint:
|
||||
_lokr_w2 = checkpoint[_lokr_key + ".lokr_w2"]
|
||||
elif _lokr_key + "lokr_w2_b" in checkpoint:
|
||||
_lokr_w2 = checkpoint[_lokr_key + ".lokr_w2_b"]
|
||||
else:
|
||||
return lora_token_vector_length # unknown format
|
||||
|
||||
lora_token_vector_length = _lokr_w1.shape[1] * _lokr_w2.shape[1]
|
||||
|
||||
elif ".diff" in key:
|
||||
lora_token_vector_length = tensor.shape[1]
|
||||
|
||||
return lora_token_vector_length
|
||||
|
||||
def _handle_te_key(key, tensor, checkpoint):
|
||||
lora_token_vector_length = None
|
||||
if "text_model_encoder_layers_" not in key:
|
||||
return lora_token_vector_length
|
||||
|
||||
# skip detect by mlp
|
||||
if "_self_attn_" not in key:
|
||||
return lora_token_vector_length
|
||||
|
||||
# check lora/locon
|
||||
if ".lora_up.weight" in key:
|
||||
lora_token_vector_length = tensor.shape[0]
|
||||
elif ".lora_down.weight" in key:
|
||||
lora_token_vector_length = tensor.shape[1]
|
||||
|
||||
# check loha (don't worry about hada_t1/hada_t2 as it used only in 4d shapes)
|
||||
elif ".hada_w1_a" in key or ".hada_w2_a" in key:
|
||||
lora_token_vector_length = tensor.shape[0]
|
||||
elif ".hada_w1_b" in key or ".hada_w2_b" in key:
|
||||
lora_token_vector_length = tensor.shape[1]
|
||||
|
||||
# check lokr (don't worry about lokr_t2 as it used only in 4d shapes)
|
||||
elif ".lokr_" in key:
|
||||
_lokr_key = key.split(".")[0]
|
||||
|
||||
if _lokr_key + ".lokr_w1" in checkpoint:
|
||||
_lokr_w1 = checkpoint[_lokr_key + ".lokr_w1"]
|
||||
elif _lokr_key + "lokr_w1_b" in checkpoint:
|
||||
_lokr_w1 = checkpoint[_lokr_key + ".lokr_w1_b"]
|
||||
else:
|
||||
return lora_token_vector_length # unknown format
|
||||
|
||||
if _lokr_key + ".lokr_w2" in checkpoint:
|
||||
_lokr_w2 = checkpoint[_lokr_key + ".lokr_w2"]
|
||||
elif _lokr_key + "lokr_w2_b" in checkpoint:
|
||||
_lokr_w2 = checkpoint[_lokr_key + ".lokr_w2_b"]
|
||||
else:
|
||||
return lora_token_vector_length # unknown format
|
||||
|
||||
lora_token_vector_length = _lokr_w1.shape[1] * _lokr_w2.shape[1]
|
||||
|
||||
elif ".diff" in key:
|
||||
lora_token_vector_length = tensor.shape[1]
|
||||
|
||||
return lora_token_vector_length
|
||||
|
||||
lora_token_vector_length = None
|
||||
lora_te1_length = None
|
||||
lora_te2_length = None
|
||||
for key, tensor in checkpoint.items():
|
||||
if key.startswith("lora_unet_"):
|
||||
lora_token_vector_length = _handle_unet_key(key, tensor, checkpoint)
|
||||
elif key.startswith("lora_te_"):
|
||||
lora_token_vector_length = _handle_te_key(key, tensor, checkpoint)
|
||||
|
||||
elif key.startswith("lora_te1_"):
|
||||
lora_te1_length = _handle_te_key(key, tensor, checkpoint)
|
||||
elif key.startswith("lora_te2_"):
|
||||
lora_te2_length = _handle_te_key(key, tensor, checkpoint)
|
||||
|
||||
if lora_te1_length is not None and lora_te2_length is not None:
|
||||
lora_token_vector_length = lora_te1_length + lora_te2_length
|
||||
|
||||
if lora_token_vector_length is not None:
|
||||
break
|
||||
|
||||
return lora_token_vector_length
|
Loading…
Reference in New Issue
Block a user