This commit is contained in:
Lincoln Stein 2023-11-10 19:14:29 -05:00
parent 3a6ba236f5
commit f1c846ba5c
2 changed files with 9 additions and 5 deletions

View File

@ -249,10 +249,12 @@ class T2IConfig(ModelConfigBase):
format: Literal[ModelFormat.Diffusers] format: Literal[ModelFormat.Diffusers]
_ONNXConfig = Annotated[Union[ONNXSD1Config, ONNXSD2Config], Field(discriminator='base')] _ONNXConfig = Annotated[Union[ONNXSD1Config, ONNXSD2Config], Field(discriminator="base")]
_ControlNetConfig = Annotated[Union[ControlNetDiffusersConfig, ControlNetCheckpointConfig], Field(discriminator='format')] _ControlNetConfig = Annotated[
_VaeConfig = Annotated[Union[VaeDiffusersConfig, VaeCheckpointConfig], Field(discriminator='format')] Union[ControlNetDiffusersConfig, ControlNetCheckpointConfig], Field(discriminator="format")
_MainModelConfig = Annotated[Union[MainDiffusersConfig, MainCheckpointConfig], Field(discriminator='format')] ]
_VaeConfig = Annotated[Union[VaeDiffusersConfig, VaeCheckpointConfig], Field(discriminator="format")]
_MainModelConfig = Annotated[Union[MainDiffusersConfig, MainCheckpointConfig], Field(discriminator="format")]
AnyModelConfig = Annotated[ AnyModelConfig = Annotated[
Union[ Union[
@ -266,11 +268,12 @@ AnyModelConfig = Annotated[
CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig,
T2IConfig, T2IConfig,
], ],
Field(discriminator='type') Field(discriminator="type"),
] ]
AnyModelConfigValidator = TypeAdapter(AnyModelConfig) AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
class ModelConfigFactory(object): class ModelConfigFactory(object):
"""Class for parsing config dicts into StableDiffusion Config obects.""" """Class for parsing config dicts into StableDiffusion Config obects."""

View File

@ -83,6 +83,7 @@ def test_update(store: ModelRecordServiceBase):
new_config = store.get_model("key1") new_config = store.get_model("key1")
assert new_config.name == "new name" assert new_config.name == "new name"
def test_rename(store: ModelRecordServiceBase): def test_rename(store: ModelRecordServiceBase):
config = example_config() config = example_config()
store.add_model("key1", config) store.add_model("key1", config)