feat: add itemstorage retrieval as dict (skip pydantic)

- Create `PaginatedDictResults`, a non-type-safe, non-generic version of `PaginatedResults`. (because `PaginatedResults` is a `pydantic.GenericlModel`, it requires a pydantic model as the generic type and is not suitable)
- Add `ItemStorageABC` and `SqliteItemStorage` methods to return the dict representation of items
- The existing methods are unchanged in what they output, but now they use the `dict` methods to retrieve items before parsing them
- `ImagesService` and some metadata stuff is updated to use the appropriate methods
- Sessions router updated to use the dict versions
- Client types regenerated
This commit is contained in:
psychedelicious 2023-08-18 12:55:13 +10:00
parent 832335998f
commit 0022e4d95d
6 changed files with 103 additions and 71 deletions

View File

@ -1,5 +1,6 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import json
from typing import Annotated, List, Optional, Union
from fastapi import Body, HTTPException, Path, Query, Response
@ -8,14 +9,8 @@ from pydantic.fields import Field
from ...invocations import *
from ...invocations.baseinvocation import BaseInvocation
from ...services.graph import (
Edge,
EdgeConnection,
Graph,
GraphExecutionState,
NodeAlreadyExecutedError,
)
from ...services.item_storage import PaginatedResults
from ...services.graph import Edge, EdgeConnection, Graph, GraphExecutionState, NodeAlreadyExecutedError
from ...services.item_storage import PaginatedDictResults, PaginatedResults
from ..dependencies import ApiDependencies
session_router = APIRouter(prefix="/v1/sessions", tags=["sessions"])
@ -40,18 +35,18 @@ async def create_session(
@session_router.get(
"/",
operation_id="list_sessions",
responses={200: {"model": PaginatedResults[GraphExecutionState]}},
responses={200: {"model": PaginatedDictResults}},
)
async def list_sessions(
page: int = Query(default=0, description="The page of results to get"),
per_page: int = Query(default=10, description="The number of results per page"),
query: str = Query(default="", description="The query string to search for"),
) -> PaginatedResults[GraphExecutionState]:
) -> PaginatedDictResults:
"""Gets a list of sessions, optionally searching"""
if query == "":
result = ApiDependencies.invoker.services.graph_execution_manager.list(page, per_page)
result = ApiDependencies.invoker.services.graph_execution_manager.list_as_dict(page, per_page)
else:
result = ApiDependencies.invoker.services.graph_execution_manager.search(query, page, per_page)
result = ApiDependencies.invoker.services.graph_execution_manager.search_as_dict(query, page, per_page)
return result
@ -59,19 +54,18 @@ async def list_sessions(
"/{session_id}",
operation_id="get_session",
responses={
200: {"model": GraphExecutionState},
200: {"model": dict},
404: {"description": "Session not found"},
},
)
async def get_session(
session_id: str = Path(description="The id of the session to get"),
) -> GraphExecutionState:
) -> dict:
"""Gets a session"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None:
session_dict = ApiDependencies.invoker.services.graph_execution_manager.get_as_dict(session_id)
if session_dict is None:
raise HTTPException(status_code=404)
else:
return session
return session_dict
@session_router.post(

View File

@ -35,7 +35,7 @@ from invokeai.app.services.models.image_record import (
)
from invokeai.app.services.resource_name import NameServiceBase
from invokeai.app.services.urls import UrlServiceBase
from invokeai.app.util.metadata import get_metadata_graph_from_raw_session
from invokeai.app.util.metadata import get_metadata_graph_from_session_dict
if TYPE_CHECKING:
from invokeai.app.services.graph import GraphExecutionState
@ -190,10 +190,10 @@ class ImageService(ImageServiceABC):
graph = None
if session_id is not None:
session_raw = self._services.graph_execution_manager.get_raw(session_id)
session_raw = self._services.graph_execution_manager.get_as_dict(session_id)
if session_raw is not None:
try:
graph = get_metadata_graph_from_raw_session(session_raw)
graph = get_metadata_graph_from_session_dict(session_raw)
except Exception as e:
self._services.logger.warn(f"Failed to parse session graph: {e}")
graph = None
@ -294,12 +294,12 @@ class ImageService(ImageServiceABC):
if not image_record.session_id:
return ImageMetadata(metadata=metadata)
session_raw = self._services.graph_execution_manager.get_raw(image_record.session_id)
session_raw = self._services.graph_execution_manager.get_as_dict(image_record.session_id)
graph = None
if session_raw:
try:
graph = get_metadata_graph_from_raw_session(session_raw)
graph = get_metadata_graph_from_session_dict(session_raw)
except Exception as e:
self._services.logger.warn(f"Failed to parse session graph: {e}")
graph = None

View File

@ -19,6 +19,18 @@ class PaginatedResults(GenericModel, Generic[T]):
# fmt: on
class PaginatedDictResults(BaseModel):
"""Paginated raw results (dict)"""
# fmt: off
items: list[dict] = Field(description="Items")
page: int = Field(description="Current Page")
pages: int = Field(description="Total number of pages")
per_page: int = Field(description="Number of items per page")
total: int = Field(description="Total number of items in result")
# fmt: on
class ItemStorageABC(ABC, Generic[T]):
_on_changed_callbacks: list[Callable[[T], None]]
_on_deleted_callbacks: list[Callable[[str], None]]
@ -35,8 +47,8 @@ class ItemStorageABC(ABC, Generic[T]):
pass
@abstractmethod
def get_raw(self, item_id: str) -> Optional[str]:
"""Gets the raw item as a string, skipping Pydantic parsing"""
def get_as_dict(self, item_id: str) -> Optional[dict]:
"""Gets the item as a dict, skipping Pydantic parsing"""
pass
@abstractmethod
@ -49,10 +61,19 @@ class ItemStorageABC(ABC, Generic[T]):
"""Gets a paginated list of items"""
pass
@abstractmethod
def list_as_dict(self, page: int = 0, per_page: int = 10) -> PaginatedDictResults:
"""Gets a paginated list of items, skipping Pydantic parsing"""
pass
@abstractmethod
def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
pass
@abstractmethod
def search_as_dict(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedDictResults:
pass
def on_changed(self, on_changed: Callable[[T], None]) -> None:
"""Register a callback for when an item is changed"""
self._on_changed_callbacks.append(on_changed)

View File

@ -1,10 +1,11 @@
import json
import sqlite3
from threading import Lock
from typing import Generic, Optional, TypeVar, get_args
from pydantic import BaseModel, parse_raw_as
from pydantic import BaseModel, parse_obj_as, parse_raw_as
from .item_storage import ItemStorageABC, PaginatedResults
from .item_storage import ItemStorageABC, PaginatedDictResults, PaginatedResults
T = TypeVar("T", bound=BaseModel)
@ -47,9 +48,9 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
finally:
self._lock.release()
def _parse_item(self, item: str) -> T:
def _parse_item_from_dict(self, item: dict) -> T:
item_type = get_args(self.__orig_class__)[0]
parsed = parse_raw_as(item_type, item)
parsed = parse_obj_as(item_type, item)
return parsed
def set(self, item: T):
@ -64,31 +65,25 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
self._lock.release()
self._on_changed(item)
def get_as_dict(self, id: str) -> Optional[dict]:
try:
self._lock.acquire()
self._cursor.execute(f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),))
result = self._cursor.fetchone()
finally:
self._lock.release()
if not result:
return None
return json.loads(result[0])
def get(self, id: str) -> Optional[T]:
try:
self._lock.acquire()
self._cursor.execute(f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),))
result = self._cursor.fetchone()
finally:
self._lock.release()
if not result:
item = self.get_as_dict(id)
if not item:
return None
return self._parse_item(result[0])
def get_raw(self, id: str) -> Optional[str]:
try:
self._lock.acquire()
self._cursor.execute(f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),))
result = self._cursor.fetchone()
finally:
self._lock.release()
if not result:
return None
return result[0]
return self._parse_item_from_dict(item)
def delete(self, id: str):
try:
@ -99,7 +94,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
self._lock.release()
self._on_deleted(id)
def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
def list_as_dict(self, page: int = 0, per_page: int = 10) -> PaginatedDictResults:
try:
self._lock.acquire()
self._cursor.execute(
@ -108,7 +103,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
)
result = self._cursor.fetchall()
items = list(map(lambda r: self._parse_item(r[0]), result))
items = list(map(lambda r: json.loads(r[0]), result))
self._cursor.execute(f"""SELECT count(*) FROM {self._table_name};""")
count = self._cursor.fetchone()[0]
@ -117,9 +112,20 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
pageCount = int(count / per_page) + 1
return PaginatedResults[T](items=items, page=page, pages=pageCount, per_page=per_page, total=count)
return PaginatedDictResults(items=items, page=page, pages=pageCount, per_page=per_page, total=count)
def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
paginated_raw_results = self.list_as_dict(page, per_page)
items = list(map(lambda r: self._parse_item_from_dict(r), paginated_raw_results.items))
return PaginatedResults[T](
items=items,
page=page,
pages=paginated_raw_results.pages,
per_page=per_page,
total=paginated_raw_results.total,
)
def search_as_dict(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedDictResults:
try:
self._lock.acquire()
self._cursor.execute(
@ -128,7 +134,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
)
result = self._cursor.fetchall()
items = list(map(lambda r: self._parse_item(r[0]), result))
items = list(map(lambda r: json.loads(r[0]), result))
self._cursor.execute(
f"""SELECT count(*) FROM {self._table_name} WHERE item LIKE ?;""",
@ -140,4 +146,15 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
pageCount = int(count / per_page) + 1
return PaginatedResults[T](items=items, page=page, pages=pageCount, per_page=per_page, total=count)
return PaginatedDictResults(items=items, page=page, pages=pageCount, per_page=per_page, total=count)
def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
paginated_raw_results = self.search_as_dict(query, page, per_page)
items = list(map(lambda r: self._parse_item_from_dict(r), paginated_raw_results.items))
return PaginatedResults[T](
items=items,
page=page,
pages=paginated_raw_results.pages,
per_page=per_page,
total=paginated_raw_results.total,
)

View File

@ -6,7 +6,7 @@ from pydantic import ValidationError
from invokeai.app.services.graph import Edge
def get_metadata_graph_from_raw_session(session_raw: str) -> Optional[dict]:
def get_metadata_graph_from_session_dict(session_dict: dict) -> Optional[dict]:
"""
Parses raw session string, returning a dict of the graph.
@ -17,7 +17,7 @@ def get_metadata_graph_from_raw_session(session_raw: str) -> Optional[dict]:
Any validation failure will return None.
"""
graph = json.loads(session_raw).get("graph", None)
graph = session_dict.get("graph", None)
# sanity check make sure the graph is at least reasonably shaped
if (

View File

@ -4755,15 +4755,15 @@ export type components = {
image_resolution?: number;
};
/**
* PaginatedResults[GraphExecutionState]
* @description Paginated results
* PaginatedDictResults
* @description Paginated raw results (dict)
*/
PaginatedResults_GraphExecutionState_: {
PaginatedDictResults: {
/**
* Items
* @description Items
*/
items: (components["schemas"]["GraphExecutionState"])[];
items: (Record<string, never>)[];
/**
* Page
* @description Current Page
@ -6193,12 +6193,6 @@ export type components = {
ui_hidden: boolean;
ui_type?: components["schemas"]["UIType"];
};
/**
* StableDiffusionOnnxModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusionOnnxModelFormat: "olive" | "onnx";
/**
* StableDiffusion1ModelFormat
* @description An enumeration.
@ -6217,6 +6211,12 @@ export type components = {
* @enum {string}
*/
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusionOnnxModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusionOnnxModelFormat: "olive" | "onnx";
/**
* StableDiffusionXLModelFormat
* @description An enumeration.
@ -6254,7 +6254,7 @@ export type operations = {
/** @description Successful Response */
200: {
content: {
"application/json": components["schemas"]["PaginatedResults_GraphExecutionState_"];
"application/json": components["schemas"]["PaginatedDictResults"];
};
};
/** @description Validation Error */
@ -6307,7 +6307,7 @@ export type operations = {
/** @description Successful Response */
200: {
content: {
"application/json": components["schemas"]["GraphExecutionState"];
"application/json": Record<string, never>;
};
};
/** @description Session not found */