use Stalker's simplified LoRA vector-length detection code

This commit is contained in:
Lincoln Stein 2023-08-09 09:21:29 -04:00
parent 7d4ace962a
commit 2f68a1a76c

View File

@ -9,15 +9,11 @@ def lora_token_vector_length(checkpoint: dict) -> int:
:param checkpoint: The checkpoint
"""
def _handle_unet_key(key, tensor, checkpoint):
def _get_shape_1(key, tensor, checkpoint):
lora_token_vector_length = None
if "_attn2_to_k." not in key and "_attn2_to_v." not in key:
return lora_token_vector_length
# check lora/locon
if ".lora_up.weight" in key:
lora_token_vector_length = tensor.shape[0]
elif ".lora_down.weight" in key:
if ".lora_down.weight" in key:
lora_token_vector_length = tensor.shape[1]
# check loha (don't worry about hada_t1/hada_t2 as it used only in 4d shapes)
@ -49,65 +45,20 @@ def lora_token_vector_length(checkpoint: dict) -> int:
return lora_token_vector_length
def _handle_te_key(key, tensor, checkpoint):
lora_token_vector_length = None
if "text_model_encoder_layers_" not in key:
return lora_token_vector_length
# skip detect by mlp
if "_self_attn_" not in key:
return lora_token_vector_length
# check lora/locon
if ".lora_up.weight" in key:
lora_token_vector_length = tensor.shape[0]
elif ".lora_down.weight" in key:
lora_token_vector_length = tensor.shape[1]
# check loha (don't worry about hada_t1/hada_t2 as it used only in 4d shapes)
elif ".hada_w1_a" in key or ".hada_w2_a" in key:
lora_token_vector_length = tensor.shape[0]
elif ".hada_w1_b" in key or ".hada_w2_b" in key:
lora_token_vector_length = tensor.shape[1]
# check lokr (don't worry about lokr_t2 as it used only in 4d shapes)
elif ".lokr_" in key:
_lokr_key = key.split(".")[0]
if _lokr_key + ".lokr_w1" in checkpoint:
_lokr_w1 = checkpoint[_lokr_key + ".lokr_w1"]
elif _lokr_key + "lokr_w1_b" in checkpoint:
_lokr_w1 = checkpoint[_lokr_key + ".lokr_w1_b"]
else:
return lora_token_vector_length # unknown format
if _lokr_key + ".lokr_w2" in checkpoint:
_lokr_w2 = checkpoint[_lokr_key + ".lokr_w2"]
elif _lokr_key + "lokr_w2_b" in checkpoint:
_lokr_w2 = checkpoint[_lokr_key + ".lokr_w2_b"]
else:
return lora_token_vector_length # unknown format
lora_token_vector_length = _lokr_w1.shape[1] * _lokr_w2.shape[1]
elif ".diff" in key:
lora_token_vector_length = tensor.shape[1]
return lora_token_vector_length
lora_token_vector_length = None
lora_te1_length = None
lora_te2_length = None
for key, tensor in checkpoint.items():
if key.startswith("lora_unet_"):
lora_token_vector_length = _handle_unet_key(key, tensor, checkpoint)
elif key.startswith("lora_te_"):
lora_token_vector_length = _handle_te_key(key, tensor, checkpoint)
elif key.startswith("lora_te1_"):
lora_te1_length = _handle_te_key(key, tensor, checkpoint)
elif key.startswith("lora_te2_"):
lora_te2_length = _handle_te_key(key, tensor, checkpoint)
if key.startswith("lora_unet_") and ("_attn2_to_k." in key or "_attn2_to_v." in key):
lora_token_vector_length = _get_shape_1(key, tensor, checkpoint)
elif key.startswith("lora_te") and "_self_attn_" in key:
tmp_length = _get_shape_1(key, tensor, checkpoint)
if key.startswith("lora_te_"):
lora_token_vector_length = tmp_length
elif key.startswith("lora_te1_"):
lora_te1_length = tmp_length
elif key.startswith("lora_te2_"):
lora_te2_length = tmp_length
if lora_te1_length is not None and lora_te2_length is not None:
lora_token_vector_length = lora_te1_length + lora_te2_length