mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Update model format field to use enums
This commit is contained in:
parent
4cefe37723
commit
46dc751139
@ -125,30 +125,24 @@ class ModelBase(metaclass=ABCMeta):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
fields = inspect.get_annotations(value)
|
fields = inspect.get_annotations(value)
|
||||||
if "model_format" not in fields:
|
try:
|
||||||
raise Exception("Invalid config definition - model_format field not found")
|
field = fields["model_format"]
|
||||||
|
except:
|
||||||
|
raise Exception(f"Invalid config definition - format field not found({cls.__qualname__})")
|
||||||
|
|
||||||
format_type = typing.get_origin(fields["model_format"])
|
if isinstance(field, type) and issubclass(field, str) and issubclass(field, Enum):
|
||||||
if format_type not in {None, Literal, Union}:
|
for model_format in field:
|
||||||
raise Exception(f"Invalid config definition - unknown format type: {fields['model_format']}")
|
configs[model_format.value] = value
|
||||||
|
|
||||||
if format_type is Union and not all(typing.get_origin(v) in {None, Literal} for v in fields["model_format"].__args__):
|
elif typing.get_origin(field) is Literal and all(isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__):
|
||||||
raise Exception(f"Invalid config definition - unknown format type: {fields['model_format']}")
|
for model_format in field.__args__:
|
||||||
|
configs[model_format.value] = value
|
||||||
|
|
||||||
|
elif field is None:
|
||||||
|
configs[None] = value
|
||||||
|
|
||||||
if format_type == Union:
|
|
||||||
f_fields = fields["model_format"].__args__
|
|
||||||
else:
|
else:
|
||||||
f_fields = (fields["model_format"],)
|
raise Exception(f"Unsupported format definition in {cls.__qualname__}")
|
||||||
|
|
||||||
|
|
||||||
for field in f_fields:
|
|
||||||
if field is None:
|
|
||||||
format_name = None
|
|
||||||
else:
|
|
||||||
format_name = field.__args__[0]
|
|
||||||
|
|
||||||
configs[format_name] = value # TODO: error when override(multiple)?
|
|
||||||
|
|
||||||
|
|
||||||
cls.__configs = configs
|
cls.__configs = configs
|
||||||
return cls.__configs
|
return cls.__configs
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Union, Literal
|
from typing import Optional, Union, Literal
|
||||||
from .base import (
|
from .base import (
|
||||||
@ -14,12 +15,16 @@ from .base import (
|
|||||||
classproperty,
|
classproperty,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
class ControlNetModelFormat(str, Enum):
|
||||||
|
Checkpoint = "checkpoint"
|
||||||
|
Diffusers = "diffusers"
|
||||||
|
|
||||||
class ControlNetModel(ModelBase):
|
class ControlNetModel(ModelBase):
|
||||||
#model_class: Type
|
#model_class: Type
|
||||||
#model_size: int
|
#model_size: int
|
||||||
|
|
||||||
class Config(ModelConfigBase):
|
class Config(ModelConfigBase):
|
||||||
model_format: Union[Literal["checkpoint"], Literal["diffusers"]]
|
model_format: ControlNetModelFormat
|
||||||
|
|
||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
assert model_type == ModelType.ControlNet
|
assert model_type == ModelType.ControlNet
|
||||||
@ -69,9 +74,9 @@ class ControlNetModel(ModelBase):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def detect_format(cls, path: str):
|
def detect_format(cls, path: str):
|
||||||
if os.path.isdir(path):
|
if os.path.isdir(path):
|
||||||
return "diffusers"
|
return ControlNetModelFormat.Diffusers
|
||||||
else:
|
else:
|
||||||
return "checkpoint"
|
return ControlNetModelFormat.Checkpoint
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_if_required(
|
def convert_if_required(
|
||||||
@ -81,7 +86,7 @@ class ControlNetModel(ModelBase):
|
|||||||
config: ModelConfigBase, # empty config or config of parent model
|
config: ModelConfigBase, # empty config or config of parent model
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
) -> str:
|
) -> str:
|
||||||
if cls.detect_format(model_path) != "diffusers":
|
if cls.detect_format(model_path) != ControlNetModelFormat.Diffusers:
|
||||||
raise NotImplementedError("Checkpoint controlnet models currently unsupported")
|
raise NotImplementedError("Checkpoint controlnet models currently unsupported")
|
||||||
else:
|
else:
|
||||||
return model_path
|
return model_path
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
|
from enum import Enum
|
||||||
from typing import Optional, Union, Literal
|
from typing import Optional, Union, Literal
|
||||||
from .base import (
|
from .base import (
|
||||||
ModelBase,
|
ModelBase,
|
||||||
@ -12,11 +13,15 @@ from .base import (
|
|||||||
# TODO: naming
|
# TODO: naming
|
||||||
from ..lora import LoRAModel as LoRAModelRaw
|
from ..lora import LoRAModel as LoRAModelRaw
|
||||||
|
|
||||||
|
class LoRAModelFormat(str, Enum):
|
||||||
|
LyCORIS = "lycoris"
|
||||||
|
Diffusers = "diffusers"
|
||||||
|
|
||||||
class LoRAModel(ModelBase):
|
class LoRAModel(ModelBase):
|
||||||
#model_size: int
|
#model_size: int
|
||||||
|
|
||||||
class Config(ModelConfigBase):
|
class Config(ModelConfigBase):
|
||||||
model_format: Union[Literal["lycoris"], Literal["diffusers"]]
|
model_format: LoRAModelFormat # TODO:
|
||||||
|
|
||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
assert model_type == ModelType.Lora
|
assert model_type == ModelType.Lora
|
||||||
@ -52,9 +57,9 @@ class LoRAModel(ModelBase):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def detect_format(cls, path: str):
|
def detect_format(cls, path: str):
|
||||||
if os.path.isdir(path):
|
if os.path.isdir(path):
|
||||||
return "diffusers"
|
return LoRAModelFormat.Diffusers
|
||||||
else:
|
else:
|
||||||
return "lycoris"
|
return LoRAModelFormat.LyCORIS
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_if_required(
|
def convert_if_required(
|
||||||
@ -64,7 +69,7 @@ class LoRAModel(ModelBase):
|
|||||||
config: ModelConfigBase,
|
config: ModelConfigBase,
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
) -> str:
|
) -> str:
|
||||||
if cls.detect_format(model_path) == "diffusers":
|
if cls.detect_format(model_path) == LoRAModelFormat.Diffusers:
|
||||||
# TODO: add diffusers lora when it stabilizes a bit
|
# TODO: add diffusers lora when it stabilizes a bit
|
||||||
raise NotImplementedError("Diffusers lora not supported")
|
raise NotImplementedError("Diffusers lora not supported")
|
||||||
else:
|
else:
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
|
from enum import Enum
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Literal, Optional, Union
|
from typing import Literal, Optional, Union
|
||||||
@ -19,16 +20,19 @@ from .base import (
|
|||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
class StableDiffusion1ModelFormat(str, Enum):
|
||||||
|
Checkpoint = "checkpoint"
|
||||||
|
Diffusers = "diffusers"
|
||||||
|
|
||||||
class StableDiffusion1Model(DiffusersModel):
|
class StableDiffusion1Model(DiffusersModel):
|
||||||
|
|
||||||
class DiffusersConfig(ModelConfigBase):
|
class DiffusersConfig(ModelConfigBase):
|
||||||
model_format: Literal["diffusers"]
|
model_format: Literal[StableDiffusion1ModelFormat.Diffusers]
|
||||||
vae: Optional[str] = Field(None)
|
vae: Optional[str] = Field(None)
|
||||||
variant: ModelVariantType
|
variant: ModelVariantType
|
||||||
|
|
||||||
class CheckpointConfig(ModelConfigBase):
|
class CheckpointConfig(ModelConfigBase):
|
||||||
model_format: Literal["checkpoint"]
|
model_format: Literal[StableDiffusion1ModelFormat.Checkpoint]
|
||||||
vae: Optional[str] = Field(None)
|
vae: Optional[str] = Field(None)
|
||||||
config: Optional[str] = Field(None)
|
config: Optional[str] = Field(None)
|
||||||
variant: ModelVariantType
|
variant: ModelVariantType
|
||||||
@ -47,7 +51,7 @@ class StableDiffusion1Model(DiffusersModel):
|
|||||||
def probe_config(cls, path: str, **kwargs):
|
def probe_config(cls, path: str, **kwargs):
|
||||||
model_format = cls.detect_format(path)
|
model_format = cls.detect_format(path)
|
||||||
ckpt_config_path = kwargs.get("config", None)
|
ckpt_config_path = kwargs.get("config", None)
|
||||||
if model_format == "checkpoint":
|
if model_format == StableDiffusion1ModelFormat.Checkpoint:
|
||||||
if ckpt_config_path:
|
if ckpt_config_path:
|
||||||
ckpt_config = OmegaConf.load(ckpt_config_path)
|
ckpt_config = OmegaConf.load(ckpt_config_path)
|
||||||
ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
|
ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
|
||||||
@ -57,7 +61,7 @@ class StableDiffusion1Model(DiffusersModel):
|
|||||||
checkpoint = checkpoint.get('state_dict', checkpoint)
|
checkpoint = checkpoint.get('state_dict', checkpoint)
|
||||||
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
||||||
|
|
||||||
elif model_format == "diffusers":
|
elif model_format == StableDiffusion1ModelFormat.Diffusers:
|
||||||
unet_config_path = os.path.join(path, "unet", "config.json")
|
unet_config_path = os.path.join(path, "unet", "config.json")
|
||||||
if os.path.exists(unet_config_path):
|
if os.path.exists(unet_config_path):
|
||||||
with open(unet_config_path, "r") as f:
|
with open(unet_config_path, "r") as f:
|
||||||
@ -93,9 +97,9 @@ class StableDiffusion1Model(DiffusersModel):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def detect_format(cls, model_path: str):
|
def detect_format(cls, model_path: str):
|
||||||
if os.path.isdir(model_path):
|
if os.path.isdir(model_path):
|
||||||
return "diffusers"
|
return StableDiffusion1ModelFormat.Diffusers
|
||||||
else:
|
else:
|
||||||
return "checkpoint"
|
return StableDiffusion1ModelFormat.Checkpoint
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_if_required(
|
def convert_if_required(
|
||||||
@ -116,19 +120,22 @@ class StableDiffusion1Model(DiffusersModel):
|
|||||||
else:
|
else:
|
||||||
return model_path
|
return model_path
|
||||||
|
|
||||||
|
class StableDiffusion2ModelFormat(str, Enum):
|
||||||
|
Checkpoint = "checkpoint"
|
||||||
|
Diffusers = "diffusers"
|
||||||
|
|
||||||
class StableDiffusion2Model(DiffusersModel):
|
class StableDiffusion2Model(DiffusersModel):
|
||||||
|
|
||||||
# TODO: check that configs overwriten properly
|
# TODO: check that configs overwriten properly
|
||||||
class DiffusersConfig(ModelConfigBase):
|
class DiffusersConfig(ModelConfigBase):
|
||||||
model_format: Literal["diffusers"]
|
model_format: Literal[StableDiffusion2ModelFormat.Diffusers]
|
||||||
vae: Optional[str] = Field(None)
|
vae: Optional[str] = Field(None)
|
||||||
variant: ModelVariantType
|
variant: ModelVariantType
|
||||||
prediction_type: SchedulerPredictionType
|
prediction_type: SchedulerPredictionType
|
||||||
upcast_attention: bool
|
upcast_attention: bool
|
||||||
|
|
||||||
class CheckpointConfig(ModelConfigBase):
|
class CheckpointConfig(ModelConfigBase):
|
||||||
model_format: Literal["checkpoint"]
|
model_format: Literal[StableDiffusion2ModelFormat.Checkpoint]
|
||||||
vae: Optional[str] = Field(None)
|
vae: Optional[str] = Field(None)
|
||||||
config: Optional[str] = Field(None)
|
config: Optional[str] = Field(None)
|
||||||
variant: ModelVariantType
|
variant: ModelVariantType
|
||||||
@ -149,7 +156,7 @@ class StableDiffusion2Model(DiffusersModel):
|
|||||||
def probe_config(cls, path: str, **kwargs):
|
def probe_config(cls, path: str, **kwargs):
|
||||||
model_format = cls.detect_format(path)
|
model_format = cls.detect_format(path)
|
||||||
ckpt_config_path = kwargs.get("config", None)
|
ckpt_config_path = kwargs.get("config", None)
|
||||||
if model_format == "checkpoint":
|
if model_format == StableDiffusion2ModelFormat.Checkpoint:
|
||||||
if ckpt_config_path:
|
if ckpt_config_path:
|
||||||
ckpt_config = OmegaConf.load(ckpt_config_path)
|
ckpt_config = OmegaConf.load(ckpt_config_path)
|
||||||
ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
|
ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
|
||||||
@ -159,7 +166,7 @@ class StableDiffusion2Model(DiffusersModel):
|
|||||||
checkpoint = checkpoint.get('state_dict', checkpoint)
|
checkpoint = checkpoint.get('state_dict', checkpoint)
|
||||||
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
||||||
|
|
||||||
elif model_format == "diffusers":
|
elif model_format == StableDiffusion2ModelFormat.Diffusers:
|
||||||
unet_config_path = os.path.join(path, "unet", "config.json")
|
unet_config_path = os.path.join(path, "unet", "config.json")
|
||||||
if os.path.exists(unet_config_path):
|
if os.path.exists(unet_config_path):
|
||||||
with open(unet_config_path, "r") as f:
|
with open(unet_config_path, "r") as f:
|
||||||
@ -206,9 +213,9 @@ class StableDiffusion2Model(DiffusersModel):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def detect_format(cls, model_path: str):
|
def detect_format(cls, model_path: str):
|
||||||
if os.path.isdir(model_path):
|
if os.path.isdir(model_path):
|
||||||
return "diffusers"
|
return StableDiffusion2ModelFormat.Diffusers
|
||||||
else:
|
else:
|
||||||
return "checkpoint"
|
return StableDiffusion2ModelFormat.Checkpoint
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_if_required(
|
def convert_if_required(
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
import safetensors
|
import safetensors
|
||||||
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Union, Literal
|
from typing import Optional, Union, Literal
|
||||||
from .base import (
|
from .base import (
|
||||||
@ -19,12 +20,16 @@ from invokeai.app.services.config import InvokeAIAppConfig
|
|||||||
from diffusers.utils import is_safetensors_available
|
from diffusers.utils import is_safetensors_available
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
class VaeModelFormat(str, Enum):
|
||||||
|
Checkpoint = "checkpoint"
|
||||||
|
Diffusers = "diffusers"
|
||||||
|
|
||||||
class VaeModel(ModelBase):
|
class VaeModel(ModelBase):
|
||||||
#vae_class: Type
|
#vae_class: Type
|
||||||
#model_size: int
|
#model_size: int
|
||||||
|
|
||||||
class Config(ModelConfigBase):
|
class Config(ModelConfigBase):
|
||||||
model_format: Union[Literal["checkpoint"], Literal["diffusers"]]
|
model_format: VaeModelFormat
|
||||||
|
|
||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
assert model_type == ModelType.Vae
|
assert model_type == ModelType.Vae
|
||||||
@ -71,9 +76,9 @@ class VaeModel(ModelBase):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def detect_format(cls, path: str):
|
def detect_format(cls, path: str):
|
||||||
if os.path.isdir(path):
|
if os.path.isdir(path):
|
||||||
return "diffusers"
|
return VaeModelFormat.Diffusers
|
||||||
else:
|
else:
|
||||||
return "checkpoint"
|
return VaeModelFormat.Checkpoint
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_if_required(
|
def convert_if_required(
|
||||||
@ -83,7 +88,7 @@ class VaeModel(ModelBase):
|
|||||||
config: ModelConfigBase, # empty config or config of parent model
|
config: ModelConfigBase, # empty config or config of parent model
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
) -> str:
|
) -> str:
|
||||||
if cls.detect_format(model_path) != "diffusers":
|
if cls.detect_format(model_path) == VaeModelFormat.Checkpoint:
|
||||||
return _convert_vae_ckpt_and_cache(
|
return _convert_vae_ckpt_and_cache(
|
||||||
weights_path=model_path,
|
weights_path=model_path,
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
|
Loading…
Reference in New Issue
Block a user