From ac3cf48d7fb6d762d8b42e2b34c0d29e2fa48c6f Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Thu, 14 Dec 2023 22:52:50 -0500 Subject: [PATCH 1/2] make probe recognize lora format at https://civitai.com/models/224641 --- invokeai/backend/model_management/util.py | 6 +++++- invokeai/backend/model_manager/probe.py | 2 ++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/invokeai/backend/model_management/util.py b/invokeai/backend/model_management/util.py index 6d70107c93..441467a4c7 100644 --- a/invokeai/backend/model_management/util.py +++ b/invokeai/backend/model_management/util.py @@ -9,7 +9,7 @@ def lora_token_vector_length(checkpoint: dict) -> int: :param checkpoint: The checkpoint """ - def _get_shape_1(key, tensor, checkpoint): + def _get_shape_1(key: str, tensor, checkpoint) -> int: lora_token_vector_length = None if "." not in key: @@ -57,6 +57,10 @@ def lora_token_vector_length(checkpoint: dict) -> int: for key, tensor in checkpoint.items(): if key.startswith("lora_unet_") and ("_attn2_to_k." in key or "_attn2_to_v." in key): lora_token_vector_length = _get_shape_1(key, tensor, checkpoint) + elif key.startswith("lora_unet_") and ( + "time_emb_proj.lora_down" in key + ): # recognizes format at https://civitai.com/models/224641 work + lora_token_vector_length = _get_shape_1(key, tensor, checkpoint) elif key.startswith("lora_te") and "_self_attn_" in key: tmp_length = _get_shape_1(key, tensor, checkpoint) if key.startswith("lora_te_"): diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index 3ba62fa6ff..25120e2e33 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -400,6 +400,8 @@ class LoRACheckpointProbe(CheckpointProbeBase): return BaseModelType.StableDiffusion1 elif token_vector_length == 1024: return BaseModelType.StableDiffusion2 + elif token_vector_length == 1280: + return BaseModelType.StableDiffusionXL # recognizes format at https://civitai.com/models/224641 work elif token_vector_length == 2048: return BaseModelType.StableDiffusionXL else: From 212dbaf9a2b3b3cf6d5bbf7204feeafacb40c56f Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Thu, 14 Dec 2023 23:04:13 -0500 Subject: [PATCH 2/2] fix comment --- invokeai/backend/model_management/util.py | 2 +- invokeai/backend/model_manager/probe.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/invokeai/backend/model_management/util.py b/invokeai/backend/model_management/util.py index 441467a4c7..f4737d9f0b 100644 --- a/invokeai/backend/model_management/util.py +++ b/invokeai/backend/model_management/util.py @@ -59,7 +59,7 @@ def lora_token_vector_length(checkpoint: dict) -> int: lora_token_vector_length = _get_shape_1(key, tensor, checkpoint) elif key.startswith("lora_unet_") and ( "time_emb_proj.lora_down" in key - ): # recognizes format at https://civitai.com/models/224641 work + ): # recognizes format at https://civitai.com/models/224641 lora_token_vector_length = _get_shape_1(key, tensor, checkpoint) elif key.startswith("lora_te") and "_self_attn_" in key: tmp_length = _get_shape_1(key, tensor, checkpoint) diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index 25120e2e33..64eefb774e 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -401,7 +401,7 @@ class LoRACheckpointProbe(CheckpointProbeBase): elif token_vector_length == 1024: return BaseModelType.StableDiffusion2 elif token_vector_length == 1280: - return BaseModelType.StableDiffusionXL # recognizes format at https://civitai.com/models/224641 work + return BaseModelType.StableDiffusionXL # recognizes format at https://civitai.com/models/224641 elif token_vector_length == 2048: return BaseModelType.StableDiffusionXL else: