mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
[feat] Make model prober recognize yet another LoRA format (#5296)
## What type of PR is this? (check all applicable) - [ ] Refactor - [X] Feature - [ ] Bug Fix - [ ] Optimization - [ ] Documentation Update - [ ] Community Node Submission ## Have you discussed this change with the InvokeAI team? - [X] Yes - [ ] No, because: ## Have you updated all relevant documentation? - [X] Yes - [ ] No ## Description This adds a probe for the SDXL LoRA format found in the wild at https://civitai.com/models/224641. ## Related Tickets & Documents <!-- For pull requests that relate or close an issue, please include them below. For example having the text: "closes #1234" would connect the current pull request to issue 1234. And when we merge the pull request, Github will automatically close the issue. --> See discord message at: https://discord.com/channels/1020123559063990373/1149510134058471514/1184982133912113182 ## QA Instructions, Screenshots, Recordings Try installing the SDXL LoRA at the URL given above. ## Merge Plan This can be merged when approved. ## Added/updated tests? - [ ] Yes - [X] No : we do not yet have a comprehensive suite of models to test probing on. ## [optional] Are there any post deployment tasks we need to perform?
This commit is contained in:
commit
fc150acde5
@ -9,7 +9,7 @@ def lora_token_vector_length(checkpoint: dict) -> int:
|
|||||||
:param checkpoint: The checkpoint
|
:param checkpoint: The checkpoint
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _get_shape_1(key, tensor, checkpoint):
|
def _get_shape_1(key: str, tensor, checkpoint) -> int:
|
||||||
lora_token_vector_length = None
|
lora_token_vector_length = None
|
||||||
|
|
||||||
if "." not in key:
|
if "." not in key:
|
||||||
@ -57,6 +57,10 @@ def lora_token_vector_length(checkpoint: dict) -> int:
|
|||||||
for key, tensor in checkpoint.items():
|
for key, tensor in checkpoint.items():
|
||||||
if key.startswith("lora_unet_") and ("_attn2_to_k." in key or "_attn2_to_v." in key):
|
if key.startswith("lora_unet_") and ("_attn2_to_k." in key or "_attn2_to_v." in key):
|
||||||
lora_token_vector_length = _get_shape_1(key, tensor, checkpoint)
|
lora_token_vector_length = _get_shape_1(key, tensor, checkpoint)
|
||||||
|
elif key.startswith("lora_unet_") and (
|
||||||
|
"time_emb_proj.lora_down" in key
|
||||||
|
): # recognizes format at https://civitai.com/models/224641
|
||||||
|
lora_token_vector_length = _get_shape_1(key, tensor, checkpoint)
|
||||||
elif key.startswith("lora_te") and "_self_attn_" in key:
|
elif key.startswith("lora_te") and "_self_attn_" in key:
|
||||||
tmp_length = _get_shape_1(key, tensor, checkpoint)
|
tmp_length = _get_shape_1(key, tensor, checkpoint)
|
||||||
if key.startswith("lora_te_"):
|
if key.startswith("lora_te_"):
|
||||||
|
@ -400,6 +400,8 @@ class LoRACheckpointProbe(CheckpointProbeBase):
|
|||||||
return BaseModelType.StableDiffusion1
|
return BaseModelType.StableDiffusion1
|
||||||
elif token_vector_length == 1024:
|
elif token_vector_length == 1024:
|
||||||
return BaseModelType.StableDiffusion2
|
return BaseModelType.StableDiffusion2
|
||||||
|
elif token_vector_length == 1280:
|
||||||
|
return BaseModelType.StableDiffusionXL # recognizes format at https://civitai.com/models/224641
|
||||||
elif token_vector_length == 2048:
|
elif token_vector_length == 2048:
|
||||||
return BaseModelType.StableDiffusionXL
|
return BaseModelType.StableDiffusionXL
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user