mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
added add_model_record and get_model_record to router api
This commit is contained in:
parent
edeea5237b
commit
72c34aea75
@ -4,12 +4,14 @@
|
|||||||
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from fastapi import Body, Path, Query
|
from fastapi import Body, Path, Query, Response
|
||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
from pydantic import BaseModel, ConfigDict, TypeAdapter
|
from pydantic import BaseModel, ConfigDict, TypeAdapter
|
||||||
|
from random import randbytes
|
||||||
|
from hashlib import sha1
|
||||||
from starlette.exceptions import HTTPException
|
from starlette.exceptions import HTTPException
|
||||||
|
|
||||||
from invokeai.app.services.model_records import UnknownModelException
|
from invokeai.app.services.model_records import UnknownModelException, DuplicateModelException, InvalidModelException
|
||||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType
|
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType
|
||||||
|
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
@ -53,6 +55,26 @@ async def list_model_records(
|
|||||||
return models
|
return models
|
||||||
|
|
||||||
|
|
||||||
|
@model_records_router.get(
|
||||||
|
"/i/{key}",
|
||||||
|
operation_id="get_model_record",
|
||||||
|
responses={
|
||||||
|
200: {"description": "Success"},
|
||||||
|
400: {"description": "Bad request"},
|
||||||
|
404: {"description": "The model could not be found"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def get_model_record(
|
||||||
|
key: str = Path(description="Key of the model record to fetch."),
|
||||||
|
) -> AnyModelConfig:
|
||||||
|
"""Get a model record"""
|
||||||
|
record_store = ApiDependencies.invoker.services.model_records
|
||||||
|
try:
|
||||||
|
return record_store.get_model(key)
|
||||||
|
except UnknownModelException as e:
|
||||||
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
@model_records_router.patch(
|
@model_records_router.patch(
|
||||||
"/i/{key}",
|
"/i/{key}",
|
||||||
operation_id="update_model_record",
|
operation_id="update_model_record",
|
||||||
@ -80,3 +102,62 @@ async def update_model_record(
|
|||||||
logger.error(str(e))
|
logger.error(str(e))
|
||||||
raise HTTPException(status_code=409, detail=str(e))
|
raise HTTPException(status_code=409, detail=str(e))
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
|
|
||||||
|
@model_records_router.delete(
|
||||||
|
"/i/{key}",
|
||||||
|
operation_id="del_model_record",
|
||||||
|
responses={204: {"description": "Model deleted successfully"}, 404: {"description": "Model not found"}},
|
||||||
|
status_code=204,
|
||||||
|
response_model=None,
|
||||||
|
)
|
||||||
|
async def del_model_record(
|
||||||
|
key: str = Path(description="Unique key of model to remove from model registry."),
|
||||||
|
) -> Response:
|
||||||
|
"""Delete Model"""
|
||||||
|
logger = ApiDependencies.invoker.services.logger
|
||||||
|
|
||||||
|
try:
|
||||||
|
record_store = ApiDependencies.invoker.services.model_records
|
||||||
|
record_store.del_model(key)
|
||||||
|
logger.info(f"Deleted model: {key}")
|
||||||
|
return Response(status_code=204)
|
||||||
|
except UnknownModelException as e:
|
||||||
|
logger.error(str(e))
|
||||||
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@model_records_router.post(
|
||||||
|
"/i/",
|
||||||
|
operation_id="add_model_record",
|
||||||
|
responses={
|
||||||
|
201: {"description": "The model added successfully"},
|
||||||
|
404: {"description": "The model could not be found"},
|
||||||
|
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||||
|
415: {"description": "Unrecognized file/folder format"},
|
||||||
|
},
|
||||||
|
status_code=201,
|
||||||
|
response_model=AnyModelConfig,
|
||||||
|
)
|
||||||
|
async def add_model_record(
|
||||||
|
config: AnyModelConfig = Body(description="Model configuration"),
|
||||||
|
) -> AnyModelConfig:
|
||||||
|
"""
|
||||||
|
Add a model using the configuration information appropriate for its type.
|
||||||
|
"""
|
||||||
|
logger = ApiDependencies.invoker.services.logger
|
||||||
|
record_store = ApiDependencies.invoker.services.model_records
|
||||||
|
if config.key == "<NOKEY>":
|
||||||
|
config.key = sha1(randbytes(100)).hexdigest()
|
||||||
|
logger.info(f"Created model {config.key} for {config.name}")
|
||||||
|
try:
|
||||||
|
record_store.add_model(config.key, config)
|
||||||
|
except DuplicateModelException as e:
|
||||||
|
logger.error(str(e))
|
||||||
|
raise HTTPException(status_code=409, detail=str(e))
|
||||||
|
except InvalidModelException as e:
|
||||||
|
logger.error(str(e))
|
||||||
|
raise HTTPException(status_code=415)
|
||||||
|
|
||||||
|
# now fetch it out
|
||||||
|
return record_store.get_model(config.key)
|
||||||
|
@ -1,2 +1,7 @@
|
|||||||
from .model_records_base import DuplicateModelException, ModelRecordServiceBase, UnknownModelException
|
from .model_records_base import (
|
||||||
|
DuplicateModelException,
|
||||||
|
ModelRecordServiceBase,
|
||||||
|
UnknownModelException,
|
||||||
|
InvalidModelException,
|
||||||
|
)
|
||||||
from .model_records_sql import ModelRecordServiceSQL
|
from .model_records_sql import ModelRecordServiceSQL
|
||||||
|
@ -236,7 +236,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
(key,),
|
(key,),
|
||||||
)
|
)
|
||||||
if self._cursor.rowcount == 0:
|
if self._cursor.rowcount == 0:
|
||||||
raise UnknownModelException
|
raise UnknownModelException("model not found")
|
||||||
self._conn.commit()
|
self._conn.commit()
|
||||||
except sqlite3.Error as e:
|
except sqlite3.Error as e:
|
||||||
self._conn.rollback()
|
self._conn.rollback()
|
||||||
@ -267,7 +267,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
(record.base_model, record.type, record.name, record.path, json_serialized, key),
|
(record.base_model, record.type, record.name, record.path, json_serialized, key),
|
||||||
)
|
)
|
||||||
if self._cursor.rowcount == 0:
|
if self._cursor.rowcount == 0:
|
||||||
raise UnknownModelException
|
raise UnknownModelException("model not found")
|
||||||
self._conn.commit()
|
self._conn.commit()
|
||||||
except sqlite3.Error as e:
|
except sqlite3.Error as e:
|
||||||
self._conn.rollback()
|
self._conn.rollback()
|
||||||
@ -293,7 +293,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
)
|
)
|
||||||
rows = self._cursor.fetchone()
|
rows = self._cursor.fetchone()
|
||||||
if not rows:
|
if not rows:
|
||||||
raise UnknownModelException
|
raise UnknownModelException("model not found")
|
||||||
model = ModelConfigFactory.make_config(json.loads(rows[0]))
|
model = ModelConfigFactory.make_config(json.loads(rows[0]))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user