mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
blackify
This commit is contained in:
parent
024a156114
commit
ef8dcf5fae
@ -36,8 +36,8 @@ ModelsListValidator = TypeAdapter(ModelsList)
|
|||||||
operation_id="list_model_records",
|
operation_id="list_model_records",
|
||||||
)
|
)
|
||||||
async def list_model_records(
|
async def list_model_records(
|
||||||
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
|
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
|
||||||
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
|
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
|
||||||
) -> ModelsList:
|
) -> ModelsList:
|
||||||
"""Get a list of models."""
|
"""Get a list of models."""
|
||||||
record_store = ApiDependencies.invoker.services.model_records
|
record_store = ApiDependencies.invoker.services.model_records
|
||||||
@ -63,7 +63,7 @@ async def list_model_records(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
async def get_model_record(
|
async def get_model_record(
|
||||||
key: str = Path(description="Key of the model record to fetch."),
|
key: str = Path(description="Key of the model record to fetch."),
|
||||||
) -> AnyModelConfig:
|
) -> AnyModelConfig:
|
||||||
"""Get a model record"""
|
"""Get a model record"""
|
||||||
record_store = ApiDependencies.invoker.services.model_records
|
record_store = ApiDependencies.invoker.services.model_records
|
||||||
@ -86,8 +86,8 @@ async def get_model_record(
|
|||||||
response_model=AnyModelConfig,
|
response_model=AnyModelConfig,
|
||||||
)
|
)
|
||||||
async def update_model_record(
|
async def update_model_record(
|
||||||
key: Annotated[str, Path(description="Unique key of model")],
|
key: Annotated[str, Path(description="Unique key of model")],
|
||||||
info: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")]
|
info: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")],
|
||||||
) -> AnyModelConfig:
|
) -> AnyModelConfig:
|
||||||
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
|
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = ApiDependencies.invoker.services.logger
|
||||||
@ -135,7 +135,7 @@ async def del_model_record(
|
|||||||
status_code=201,
|
status_code=201,
|
||||||
)
|
)
|
||||||
async def add_model_record(
|
async def add_model_record(
|
||||||
config: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")]
|
config: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")]
|
||||||
) -> AnyModelConfig:
|
) -> AnyModelConfig:
|
||||||
"""
|
"""
|
||||||
Add a model using the configuration information appropriate for its type.
|
Add a model using the configuration information appropriate for its type.
|
||||||
|
@ -49,7 +49,8 @@ def test_type(store: ModelRecordServiceBase):
|
|||||||
store.add_model("key1", config)
|
store.add_model("key1", config)
|
||||||
config1 = store.get_model("key1")
|
config1 = store.get_model("key1")
|
||||||
assert type(config1) == TextualInversionConfig
|
assert type(config1) == TextualInversionConfig
|
||||||
|
|
||||||
|
|
||||||
def test_add(store: ModelRecordServiceBase):
|
def test_add(store: ModelRecordServiceBase):
|
||||||
raw = dict(
|
raw = dict(
|
||||||
path="/tmp/foo.ckpt",
|
path="/tmp/foo.ckpt",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user