fixed ruff formatting issues

This commit is contained in:
Lincoln Stein 2023-11-13 18:15:17 -05:00
parent 38c1436f02
commit acc0a29dca
3 changed files with 11 additions and 35 deletions

View File

@ -41,23 +41,15 @@ class ModelsList(BaseModel):
operation_id="list_model_records", operation_id="list_model_records",
) )
async def list_model_records( async def list_model_records(
base_models: Optional[List[BaseModelType]] = Query( base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
default=None, description="Base models to include" model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
),
model_type: Optional[ModelType] = Query(
default=None, description="The type of model to get"
),
) -> ModelsList: ) -> ModelsList:
"""Get a list of models.""" """Get a list of models."""
record_store = ApiDependencies.invoker.services.model_records record_store = ApiDependencies.invoker.services.model_records
found_models: list[AnyModelConfig] = [] found_models: list[AnyModelConfig] = []
if base_models: if base_models:
for base_model in base_models: for base_model in base_models:
found_models.extend( found_models.extend(record_store.search_by_attr(base_model=base_model, model_type=model_type))
record_store.search_by_attr(
base_model=base_model, model_type=model_type
)
)
else: else:
found_models.extend(record_store.search_by_attr(model_type=model_type)) found_models.extend(record_store.search_by_attr(model_type=model_type))
return ModelsList(models=found_models) return ModelsList(models=found_models)
@ -97,9 +89,7 @@ async def get_model_record(
) )
async def update_model_record( async def update_model_record(
key: Annotated[str, Path(description="Unique key of model")], key: Annotated[str, Path(description="Unique key of model")],
info: Annotated[ info: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")],
AnyModelConfig, Body(description="Model config", discriminator="type")
],
) -> AnyModelConfig: ) -> AnyModelConfig:
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed.""" """Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
@ -145,17 +135,13 @@ async def del_model_record(
operation_id="add_model_record", operation_id="add_model_record",
responses={ responses={
201: {"description": "The model added successfully"}, 201: {"description": "The model added successfully"},
409: { 409: {"description": "There is already a model corresponding to this path or repo_id"},
"description": "There is already a model corresponding to this path or repo_id"
},
415: {"description": "Unrecognized file/folder format"}, 415: {"description": "Unrecognized file/folder format"},
}, },
status_code=201, status_code=201,
) )
async def add_model_record( async def add_model_record(
config: Annotated[ config: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")]
AnyModelConfig, Body(description="Model config", discriminator="type")
]
) -> AnyModelConfig: ) -> AnyModelConfig:
""" """
Add a model using the configuration information appropriate for its type. Add a model using the configuration information appropriate for its type.

View File

@ -115,9 +115,7 @@ class ModelConfigBase(BaseModel):
description="current fasthash of model contents", default=None description="current fasthash of model contents", default=None
) # if model is converted or otherwise modified, this will hold updated hash ) # if model is converted or otherwise modified, this will hold updated hash
description: Optional[str] = Field(default=None) description: Optional[str] = Field(default=None)
source: Optional[str] = Field( source: Optional[str] = Field(description="Model download source (URL or repo_id)", default=None)
description="Model download source (URL or repo_id)", default=None
)
model_config = ConfigDict( model_config = ConfigDict(
use_enum_values=False, use_enum_values=False,
@ -251,19 +249,13 @@ class T2IConfig(ModelConfigBase):
format: Literal[ModelFormat.Diffusers] format: Literal[ModelFormat.Diffusers]
_ONNXConfig = Annotated[ _ONNXConfig = Annotated[Union[ONNXSD1Config, ONNXSD2Config], Field(discriminator="base")]
Union[ONNXSD1Config, ONNXSD2Config], Field(discriminator="base")
]
_ControlNetConfig = Annotated[ _ControlNetConfig = Annotated[
Union[ControlNetDiffusersConfig, ControlNetCheckpointConfig], Union[ControlNetDiffusersConfig, ControlNetCheckpointConfig],
Field(discriminator="format"), Field(discriminator="format"),
] ]
_VaeConfig = Annotated[ _VaeConfig = Annotated[Union[VaeDiffusersConfig, VaeCheckpointConfig], Field(discriminator="format")]
Union[VaeDiffusersConfig, VaeCheckpointConfig], Field(discriminator="format") _MainModelConfig = Annotated[Union[MainDiffusersConfig, MainCheckpointConfig], Field(discriminator="format")]
]
_MainModelConfig = Annotated[
Union[MainDiffusersConfig, MainCheckpointConfig], Field(discriminator="format")
]
AnyModelConfig = Union[ AnyModelConfig = Union[
_MainModelConfig, _MainModelConfig,

View File

@ -159,9 +159,7 @@ def test_filter(store: ModelRecordServiceBase):
assert len(matches) == 1 assert len(matches) == 1
assert matches[0].name == "config3" assert matches[0].name == "config3"
assert matches[0].key == sha256("config3".encode("utf-8")).hexdigest() assert matches[0].key == sha256("config3".encode("utf-8")).hexdigest()
assert isinstance( assert isinstance(matches[0].type, ModelType) # This tests that we get proper enums back
matches[0].type, ModelType
) # This tests that we get proper enums back
matches = store.search_by_hash("CONFIG1HASH") matches = store.search_by_hash("CONFIG1HASH")
assert len(matches) == 1 assert len(matches) == 1