mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into fix/pypi-release-frontend-build
This commit is contained in:
commit
311be8f97d
@ -9,7 +9,7 @@ def lora_token_vector_length(checkpoint: dict) -> int:
|
|||||||
:param checkpoint: The checkpoint
|
:param checkpoint: The checkpoint
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _get_shape_1(key, tensor, checkpoint):
|
def _get_shape_1(key: str, tensor, checkpoint) -> int:
|
||||||
lora_token_vector_length = None
|
lora_token_vector_length = None
|
||||||
|
|
||||||
if "." not in key:
|
if "." not in key:
|
||||||
@ -57,6 +57,10 @@ def lora_token_vector_length(checkpoint: dict) -> int:
|
|||||||
for key, tensor in checkpoint.items():
|
for key, tensor in checkpoint.items():
|
||||||
if key.startswith("lora_unet_") and ("_attn2_to_k." in key or "_attn2_to_v." in key):
|
if key.startswith("lora_unet_") and ("_attn2_to_k." in key or "_attn2_to_v." in key):
|
||||||
lora_token_vector_length = _get_shape_1(key, tensor, checkpoint)
|
lora_token_vector_length = _get_shape_1(key, tensor, checkpoint)
|
||||||
|
elif key.startswith("lora_unet_") and (
|
||||||
|
"time_emb_proj.lora_down" in key
|
||||||
|
): # recognizes format at https://civitai.com/models/224641
|
||||||
|
lora_token_vector_length = _get_shape_1(key, tensor, checkpoint)
|
||||||
elif key.startswith("lora_te") and "_self_attn_" in key:
|
elif key.startswith("lora_te") and "_self_attn_" in key:
|
||||||
tmp_length = _get_shape_1(key, tensor, checkpoint)
|
tmp_length = _get_shape_1(key, tensor, checkpoint)
|
||||||
if key.startswith("lora_te_"):
|
if key.startswith("lora_te_"):
|
||||||
|
@ -400,6 +400,8 @@ class LoRACheckpointProbe(CheckpointProbeBase):
|
|||||||
return BaseModelType.StableDiffusion1
|
return BaseModelType.StableDiffusion1
|
||||||
elif token_vector_length == 1024:
|
elif token_vector_length == 1024:
|
||||||
return BaseModelType.StableDiffusion2
|
return BaseModelType.StableDiffusion2
|
||||||
|
elif token_vector_length == 1280:
|
||||||
|
return BaseModelType.StableDiffusionXL # recognizes format at https://civitai.com/models/224641
|
||||||
elif token_vector_length == 2048:
|
elif token_vector_length == 2048:
|
||||||
return BaseModelType.StableDiffusionXL
|
return BaseModelType.StableDiffusionXL
|
||||||
else:
|
else:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user