multiple small fixes suggested in reviews from psychedelicious and ryan

This commit is contained in:
Lincoln Stein
2023-11-10 18:25:37 -05:00
parent fdaa661245
commit 0544917161
4 changed files with 48 additions and 37 deletions

View File

@ -127,14 +127,14 @@ class ModelConfigBase(BaseModel):
setattr(self, key, value) # may raise a validation error
class CheckpointConfig(ModelConfigBase):
class _CheckpointConfig(ModelConfigBase):
"""Model config for checkpoint-style models."""
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
config: str = Field(description="path to the checkpoint model config file")
class DiffusersConfig(ModelConfigBase):
class _DiffusersConfig(ModelConfigBase):
"""Model config for diffusers-style models."""
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@ -158,13 +158,13 @@ class VaeDiffusersConfig(ModelConfigBase):
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
class ControlNetDiffusersConfig(DiffusersConfig):
class ControlNetDiffusersConfig(_DiffusersConfig):
"""Model config for ControlNet models (diffusers version)."""
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
class ControlNetCheckpointConfig(CheckpointConfig):
class ControlNetCheckpointConfig(_CheckpointConfig):
"""Model config for ControlNet models (diffusers version)."""
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
@ -176,29 +176,29 @@ class TextualInversionConfig(ModelConfigBase):
format: Literal[ModelFormat.EmbeddingFile, ModelFormat.EmbeddingFolder]
class MainConfig(ModelConfigBase):
class _MainConfig(ModelConfigBase):
"""Model config for main models."""
vae: Optional[str] = Field(None)
vae: Optional[str] = Field(default=None)
variant: ModelVariantType = ModelVariantType.Normal
ztsnr_training: bool = False
class MainCheckpointConfig(CheckpointConfig, MainConfig):
class MainCheckpointConfig(_CheckpointConfig, _MainConfig):
"""Model config for main checkpoint models."""
# Note that we do not need prediction_type or upcast_attention here
# because they are provided in the checkpoint's own config file.
class MainDiffusersConfig(DiffusersConfig, MainConfig):
class MainDiffusersConfig(_DiffusersConfig, _MainConfig):
"""Model config for main diffusers models."""
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
class ONNXSD1Config(MainConfig):
class ONNXSD1Config(_MainConfig):
"""Model config for ONNX format models based on sd-1."""
format: Literal[ModelFormat.Onnx, ModelFormat.Olive]
@ -206,7 +206,7 @@ class ONNXSD1Config(MainConfig):
upcast_attention: bool = False
class ONNXSD2Config(MainConfig):
class ONNXSD2Config(_MainConfig):
"""Model config for ONNX format models based on sd-2."""
format: Literal[ModelFormat.Onnx, ModelFormat.Olive]