This commit is contained in:
Lincoln Stein 2023-11-10 19:14:29 -05:00
parent 3a6ba236f5
commit f1c846ba5c
2 changed files with 9 additions and 5 deletions

View File

@ -249,10 +249,12 @@ class T2IConfig(ModelConfigBase):
format: Literal[ModelFormat.Diffusers]
_ONNXConfig = Annotated[Union[ONNXSD1Config, ONNXSD2Config], Field(discriminator='base')]
_ControlNetConfig = Annotated[Union[ControlNetDiffusersConfig, ControlNetCheckpointConfig], Field(discriminator='format')]
_VaeConfig = Annotated[Union[VaeDiffusersConfig, VaeCheckpointConfig], Field(discriminator='format')]
_MainModelConfig = Annotated[Union[MainDiffusersConfig, MainCheckpointConfig], Field(discriminator='format')]
_ONNXConfig = Annotated[Union[ONNXSD1Config, ONNXSD2Config], Field(discriminator="base")]
_ControlNetConfig = Annotated[
Union[ControlNetDiffusersConfig, ControlNetCheckpointConfig], Field(discriminator="format")
]
_VaeConfig = Annotated[Union[VaeDiffusersConfig, VaeCheckpointConfig], Field(discriminator="format")]
_MainModelConfig = Annotated[Union[MainDiffusersConfig, MainCheckpointConfig], Field(discriminator="format")]
AnyModelConfig = Annotated[
Union[
@ -266,11 +268,12 @@ AnyModelConfig = Annotated[
CLIPVisionDiffusersConfig,
T2IConfig,
],
Field(discriminator='type')
Field(discriminator="type"),
]
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
class ModelConfigFactory(object):
"""Class for parsing config dicts into StableDiffusion Config obects."""

View File

@ -83,6 +83,7 @@ def test_update(store: ModelRecordServiceBase):
new_config = store.get_model("key1")
assert new_config.name == "new name"
def test_rename(store: ModelRecordServiceBase):
config = example_config()
store.add_model("key1", config)