diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index 2f18f1a8a6..ce68b1e902 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -221,17 +221,17 @@ class ModelProbe(object): ckpt = ckpt.get("state_dict", ckpt) for key in [str(k) for k in ckpt.keys()]: - if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}): + if key.startswith(("cond_stage_model.", "first_stage_model.", "model.diffusion_model.")): return ModelType.Main - elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}): + elif key.startswith(("encoder.conv_in", "decoder.conv_in")): return ModelType.VAE - elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}): + elif key.startswith(("lora_te_", "lora_unet_")): return ModelType.LoRA - elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}): + elif key.endswith(("to_k_lora.up.weight", "to_q_lora.down.weight")): return ModelType.LoRA - elif any(key.startswith(v) for v in {"controlnet", "control_model", "input_blocks"}): + elif key.startswith(("controlnet", "control_model", "input_blocks")): return ModelType.ControlNet - elif any(key.startswith(v) for v in {"image_proj.", "ip_adapter."}): + elif key.startswith(("image_proj.", "ip_adapter.")): return ModelType.IPAdapter elif key in {"emb_params", "string_to_param"}: return ModelType.TextualInversion