From 18098cc4b97ebc5e847883ef470b99dccdee78d5 Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Thu, 21 Mar 2024 11:24:27 -0400 Subject: [PATCH] For loras/embeddings treat folders the same as single files --- invokeai/backend/model_manager/probe.py | 26 ++++++++++--------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index e539f5c45a..8ebe7c388f 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -644,18 +644,15 @@ class VaeFolderProbe(FolderProbeBase): return name -class TextualInversionFolderProbe(FolderProbeBase): - def get_format(self) -> ModelFormat: - return ModelFormat.EmbeddingFolder - - def get_base_type(self) -> BaseModelType: - files = os.scandir(self.model_path) +class TextualInversionFolderProbe(TextualInversionCheckpointProbe): + def __init__(self, model_path: Path): + files = os.scandir(model_path) files = [Path(f.path) for f in files if f.is_file() and f.name.endswith((".ckpt", ".pt", ".pth", ".bin", ".safetensors"))] if len(files) != 1: raise InvalidModelConfigException( - f"Unable to determine base type for {self.model_path}: expected exactly one valid model file, found {[f.name for f in files]}." + f"Unable to determine base type for {model_path}: expected exactly one valid model file, found {[f.name for f in files]}." ) - return TextualInversionCheckpointProbe(files.pop()).get_base_type() + super().__init__(model_path) class ONNXFolderProbe(PipelineFolderProbe): @@ -701,19 +698,16 @@ class ControlNetFolderProbe(FolderProbeBase): return base_model -class LoRAFolderProbe(FolderProbeBase): - def get_base_type(self) -> BaseModelType: - model_file = None - files = os.scandir(self.model_path) +class LoRAFolderProbe(LoRACheckpointProbe): + def __init__(self, model_path: Path): + files = os.scandir(model_path) files = [Path(f.path) for f in files if f.is_file() and f.name.endswith((".bin", ".safetensors"))] if len(files) != 1: raise InvalidModelConfigException( - f"Unable to determine base type for {self.model_path}: expected exactly one valid model file, found {[f.name for f in files]}." + f"Unable to determine base type for lora {model_path}: expected exactly one valid model file, found {[f.name for f in files]}." ) model_file = files.pop() - if not model_file: - raise InvalidModelConfigException("Unknown LoRA format encountered") - return LoRACheckpointProbe(model_file).get_base_type() + super().__init__(model_file) class IPAdapterFolderProbe(FolderProbeBase):