mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
[feature] Add probe for SDXL controlnet models (#5382)
* add probe for SDXL controlnet models * Update invokeai/backend/model_management/model_probe.py Co-authored-by: Ryan Dick <ryanjdick3@gmail.com> * Update invokeai/backend/model_manager/probe.py Co-authored-by: Ryan Dick <ryanjdick3@gmail.com> --------- Co-authored-by: Lincoln Stein <lstein@gmail.com> Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>
This commit is contained in:
parent
a5771f6120
commit
d4d0fea078
@ -227,8 +227,8 @@ class ModelProbe(object):
|
|||||||
elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
|
elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
|
||||||
return ModelType.LoRA
|
return ModelType.LoRA
|
||||||
elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}):
|
elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}):
|
||||||
return ModelType.LoRA
|
return ModelType.Lora
|
||||||
elif any(key.startswith(v) for v in {"control_model", "input_blocks"}):
|
elif any(key.startswith(v) for v in {"controlnet", "control_model", "input_blocks"}):
|
||||||
return ModelType.ControlNet
|
return ModelType.ControlNet
|
||||||
elif key in {"emb_params", "string_to_param"}:
|
elif key in {"emb_params", "string_to_param"}:
|
||||||
return ModelType.TextualInversion
|
return ModelType.TextualInversion
|
||||||
@ -508,15 +508,22 @@ class ControlNetCheckpointProbe(CheckpointProbeBase):
|
|||||||
checkpoint = self.checkpoint
|
checkpoint = self.checkpoint
|
||||||
for key_name in (
|
for key_name in (
|
||||||
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
|
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
|
||||||
|
"controlnet_mid_block.bias",
|
||||||
"input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
|
"input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
|
||||||
|
"down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k.weight",
|
||||||
):
|
):
|
||||||
if key_name not in checkpoint:
|
if key_name not in checkpoint:
|
||||||
continue
|
continue
|
||||||
if checkpoint[key_name].shape[-1] == 768:
|
width = checkpoint[key_name].shape[-1]
|
||||||
|
if width == 768:
|
||||||
return BaseModelType.StableDiffusion1
|
return BaseModelType.StableDiffusion1
|
||||||
elif checkpoint[key_name].shape[-1] == 1024:
|
elif width == 1024:
|
||||||
return BaseModelType.StableDiffusion2
|
return BaseModelType.StableDiffusion2
|
||||||
raise InvalidModelConfigException("{self.model_path}: Unable to determine base type")
|
elif width == 2048:
|
||||||
|
return BaseModelType.StableDiffusionXL
|
||||||
|
elif width == 1280:
|
||||||
|
return BaseModelType.StableDiffusionXL
|
||||||
|
raise InvalidModelConfigException(f"{self.model_path}: Unable to determine base type")
|
||||||
|
|
||||||
|
|
||||||
class IPAdapterCheckpointProbe(CheckpointProbeBase):
|
class IPAdapterCheckpointProbe(CheckpointProbeBase):
|
||||||
|
Loading…
Reference in New Issue
Block a user