[feature] Add probe for SDXL controlnet models (#5382)

* add probe for SDXL controlnet models

* Update invokeai/backend/model_management/model_probe.py

Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>

* Update invokeai/backend/model_manager/probe.py

Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>

---------

Co-authored-by: Lincoln Stein <lstein@gmail.com>
Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>
This commit is contained in:
Lincoln Stein 2024-03-21 10:49:45 -04:00 committed by GitHub
parent a5771f6120
commit d4d0fea078
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -227,8 +227,8 @@ class ModelProbe(object):
elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
return ModelType.LoRA
elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}):
return ModelType.LoRA
elif any(key.startswith(v) for v in {"control_model", "input_blocks"}):
return ModelType.Lora
elif any(key.startswith(v) for v in {"controlnet", "control_model", "input_blocks"}):
return ModelType.ControlNet
elif key in {"emb_params", "string_to_param"}:
return ModelType.TextualInversion
@ -508,15 +508,22 @@ class ControlNetCheckpointProbe(CheckpointProbeBase):
checkpoint = self.checkpoint
for key_name in (
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
"controlnet_mid_block.bias",
"input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
"down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k.weight",
):
if key_name not in checkpoint:
continue
if checkpoint[key_name].shape[-1] == 768:
width = checkpoint[key_name].shape[-1]
if width == 768:
return BaseModelType.StableDiffusion1
elif checkpoint[key_name].shape[-1] == 1024:
elif width == 1024:
return BaseModelType.StableDiffusion2
raise InvalidModelConfigException("{self.model_path}: Unable to determine base type")
elif width == 2048:
return BaseModelType.StableDiffusionXL
elif width == 1280:
return BaseModelType.StableDiffusionXL
raise InvalidModelConfigException(f"{self.model_path}: Unable to determine base type")
class IPAdapterCheckpointProbe(CheckpointProbeBase):