mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat: add itemstorage retrieval as dict (skip pydantic)
- Create `PaginatedDictResults`, a non-type-safe, non-generic version of `PaginatedResults`. (because `PaginatedResults` is a `pydantic.GenericlModel`, it requires a pydantic model as the generic type and is not suitable) - Add `ItemStorageABC` and `SqliteItemStorage` methods to return the dict representation of items - The existing methods are unchanged in what they output, but now they use the `dict` methods to retrieve items before parsing them - `ImagesService` and some metadata stuff is updated to use the appropriate methods - Sessions router updated to use the dict versions - Client types regenerated
This commit is contained in:
parent
832335998f
commit
0022e4d95d
@ -1,5 +1,6 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import json
|
||||
from typing import Annotated, List, Optional, Union
|
||||
|
||||
from fastapi import Body, HTTPException, Path, Query, Response
|
||||
@ -8,14 +9,8 @@ from pydantic.fields import Field
|
||||
|
||||
from ...invocations import *
|
||||
from ...invocations.baseinvocation import BaseInvocation
|
||||
from ...services.graph import (
|
||||
Edge,
|
||||
EdgeConnection,
|
||||
Graph,
|
||||
GraphExecutionState,
|
||||
NodeAlreadyExecutedError,
|
||||
)
|
||||
from ...services.item_storage import PaginatedResults
|
||||
from ...services.graph import Edge, EdgeConnection, Graph, GraphExecutionState, NodeAlreadyExecutedError
|
||||
from ...services.item_storage import PaginatedDictResults, PaginatedResults
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
session_router = APIRouter(prefix="/v1/sessions", tags=["sessions"])
|
||||
@ -40,18 +35,18 @@ async def create_session(
|
||||
@session_router.get(
|
||||
"/",
|
||||
operation_id="list_sessions",
|
||||
responses={200: {"model": PaginatedResults[GraphExecutionState]}},
|
||||
responses={200: {"model": PaginatedDictResults}},
|
||||
)
|
||||
async def list_sessions(
|
||||
page: int = Query(default=0, description="The page of results to get"),
|
||||
per_page: int = Query(default=10, description="The number of results per page"),
|
||||
query: str = Query(default="", description="The query string to search for"),
|
||||
) -> PaginatedResults[GraphExecutionState]:
|
||||
) -> PaginatedDictResults:
|
||||
"""Gets a list of sessions, optionally searching"""
|
||||
if query == "":
|
||||
result = ApiDependencies.invoker.services.graph_execution_manager.list(page, per_page)
|
||||
result = ApiDependencies.invoker.services.graph_execution_manager.list_as_dict(page, per_page)
|
||||
else:
|
||||
result = ApiDependencies.invoker.services.graph_execution_manager.search(query, page, per_page)
|
||||
result = ApiDependencies.invoker.services.graph_execution_manager.search_as_dict(query, page, per_page)
|
||||
return result
|
||||
|
||||
|
||||
@ -59,19 +54,18 @@ async def list_sessions(
|
||||
"/{session_id}",
|
||||
operation_id="get_session",
|
||||
responses={
|
||||
200: {"model": GraphExecutionState},
|
||||
200: {"model": dict},
|
||||
404: {"description": "Session not found"},
|
||||
},
|
||||
)
|
||||
async def get_session(
|
||||
session_id: str = Path(description="The id of the session to get"),
|
||||
) -> GraphExecutionState:
|
||||
) -> dict:
|
||||
"""Gets a session"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
session_dict = ApiDependencies.invoker.services.graph_execution_manager.get_as_dict(session_id)
|
||||
if session_dict is None:
|
||||
raise HTTPException(status_code=404)
|
||||
else:
|
||||
return session
|
||||
return session_dict
|
||||
|
||||
|
||||
@session_router.post(
|
||||
|
@ -35,7 +35,7 @@ from invokeai.app.services.models.image_record import (
|
||||
)
|
||||
from invokeai.app.services.resource_name import NameServiceBase
|
||||
from invokeai.app.services.urls import UrlServiceBase
|
||||
from invokeai.app.util.metadata import get_metadata_graph_from_raw_session
|
||||
from invokeai.app.util.metadata import get_metadata_graph_from_session_dict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.services.graph import GraphExecutionState
|
||||
@ -190,10 +190,10 @@ class ImageService(ImageServiceABC):
|
||||
graph = None
|
||||
|
||||
if session_id is not None:
|
||||
session_raw = self._services.graph_execution_manager.get_raw(session_id)
|
||||
session_raw = self._services.graph_execution_manager.get_as_dict(session_id)
|
||||
if session_raw is not None:
|
||||
try:
|
||||
graph = get_metadata_graph_from_raw_session(session_raw)
|
||||
graph = get_metadata_graph_from_session_dict(session_raw)
|
||||
except Exception as e:
|
||||
self._services.logger.warn(f"Failed to parse session graph: {e}")
|
||||
graph = None
|
||||
@ -294,12 +294,12 @@ class ImageService(ImageServiceABC):
|
||||
if not image_record.session_id:
|
||||
return ImageMetadata(metadata=metadata)
|
||||
|
||||
session_raw = self._services.graph_execution_manager.get_raw(image_record.session_id)
|
||||
session_raw = self._services.graph_execution_manager.get_as_dict(image_record.session_id)
|
||||
graph = None
|
||||
|
||||
if session_raw:
|
||||
try:
|
||||
graph = get_metadata_graph_from_raw_session(session_raw)
|
||||
graph = get_metadata_graph_from_session_dict(session_raw)
|
||||
except Exception as e:
|
||||
self._services.logger.warn(f"Failed to parse session graph: {e}")
|
||||
graph = None
|
||||
|
@ -19,6 +19,18 @@ class PaginatedResults(GenericModel, Generic[T]):
|
||||
# fmt: on
|
||||
|
||||
|
||||
class PaginatedDictResults(BaseModel):
|
||||
"""Paginated raw results (dict)"""
|
||||
|
||||
# fmt: off
|
||||
items: list[dict] = Field(description="Items")
|
||||
page: int = Field(description="Current Page")
|
||||
pages: int = Field(description="Total number of pages")
|
||||
per_page: int = Field(description="Number of items per page")
|
||||
total: int = Field(description="Total number of items in result")
|
||||
# fmt: on
|
||||
|
||||
|
||||
class ItemStorageABC(ABC, Generic[T]):
|
||||
_on_changed_callbacks: list[Callable[[T], None]]
|
||||
_on_deleted_callbacks: list[Callable[[str], None]]
|
||||
@ -35,8 +47,8 @@ class ItemStorageABC(ABC, Generic[T]):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_raw(self, item_id: str) -> Optional[str]:
|
||||
"""Gets the raw item as a string, skipping Pydantic parsing"""
|
||||
def get_as_dict(self, item_id: str) -> Optional[dict]:
|
||||
"""Gets the item as a dict, skipping Pydantic parsing"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@ -49,10 +61,19 @@ class ItemStorageABC(ABC, Generic[T]):
|
||||
"""Gets a paginated list of items"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_as_dict(self, page: int = 0, per_page: int = 10) -> PaginatedDictResults:
|
||||
"""Gets a paginated list of items, skipping Pydantic parsing"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search_as_dict(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedDictResults:
|
||||
pass
|
||||
|
||||
def on_changed(self, on_changed: Callable[[T], None]) -> None:
|
||||
"""Register a callback for when an item is changed"""
|
||||
self._on_changed_callbacks.append(on_changed)
|
||||
|
@ -1,10 +1,11 @@
|
||||
import json
|
||||
import sqlite3
|
||||
from threading import Lock
|
||||
from typing import Generic, Optional, TypeVar, get_args
|
||||
|
||||
from pydantic import BaseModel, parse_raw_as
|
||||
from pydantic import BaseModel, parse_obj_as, parse_raw_as
|
||||
|
||||
from .item_storage import ItemStorageABC, PaginatedResults
|
||||
from .item_storage import ItemStorageABC, PaginatedDictResults, PaginatedResults
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
@ -47,9 +48,9 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _parse_item(self, item: str) -> T:
|
||||
def _parse_item_from_dict(self, item: dict) -> T:
|
||||
item_type = get_args(self.__orig_class__)[0]
|
||||
parsed = parse_raw_as(item_type, item)
|
||||
parsed = parse_obj_as(item_type, item)
|
||||
return parsed
|
||||
|
||||
def set(self, item: T):
|
||||
@ -64,31 +65,25 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
self._lock.release()
|
||||
self._on_changed(item)
|
||||
|
||||
def get_as_dict(self, id: str) -> Optional[dict]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),))
|
||||
result = self._cursor.fetchone()
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
if not result:
|
||||
return None
|
||||
|
||||
return json.loads(result[0])
|
||||
|
||||
def get(self, id: str) -> Optional[T]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),))
|
||||
result = self._cursor.fetchone()
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
if not result:
|
||||
item = self.get_as_dict(id)
|
||||
if not item:
|
||||
return None
|
||||
|
||||
return self._parse_item(result[0])
|
||||
|
||||
def get_raw(self, id: str) -> Optional[str]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),))
|
||||
result = self._cursor.fetchone()
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
if not result:
|
||||
return None
|
||||
|
||||
return result[0]
|
||||
return self._parse_item_from_dict(item)
|
||||
|
||||
def delete(self, id: str):
|
||||
try:
|
||||
@ -99,7 +94,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
self._lock.release()
|
||||
self._on_deleted(id)
|
||||
|
||||
def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
|
||||
def list_as_dict(self, page: int = 0, per_page: int = 10) -> PaginatedDictResults:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
@ -108,7 +103,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
)
|
||||
result = self._cursor.fetchall()
|
||||
|
||||
items = list(map(lambda r: self._parse_item(r[0]), result))
|
||||
items = list(map(lambda r: json.loads(r[0]), result))
|
||||
|
||||
self._cursor.execute(f"""SELECT count(*) FROM {self._table_name};""")
|
||||
count = self._cursor.fetchone()[0]
|
||||
@ -117,9 +112,20 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
|
||||
pageCount = int(count / per_page) + 1
|
||||
|
||||
return PaginatedResults[T](items=items, page=page, pages=pageCount, per_page=per_page, total=count)
|
||||
return PaginatedDictResults(items=items, page=page, pages=pageCount, per_page=per_page, total=count)
|
||||
|
||||
def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
|
||||
def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
|
||||
paginated_raw_results = self.list_as_dict(page, per_page)
|
||||
items = list(map(lambda r: self._parse_item_from_dict(r), paginated_raw_results.items))
|
||||
return PaginatedResults[T](
|
||||
items=items,
|
||||
page=page,
|
||||
pages=paginated_raw_results.pages,
|
||||
per_page=per_page,
|
||||
total=paginated_raw_results.total,
|
||||
)
|
||||
|
||||
def search_as_dict(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedDictResults:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
@ -128,7 +134,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
)
|
||||
result = self._cursor.fetchall()
|
||||
|
||||
items = list(map(lambda r: self._parse_item(r[0]), result))
|
||||
items = list(map(lambda r: json.loads(r[0]), result))
|
||||
|
||||
self._cursor.execute(
|
||||
f"""SELECT count(*) FROM {self._table_name} WHERE item LIKE ?;""",
|
||||
@ -140,4 +146,15 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
|
||||
pageCount = int(count / per_page) + 1
|
||||
|
||||
return PaginatedResults[T](items=items, page=page, pages=pageCount, per_page=per_page, total=count)
|
||||
return PaginatedDictResults(items=items, page=page, pages=pageCount, per_page=per_page, total=count)
|
||||
|
||||
def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
|
||||
paginated_raw_results = self.search_as_dict(query, page, per_page)
|
||||
items = list(map(lambda r: self._parse_item_from_dict(r), paginated_raw_results.items))
|
||||
return PaginatedResults[T](
|
||||
items=items,
|
||||
page=page,
|
||||
pages=paginated_raw_results.pages,
|
||||
per_page=per_page,
|
||||
total=paginated_raw_results.total,
|
||||
)
|
||||
|
@ -6,7 +6,7 @@ from pydantic import ValidationError
|
||||
from invokeai.app.services.graph import Edge
|
||||
|
||||
|
||||
def get_metadata_graph_from_raw_session(session_raw: str) -> Optional[dict]:
|
||||
def get_metadata_graph_from_session_dict(session_dict: dict) -> Optional[dict]:
|
||||
"""
|
||||
Parses raw session string, returning a dict of the graph.
|
||||
|
||||
@ -17,7 +17,7 @@ def get_metadata_graph_from_raw_session(session_raw: str) -> Optional[dict]:
|
||||
Any validation failure will return None.
|
||||
"""
|
||||
|
||||
graph = json.loads(session_raw).get("graph", None)
|
||||
graph = session_dict.get("graph", None)
|
||||
|
||||
# sanity check make sure the graph is at least reasonably shaped
|
||||
if (
|
||||
|
@ -4755,15 +4755,15 @@ export type components = {
|
||||
image_resolution?: number;
|
||||
};
|
||||
/**
|
||||
* PaginatedResults[GraphExecutionState]
|
||||
* @description Paginated results
|
||||
* PaginatedDictResults
|
||||
* @description Paginated raw results (dict)
|
||||
*/
|
||||
PaginatedResults_GraphExecutionState_: {
|
||||
PaginatedDictResults: {
|
||||
/**
|
||||
* Items
|
||||
* @description Items
|
||||
*/
|
||||
items: (components["schemas"]["GraphExecutionState"])[];
|
||||
items: (Record<string, never>)[];
|
||||
/**
|
||||
* Page
|
||||
* @description Current Page
|
||||
@ -6193,12 +6193,6 @@ export type components = {
|
||||
ui_hidden: boolean;
|
||||
ui_type?: components["schemas"]["UIType"];
|
||||
};
|
||||
/**
|
||||
* StableDiffusionOnnxModelFormat
|
||||
* @description An enumeration.
|
||||
* @enum {string}
|
||||
*/
|
||||
StableDiffusionOnnxModelFormat: "olive" | "onnx";
|
||||
/**
|
||||
* StableDiffusion1ModelFormat
|
||||
* @description An enumeration.
|
||||
@ -6217,6 +6211,12 @@ export type components = {
|
||||
* @enum {string}
|
||||
*/
|
||||
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
|
||||
/**
|
||||
* StableDiffusionOnnxModelFormat
|
||||
* @description An enumeration.
|
||||
* @enum {string}
|
||||
*/
|
||||
StableDiffusionOnnxModelFormat: "olive" | "onnx";
|
||||
/**
|
||||
* StableDiffusionXLModelFormat
|
||||
* @description An enumeration.
|
||||
@ -6254,7 +6254,7 @@ export type operations = {
|
||||
/** @description Successful Response */
|
||||
200: {
|
||||
content: {
|
||||
"application/json": components["schemas"]["PaginatedResults_GraphExecutionState_"];
|
||||
"application/json": components["schemas"]["PaginatedDictResults"];
|
||||
};
|
||||
};
|
||||
/** @description Validation Error */
|
||||
@ -6307,7 +6307,7 @@ export type operations = {
|
||||
/** @description Successful Response */
|
||||
200: {
|
||||
content: {
|
||||
"application/json": components["schemas"]["GraphExecutionState"];
|
||||
"application/json": Record<string, never>;
|
||||
};
|
||||
};
|
||||
/** @description Session not found */
|
||||
|
Loading…
Reference in New Issue
Block a user