mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
For loras/embeddings treat folders the same as single files
This commit is contained in:
parent
bbecb99eb4
commit
18098cc4b9
@ -644,18 +644,15 @@ class VaeFolderProbe(FolderProbeBase):
|
|||||||
return name
|
return name
|
||||||
|
|
||||||
|
|
||||||
class TextualInversionFolderProbe(FolderProbeBase):
|
class TextualInversionFolderProbe(TextualInversionCheckpointProbe):
|
||||||
def get_format(self) -> ModelFormat:
|
def __init__(self, model_path: Path):
|
||||||
return ModelFormat.EmbeddingFolder
|
files = os.scandir(model_path)
|
||||||
|
|
||||||
def get_base_type(self) -> BaseModelType:
|
|
||||||
files = os.scandir(self.model_path)
|
|
||||||
files = [Path(f.path) for f in files if f.is_file() and f.name.endswith((".ckpt", ".pt", ".pth", ".bin", ".safetensors"))]
|
files = [Path(f.path) for f in files if f.is_file() and f.name.endswith((".ckpt", ".pt", ".pth", ".bin", ".safetensors"))]
|
||||||
if len(files) != 1:
|
if len(files) != 1:
|
||||||
raise InvalidModelConfigException(
|
raise InvalidModelConfigException(
|
||||||
f"Unable to determine base type for {self.model_path}: expected exactly one valid model file, found {[f.name for f in files]}."
|
f"Unable to determine base type for {model_path}: expected exactly one valid model file, found {[f.name for f in files]}."
|
||||||
)
|
)
|
||||||
return TextualInversionCheckpointProbe(files.pop()).get_base_type()
|
super().__init__(model_path)
|
||||||
|
|
||||||
|
|
||||||
class ONNXFolderProbe(PipelineFolderProbe):
|
class ONNXFolderProbe(PipelineFolderProbe):
|
||||||
@ -701,19 +698,16 @@ class ControlNetFolderProbe(FolderProbeBase):
|
|||||||
return base_model
|
return base_model
|
||||||
|
|
||||||
|
|
||||||
class LoRAFolderProbe(FolderProbeBase):
|
class LoRAFolderProbe(LoRACheckpointProbe):
|
||||||
def get_base_type(self) -> BaseModelType:
|
def __init__(self, model_path: Path):
|
||||||
model_file = None
|
files = os.scandir(model_path)
|
||||||
files = os.scandir(self.model_path)
|
|
||||||
files = [Path(f.path) for f in files if f.is_file() and f.name.endswith((".bin", ".safetensors"))]
|
files = [Path(f.path) for f in files if f.is_file() and f.name.endswith((".bin", ".safetensors"))]
|
||||||
if len(files) != 1:
|
if len(files) != 1:
|
||||||
raise InvalidModelConfigException(
|
raise InvalidModelConfigException(
|
||||||
f"Unable to determine base type for {self.model_path}: expected exactly one valid model file, found {[f.name for f in files]}."
|
f"Unable to determine base type for lora {model_path}: expected exactly one valid model file, found {[f.name for f in files]}."
|
||||||
)
|
)
|
||||||
model_file = files.pop()
|
model_file = files.pop()
|
||||||
if not model_file:
|
super().__init__(model_file)
|
||||||
raise InvalidModelConfigException("Unknown LoRA format encountered")
|
|
||||||
return LoRACheckpointProbe(model_file).get_base_type()
|
|
||||||
|
|
||||||
|
|
||||||
class IPAdapterFolderProbe(FolderProbeBase):
|
class IPAdapterFolderProbe(FolderProbeBase):
|
||||||
|
Loading…
Reference in New Issue
Block a user