tidy(mm): tidy class names in config.py

This commit is contained in:
psychedelicious 2024-03-01 13:18:31 +11:00
parent 5b74117836
commit af9298f0ef

View File

@ -155,14 +155,14 @@ class ModelConfigBase(BaseModel):
setattr(self, key, value) # may raise a validation error
class _CheckpointConfig(ModelConfigBase):
class CheckpointConfigBase(ModelConfigBase):
"""Model config for checkpoint-style models."""
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
config: str = Field(description="path to the checkpoint model config file")
class _DiffusersConfig(ModelConfigBase):
class DiffusersConfigBase(ModelConfigBase):
"""Model config for diffusers-style models."""
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@ -213,7 +213,7 @@ class VaeDiffusersConfig(ModelConfigBase):
return Tag(f"{ModelType.Vae}.{ModelFormat.Diffusers}")
class ControlNetDiffusersConfig(_DiffusersConfig):
class ControlNetDiffusersConfig(DiffusersConfigBase):
"""Model config for ControlNet models (diffusers version)."""
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
@ -224,7 +224,7 @@ class ControlNetDiffusersConfig(_DiffusersConfig):
return Tag(f"{ModelType.ControlNet}.{ModelFormat.Diffusers}")
class ControlNetCheckpointConfig(_CheckpointConfig):
class ControlNetCheckpointConfig(CheckpointConfigBase):
"""Model config for ControlNet models (diffusers version)."""
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
@ -265,7 +265,7 @@ class _MainConfig(ModelConfigBase):
upcast_attention: bool = False
class MainCheckpointConfig(_CheckpointConfig, _MainConfig):
class MainCheckpointConfig(CheckpointConfigBase, _MainConfig):
"""Model config for main checkpoint models."""
type: Literal[ModelType.Main] = ModelType.Main
@ -275,7 +275,7 @@ class MainCheckpointConfig(_CheckpointConfig, _MainConfig):
return Tag(f"{ModelType.Main}.{ModelFormat.Checkpoint}")
class MainDiffusersConfig(_DiffusersConfig, _MainConfig):
class MainDiffusersConfig(DiffusersConfigBase, _MainConfig):
"""Model config for main diffusers models."""
type: Literal[ModelType.Main] = ModelType.Main
@ -350,27 +350,6 @@ AnyModelConfig = Annotated[
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
# IMPLEMENTATION NOTE:
# The preferred alternative to the above is a discriminated Union as shown
# below. However, it breaks FastAPI when used as the input Body parameter in a route.
# This is a known issue. Please see:
# https://github.com/tiangolo/fastapi/discussions/9761 and
# https://github.com/tiangolo/fastapi/discussions/9287
# AnyModelConfig = Annotated[
# Union[
# _MainModelConfig,
# _ONNXConfig,
# _VaeConfig,
# _ControlNetConfig,
# LoRAConfig,
# TextualInversionConfig,
# IPAdapterConfig,
# CLIPVisionDiffusersConfig,
# T2IConfig,
# ],
# Field(discriminator="type"),
# ]
class ModelConfigFactory(object):
"""Class for parsing config dicts into StableDiffusion Config obects."""