added add_model_record and get_model_record to router api

This commit is contained in:
Lincoln Stein 2023-11-04 23:42:44 -04:00
parent edeea5237b
commit 72c34aea75
3 changed files with 92 additions and 6 deletions

View File

@ -4,12 +4,14 @@
from typing import List, Optional
from fastapi import Body, Path, Query
from fastapi import Body, Path, Query, Response
from fastapi.routing import APIRouter
from pydantic import BaseModel, ConfigDict, TypeAdapter
from random import randbytes
from hashlib import sha1
from starlette.exceptions import HTTPException
from invokeai.app.services.model_records import UnknownModelException
from invokeai.app.services.model_records import UnknownModelException, DuplicateModelException, InvalidModelException
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType
from ..dependencies import ApiDependencies
@ -53,6 +55,26 @@ async def list_model_records(
return models
@model_records_router.get(
"/i/{key}",
operation_id="get_model_record",
responses={
200: {"description": "Success"},
400: {"description": "Bad request"},
404: {"description": "The model could not be found"},
},
)
async def get_model_record(
key: str = Path(description="Key of the model record to fetch."),
) -> AnyModelConfig:
"""Get a model record"""
record_store = ApiDependencies.invoker.services.model_records
try:
return record_store.get_model(key)
except UnknownModelException as e:
raise HTTPException(status_code=404, detail=str(e))
@model_records_router.patch(
"/i/{key}",
operation_id="update_model_record",
@ -80,3 +102,62 @@ async def update_model_record(
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
return model_response
@model_records_router.delete(
"/i/{key}",
operation_id="del_model_record",
responses={204: {"description": "Model deleted successfully"}, 404: {"description": "Model not found"}},
status_code=204,
response_model=None,
)
async def del_model_record(
key: str = Path(description="Unique key of model to remove from model registry."),
) -> Response:
"""Delete Model"""
logger = ApiDependencies.invoker.services.logger
try:
record_store = ApiDependencies.invoker.services.model_records
record_store.del_model(key)
logger.info(f"Deleted model: {key}")
return Response(status_code=204)
except UnknownModelException as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))
@model_records_router.post(
"/i/",
operation_id="add_model_record",
responses={
201: {"description": "The model added successfully"},
404: {"description": "The model could not be found"},
409: {"description": "There is already a model corresponding to this path or repo_id"},
415: {"description": "Unrecognized file/folder format"},
},
status_code=201,
response_model=AnyModelConfig,
)
async def add_model_record(
config: AnyModelConfig = Body(description="Model configuration"),
) -> AnyModelConfig:
"""
Add a model using the configuration information appropriate for its type.
"""
logger = ApiDependencies.invoker.services.logger
record_store = ApiDependencies.invoker.services.model_records
if config.key == "<NOKEY>":
config.key = sha1(randbytes(100)).hexdigest()
logger.info(f"Created model {config.key} for {config.name}")
try:
record_store.add_model(config.key, config)
except DuplicateModelException as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
except InvalidModelException as e:
logger.error(str(e))
raise HTTPException(status_code=415)
# now fetch it out
return record_store.get_model(config.key)

View File

@ -1,2 +1,7 @@
from .model_records_base import DuplicateModelException, ModelRecordServiceBase, UnknownModelException
from .model_records_base import (
DuplicateModelException,
ModelRecordServiceBase,
UnknownModelException,
InvalidModelException,
)
from .model_records_sql import ModelRecordServiceSQL

View File

@ -236,7 +236,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
(key,),
)
if self._cursor.rowcount == 0:
raise UnknownModelException
raise UnknownModelException("model not found")
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
@ -267,7 +267,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
(record.base_model, record.type, record.name, record.path, json_serialized, key),
)
if self._cursor.rowcount == 0:
raise UnknownModelException
raise UnknownModelException("model not found")
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
@ -293,7 +293,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
)
rows = self._cursor.fetchone()
if not rows:
raise UnknownModelException
raise UnknownModelException("model not found")
model = ModelConfigFactory.make_config(json.loads(rows[0]))
return model