Allow Embedding and Lora Folder Models to have Different Names

This commit is contained in:
Brandon Rising 2024-03-21 11:11:21 -04:00 committed by Brandon
parent 3aad6f975b
commit bbecb99eb4

View File

@ -1,4 +1,5 @@
import json import json
import os
import re import re
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Literal, Optional, Union from typing import Any, Dict, Literal, Optional, Union
@ -648,12 +649,13 @@ class TextualInversionFolderProbe(FolderProbeBase):
return ModelFormat.EmbeddingFolder return ModelFormat.EmbeddingFolder
def get_base_type(self) -> BaseModelType: def get_base_type(self) -> BaseModelType:
path = self.model_path / "learned_embeds.bin" files = os.scandir(self.model_path)
if not path.exists(): files = [Path(f.path) for f in files if f.is_file() and f.name.endswith((".ckpt", ".pt", ".pth", ".bin", ".safetensors"))]
if len(files) != 1:
raise InvalidModelConfigException( raise InvalidModelConfigException(
f"{self.model_path.as_posix()} does not contain expected 'learned_embeds.bin' file" f"Unable to determine base type for {self.model_path}: expected exactly one valid model file, found {[f.name for f in files]}."
) )
return TextualInversionCheckpointProbe(path).get_base_type() return TextualInversionCheckpointProbe(files.pop()).get_base_type()
class ONNXFolderProbe(PipelineFolderProbe): class ONNXFolderProbe(PipelineFolderProbe):
@ -702,11 +704,13 @@ class ControlNetFolderProbe(FolderProbeBase):
class LoRAFolderProbe(FolderProbeBase): class LoRAFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType: def get_base_type(self) -> BaseModelType:
model_file = None model_file = None
for suffix in ["safetensors", "bin"]: files = os.scandir(self.model_path)
base_file = self.model_path / f"pytorch_lora_weights.{suffix}" files = [Path(f.path) for f in files if f.is_file() and f.name.endswith((".bin", ".safetensors"))]
if base_file.exists(): if len(files) != 1:
model_file = base_file raise InvalidModelConfigException(
break f"Unable to determine base type for {self.model_path}: expected exactly one valid model file, found {[f.name for f in files]}."
)
model_file = files.pop()
if not model_file: if not model_file:
raise InvalidModelConfigException("Unknown LoRA format encountered") raise InvalidModelConfigException("Unknown LoRA format encountered")
return LoRACheckpointProbe(model_file).get_base_type() return LoRACheckpointProbe(model_file).get_base_type()