Controlnet model detection

This commit is contained in:
Sergey Borisov 2023-07-08 14:26:25 +03:00 committed by Kent Keirsey
parent a328986b43
commit 67c8cf4bc2

View File

@ -13,6 +13,7 @@ from .base import (
calc_model_size_by_fs,
calc_model_size_by_data,
classproperty,
InvalidModelException,
)
class ControlNetModelFormat(str, Enum):
@ -73,11 +74,19 @@ class ControlNetModel(ModelBase):
@classmethod
def detect_format(cls, path: str):
if not os.path.exists(path):
raise ModelNotFoundException()
if os.path.isdir(path):
if os.path.exists(os.path.join(path, "config.json")):
return ControlNetModelFormat.Diffusers
else:
if os.path.isfile(path):
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "pth"]]):
return ControlNetModelFormat.Checkpoint
raise InvalidModelException(f"Not a valid model: {path}")
@classmethod
def convert_if_required(
cls,