mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into fix/pypi-release-frontend-build
This commit is contained in:
commit
311be8f97d
@ -9,7 +9,7 @@ def lora_token_vector_length(checkpoint: dict) -> int:
|
||||
:param checkpoint: The checkpoint
|
||||
"""
|
||||
|
||||
def _get_shape_1(key, tensor, checkpoint):
|
||||
def _get_shape_1(key: str, tensor, checkpoint) -> int:
|
||||
lora_token_vector_length = None
|
||||
|
||||
if "." not in key:
|
||||
@ -57,6 +57,10 @@ def lora_token_vector_length(checkpoint: dict) -> int:
|
||||
for key, tensor in checkpoint.items():
|
||||
if key.startswith("lora_unet_") and ("_attn2_to_k." in key or "_attn2_to_v." in key):
|
||||
lora_token_vector_length = _get_shape_1(key, tensor, checkpoint)
|
||||
elif key.startswith("lora_unet_") and (
|
||||
"time_emb_proj.lora_down" in key
|
||||
): # recognizes format at https://civitai.com/models/224641
|
||||
lora_token_vector_length = _get_shape_1(key, tensor, checkpoint)
|
||||
elif key.startswith("lora_te") and "_self_attn_" in key:
|
||||
tmp_length = _get_shape_1(key, tensor, checkpoint)
|
||||
if key.startswith("lora_te_"):
|
||||
|
@ -400,6 +400,8 @@ class LoRACheckpointProbe(CheckpointProbeBase):
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif token_vector_length == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
elif token_vector_length == 1280:
|
||||
return BaseModelType.StableDiffusionXL # recognizes format at https://civitai.com/models/224641
|
||||
elif token_vector_length == 2048:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user