From 72c34aea751787825c7373a99ed6a754acbc54c5 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sat, 4 Nov 2023 23:42:44 -0400 Subject: [PATCH] added add_model_record and get_model_record to router api --- invokeai/app/api/routers/model_records.py | 85 ++++++++++++++++++- .../app/services/model_records/__init__.py | 7 +- .../model_records/model_records_sql.py | 6 +- 3 files changed, 92 insertions(+), 6 deletions(-) diff --git a/invokeai/app/api/routers/model_records.py b/invokeai/app/api/routers/model_records.py index e36b002d1a..f144f87e01 100644 --- a/invokeai/app/api/routers/model_records.py +++ b/invokeai/app/api/routers/model_records.py @@ -4,12 +4,14 @@ from typing import List, Optional -from fastapi import Body, Path, Query +from fastapi import Body, Path, Query, Response from fastapi.routing import APIRouter from pydantic import BaseModel, ConfigDict, TypeAdapter +from random import randbytes +from hashlib import sha1 from starlette.exceptions import HTTPException -from invokeai.app.services.model_records import UnknownModelException +from invokeai.app.services.model_records import UnknownModelException, DuplicateModelException, InvalidModelException from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType from ..dependencies import ApiDependencies @@ -53,6 +55,26 @@ async def list_model_records( return models +@model_records_router.get( + "/i/{key}", + operation_id="get_model_record", + responses={ + 200: {"description": "Success"}, + 400: {"description": "Bad request"}, + 404: {"description": "The model could not be found"}, + }, +) +async def get_model_record( + key: str = Path(description="Key of the model record to fetch."), +) -> AnyModelConfig: + """Get a model record""" + record_store = ApiDependencies.invoker.services.model_records + try: + return record_store.get_model(key) + except UnknownModelException as e: + raise HTTPException(status_code=404, detail=str(e)) + + @model_records_router.patch( "/i/{key}", operation_id="update_model_record", @@ -80,3 +102,62 @@ async def update_model_record( logger.error(str(e)) raise HTTPException(status_code=409, detail=str(e)) return model_response + + +@model_records_router.delete( + "/i/{key}", + operation_id="del_model_record", + responses={204: {"description": "Model deleted successfully"}, 404: {"description": "Model not found"}}, + status_code=204, + response_model=None, +) +async def del_model_record( + key: str = Path(description="Unique key of model to remove from model registry."), +) -> Response: + """Delete Model""" + logger = ApiDependencies.invoker.services.logger + + try: + record_store = ApiDependencies.invoker.services.model_records + record_store.del_model(key) + logger.info(f"Deleted model: {key}") + return Response(status_code=204) + except UnknownModelException as e: + logger.error(str(e)) + raise HTTPException(status_code=404, detail=str(e)) + + +@model_records_router.post( + "/i/", + operation_id="add_model_record", + responses={ + 201: {"description": "The model added successfully"}, + 404: {"description": "The model could not be found"}, + 409: {"description": "There is already a model corresponding to this path or repo_id"}, + 415: {"description": "Unrecognized file/folder format"}, + }, + status_code=201, + response_model=AnyModelConfig, +) +async def add_model_record( + config: AnyModelConfig = Body(description="Model configuration"), +) -> AnyModelConfig: + """ + Add a model using the configuration information appropriate for its type. + """ + logger = ApiDependencies.invoker.services.logger + record_store = ApiDependencies.invoker.services.model_records + if config.key == "": + config.key = sha1(randbytes(100)).hexdigest() + logger.info(f"Created model {config.key} for {config.name}") + try: + record_store.add_model(config.key, config) + except DuplicateModelException as e: + logger.error(str(e)) + raise HTTPException(status_code=409, detail=str(e)) + except InvalidModelException as e: + logger.error(str(e)) + raise HTTPException(status_code=415) + + # now fetch it out + return record_store.get_model(config.key) diff --git a/invokeai/app/services/model_records/__init__.py b/invokeai/app/services/model_records/__init__.py index 15f22ff099..7f181af5aa 100644 --- a/invokeai/app/services/model_records/__init__.py +++ b/invokeai/app/services/model_records/__init__.py @@ -1,2 +1,7 @@ -from .model_records_base import DuplicateModelException, ModelRecordServiceBase, UnknownModelException +from .model_records_base import ( + DuplicateModelException, + ModelRecordServiceBase, + UnknownModelException, + InvalidModelException, +) from .model_records_sql import ModelRecordServiceSQL diff --git a/invokeai/app/services/model_records/model_records_sql.py b/invokeai/app/services/model_records/model_records_sql.py index 7fe0a07727..897047a518 100644 --- a/invokeai/app/services/model_records/model_records_sql.py +++ b/invokeai/app/services/model_records/model_records_sql.py @@ -236,7 +236,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): (key,), ) if self._cursor.rowcount == 0: - raise UnknownModelException + raise UnknownModelException("model not found") self._conn.commit() except sqlite3.Error as e: self._conn.rollback() @@ -267,7 +267,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): (record.base_model, record.type, record.name, record.path, json_serialized, key), ) if self._cursor.rowcount == 0: - raise UnknownModelException + raise UnknownModelException("model not found") self._conn.commit() except sqlite3.Error as e: self._conn.rollback() @@ -293,7 +293,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): ) rows = self._cursor.fetchone() if not rows: - raise UnknownModelException + raise UnknownModelException("model not found") model = ModelConfigFactory.make_config(json.loads(rows[0])) return model