This commit is contained in:
Lincoln Stein 2023-11-12 14:20:32 -05:00
parent 024a156114
commit ef8dcf5fae
2 changed files with 8 additions and 7 deletions

View File

@ -36,8 +36,8 @@ ModelsListValidator = TypeAdapter(ModelsList)
operation_id="list_model_records",
)
async def list_model_records(
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
) -> ModelsList:
"""Get a list of models."""
record_store = ApiDependencies.invoker.services.model_records
@ -63,7 +63,7 @@ async def list_model_records(
},
)
async def get_model_record(
key: str = Path(description="Key of the model record to fetch."),
key: str = Path(description="Key of the model record to fetch."),
) -> AnyModelConfig:
"""Get a model record"""
record_store = ApiDependencies.invoker.services.model_records
@ -86,8 +86,8 @@ async def get_model_record(
response_model=AnyModelConfig,
)
async def update_model_record(
key: Annotated[str, Path(description="Unique key of model")],
info: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")]
key: Annotated[str, Path(description="Unique key of model")],
info: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")],
) -> AnyModelConfig:
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
logger = ApiDependencies.invoker.services.logger
@ -135,7 +135,7 @@ async def del_model_record(
status_code=201,
)
async def add_model_record(
config: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")]
config: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")]
) -> AnyModelConfig:
"""
Add a model using the configuration information appropriate for its type.

View File

@ -49,7 +49,8 @@ def test_type(store: ModelRecordServiceBase):
store.add_model("key1", config)
config1 = store.get_model("key1")
assert type(config1) == TextualInversionConfig
def test_add(store: ModelRecordServiceBase):
raw = dict(
path="/tmp/foo.ckpt",