From 46dc751139581dcb9fcf1adde2313672267d07c5 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Tue, 20 Jun 2023 03:30:09 +0300 Subject: [PATCH] Update model format field to use enums --- .../backend/model_management/models/base.py | 34 ++++++++----------- .../model_management/models/controlnet.py | 13 ++++--- .../backend/model_management/models/lora.py | 13 ++++--- .../models/stable_diffusion.py | 31 ++++++++++------- .../backend/model_management/models/vae.py | 13 ++++--- 5 files changed, 60 insertions(+), 44 deletions(-) diff --git a/invokeai/backend/model_management/models/base.py b/invokeai/backend/model_management/models/base.py index 06cc3db50f..ef354ecc07 100644 --- a/invokeai/backend/model_management/models/base.py +++ b/invokeai/backend/model_management/models/base.py @@ -125,30 +125,24 @@ class ModelBase(metaclass=ABCMeta): continue fields = inspect.get_annotations(value) - if "model_format" not in fields: - raise Exception("Invalid config definition - model_format field not found") + try: + field = fields["model_format"] + except: + raise Exception(f"Invalid config definition - format field not found({cls.__qualname__})") - format_type = typing.get_origin(fields["model_format"]) - if format_type not in {None, Literal, Union}: - raise Exception(f"Invalid config definition - unknown format type: {fields['model_format']}") + if isinstance(field, type) and issubclass(field, str) and issubclass(field, Enum): + for model_format in field: + configs[model_format.value] = value - if format_type is Union and not all(typing.get_origin(v) in {None, Literal} for v in fields["model_format"].__args__): - raise Exception(f"Invalid config definition - unknown format type: {fields['model_format']}") + elif typing.get_origin(field) is Literal and all(isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__): + for model_format in field.__args__: + configs[model_format.value] = value + + elif field is None: + configs[None] = value - if format_type == Union: - f_fields = fields["model_format"].__args__ else: - f_fields = (fields["model_format"],) - - - for field in f_fields: - if field is None: - format_name = None - else: - format_name = field.__args__[0] - - configs[format_name] = value # TODO: error when override(multiple)? - + raise Exception(f"Unsupported format definition in {cls.__qualname__}") cls.__configs = configs return cls.__configs diff --git a/invokeai/backend/model_management/models/controlnet.py b/invokeai/backend/model_management/models/controlnet.py index e452621eba..9563f87afd 100644 --- a/invokeai/backend/model_management/models/controlnet.py +++ b/invokeai/backend/model_management/models/controlnet.py @@ -1,5 +1,6 @@ import os import torch +from enum import Enum from pathlib import Path from typing import Optional, Union, Literal from .base import ( @@ -14,12 +15,16 @@ from .base import ( classproperty, ) +class ControlNetModelFormat(str, Enum): + Checkpoint = "checkpoint" + Diffusers = "diffusers" + class ControlNetModel(ModelBase): #model_class: Type #model_size: int class Config(ModelConfigBase): - model_format: Union[Literal["checkpoint"], Literal["diffusers"]] + model_format: ControlNetModelFormat def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): assert model_type == ModelType.ControlNet @@ -69,9 +74,9 @@ class ControlNetModel(ModelBase): @classmethod def detect_format(cls, path: str): if os.path.isdir(path): - return "diffusers" + return ControlNetModelFormat.Diffusers else: - return "checkpoint" + return ControlNetModelFormat.Checkpoint @classmethod def convert_if_required( @@ -81,7 +86,7 @@ class ControlNetModel(ModelBase): config: ModelConfigBase, # empty config or config of parent model base_model: BaseModelType, ) -> str: - if cls.detect_format(model_path) != "diffusers": + if cls.detect_format(model_path) != ControlNetModelFormat.Diffusers: raise NotImplementedError("Checkpoint controlnet models currently unsupported") else: return model_path diff --git a/invokeai/backend/model_management/models/lora.py b/invokeai/backend/model_management/models/lora.py index 2e4309a161..59feacde06 100644 --- a/invokeai/backend/model_management/models/lora.py +++ b/invokeai/backend/model_management/models/lora.py @@ -1,5 +1,6 @@ import os import torch +from enum import Enum from typing import Optional, Union, Literal from .base import ( ModelBase, @@ -12,11 +13,15 @@ from .base import ( # TODO: naming from ..lora import LoRAModel as LoRAModelRaw +class LoRAModelFormat(str, Enum): + LyCORIS = "lycoris" + Diffusers = "diffusers" + class LoRAModel(ModelBase): #model_size: int class Config(ModelConfigBase): - model_format: Union[Literal["lycoris"], Literal["diffusers"]] + model_format: LoRAModelFormat # TODO: def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): assert model_type == ModelType.Lora @@ -52,9 +57,9 @@ class LoRAModel(ModelBase): @classmethod def detect_format(cls, path: str): if os.path.isdir(path): - return "diffusers" + return LoRAModelFormat.Diffusers else: - return "lycoris" + return LoRAModelFormat.LyCORIS @classmethod def convert_if_required( @@ -64,7 +69,7 @@ class LoRAModel(ModelBase): config: ModelConfigBase, base_model: BaseModelType, ) -> str: - if cls.detect_format(model_path) == "diffusers": + if cls.detect_format(model_path) == LoRAModelFormat.Diffusers: # TODO: add diffusers lora when it stabilizes a bit raise NotImplementedError("Diffusers lora not supported") else: diff --git a/invokeai/backend/model_management/models/stable_diffusion.py b/invokeai/backend/model_management/models/stable_diffusion.py index 50089b3338..f169326571 100644 --- a/invokeai/backend/model_management/models/stable_diffusion.py +++ b/invokeai/backend/model_management/models/stable_diffusion.py @@ -1,5 +1,6 @@ import os import json +from enum import Enum from pydantic import Field from pathlib import Path from typing import Literal, Optional, Union @@ -19,16 +20,19 @@ from .base import ( from invokeai.app.services.config import InvokeAIAppConfig from omegaconf import OmegaConf +class StableDiffusion1ModelFormat(str, Enum): + Checkpoint = "checkpoint" + Diffusers = "diffusers" class StableDiffusion1Model(DiffusersModel): class DiffusersConfig(ModelConfigBase): - model_format: Literal["diffusers"] + model_format: Literal[StableDiffusion1ModelFormat.Diffusers] vae: Optional[str] = Field(None) variant: ModelVariantType class CheckpointConfig(ModelConfigBase): - model_format: Literal["checkpoint"] + model_format: Literal[StableDiffusion1ModelFormat.Checkpoint] vae: Optional[str] = Field(None) config: Optional[str] = Field(None) variant: ModelVariantType @@ -47,7 +51,7 @@ class StableDiffusion1Model(DiffusersModel): def probe_config(cls, path: str, **kwargs): model_format = cls.detect_format(path) ckpt_config_path = kwargs.get("config", None) - if model_format == "checkpoint": + if model_format == StableDiffusion1ModelFormat.Checkpoint: if ckpt_config_path: ckpt_config = OmegaConf.load(ckpt_config_path) ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"] @@ -57,7 +61,7 @@ class StableDiffusion1Model(DiffusersModel): checkpoint = checkpoint.get('state_dict', checkpoint) in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1] - elif model_format == "diffusers": + elif model_format == StableDiffusion1ModelFormat.Diffusers: unet_config_path = os.path.join(path, "unet", "config.json") if os.path.exists(unet_config_path): with open(unet_config_path, "r") as f: @@ -93,9 +97,9 @@ class StableDiffusion1Model(DiffusersModel): @classmethod def detect_format(cls, model_path: str): if os.path.isdir(model_path): - return "diffusers" + return StableDiffusion1ModelFormat.Diffusers else: - return "checkpoint" + return StableDiffusion1ModelFormat.Checkpoint @classmethod def convert_if_required( @@ -116,19 +120,22 @@ class StableDiffusion1Model(DiffusersModel): else: return model_path +class StableDiffusion2ModelFormat(str, Enum): + Checkpoint = "checkpoint" + Diffusers = "diffusers" class StableDiffusion2Model(DiffusersModel): # TODO: check that configs overwriten properly class DiffusersConfig(ModelConfigBase): - model_format: Literal["diffusers"] + model_format: Literal[StableDiffusion2ModelFormat.Diffusers] vae: Optional[str] = Field(None) variant: ModelVariantType prediction_type: SchedulerPredictionType upcast_attention: bool class CheckpointConfig(ModelConfigBase): - model_format: Literal["checkpoint"] + model_format: Literal[StableDiffusion2ModelFormat.Checkpoint] vae: Optional[str] = Field(None) config: Optional[str] = Field(None) variant: ModelVariantType @@ -149,7 +156,7 @@ class StableDiffusion2Model(DiffusersModel): def probe_config(cls, path: str, **kwargs): model_format = cls.detect_format(path) ckpt_config_path = kwargs.get("config", None) - if model_format == "checkpoint": + if model_format == StableDiffusion2ModelFormat.Checkpoint: if ckpt_config_path: ckpt_config = OmegaConf.load(ckpt_config_path) ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"] @@ -159,7 +166,7 @@ class StableDiffusion2Model(DiffusersModel): checkpoint = checkpoint.get('state_dict', checkpoint) in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1] - elif model_format == "diffusers": + elif model_format == StableDiffusion2ModelFormat.Diffusers: unet_config_path = os.path.join(path, "unet", "config.json") if os.path.exists(unet_config_path): with open(unet_config_path, "r") as f: @@ -206,9 +213,9 @@ class StableDiffusion2Model(DiffusersModel): @classmethod def detect_format(cls, model_path: str): if os.path.isdir(model_path): - return "diffusers" + return StableDiffusion2ModelFormat.Diffusers else: - return "checkpoint" + return StableDiffusion2ModelFormat.Checkpoint @classmethod def convert_if_required( diff --git a/invokeai/backend/model_management/models/vae.py b/invokeai/backend/model_management/models/vae.py index e86bc00ecd..76133b074d 100644 --- a/invokeai/backend/model_management/models/vae.py +++ b/invokeai/backend/model_management/models/vae.py @@ -1,6 +1,7 @@ import os import torch import safetensors +from enum import Enum from pathlib import Path from typing import Optional, Union, Literal from .base import ( @@ -19,12 +20,16 @@ from invokeai.app.services.config import InvokeAIAppConfig from diffusers.utils import is_safetensors_available from omegaconf import OmegaConf +class VaeModelFormat(str, Enum): + Checkpoint = "checkpoint" + Diffusers = "diffusers" + class VaeModel(ModelBase): #vae_class: Type #model_size: int class Config(ModelConfigBase): - model_format: Union[Literal["checkpoint"], Literal["diffusers"]] + model_format: VaeModelFormat def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): assert model_type == ModelType.Vae @@ -71,9 +76,9 @@ class VaeModel(ModelBase): @classmethod def detect_format(cls, path: str): if os.path.isdir(path): - return "diffusers" + return VaeModelFormat.Diffusers else: - return "checkpoint" + return VaeModelFormat.Checkpoint @classmethod def convert_if_required( @@ -83,7 +88,7 @@ class VaeModel(ModelBase): config: ModelConfigBase, # empty config or config of parent model base_model: BaseModelType, ) -> str: - if cls.detect_format(model_path) != "diffusers": + if cls.detect_format(model_path) == VaeModelFormat.Checkpoint: return _convert_vae_ckpt_and_cache( weights_path=model_path, output_path=output_path,