mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
[feature] Add probe for SDXL controlnet models (#5382)
* add probe for SDXL controlnet models * Update invokeai/backend/model_management/model_probe.py Co-authored-by: Ryan Dick <ryanjdick3@gmail.com> * Update invokeai/backend/model_manager/probe.py Co-authored-by: Ryan Dick <ryanjdick3@gmail.com> --------- Co-authored-by: Lincoln Stein <lstein@gmail.com> Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>
This commit is contained in:
parent
a5771f6120
commit
d4d0fea078
@ -227,8 +227,8 @@ class ModelProbe(object):
|
||||
elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
|
||||
return ModelType.LoRA
|
||||
elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}):
|
||||
return ModelType.LoRA
|
||||
elif any(key.startswith(v) for v in {"control_model", "input_blocks"}):
|
||||
return ModelType.Lora
|
||||
elif any(key.startswith(v) for v in {"controlnet", "control_model", "input_blocks"}):
|
||||
return ModelType.ControlNet
|
||||
elif key in {"emb_params", "string_to_param"}:
|
||||
return ModelType.TextualInversion
|
||||
@ -508,15 +508,22 @@ class ControlNetCheckpointProbe(CheckpointProbeBase):
|
||||
checkpoint = self.checkpoint
|
||||
for key_name in (
|
||||
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
|
||||
"controlnet_mid_block.bias",
|
||||
"input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
|
||||
"down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k.weight",
|
||||
):
|
||||
if key_name not in checkpoint:
|
||||
continue
|
||||
if checkpoint[key_name].shape[-1] == 768:
|
||||
width = checkpoint[key_name].shape[-1]
|
||||
if width == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif checkpoint[key_name].shape[-1] == 1024:
|
||||
elif width == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
raise InvalidModelConfigException("{self.model_path}: Unable to determine base type")
|
||||
elif width == 2048:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
elif width == 1280:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
raise InvalidModelConfigException(f"{self.model_path}: Unable to determine base type")
|
||||
|
||||
|
||||
class IPAdapterCheckpointProbe(CheckpointProbeBase):
|
||||
|
Loading…
Reference in New Issue
Block a user