feat: add itemstorage retrieval as dict (skip pydantic)

- Create `PaginatedDictResults`, a non-type-safe, non-generic version of `PaginatedResults`. (because `PaginatedResults` is a `pydantic.GenericlModel`, it requires a pydantic model as the generic type and is not suitable)
- Add `ItemStorageABC` and `SqliteItemStorage` methods to return the dict representation of items
- The existing methods are unchanged in what they output, but now they use the `dict` methods to retrieve items before parsing them
- `ImagesService` and some metadata stuff is updated to use the appropriate methods
- Sessions router updated to use the dict versions
- Client types regenerated
This commit is contained in:
psychedelicious 2023-08-18 12:55:13 +10:00
parent 832335998f
commit 0022e4d95d
6 changed files with 103 additions and 71 deletions

View File

@ -1,5 +1,6 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import json
from typing import Annotated, List, Optional, Union from typing import Annotated, List, Optional, Union
from fastapi import Body, HTTPException, Path, Query, Response from fastapi import Body, HTTPException, Path, Query, Response
@ -8,14 +9,8 @@ from pydantic.fields import Field
from ...invocations import * from ...invocations import *
from ...invocations.baseinvocation import BaseInvocation from ...invocations.baseinvocation import BaseInvocation
from ...services.graph import ( from ...services.graph import Edge, EdgeConnection, Graph, GraphExecutionState, NodeAlreadyExecutedError
Edge, from ...services.item_storage import PaginatedDictResults, PaginatedResults
EdgeConnection,
Graph,
GraphExecutionState,
NodeAlreadyExecutedError,
)
from ...services.item_storage import PaginatedResults
from ..dependencies import ApiDependencies from ..dependencies import ApiDependencies
session_router = APIRouter(prefix="/v1/sessions", tags=["sessions"]) session_router = APIRouter(prefix="/v1/sessions", tags=["sessions"])
@ -40,18 +35,18 @@ async def create_session(
@session_router.get( @session_router.get(
"/", "/",
operation_id="list_sessions", operation_id="list_sessions",
responses={200: {"model": PaginatedResults[GraphExecutionState]}}, responses={200: {"model": PaginatedDictResults}},
) )
async def list_sessions( async def list_sessions(
page: int = Query(default=0, description="The page of results to get"), page: int = Query(default=0, description="The page of results to get"),
per_page: int = Query(default=10, description="The number of results per page"), per_page: int = Query(default=10, description="The number of results per page"),
query: str = Query(default="", description="The query string to search for"), query: str = Query(default="", description="The query string to search for"),
) -> PaginatedResults[GraphExecutionState]: ) -> PaginatedDictResults:
"""Gets a list of sessions, optionally searching""" """Gets a list of sessions, optionally searching"""
if query == "": if query == "":
result = ApiDependencies.invoker.services.graph_execution_manager.list(page, per_page) result = ApiDependencies.invoker.services.graph_execution_manager.list_as_dict(page, per_page)
else: else:
result = ApiDependencies.invoker.services.graph_execution_manager.search(query, page, per_page) result = ApiDependencies.invoker.services.graph_execution_manager.search_as_dict(query, page, per_page)
return result return result
@ -59,19 +54,18 @@ async def list_sessions(
"/{session_id}", "/{session_id}",
operation_id="get_session", operation_id="get_session",
responses={ responses={
200: {"model": GraphExecutionState}, 200: {"model": dict},
404: {"description": "Session not found"}, 404: {"description": "Session not found"},
}, },
) )
async def get_session( async def get_session(
session_id: str = Path(description="The id of the session to get"), session_id: str = Path(description="The id of the session to get"),
) -> GraphExecutionState: ) -> dict:
"""Gets a session""" """Gets a session"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id) session_dict = ApiDependencies.invoker.services.graph_execution_manager.get_as_dict(session_id)
if session is None: if session_dict is None:
raise HTTPException(status_code=404) raise HTTPException(status_code=404)
else: return session_dict
return session
@session_router.post( @session_router.post(

View File

@ -35,7 +35,7 @@ from invokeai.app.services.models.image_record import (
) )
from invokeai.app.services.resource_name import NameServiceBase from invokeai.app.services.resource_name import NameServiceBase
from invokeai.app.services.urls import UrlServiceBase from invokeai.app.services.urls import UrlServiceBase
from invokeai.app.util.metadata import get_metadata_graph_from_raw_session from invokeai.app.util.metadata import get_metadata_graph_from_session_dict
if TYPE_CHECKING: if TYPE_CHECKING:
from invokeai.app.services.graph import GraphExecutionState from invokeai.app.services.graph import GraphExecutionState
@ -190,10 +190,10 @@ class ImageService(ImageServiceABC):
graph = None graph = None
if session_id is not None: if session_id is not None:
session_raw = self._services.graph_execution_manager.get_raw(session_id) session_raw = self._services.graph_execution_manager.get_as_dict(session_id)
if session_raw is not None: if session_raw is not None:
try: try:
graph = get_metadata_graph_from_raw_session(session_raw) graph = get_metadata_graph_from_session_dict(session_raw)
except Exception as e: except Exception as e:
self._services.logger.warn(f"Failed to parse session graph: {e}") self._services.logger.warn(f"Failed to parse session graph: {e}")
graph = None graph = None
@ -294,12 +294,12 @@ class ImageService(ImageServiceABC):
if not image_record.session_id: if not image_record.session_id:
return ImageMetadata(metadata=metadata) return ImageMetadata(metadata=metadata)
session_raw = self._services.graph_execution_manager.get_raw(image_record.session_id) session_raw = self._services.graph_execution_manager.get_as_dict(image_record.session_id)
graph = None graph = None
if session_raw: if session_raw:
try: try:
graph = get_metadata_graph_from_raw_session(session_raw) graph = get_metadata_graph_from_session_dict(session_raw)
except Exception as e: except Exception as e:
self._services.logger.warn(f"Failed to parse session graph: {e}") self._services.logger.warn(f"Failed to parse session graph: {e}")
graph = None graph = None

View File

@ -19,6 +19,18 @@ class PaginatedResults(GenericModel, Generic[T]):
# fmt: on # fmt: on
class PaginatedDictResults(BaseModel):
"""Paginated raw results (dict)"""
# fmt: off
items: list[dict] = Field(description="Items")
page: int = Field(description="Current Page")
pages: int = Field(description="Total number of pages")
per_page: int = Field(description="Number of items per page")
total: int = Field(description="Total number of items in result")
# fmt: on
class ItemStorageABC(ABC, Generic[T]): class ItemStorageABC(ABC, Generic[T]):
_on_changed_callbacks: list[Callable[[T], None]] _on_changed_callbacks: list[Callable[[T], None]]
_on_deleted_callbacks: list[Callable[[str], None]] _on_deleted_callbacks: list[Callable[[str], None]]
@ -35,8 +47,8 @@ class ItemStorageABC(ABC, Generic[T]):
pass pass
@abstractmethod @abstractmethod
def get_raw(self, item_id: str) -> Optional[str]: def get_as_dict(self, item_id: str) -> Optional[dict]:
"""Gets the raw item as a string, skipping Pydantic parsing""" """Gets the item as a dict, skipping Pydantic parsing"""
pass pass
@abstractmethod @abstractmethod
@ -49,10 +61,19 @@ class ItemStorageABC(ABC, Generic[T]):
"""Gets a paginated list of items""" """Gets a paginated list of items"""
pass pass
@abstractmethod
def list_as_dict(self, page: int = 0, per_page: int = 10) -> PaginatedDictResults:
"""Gets a paginated list of items, skipping Pydantic parsing"""
pass
@abstractmethod @abstractmethod
def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[T]: def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
pass pass
@abstractmethod
def search_as_dict(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedDictResults:
pass
def on_changed(self, on_changed: Callable[[T], None]) -> None: def on_changed(self, on_changed: Callable[[T], None]) -> None:
"""Register a callback for when an item is changed""" """Register a callback for when an item is changed"""
self._on_changed_callbacks.append(on_changed) self._on_changed_callbacks.append(on_changed)

View File

@ -1,10 +1,11 @@
import json
import sqlite3 import sqlite3
from threading import Lock from threading import Lock
from typing import Generic, Optional, TypeVar, get_args from typing import Generic, Optional, TypeVar, get_args
from pydantic import BaseModel, parse_raw_as from pydantic import BaseModel, parse_obj_as, parse_raw_as
from .item_storage import ItemStorageABC, PaginatedResults from .item_storage import ItemStorageABC, PaginatedDictResults, PaginatedResults
T = TypeVar("T", bound=BaseModel) T = TypeVar("T", bound=BaseModel)
@ -47,9 +48,9 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
finally: finally:
self._lock.release() self._lock.release()
def _parse_item(self, item: str) -> T: def _parse_item_from_dict(self, item: dict) -> T:
item_type = get_args(self.__orig_class__)[0] item_type = get_args(self.__orig_class__)[0]
parsed = parse_raw_as(item_type, item) parsed = parse_obj_as(item_type, item)
return parsed return parsed
def set(self, item: T): def set(self, item: T):
@ -64,31 +65,25 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
self._lock.release() self._lock.release()
self._on_changed(item) self._on_changed(item)
def get_as_dict(self, id: str) -> Optional[dict]:
try:
self._lock.acquire()
self._cursor.execute(f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),))
result = self._cursor.fetchone()
finally:
self._lock.release()
if not result:
return None
return json.loads(result[0])
def get(self, id: str) -> Optional[T]: def get(self, id: str) -> Optional[T]:
try: item = self.get_as_dict(id)
self._lock.acquire() if not item:
self._cursor.execute(f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),))
result = self._cursor.fetchone()
finally:
self._lock.release()
if not result:
return None return None
return self._parse_item(result[0]) return self._parse_item_from_dict(item)
def get_raw(self, id: str) -> Optional[str]:
try:
self._lock.acquire()
self._cursor.execute(f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),))
result = self._cursor.fetchone()
finally:
self._lock.release()
if not result:
return None
return result[0]
def delete(self, id: str): def delete(self, id: str):
try: try:
@ -99,7 +94,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
self._lock.release() self._lock.release()
self._on_deleted(id) self._on_deleted(id)
def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]: def list_as_dict(self, page: int = 0, per_page: int = 10) -> PaginatedDictResults:
try: try:
self._lock.acquire() self._lock.acquire()
self._cursor.execute( self._cursor.execute(
@ -108,7 +103,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
) )
result = self._cursor.fetchall() result = self._cursor.fetchall()
items = list(map(lambda r: self._parse_item(r[0]), result)) items = list(map(lambda r: json.loads(r[0]), result))
self._cursor.execute(f"""SELECT count(*) FROM {self._table_name};""") self._cursor.execute(f"""SELECT count(*) FROM {self._table_name};""")
count = self._cursor.fetchone()[0] count = self._cursor.fetchone()[0]
@ -117,9 +112,20 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
pageCount = int(count / per_page) + 1 pageCount = int(count / per_page) + 1
return PaginatedResults[T](items=items, page=page, pages=pageCount, per_page=per_page, total=count) return PaginatedDictResults(items=items, page=page, pages=pageCount, per_page=per_page, total=count)
def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[T]: def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
paginated_raw_results = self.list_as_dict(page, per_page)
items = list(map(lambda r: self._parse_item_from_dict(r), paginated_raw_results.items))
return PaginatedResults[T](
items=items,
page=page,
pages=paginated_raw_results.pages,
per_page=per_page,
total=paginated_raw_results.total,
)
def search_as_dict(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedDictResults:
try: try:
self._lock.acquire() self._lock.acquire()
self._cursor.execute( self._cursor.execute(
@ -128,7 +134,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
) )
result = self._cursor.fetchall() result = self._cursor.fetchall()
items = list(map(lambda r: self._parse_item(r[0]), result)) items = list(map(lambda r: json.loads(r[0]), result))
self._cursor.execute( self._cursor.execute(
f"""SELECT count(*) FROM {self._table_name} WHERE item LIKE ?;""", f"""SELECT count(*) FROM {self._table_name} WHERE item LIKE ?;""",
@ -140,4 +146,15 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
pageCount = int(count / per_page) + 1 pageCount = int(count / per_page) + 1
return PaginatedResults[T](items=items, page=page, pages=pageCount, per_page=per_page, total=count) return PaginatedDictResults(items=items, page=page, pages=pageCount, per_page=per_page, total=count)
def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
paginated_raw_results = self.search_as_dict(query, page, per_page)
items = list(map(lambda r: self._parse_item_from_dict(r), paginated_raw_results.items))
return PaginatedResults[T](
items=items,
page=page,
pages=paginated_raw_results.pages,
per_page=per_page,
total=paginated_raw_results.total,
)

View File

@ -6,7 +6,7 @@ from pydantic import ValidationError
from invokeai.app.services.graph import Edge from invokeai.app.services.graph import Edge
def get_metadata_graph_from_raw_session(session_raw: str) -> Optional[dict]: def get_metadata_graph_from_session_dict(session_dict: dict) -> Optional[dict]:
""" """
Parses raw session string, returning a dict of the graph. Parses raw session string, returning a dict of the graph.
@ -17,7 +17,7 @@ def get_metadata_graph_from_raw_session(session_raw: str) -> Optional[dict]:
Any validation failure will return None. Any validation failure will return None.
""" """
graph = json.loads(session_raw).get("graph", None) graph = session_dict.get("graph", None)
# sanity check make sure the graph is at least reasonably shaped # sanity check make sure the graph is at least reasonably shaped
if ( if (

View File

@ -4755,15 +4755,15 @@ export type components = {
image_resolution?: number; image_resolution?: number;
}; };
/** /**
* PaginatedResults[GraphExecutionState] * PaginatedDictResults
* @description Paginated results * @description Paginated raw results (dict)
*/ */
PaginatedResults_GraphExecutionState_: { PaginatedDictResults: {
/** /**
* Items * Items
* @description Items * @description Items
*/ */
items: (components["schemas"]["GraphExecutionState"])[]; items: (Record<string, never>)[];
/** /**
* Page * Page
* @description Current Page * @description Current Page
@ -6193,12 +6193,6 @@ export type components = {
ui_hidden: boolean; ui_hidden: boolean;
ui_type?: components["schemas"]["UIType"]; ui_type?: components["schemas"]["UIType"];
}; };
/**
* StableDiffusionOnnxModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusionOnnxModelFormat: "olive" | "onnx";
/** /**
* StableDiffusion1ModelFormat * StableDiffusion1ModelFormat
* @description An enumeration. * @description An enumeration.
@ -6217,6 +6211,12 @@ export type components = {
* @enum {string} * @enum {string}
*/ */
StableDiffusion2ModelFormat: "checkpoint" | "diffusers"; StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusionOnnxModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusionOnnxModelFormat: "olive" | "onnx";
/** /**
* StableDiffusionXLModelFormat * StableDiffusionXLModelFormat
* @description An enumeration. * @description An enumeration.
@ -6254,7 +6254,7 @@ export type operations = {
/** @description Successful Response */ /** @description Successful Response */
200: { 200: {
content: { content: {
"application/json": components["schemas"]["PaginatedResults_GraphExecutionState_"]; "application/json": components["schemas"]["PaginatedDictResults"];
}; };
}; };
/** @description Validation Error */ /** @description Validation Error */
@ -6307,7 +6307,7 @@ export type operations = {
/** @description Successful Response */ /** @description Successful Response */
200: { 200: {
content: { content: {
"application/json": components["schemas"]["GraphExecutionState"]; "application/json": Record<string, never>;
}; };
}; };
/** @description Session not found */ /** @description Session not found */