Update model format field to use enums

This commit is contained in:
Sergey Borisov 2023-06-20 03:30:09 +03:00
parent 4cefe37723
commit 46dc751139
5 changed files with 60 additions and 44 deletions

View File

@ -125,30 +125,24 @@ class ModelBase(metaclass=ABCMeta):
continue
fields = inspect.get_annotations(value)
if "model_format" not in fields:
raise Exception("Invalid config definition - model_format field not found")
try:
field = fields["model_format"]
except:
raise Exception(f"Invalid config definition - format field not found({cls.__qualname__})")
format_type = typing.get_origin(fields["model_format"])
if format_type not in {None, Literal, Union}:
raise Exception(f"Invalid config definition - unknown format type: {fields['model_format']}")
if isinstance(field, type) and issubclass(field, str) and issubclass(field, Enum):
for model_format in field:
configs[model_format.value] = value
if format_type is Union and not all(typing.get_origin(v) in {None, Literal} for v in fields["model_format"].__args__):
raise Exception(f"Invalid config definition - unknown format type: {fields['model_format']}")
elif typing.get_origin(field) is Literal and all(isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__):
for model_format in field.__args__:
configs[model_format.value] = value
elif field is None:
configs[None] = value
if format_type == Union:
f_fields = fields["model_format"].__args__
else:
f_fields = (fields["model_format"],)
for field in f_fields:
if field is None:
format_name = None
else:
format_name = field.__args__[0]
configs[format_name] = value # TODO: error when override(multiple)?
raise Exception(f"Unsupported format definition in {cls.__qualname__}")
cls.__configs = configs
return cls.__configs

View File

@ -1,5 +1,6 @@
import os
import torch
from enum import Enum
from pathlib import Path
from typing import Optional, Union, Literal
from .base import (
@ -14,12 +15,16 @@ from .base import (
classproperty,
)
class ControlNetModelFormat(str, Enum):
Checkpoint = "checkpoint"
Diffusers = "diffusers"
class ControlNetModel(ModelBase):
#model_class: Type
#model_size: int
class Config(ModelConfigBase):
model_format: Union[Literal["checkpoint"], Literal["diffusers"]]
model_format: ControlNetModelFormat
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.ControlNet
@ -69,9 +74,9 @@ class ControlNetModel(ModelBase):
@classmethod
def detect_format(cls, path: str):
if os.path.isdir(path):
return "diffusers"
return ControlNetModelFormat.Diffusers
else:
return "checkpoint"
return ControlNetModelFormat.Checkpoint
@classmethod
def convert_if_required(
@ -81,7 +86,7 @@ class ControlNetModel(ModelBase):
config: ModelConfigBase, # empty config or config of parent model
base_model: BaseModelType,
) -> str:
if cls.detect_format(model_path) != "diffusers":
if cls.detect_format(model_path) != ControlNetModelFormat.Diffusers:
raise NotImplementedError("Checkpoint controlnet models currently unsupported")
else:
return model_path

View File

@ -1,5 +1,6 @@
import os
import torch
from enum import Enum
from typing import Optional, Union, Literal
from .base import (
ModelBase,
@ -12,11 +13,15 @@ from .base import (
# TODO: naming
from ..lora import LoRAModel as LoRAModelRaw
class LoRAModelFormat(str, Enum):
LyCORIS = "lycoris"
Diffusers = "diffusers"
class LoRAModel(ModelBase):
#model_size: int
class Config(ModelConfigBase):
model_format: Union[Literal["lycoris"], Literal["diffusers"]]
model_format: LoRAModelFormat # TODO:
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.Lora
@ -52,9 +57,9 @@ class LoRAModel(ModelBase):
@classmethod
def detect_format(cls, path: str):
if os.path.isdir(path):
return "diffusers"
return LoRAModelFormat.Diffusers
else:
return "lycoris"
return LoRAModelFormat.LyCORIS
@classmethod
def convert_if_required(
@ -64,7 +69,7 @@ class LoRAModel(ModelBase):
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
if cls.detect_format(model_path) == "diffusers":
if cls.detect_format(model_path) == LoRAModelFormat.Diffusers:
# TODO: add diffusers lora when it stabilizes a bit
raise NotImplementedError("Diffusers lora not supported")
else:

View File

@ -1,5 +1,6 @@
import os
import json
from enum import Enum
from pydantic import Field
from pathlib import Path
from typing import Literal, Optional, Union
@ -19,16 +20,19 @@ from .base import (
from invokeai.app.services.config import InvokeAIAppConfig
from omegaconf import OmegaConf
class StableDiffusion1ModelFormat(str, Enum):
Checkpoint = "checkpoint"
Diffusers = "diffusers"
class StableDiffusion1Model(DiffusersModel):
class DiffusersConfig(ModelConfigBase):
model_format: Literal["diffusers"]
model_format: Literal[StableDiffusion1ModelFormat.Diffusers]
vae: Optional[str] = Field(None)
variant: ModelVariantType
class CheckpointConfig(ModelConfigBase):
model_format: Literal["checkpoint"]
model_format: Literal[StableDiffusion1ModelFormat.Checkpoint]
vae: Optional[str] = Field(None)
config: Optional[str] = Field(None)
variant: ModelVariantType
@ -47,7 +51,7 @@ class StableDiffusion1Model(DiffusersModel):
def probe_config(cls, path: str, **kwargs):
model_format = cls.detect_format(path)
ckpt_config_path = kwargs.get("config", None)
if model_format == "checkpoint":
if model_format == StableDiffusion1ModelFormat.Checkpoint:
if ckpt_config_path:
ckpt_config = OmegaConf.load(ckpt_config_path)
ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
@ -57,7 +61,7 @@ class StableDiffusion1Model(DiffusersModel):
checkpoint = checkpoint.get('state_dict', checkpoint)
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
elif model_format == "diffusers":
elif model_format == StableDiffusion1ModelFormat.Diffusers:
unet_config_path = os.path.join(path, "unet", "config.json")
if os.path.exists(unet_config_path):
with open(unet_config_path, "r") as f:
@ -93,9 +97,9 @@ class StableDiffusion1Model(DiffusersModel):
@classmethod
def detect_format(cls, model_path: str):
if os.path.isdir(model_path):
return "diffusers"
return StableDiffusion1ModelFormat.Diffusers
else:
return "checkpoint"
return StableDiffusion1ModelFormat.Checkpoint
@classmethod
def convert_if_required(
@ -116,19 +120,22 @@ class StableDiffusion1Model(DiffusersModel):
else:
return model_path
class StableDiffusion2ModelFormat(str, Enum):
Checkpoint = "checkpoint"
Diffusers = "diffusers"
class StableDiffusion2Model(DiffusersModel):
# TODO: check that configs overwriten properly
class DiffusersConfig(ModelConfigBase):
model_format: Literal["diffusers"]
model_format: Literal[StableDiffusion2ModelFormat.Diffusers]
vae: Optional[str] = Field(None)
variant: ModelVariantType
prediction_type: SchedulerPredictionType
upcast_attention: bool
class CheckpointConfig(ModelConfigBase):
model_format: Literal["checkpoint"]
model_format: Literal[StableDiffusion2ModelFormat.Checkpoint]
vae: Optional[str] = Field(None)
config: Optional[str] = Field(None)
variant: ModelVariantType
@ -149,7 +156,7 @@ class StableDiffusion2Model(DiffusersModel):
def probe_config(cls, path: str, **kwargs):
model_format = cls.detect_format(path)
ckpt_config_path = kwargs.get("config", None)
if model_format == "checkpoint":
if model_format == StableDiffusion2ModelFormat.Checkpoint:
if ckpt_config_path:
ckpt_config = OmegaConf.load(ckpt_config_path)
ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
@ -159,7 +166,7 @@ class StableDiffusion2Model(DiffusersModel):
checkpoint = checkpoint.get('state_dict', checkpoint)
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
elif model_format == "diffusers":
elif model_format == StableDiffusion2ModelFormat.Diffusers:
unet_config_path = os.path.join(path, "unet", "config.json")
if os.path.exists(unet_config_path):
with open(unet_config_path, "r") as f:
@ -206,9 +213,9 @@ class StableDiffusion2Model(DiffusersModel):
@classmethod
def detect_format(cls, model_path: str):
if os.path.isdir(model_path):
return "diffusers"
return StableDiffusion2ModelFormat.Diffusers
else:
return "checkpoint"
return StableDiffusion2ModelFormat.Checkpoint
@classmethod
def convert_if_required(

View File

@ -1,6 +1,7 @@
import os
import torch
import safetensors
from enum import Enum
from pathlib import Path
from typing import Optional, Union, Literal
from .base import (
@ -19,12 +20,16 @@ from invokeai.app.services.config import InvokeAIAppConfig
from diffusers.utils import is_safetensors_available
from omegaconf import OmegaConf
class VaeModelFormat(str, Enum):
Checkpoint = "checkpoint"
Diffusers = "diffusers"
class VaeModel(ModelBase):
#vae_class: Type
#model_size: int
class Config(ModelConfigBase):
model_format: Union[Literal["checkpoint"], Literal["diffusers"]]
model_format: VaeModelFormat
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.Vae
@ -71,9 +76,9 @@ class VaeModel(ModelBase):
@classmethod
def detect_format(cls, path: str):
if os.path.isdir(path):
return "diffusers"
return VaeModelFormat.Diffusers
else:
return "checkpoint"
return VaeModelFormat.Checkpoint
@classmethod
def convert_if_required(
@ -83,7 +88,7 @@ class VaeModel(ModelBase):
config: ModelConfigBase, # empty config or config of parent model
base_model: BaseModelType,
) -> str:
if cls.detect_format(model_path) != "diffusers":
if cls.detect_format(model_path) == VaeModelFormat.Checkpoint:
return _convert_vae_ckpt_and_cache(
weights_path=model_path,
output_path=output_path,