Update model format field to use enums

This commit is contained in:
Sergey Borisov 2023-06-20 03:30:09 +03:00
parent 4cefe37723
commit 46dc751139
5 changed files with 60 additions and 44 deletions

View File

@ -125,30 +125,24 @@ class ModelBase(metaclass=ABCMeta):
continue continue
fields = inspect.get_annotations(value) fields = inspect.get_annotations(value)
if "model_format" not in fields: try:
raise Exception("Invalid config definition - model_format field not found") field = fields["model_format"]
except:
raise Exception(f"Invalid config definition - format field not found({cls.__qualname__})")
format_type = typing.get_origin(fields["model_format"]) if isinstance(field, type) and issubclass(field, str) and issubclass(field, Enum):
if format_type not in {None, Literal, Union}: for model_format in field:
raise Exception(f"Invalid config definition - unknown format type: {fields['model_format']}") configs[model_format.value] = value
if format_type is Union and not all(typing.get_origin(v) in {None, Literal} for v in fields["model_format"].__args__): elif typing.get_origin(field) is Literal and all(isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__):
raise Exception(f"Invalid config definition - unknown format type: {fields['model_format']}") for model_format in field.__args__:
configs[model_format.value] = value
elif field is None:
configs[None] = value
if format_type == Union:
f_fields = fields["model_format"].__args__
else: else:
f_fields = (fields["model_format"],) raise Exception(f"Unsupported format definition in {cls.__qualname__}")
for field in f_fields:
if field is None:
format_name = None
else:
format_name = field.__args__[0]
configs[format_name] = value # TODO: error when override(multiple)?
cls.__configs = configs cls.__configs = configs
return cls.__configs return cls.__configs

View File

@ -1,5 +1,6 @@
import os import os
import torch import torch
from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Optional, Union, Literal from typing import Optional, Union, Literal
from .base import ( from .base import (
@ -14,12 +15,16 @@ from .base import (
classproperty, classproperty,
) )
class ControlNetModelFormat(str, Enum):
Checkpoint = "checkpoint"
Diffusers = "diffusers"
class ControlNetModel(ModelBase): class ControlNetModel(ModelBase):
#model_class: Type #model_class: Type
#model_size: int #model_size: int
class Config(ModelConfigBase): class Config(ModelConfigBase):
model_format: Union[Literal["checkpoint"], Literal["diffusers"]] model_format: ControlNetModelFormat
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.ControlNet assert model_type == ModelType.ControlNet
@ -69,9 +74,9 @@ class ControlNetModel(ModelBase):
@classmethod @classmethod
def detect_format(cls, path: str): def detect_format(cls, path: str):
if os.path.isdir(path): if os.path.isdir(path):
return "diffusers" return ControlNetModelFormat.Diffusers
else: else:
return "checkpoint" return ControlNetModelFormat.Checkpoint
@classmethod @classmethod
def convert_if_required( def convert_if_required(
@ -81,7 +86,7 @@ class ControlNetModel(ModelBase):
config: ModelConfigBase, # empty config or config of parent model config: ModelConfigBase, # empty config or config of parent model
base_model: BaseModelType, base_model: BaseModelType,
) -> str: ) -> str:
if cls.detect_format(model_path) != "diffusers": if cls.detect_format(model_path) != ControlNetModelFormat.Diffusers:
raise NotImplementedError("Checkpoint controlnet models currently unsupported") raise NotImplementedError("Checkpoint controlnet models currently unsupported")
else: else:
return model_path return model_path

View File

@ -1,5 +1,6 @@
import os import os
import torch import torch
from enum import Enum
from typing import Optional, Union, Literal from typing import Optional, Union, Literal
from .base import ( from .base import (
ModelBase, ModelBase,
@ -12,11 +13,15 @@ from .base import (
# TODO: naming # TODO: naming
from ..lora import LoRAModel as LoRAModelRaw from ..lora import LoRAModel as LoRAModelRaw
class LoRAModelFormat(str, Enum):
LyCORIS = "lycoris"
Diffusers = "diffusers"
class LoRAModel(ModelBase): class LoRAModel(ModelBase):
#model_size: int #model_size: int
class Config(ModelConfigBase): class Config(ModelConfigBase):
model_format: Union[Literal["lycoris"], Literal["diffusers"]] model_format: LoRAModelFormat # TODO:
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.Lora assert model_type == ModelType.Lora
@ -52,9 +57,9 @@ class LoRAModel(ModelBase):
@classmethod @classmethod
def detect_format(cls, path: str): def detect_format(cls, path: str):
if os.path.isdir(path): if os.path.isdir(path):
return "diffusers" return LoRAModelFormat.Diffusers
else: else:
return "lycoris" return LoRAModelFormat.LyCORIS
@classmethod @classmethod
def convert_if_required( def convert_if_required(
@ -64,7 +69,7 @@ class LoRAModel(ModelBase):
config: ModelConfigBase, config: ModelConfigBase,
base_model: BaseModelType, base_model: BaseModelType,
) -> str: ) -> str:
if cls.detect_format(model_path) == "diffusers": if cls.detect_format(model_path) == LoRAModelFormat.Diffusers:
# TODO: add diffusers lora when it stabilizes a bit # TODO: add diffusers lora when it stabilizes a bit
raise NotImplementedError("Diffusers lora not supported") raise NotImplementedError("Diffusers lora not supported")
else: else:

View File

@ -1,5 +1,6 @@
import os import os
import json import json
from enum import Enum
from pydantic import Field from pydantic import Field
from pathlib import Path from pathlib import Path
from typing import Literal, Optional, Union from typing import Literal, Optional, Union
@ -19,16 +20,19 @@ from .base import (
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from omegaconf import OmegaConf from omegaconf import OmegaConf
class StableDiffusion1ModelFormat(str, Enum):
Checkpoint = "checkpoint"
Diffusers = "diffusers"
class StableDiffusion1Model(DiffusersModel): class StableDiffusion1Model(DiffusersModel):
class DiffusersConfig(ModelConfigBase): class DiffusersConfig(ModelConfigBase):
model_format: Literal["diffusers"] model_format: Literal[StableDiffusion1ModelFormat.Diffusers]
vae: Optional[str] = Field(None) vae: Optional[str] = Field(None)
variant: ModelVariantType variant: ModelVariantType
class CheckpointConfig(ModelConfigBase): class CheckpointConfig(ModelConfigBase):
model_format: Literal["checkpoint"] model_format: Literal[StableDiffusion1ModelFormat.Checkpoint]
vae: Optional[str] = Field(None) vae: Optional[str] = Field(None)
config: Optional[str] = Field(None) config: Optional[str] = Field(None)
variant: ModelVariantType variant: ModelVariantType
@ -47,7 +51,7 @@ class StableDiffusion1Model(DiffusersModel):
def probe_config(cls, path: str, **kwargs): def probe_config(cls, path: str, **kwargs):
model_format = cls.detect_format(path) model_format = cls.detect_format(path)
ckpt_config_path = kwargs.get("config", None) ckpt_config_path = kwargs.get("config", None)
if model_format == "checkpoint": if model_format == StableDiffusion1ModelFormat.Checkpoint:
if ckpt_config_path: if ckpt_config_path:
ckpt_config = OmegaConf.load(ckpt_config_path) ckpt_config = OmegaConf.load(ckpt_config_path)
ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"] ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
@ -57,7 +61,7 @@ class StableDiffusion1Model(DiffusersModel):
checkpoint = checkpoint.get('state_dict', checkpoint) checkpoint = checkpoint.get('state_dict', checkpoint)
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1] in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
elif model_format == "diffusers": elif model_format == StableDiffusion1ModelFormat.Diffusers:
unet_config_path = os.path.join(path, "unet", "config.json") unet_config_path = os.path.join(path, "unet", "config.json")
if os.path.exists(unet_config_path): if os.path.exists(unet_config_path):
with open(unet_config_path, "r") as f: with open(unet_config_path, "r") as f:
@ -93,9 +97,9 @@ class StableDiffusion1Model(DiffusersModel):
@classmethod @classmethod
def detect_format(cls, model_path: str): def detect_format(cls, model_path: str):
if os.path.isdir(model_path): if os.path.isdir(model_path):
return "diffusers" return StableDiffusion1ModelFormat.Diffusers
else: else:
return "checkpoint" return StableDiffusion1ModelFormat.Checkpoint
@classmethod @classmethod
def convert_if_required( def convert_if_required(
@ -116,19 +120,22 @@ class StableDiffusion1Model(DiffusersModel):
else: else:
return model_path return model_path
class StableDiffusion2ModelFormat(str, Enum):
Checkpoint = "checkpoint"
Diffusers = "diffusers"
class StableDiffusion2Model(DiffusersModel): class StableDiffusion2Model(DiffusersModel):
# TODO: check that configs overwriten properly # TODO: check that configs overwriten properly
class DiffusersConfig(ModelConfigBase): class DiffusersConfig(ModelConfigBase):
model_format: Literal["diffusers"] model_format: Literal[StableDiffusion2ModelFormat.Diffusers]
vae: Optional[str] = Field(None) vae: Optional[str] = Field(None)
variant: ModelVariantType variant: ModelVariantType
prediction_type: SchedulerPredictionType prediction_type: SchedulerPredictionType
upcast_attention: bool upcast_attention: bool
class CheckpointConfig(ModelConfigBase): class CheckpointConfig(ModelConfigBase):
model_format: Literal["checkpoint"] model_format: Literal[StableDiffusion2ModelFormat.Checkpoint]
vae: Optional[str] = Field(None) vae: Optional[str] = Field(None)
config: Optional[str] = Field(None) config: Optional[str] = Field(None)
variant: ModelVariantType variant: ModelVariantType
@ -149,7 +156,7 @@ class StableDiffusion2Model(DiffusersModel):
def probe_config(cls, path: str, **kwargs): def probe_config(cls, path: str, **kwargs):
model_format = cls.detect_format(path) model_format = cls.detect_format(path)
ckpt_config_path = kwargs.get("config", None) ckpt_config_path = kwargs.get("config", None)
if model_format == "checkpoint": if model_format == StableDiffusion2ModelFormat.Checkpoint:
if ckpt_config_path: if ckpt_config_path:
ckpt_config = OmegaConf.load(ckpt_config_path) ckpt_config = OmegaConf.load(ckpt_config_path)
ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"] ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
@ -159,7 +166,7 @@ class StableDiffusion2Model(DiffusersModel):
checkpoint = checkpoint.get('state_dict', checkpoint) checkpoint = checkpoint.get('state_dict', checkpoint)
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1] in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
elif model_format == "diffusers": elif model_format == StableDiffusion2ModelFormat.Diffusers:
unet_config_path = os.path.join(path, "unet", "config.json") unet_config_path = os.path.join(path, "unet", "config.json")
if os.path.exists(unet_config_path): if os.path.exists(unet_config_path):
with open(unet_config_path, "r") as f: with open(unet_config_path, "r") as f:
@ -206,9 +213,9 @@ class StableDiffusion2Model(DiffusersModel):
@classmethod @classmethod
def detect_format(cls, model_path: str): def detect_format(cls, model_path: str):
if os.path.isdir(model_path): if os.path.isdir(model_path):
return "diffusers" return StableDiffusion2ModelFormat.Diffusers
else: else:
return "checkpoint" return StableDiffusion2ModelFormat.Checkpoint
@classmethod @classmethod
def convert_if_required( def convert_if_required(

View File

@ -1,6 +1,7 @@
import os import os
import torch import torch
import safetensors import safetensors
from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Optional, Union, Literal from typing import Optional, Union, Literal
from .base import ( from .base import (
@ -19,12 +20,16 @@ from invokeai.app.services.config import InvokeAIAppConfig
from diffusers.utils import is_safetensors_available from diffusers.utils import is_safetensors_available
from omegaconf import OmegaConf from omegaconf import OmegaConf
class VaeModelFormat(str, Enum):
Checkpoint = "checkpoint"
Diffusers = "diffusers"
class VaeModel(ModelBase): class VaeModel(ModelBase):
#vae_class: Type #vae_class: Type
#model_size: int #model_size: int
class Config(ModelConfigBase): class Config(ModelConfigBase):
model_format: Union[Literal["checkpoint"], Literal["diffusers"]] model_format: VaeModelFormat
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.Vae assert model_type == ModelType.Vae
@ -71,9 +76,9 @@ class VaeModel(ModelBase):
@classmethod @classmethod
def detect_format(cls, path: str): def detect_format(cls, path: str):
if os.path.isdir(path): if os.path.isdir(path):
return "diffusers" return VaeModelFormat.Diffusers
else: else:
return "checkpoint" return VaeModelFormat.Checkpoint
@classmethod @classmethod
def convert_if_required( def convert_if_required(
@ -83,7 +88,7 @@ class VaeModel(ModelBase):
config: ModelConfigBase, # empty config or config of parent model config: ModelConfigBase, # empty config or config of parent model
base_model: BaseModelType, base_model: BaseModelType,
) -> str: ) -> str:
if cls.detect_format(model_path) != "diffusers": if cls.detect_format(model_path) == VaeModelFormat.Checkpoint:
return _convert_vae_ckpt_and_cache( return _convert_vae_ckpt_and_cache(
weights_path=model_path, weights_path=model_path,
output_path=output_path, output_path=output_path,