[feat] Make model prober recognize yet another LoRA format (#5296)

## What type of PR is this? (check all applicable)

- [ ] Refactor
- [X] Feature
- [ ] Bug Fix
- [ ] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission


## Have you discussed this change with the InvokeAI team?
- [X] Yes
- [ ] No, because:

      
## Have you updated all relevant documentation?
- [X] Yes
- [ ] No


## Description

This adds a probe for the SDXL LoRA format found in the wild at
https://civitai.com/models/224641.

## Related Tickets & Documents

<!--
For pull requests that relate or close an issue, please include them
below. 

For example having the text: "closes #1234" would connect the current
pull
request to issue 1234.  And when we merge the pull request, Github will
automatically close the issue.
-->

See discord message at:
https://discord.com/channels/1020123559063990373/1149510134058471514/1184982133912113182

## QA Instructions, Screenshots, Recordings

Try installing the SDXL LoRA at the URL given above.
## Merge Plan

This can be merged when approved.
## Added/updated tests?

- [ ] Yes
- [X] No : we do not yet have a comprehensive suite of models to test
probing on.

## [optional] Are there any post deployment tasks we need to perform?
This commit is contained in:
Lincoln Stein 2023-12-15 09:49:51 -05:00 committed by GitHub
commit fc150acde5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 1 deletions

View File

@ -9,7 +9,7 @@ def lora_token_vector_length(checkpoint: dict) -> int:
:param checkpoint: The checkpoint
"""
def _get_shape_1(key, tensor, checkpoint):
def _get_shape_1(key: str, tensor, checkpoint) -> int:
lora_token_vector_length = None
if "." not in key:
@ -57,6 +57,10 @@ def lora_token_vector_length(checkpoint: dict) -> int:
for key, tensor in checkpoint.items():
if key.startswith("lora_unet_") and ("_attn2_to_k." in key or "_attn2_to_v." in key):
lora_token_vector_length = _get_shape_1(key, tensor, checkpoint)
elif key.startswith("lora_unet_") and (
"time_emb_proj.lora_down" in key
): # recognizes format at https://civitai.com/models/224641
lora_token_vector_length = _get_shape_1(key, tensor, checkpoint)
elif key.startswith("lora_te") and "_self_attn_" in key:
tmp_length = _get_shape_1(key, tensor, checkpoint)
if key.startswith("lora_te_"):

View File

@ -400,6 +400,8 @@ class LoRACheckpointProbe(CheckpointProbeBase):
return BaseModelType.StableDiffusion1
elif token_vector_length == 1024:
return BaseModelType.StableDiffusion2
elif token_vector_length == 1280:
return BaseModelType.StableDiffusionXL # recognizes format at https://civitai.com/models/224641
elif token_vector_length == 2048:
return BaseModelType.StableDiffusionXL
else: