mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
added add_model_record and get_model_record to router api
This commit is contained in:
parent
edeea5237b
commit
72c34aea75
@ -4,12 +4,14 @@
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import Body, Path, Query
|
||||
from fastapi import Body, Path, Query, Response
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic import BaseModel, ConfigDict, TypeAdapter
|
||||
from random import randbytes
|
||||
from hashlib import sha1
|
||||
from starlette.exceptions import HTTPException
|
||||
|
||||
from invokeai.app.services.model_records import UnknownModelException
|
||||
from invokeai.app.services.model_records import UnknownModelException, DuplicateModelException, InvalidModelException
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
@ -53,6 +55,26 @@ async def list_model_records(
|
||||
return models
|
||||
|
||||
|
||||
@model_records_router.get(
|
||||
"/i/{key}",
|
||||
operation_id="get_model_record",
|
||||
responses={
|
||||
200: {"description": "Success"},
|
||||
400: {"description": "Bad request"},
|
||||
404: {"description": "The model could not be found"},
|
||||
},
|
||||
)
|
||||
async def get_model_record(
|
||||
key: str = Path(description="Key of the model record to fetch."),
|
||||
) -> AnyModelConfig:
|
||||
"""Get a model record"""
|
||||
record_store = ApiDependencies.invoker.services.model_records
|
||||
try:
|
||||
return record_store.get_model(key)
|
||||
except UnknownModelException as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@model_records_router.patch(
|
||||
"/i/{key}",
|
||||
operation_id="update_model_record",
|
||||
@ -80,3 +102,62 @@ async def update_model_record(
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
return model_response
|
||||
|
||||
|
||||
@model_records_router.delete(
|
||||
"/i/{key}",
|
||||
operation_id="del_model_record",
|
||||
responses={204: {"description": "Model deleted successfully"}, 404: {"description": "Model not found"}},
|
||||
status_code=204,
|
||||
response_model=None,
|
||||
)
|
||||
async def del_model_record(
|
||||
key: str = Path(description="Unique key of model to remove from model registry."),
|
||||
) -> Response:
|
||||
"""Delete Model"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
try:
|
||||
record_store = ApiDependencies.invoker.services.model_records
|
||||
record_store.del_model(key)
|
||||
logger.info(f"Deleted model: {key}")
|
||||
return Response(status_code=204)
|
||||
except UnknownModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@model_records_router.post(
|
||||
"/i/",
|
||||
operation_id="add_model_record",
|
||||
responses={
|
||||
201: {"description": "The model added successfully"},
|
||||
404: {"description": "The model could not be found"},
|
||||
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||
415: {"description": "Unrecognized file/folder format"},
|
||||
},
|
||||
status_code=201,
|
||||
response_model=AnyModelConfig,
|
||||
)
|
||||
async def add_model_record(
|
||||
config: AnyModelConfig = Body(description="Model configuration"),
|
||||
) -> AnyModelConfig:
|
||||
"""
|
||||
Add a model using the configuration information appropriate for its type.
|
||||
"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
record_store = ApiDependencies.invoker.services.model_records
|
||||
if config.key == "<NOKEY>":
|
||||
config.key = sha1(randbytes(100)).hexdigest()
|
||||
logger.info(f"Created model {config.key} for {config.name}")
|
||||
try:
|
||||
record_store.add_model(config.key, config)
|
||||
except DuplicateModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
except InvalidModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=415)
|
||||
|
||||
# now fetch it out
|
||||
return record_store.get_model(config.key)
|
||||
|
@ -1,2 +1,7 @@
|
||||
from .model_records_base import DuplicateModelException, ModelRecordServiceBase, UnknownModelException
|
||||
from .model_records_base import (
|
||||
DuplicateModelException,
|
||||
ModelRecordServiceBase,
|
||||
UnknownModelException,
|
||||
InvalidModelException,
|
||||
)
|
||||
from .model_records_sql import ModelRecordServiceSQL
|
||||
|
@ -236,7 +236,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
(key,),
|
||||
)
|
||||
if self._cursor.rowcount == 0:
|
||||
raise UnknownModelException
|
||||
raise UnknownModelException("model not found")
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
@ -267,7 +267,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
(record.base_model, record.type, record.name, record.path, json_serialized, key),
|
||||
)
|
||||
if self._cursor.rowcount == 0:
|
||||
raise UnknownModelException
|
||||
raise UnknownModelException("model not found")
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
@ -293,7 +293,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
)
|
||||
rows = self._cursor.fetchone()
|
||||
if not rows:
|
||||
raise UnknownModelException
|
||||
raise UnknownModelException("model not found")
|
||||
model = ModelConfigFactory.make_config(json.loads(rows[0]))
|
||||
return model
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user