diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index d06ab5c023..e539f5c45a 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -1,4 +1,5 @@ import json +import os import re from pathlib import Path from typing import Any, Dict, Literal, Optional, Union @@ -648,12 +649,13 @@ class TextualInversionFolderProbe(FolderProbeBase): return ModelFormat.EmbeddingFolder def get_base_type(self) -> BaseModelType: - path = self.model_path / "learned_embeds.bin" - if not path.exists(): + files = os.scandir(self.model_path) + files = [Path(f.path) for f in files if f.is_file() and f.name.endswith((".ckpt", ".pt", ".pth", ".bin", ".safetensors"))] + if len(files) != 1: raise InvalidModelConfigException( - f"{self.model_path.as_posix()} does not contain expected 'learned_embeds.bin' file" + f"Unable to determine base type for {self.model_path}: expected exactly one valid model file, found {[f.name for f in files]}." ) - return TextualInversionCheckpointProbe(path).get_base_type() + return TextualInversionCheckpointProbe(files.pop()).get_base_type() class ONNXFolderProbe(PipelineFolderProbe): @@ -702,11 +704,13 @@ class ControlNetFolderProbe(FolderProbeBase): class LoRAFolderProbe(FolderProbeBase): def get_base_type(self) -> BaseModelType: model_file = None - for suffix in ["safetensors", "bin"]: - base_file = self.model_path / f"pytorch_lora_weights.{suffix}" - if base_file.exists(): - model_file = base_file - break + files = os.scandir(self.model_path) + files = [Path(f.path) for f in files if f.is_file() and f.name.endswith((".bin", ".safetensors"))] + if len(files) != 1: + raise InvalidModelConfigException( + f"Unable to determine base type for {self.model_path}: expected exactly one valid model file, found {[f.name for f in files]}." + ) + model_file = files.pop() if not model_file: raise InvalidModelConfigException("Unknown LoRA format encountered") return LoRACheckpointProbe(model_file).get_base_type()