Allow Embedding and Lora Folder Models to have Different Names

This commit is contained in:
Brandon Rising 2024-03-21 11:11:21 -04:00 committed by Brandon
parent 3aad6f975b
commit bbecb99eb4

View File

@ -1,4 +1,5 @@
import json
import os
import re
from pathlib import Path
from typing import Any, Dict, Literal, Optional, Union
@ -648,12 +649,13 @@ class TextualInversionFolderProbe(FolderProbeBase):
return ModelFormat.EmbeddingFolder
def get_base_type(self) -> BaseModelType:
path = self.model_path / "learned_embeds.bin"
if not path.exists():
files = os.scandir(self.model_path)
files = [Path(f.path) for f in files if f.is_file() and f.name.endswith((".ckpt", ".pt", ".pth", ".bin", ".safetensors"))]
if len(files) != 1:
raise InvalidModelConfigException(
f"{self.model_path.as_posix()} does not contain expected 'learned_embeds.bin' file"
f"Unable to determine base type for {self.model_path}: expected exactly one valid model file, found {[f.name for f in files]}."
)
return TextualInversionCheckpointProbe(path).get_base_type()
return TextualInversionCheckpointProbe(files.pop()).get_base_type()
class ONNXFolderProbe(PipelineFolderProbe):
@ -702,11 +704,13 @@ class ControlNetFolderProbe(FolderProbeBase):
class LoRAFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
model_file = None
for suffix in ["safetensors", "bin"]:
base_file = self.model_path / f"pytorch_lora_weights.{suffix}"
if base_file.exists():
model_file = base_file
break
files = os.scandir(self.model_path)
files = [Path(f.path) for f in files if f.is_file() and f.name.endswith((".bin", ".safetensors"))]
if len(files) != 1:
raise InvalidModelConfigException(
f"Unable to determine base type for {self.model_path}: expected exactly one valid model file, found {[f.name for f in files]}."
)
model_file = files.pop()
if not model_file:
raise InvalidModelConfigException("Unknown LoRA format encountered")
return LoRACheckpointProbe(model_file).get_base_type()