mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fixed ruff formatting issues
This commit is contained in:
parent
38c1436f02
commit
acc0a29dca
@ -41,23 +41,15 @@ class ModelsList(BaseModel):
|
||||
operation_id="list_model_records",
|
||||
)
|
||||
async def list_model_records(
|
||||
base_models: Optional[List[BaseModelType]] = Query(
|
||||
default=None, description="Base models to include"
|
||||
),
|
||||
model_type: Optional[ModelType] = Query(
|
||||
default=None, description="The type of model to get"
|
||||
),
|
||||
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
|
||||
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
|
||||
) -> ModelsList:
|
||||
"""Get a list of models."""
|
||||
record_store = ApiDependencies.invoker.services.model_records
|
||||
found_models: list[AnyModelConfig] = []
|
||||
if base_models:
|
||||
for base_model in base_models:
|
||||
found_models.extend(
|
||||
record_store.search_by_attr(
|
||||
base_model=base_model, model_type=model_type
|
||||
)
|
||||
)
|
||||
found_models.extend(record_store.search_by_attr(base_model=base_model, model_type=model_type))
|
||||
else:
|
||||
found_models.extend(record_store.search_by_attr(model_type=model_type))
|
||||
return ModelsList(models=found_models)
|
||||
@ -97,9 +89,7 @@ async def get_model_record(
|
||||
)
|
||||
async def update_model_record(
|
||||
key: Annotated[str, Path(description="Unique key of model")],
|
||||
info: Annotated[
|
||||
AnyModelConfig, Body(description="Model config", discriminator="type")
|
||||
],
|
||||
info: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")],
|
||||
) -> AnyModelConfig:
|
||||
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
@ -145,17 +135,13 @@ async def del_model_record(
|
||||
operation_id="add_model_record",
|
||||
responses={
|
||||
201: {"description": "The model added successfully"},
|
||||
409: {
|
||||
"description": "There is already a model corresponding to this path or repo_id"
|
||||
},
|
||||
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||
415: {"description": "Unrecognized file/folder format"},
|
||||
},
|
||||
status_code=201,
|
||||
)
|
||||
async def add_model_record(
|
||||
config: Annotated[
|
||||
AnyModelConfig, Body(description="Model config", discriminator="type")
|
||||
]
|
||||
config: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")]
|
||||
) -> AnyModelConfig:
|
||||
"""
|
||||
Add a model using the configuration information appropriate for its type.
|
||||
|
@ -115,9 +115,7 @@ class ModelConfigBase(BaseModel):
|
||||
description="current fasthash of model contents", default=None
|
||||
) # if model is converted or otherwise modified, this will hold updated hash
|
||||
description: Optional[str] = Field(default=None)
|
||||
source: Optional[str] = Field(
|
||||
description="Model download source (URL or repo_id)", default=None
|
||||
)
|
||||
source: Optional[str] = Field(description="Model download source (URL or repo_id)", default=None)
|
||||
|
||||
model_config = ConfigDict(
|
||||
use_enum_values=False,
|
||||
@ -251,19 +249,13 @@ class T2IConfig(ModelConfigBase):
|
||||
format: Literal[ModelFormat.Diffusers]
|
||||
|
||||
|
||||
_ONNXConfig = Annotated[
|
||||
Union[ONNXSD1Config, ONNXSD2Config], Field(discriminator="base")
|
||||
]
|
||||
_ONNXConfig = Annotated[Union[ONNXSD1Config, ONNXSD2Config], Field(discriminator="base")]
|
||||
_ControlNetConfig = Annotated[
|
||||
Union[ControlNetDiffusersConfig, ControlNetCheckpointConfig],
|
||||
Field(discriminator="format"),
|
||||
]
|
||||
_VaeConfig = Annotated[
|
||||
Union[VaeDiffusersConfig, VaeCheckpointConfig], Field(discriminator="format")
|
||||
]
|
||||
_MainModelConfig = Annotated[
|
||||
Union[MainDiffusersConfig, MainCheckpointConfig], Field(discriminator="format")
|
||||
]
|
||||
_VaeConfig = Annotated[Union[VaeDiffusersConfig, VaeCheckpointConfig], Field(discriminator="format")]
|
||||
_MainModelConfig = Annotated[Union[MainDiffusersConfig, MainCheckpointConfig], Field(discriminator="format")]
|
||||
|
||||
AnyModelConfig = Union[
|
||||
_MainModelConfig,
|
||||
|
@ -159,9 +159,7 @@ def test_filter(store: ModelRecordServiceBase):
|
||||
assert len(matches) == 1
|
||||
assert matches[0].name == "config3"
|
||||
assert matches[0].key == sha256("config3".encode("utf-8")).hexdigest()
|
||||
assert isinstance(
|
||||
matches[0].type, ModelType
|
||||
) # This tests that we get proper enums back
|
||||
assert isinstance(matches[0].type, ModelType) # This tests that we get proper enums back
|
||||
|
||||
matches = store.search_by_hash("CONFIG1HASH")
|
||||
assert len(matches) == 1
|
||||
|
Loading…
Reference in New Issue
Block a user