mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fixed ruff formatting issues
This commit is contained in:
parent
38c1436f02
commit
acc0a29dca
@ -41,23 +41,15 @@ class ModelsList(BaseModel):
|
|||||||
operation_id="list_model_records",
|
operation_id="list_model_records",
|
||||||
)
|
)
|
||||||
async def list_model_records(
|
async def list_model_records(
|
||||||
base_models: Optional[List[BaseModelType]] = Query(
|
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
|
||||||
default=None, description="Base models to include"
|
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
|
||||||
),
|
|
||||||
model_type: Optional[ModelType] = Query(
|
|
||||||
default=None, description="The type of model to get"
|
|
||||||
),
|
|
||||||
) -> ModelsList:
|
) -> ModelsList:
|
||||||
"""Get a list of models."""
|
"""Get a list of models."""
|
||||||
record_store = ApiDependencies.invoker.services.model_records
|
record_store = ApiDependencies.invoker.services.model_records
|
||||||
found_models: list[AnyModelConfig] = []
|
found_models: list[AnyModelConfig] = []
|
||||||
if base_models:
|
if base_models:
|
||||||
for base_model in base_models:
|
for base_model in base_models:
|
||||||
found_models.extend(
|
found_models.extend(record_store.search_by_attr(base_model=base_model, model_type=model_type))
|
||||||
record_store.search_by_attr(
|
|
||||||
base_model=base_model, model_type=model_type
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
found_models.extend(record_store.search_by_attr(model_type=model_type))
|
found_models.extend(record_store.search_by_attr(model_type=model_type))
|
||||||
return ModelsList(models=found_models)
|
return ModelsList(models=found_models)
|
||||||
@ -97,9 +89,7 @@ async def get_model_record(
|
|||||||
)
|
)
|
||||||
async def update_model_record(
|
async def update_model_record(
|
||||||
key: Annotated[str, Path(description="Unique key of model")],
|
key: Annotated[str, Path(description="Unique key of model")],
|
||||||
info: Annotated[
|
info: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")],
|
||||||
AnyModelConfig, Body(description="Model config", discriminator="type")
|
|
||||||
],
|
|
||||||
) -> AnyModelConfig:
|
) -> AnyModelConfig:
|
||||||
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
|
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = ApiDependencies.invoker.services.logger
|
||||||
@ -145,17 +135,13 @@ async def del_model_record(
|
|||||||
operation_id="add_model_record",
|
operation_id="add_model_record",
|
||||||
responses={
|
responses={
|
||||||
201: {"description": "The model added successfully"},
|
201: {"description": "The model added successfully"},
|
||||||
409: {
|
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||||
"description": "There is already a model corresponding to this path or repo_id"
|
|
||||||
},
|
|
||||||
415: {"description": "Unrecognized file/folder format"},
|
415: {"description": "Unrecognized file/folder format"},
|
||||||
},
|
},
|
||||||
status_code=201,
|
status_code=201,
|
||||||
)
|
)
|
||||||
async def add_model_record(
|
async def add_model_record(
|
||||||
config: Annotated[
|
config: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")]
|
||||||
AnyModelConfig, Body(description="Model config", discriminator="type")
|
|
||||||
]
|
|
||||||
) -> AnyModelConfig:
|
) -> AnyModelConfig:
|
||||||
"""
|
"""
|
||||||
Add a model using the configuration information appropriate for its type.
|
Add a model using the configuration information appropriate for its type.
|
||||||
|
@ -115,9 +115,7 @@ class ModelConfigBase(BaseModel):
|
|||||||
description="current fasthash of model contents", default=None
|
description="current fasthash of model contents", default=None
|
||||||
) # if model is converted or otherwise modified, this will hold updated hash
|
) # if model is converted or otherwise modified, this will hold updated hash
|
||||||
description: Optional[str] = Field(default=None)
|
description: Optional[str] = Field(default=None)
|
||||||
source: Optional[str] = Field(
|
source: Optional[str] = Field(description="Model download source (URL or repo_id)", default=None)
|
||||||
description="Model download source (URL or repo_id)", default=None
|
|
||||||
)
|
|
||||||
|
|
||||||
model_config = ConfigDict(
|
model_config = ConfigDict(
|
||||||
use_enum_values=False,
|
use_enum_values=False,
|
||||||
@ -251,19 +249,13 @@ class T2IConfig(ModelConfigBase):
|
|||||||
format: Literal[ModelFormat.Diffusers]
|
format: Literal[ModelFormat.Diffusers]
|
||||||
|
|
||||||
|
|
||||||
_ONNXConfig = Annotated[
|
_ONNXConfig = Annotated[Union[ONNXSD1Config, ONNXSD2Config], Field(discriminator="base")]
|
||||||
Union[ONNXSD1Config, ONNXSD2Config], Field(discriminator="base")
|
|
||||||
]
|
|
||||||
_ControlNetConfig = Annotated[
|
_ControlNetConfig = Annotated[
|
||||||
Union[ControlNetDiffusersConfig, ControlNetCheckpointConfig],
|
Union[ControlNetDiffusersConfig, ControlNetCheckpointConfig],
|
||||||
Field(discriminator="format"),
|
Field(discriminator="format"),
|
||||||
]
|
]
|
||||||
_VaeConfig = Annotated[
|
_VaeConfig = Annotated[Union[VaeDiffusersConfig, VaeCheckpointConfig], Field(discriminator="format")]
|
||||||
Union[VaeDiffusersConfig, VaeCheckpointConfig], Field(discriminator="format")
|
_MainModelConfig = Annotated[Union[MainDiffusersConfig, MainCheckpointConfig], Field(discriminator="format")]
|
||||||
]
|
|
||||||
_MainModelConfig = Annotated[
|
|
||||||
Union[MainDiffusersConfig, MainCheckpointConfig], Field(discriminator="format")
|
|
||||||
]
|
|
||||||
|
|
||||||
AnyModelConfig = Union[
|
AnyModelConfig = Union[
|
||||||
_MainModelConfig,
|
_MainModelConfig,
|
||||||
|
@ -159,9 +159,7 @@ def test_filter(store: ModelRecordServiceBase):
|
|||||||
assert len(matches) == 1
|
assert len(matches) == 1
|
||||||
assert matches[0].name == "config3"
|
assert matches[0].name == "config3"
|
||||||
assert matches[0].key == sha256("config3".encode("utf-8")).hexdigest()
|
assert matches[0].key == sha256("config3".encode("utf-8")).hexdigest()
|
||||||
assert isinstance(
|
assert isinstance(matches[0].type, ModelType) # This tests that we get proper enums back
|
||||||
matches[0].type, ModelType
|
|
||||||
) # This tests that we get proper enums back
|
|
||||||
|
|
||||||
matches = store.search_by_hash("CONFIG1HASH")
|
matches = store.search_by_hash("CONFIG1HASH")
|
||||||
assert len(matches) == 1
|
assert len(matches) == 1
|
||||||
|
Loading…
x
Reference in New Issue
Block a user