diff --git a/invokeai/backend/model_management/util.py b/invokeai/backend/model_management/util.py index 6d70107c93..f4737d9f0b 100644 --- a/invokeai/backend/model_management/util.py +++ b/invokeai/backend/model_management/util.py @@ -9,7 +9,7 @@ def lora_token_vector_length(checkpoint: dict) -> int: :param checkpoint: The checkpoint """ - def _get_shape_1(key, tensor, checkpoint): + def _get_shape_1(key: str, tensor, checkpoint) -> int: lora_token_vector_length = None if "." not in key: @@ -57,6 +57,10 @@ def lora_token_vector_length(checkpoint: dict) -> int: for key, tensor in checkpoint.items(): if key.startswith("lora_unet_") and ("_attn2_to_k." in key or "_attn2_to_v." in key): lora_token_vector_length = _get_shape_1(key, tensor, checkpoint) + elif key.startswith("lora_unet_") and ( + "time_emb_proj.lora_down" in key + ): # recognizes format at https://civitai.com/models/224641 + lora_token_vector_length = _get_shape_1(key, tensor, checkpoint) elif key.startswith("lora_te") and "_self_attn_" in key: tmp_length = _get_shape_1(key, tensor, checkpoint) if key.startswith("lora_te_"): diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index 3ba62fa6ff..64eefb774e 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -400,6 +400,8 @@ class LoRACheckpointProbe(CheckpointProbeBase): return BaseModelType.StableDiffusion1 elif token_vector_length == 1024: return BaseModelType.StableDiffusion2 + elif token_vector_length == 1280: + return BaseModelType.StableDiffusionXL # recognizes format at https://civitai.com/models/224641 elif token_vector_length == 2048: return BaseModelType.StableDiffusionXL else: