fixed ruff formatting issues

This commit is contained in:
Lincoln Stein 2023-11-13 18:15:17 -05:00
parent 38c1436f02
commit acc0a29dca
3 changed files with 11 additions and 35 deletions

View File

@ -41,23 +41,15 @@ class ModelsList(BaseModel):
operation_id="list_model_records",
)
async def list_model_records(
base_models: Optional[List[BaseModelType]] = Query(
default=None, description="Base models to include"
),
model_type: Optional[ModelType] = Query(
default=None, description="The type of model to get"
),
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
) -> ModelsList:
"""Get a list of models."""
record_store = ApiDependencies.invoker.services.model_records
found_models: list[AnyModelConfig] = []
if base_models:
for base_model in base_models:
found_models.extend(
record_store.search_by_attr(
base_model=base_model, model_type=model_type
)
)
found_models.extend(record_store.search_by_attr(base_model=base_model, model_type=model_type))
else:
found_models.extend(record_store.search_by_attr(model_type=model_type))
return ModelsList(models=found_models)
@ -97,9 +89,7 @@ async def get_model_record(
)
async def update_model_record(
key: Annotated[str, Path(description="Unique key of model")],
info: Annotated[
AnyModelConfig, Body(description="Model config", discriminator="type")
],
info: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")],
) -> AnyModelConfig:
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
logger = ApiDependencies.invoker.services.logger
@ -145,17 +135,13 @@ async def del_model_record(
operation_id="add_model_record",
responses={
201: {"description": "The model added successfully"},
409: {
"description": "There is already a model corresponding to this path or repo_id"
},
409: {"description": "There is already a model corresponding to this path or repo_id"},
415: {"description": "Unrecognized file/folder format"},
},
status_code=201,
)
async def add_model_record(
config: Annotated[
AnyModelConfig, Body(description="Model config", discriminator="type")
]
config: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")]
) -> AnyModelConfig:
"""
Add a model using the configuration information appropriate for its type.

View File

@ -115,9 +115,7 @@ class ModelConfigBase(BaseModel):
description="current fasthash of model contents", default=None
) # if model is converted or otherwise modified, this will hold updated hash
description: Optional[str] = Field(default=None)
source: Optional[str] = Field(
description="Model download source (URL or repo_id)", default=None
)
source: Optional[str] = Field(description="Model download source (URL or repo_id)", default=None)
model_config = ConfigDict(
use_enum_values=False,
@ -251,19 +249,13 @@ class T2IConfig(ModelConfigBase):
format: Literal[ModelFormat.Diffusers]
_ONNXConfig = Annotated[
Union[ONNXSD1Config, ONNXSD2Config], Field(discriminator="base")
]
_ONNXConfig = Annotated[Union[ONNXSD1Config, ONNXSD2Config], Field(discriminator="base")]
_ControlNetConfig = Annotated[
Union[ControlNetDiffusersConfig, ControlNetCheckpointConfig],
Field(discriminator="format"),
]
_VaeConfig = Annotated[
Union[VaeDiffusersConfig, VaeCheckpointConfig], Field(discriminator="format")
]
_MainModelConfig = Annotated[
Union[MainDiffusersConfig, MainCheckpointConfig], Field(discriminator="format")
]
_VaeConfig = Annotated[Union[VaeDiffusersConfig, VaeCheckpointConfig], Field(discriminator="format")]
_MainModelConfig = Annotated[Union[MainDiffusersConfig, MainCheckpointConfig], Field(discriminator="format")]
AnyModelConfig = Union[
_MainModelConfig,

View File

@ -159,9 +159,7 @@ def test_filter(store: ModelRecordServiceBase):
assert len(matches) == 1
assert matches[0].name == "config3"
assert matches[0].key == sha256("config3".encode("utf-8")).hexdigest()
assert isinstance(
matches[0].type, ModelType
) # This tests that we get proper enums back
assert isinstance(matches[0].type, ModelType) # This tests that we get proper enums back
matches = store.search_by_hash("CONFIG1HASH")
assert len(matches) == 1