From 0022e4d95dc346a56446de3b4191e8f90fc42b61 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 18 Aug 2023 12:55:13 +1000 Subject: [PATCH] feat: add itemstorage retrieval as dict (skip pydantic) - Create `PaginatedDictResults`, a non-type-safe, non-generic version of `PaginatedResults`. (because `PaginatedResults` is a `pydantic.GenericlModel`, it requires a pydantic model as the generic type and is not suitable) - Add `ItemStorageABC` and `SqliteItemStorage` methods to return the dict representation of items - The existing methods are unchanged in what they output, but now they use the `dict` methods to retrieve items before parsing them - `ImagesService` and some metadata stuff is updated to use the appropriate methods - Sessions router updated to use the dict versions - Client types regenerated --- invokeai/app/api/routers/sessions.py | 30 +++---- invokeai/app/services/images.py | 10 +-- invokeai/app/services/item_storage.py | 25 +++++- invokeai/app/services/sqlite.py | 81 +++++++++++-------- invokeai/app/util/metadata.py | 4 +- .../frontend/web/src/services/api/schema.d.ts | 24 +++--- 6 files changed, 103 insertions(+), 71 deletions(-) diff --git a/invokeai/app/api/routers/sessions.py b/invokeai/app/api/routers/sessions.py index e4ba2a353e..4847ef5d40 100644 --- a/invokeai/app/api/routers/sessions.py +++ b/invokeai/app/api/routers/sessions.py @@ -1,5 +1,6 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) +import json from typing import Annotated, List, Optional, Union from fastapi import Body, HTTPException, Path, Query, Response @@ -8,14 +9,8 @@ from pydantic.fields import Field from ...invocations import * from ...invocations.baseinvocation import BaseInvocation -from ...services.graph import ( - Edge, - EdgeConnection, - Graph, - GraphExecutionState, - NodeAlreadyExecutedError, -) -from ...services.item_storage import PaginatedResults +from ...services.graph import Edge, EdgeConnection, Graph, GraphExecutionState, NodeAlreadyExecutedError +from ...services.item_storage import PaginatedDictResults, PaginatedResults from ..dependencies import ApiDependencies session_router = APIRouter(prefix="/v1/sessions", tags=["sessions"]) @@ -40,18 +35,18 @@ async def create_session( @session_router.get( "/", operation_id="list_sessions", - responses={200: {"model": PaginatedResults[GraphExecutionState]}}, + responses={200: {"model": PaginatedDictResults}}, ) async def list_sessions( page: int = Query(default=0, description="The page of results to get"), per_page: int = Query(default=10, description="The number of results per page"), query: str = Query(default="", description="The query string to search for"), -) -> PaginatedResults[GraphExecutionState]: +) -> PaginatedDictResults: """Gets a list of sessions, optionally searching""" if query == "": - result = ApiDependencies.invoker.services.graph_execution_manager.list(page, per_page) + result = ApiDependencies.invoker.services.graph_execution_manager.list_as_dict(page, per_page) else: - result = ApiDependencies.invoker.services.graph_execution_manager.search(query, page, per_page) + result = ApiDependencies.invoker.services.graph_execution_manager.search_as_dict(query, page, per_page) return result @@ -59,19 +54,18 @@ async def list_sessions( "/{session_id}", operation_id="get_session", responses={ - 200: {"model": GraphExecutionState}, + 200: {"model": dict}, 404: {"description": "Session not found"}, }, ) async def get_session( session_id: str = Path(description="The id of the session to get"), -) -> GraphExecutionState: +) -> dict: """Gets a session""" - session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id) - if session is None: + session_dict = ApiDependencies.invoker.services.graph_execution_manager.get_as_dict(session_id) + if session_dict is None: raise HTTPException(status_code=404) - else: - return session + return session_dict @session_router.post( diff --git a/invokeai/app/services/images.py b/invokeai/app/services/images.py index 2240846dac..278e071733 100644 --- a/invokeai/app/services/images.py +++ b/invokeai/app/services/images.py @@ -35,7 +35,7 @@ from invokeai.app.services.models.image_record import ( ) from invokeai.app.services.resource_name import NameServiceBase from invokeai.app.services.urls import UrlServiceBase -from invokeai.app.util.metadata import get_metadata_graph_from_raw_session +from invokeai.app.util.metadata import get_metadata_graph_from_session_dict if TYPE_CHECKING: from invokeai.app.services.graph import GraphExecutionState @@ -190,10 +190,10 @@ class ImageService(ImageServiceABC): graph = None if session_id is not None: - session_raw = self._services.graph_execution_manager.get_raw(session_id) + session_raw = self._services.graph_execution_manager.get_as_dict(session_id) if session_raw is not None: try: - graph = get_metadata_graph_from_raw_session(session_raw) + graph = get_metadata_graph_from_session_dict(session_raw) except Exception as e: self._services.logger.warn(f"Failed to parse session graph: {e}") graph = None @@ -294,12 +294,12 @@ class ImageService(ImageServiceABC): if not image_record.session_id: return ImageMetadata(metadata=metadata) - session_raw = self._services.graph_execution_manager.get_raw(image_record.session_id) + session_raw = self._services.graph_execution_manager.get_as_dict(image_record.session_id) graph = None if session_raw: try: - graph = get_metadata_graph_from_raw_session(session_raw) + graph = get_metadata_graph_from_session_dict(session_raw) except Exception as e: self._services.logger.warn(f"Failed to parse session graph: {e}") graph = None diff --git a/invokeai/app/services/item_storage.py b/invokeai/app/services/item_storage.py index 5fe4eb7456..84b2fabfa5 100644 --- a/invokeai/app/services/item_storage.py +++ b/invokeai/app/services/item_storage.py @@ -19,6 +19,18 @@ class PaginatedResults(GenericModel, Generic[T]): # fmt: on +class PaginatedDictResults(BaseModel): + """Paginated raw results (dict)""" + + # fmt: off + items: list[dict] = Field(description="Items") + page: int = Field(description="Current Page") + pages: int = Field(description="Total number of pages") + per_page: int = Field(description="Number of items per page") + total: int = Field(description="Total number of items in result") + # fmt: on + + class ItemStorageABC(ABC, Generic[T]): _on_changed_callbacks: list[Callable[[T], None]] _on_deleted_callbacks: list[Callable[[str], None]] @@ -35,8 +47,8 @@ class ItemStorageABC(ABC, Generic[T]): pass @abstractmethod - def get_raw(self, item_id: str) -> Optional[str]: - """Gets the raw item as a string, skipping Pydantic parsing""" + def get_as_dict(self, item_id: str) -> Optional[dict]: + """Gets the item as a dict, skipping Pydantic parsing""" pass @abstractmethod @@ -49,10 +61,19 @@ class ItemStorageABC(ABC, Generic[T]): """Gets a paginated list of items""" pass + @abstractmethod + def list_as_dict(self, page: int = 0, per_page: int = 10) -> PaginatedDictResults: + """Gets a paginated list of items, skipping Pydantic parsing""" + pass + @abstractmethod def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[T]: pass + @abstractmethod + def search_as_dict(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedDictResults: + pass + def on_changed(self, on_changed: Callable[[T], None]) -> None: """Register a callback for when an item is changed""" self._on_changed_callbacks.append(on_changed) diff --git a/invokeai/app/services/sqlite.py b/invokeai/app/services/sqlite.py index 3c46b1c2a0..1c8ed18cb3 100644 --- a/invokeai/app/services/sqlite.py +++ b/invokeai/app/services/sqlite.py @@ -1,10 +1,11 @@ +import json import sqlite3 from threading import Lock from typing import Generic, Optional, TypeVar, get_args -from pydantic import BaseModel, parse_raw_as +from pydantic import BaseModel, parse_obj_as, parse_raw_as -from .item_storage import ItemStorageABC, PaginatedResults +from .item_storage import ItemStorageABC, PaginatedDictResults, PaginatedResults T = TypeVar("T", bound=BaseModel) @@ -47,9 +48,9 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]): finally: self._lock.release() - def _parse_item(self, item: str) -> T: + def _parse_item_from_dict(self, item: dict) -> T: item_type = get_args(self.__orig_class__)[0] - parsed = parse_raw_as(item_type, item) + parsed = parse_obj_as(item_type, item) return parsed def set(self, item: T): @@ -64,31 +65,25 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]): self._lock.release() self._on_changed(item) + def get_as_dict(self, id: str) -> Optional[dict]: + try: + self._lock.acquire() + self._cursor.execute(f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),)) + result = self._cursor.fetchone() + finally: + self._lock.release() + + if not result: + return None + + return json.loads(result[0]) + def get(self, id: str) -> Optional[T]: - try: - self._lock.acquire() - self._cursor.execute(f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),)) - result = self._cursor.fetchone() - finally: - self._lock.release() - - if not result: + item = self.get_as_dict(id) + if not item: return None - return self._parse_item(result[0]) - - def get_raw(self, id: str) -> Optional[str]: - try: - self._lock.acquire() - self._cursor.execute(f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),)) - result = self._cursor.fetchone() - finally: - self._lock.release() - - if not result: - return None - - return result[0] + return self._parse_item_from_dict(item) def delete(self, id: str): try: @@ -99,7 +94,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]): self._lock.release() self._on_deleted(id) - def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]: + def list_as_dict(self, page: int = 0, per_page: int = 10) -> PaginatedDictResults: try: self._lock.acquire() self._cursor.execute( @@ -108,7 +103,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]): ) result = self._cursor.fetchall() - items = list(map(lambda r: self._parse_item(r[0]), result)) + items = list(map(lambda r: json.loads(r[0]), result)) self._cursor.execute(f"""SELECT count(*) FROM {self._table_name};""") count = self._cursor.fetchone()[0] @@ -117,9 +112,20 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]): pageCount = int(count / per_page) + 1 - return PaginatedResults[T](items=items, page=page, pages=pageCount, per_page=per_page, total=count) + return PaginatedDictResults(items=items, page=page, pages=pageCount, per_page=per_page, total=count) - def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[T]: + def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]: + paginated_raw_results = self.list_as_dict(page, per_page) + items = list(map(lambda r: self._parse_item_from_dict(r), paginated_raw_results.items)) + return PaginatedResults[T]( + items=items, + page=page, + pages=paginated_raw_results.pages, + per_page=per_page, + total=paginated_raw_results.total, + ) + + def search_as_dict(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedDictResults: try: self._lock.acquire() self._cursor.execute( @@ -128,7 +134,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]): ) result = self._cursor.fetchall() - items = list(map(lambda r: self._parse_item(r[0]), result)) + items = list(map(lambda r: json.loads(r[0]), result)) self._cursor.execute( f"""SELECT count(*) FROM {self._table_name} WHERE item LIKE ?;""", @@ -140,4 +146,15 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]): pageCount = int(count / per_page) + 1 - return PaginatedResults[T](items=items, page=page, pages=pageCount, per_page=per_page, total=count) + return PaginatedDictResults(items=items, page=page, pages=pageCount, per_page=per_page, total=count) + + def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[T]: + paginated_raw_results = self.search_as_dict(query, page, per_page) + items = list(map(lambda r: self._parse_item_from_dict(r), paginated_raw_results.items)) + return PaginatedResults[T]( + items=items, + page=page, + pages=paginated_raw_results.pages, + per_page=per_page, + total=paginated_raw_results.total, + ) diff --git a/invokeai/app/util/metadata.py b/invokeai/app/util/metadata.py index 5ca5f14e12..d8ad41d8d9 100644 --- a/invokeai/app/util/metadata.py +++ b/invokeai/app/util/metadata.py @@ -6,7 +6,7 @@ from pydantic import ValidationError from invokeai.app.services.graph import Edge -def get_metadata_graph_from_raw_session(session_raw: str) -> Optional[dict]: +def get_metadata_graph_from_session_dict(session_dict: dict) -> Optional[dict]: """ Parses raw session string, returning a dict of the graph. @@ -17,7 +17,7 @@ def get_metadata_graph_from_raw_session(session_raw: str) -> Optional[dict]: Any validation failure will return None. """ - graph = json.loads(session_raw).get("graph", None) + graph = session_dict.get("graph", None) # sanity check make sure the graph is at least reasonably shaped if ( diff --git a/invokeai/frontend/web/src/services/api/schema.d.ts b/invokeai/frontend/web/src/services/api/schema.d.ts index 316ee0c085..e4b6160763 100644 --- a/invokeai/frontend/web/src/services/api/schema.d.ts +++ b/invokeai/frontend/web/src/services/api/schema.d.ts @@ -4755,15 +4755,15 @@ export type components = { image_resolution?: number; }; /** - * PaginatedResults[GraphExecutionState] - * @description Paginated results + * PaginatedDictResults + * @description Paginated raw results (dict) */ - PaginatedResults_GraphExecutionState_: { + PaginatedDictResults: { /** * Items * @description Items */ - items: (components["schemas"]["GraphExecutionState"])[]; + items: (Record)[]; /** * Page * @description Current Page @@ -6193,12 +6193,6 @@ export type components = { ui_hidden: boolean; ui_type?: components["schemas"]["UIType"]; }; - /** - * StableDiffusionOnnxModelFormat - * @description An enumeration. - * @enum {string} - */ - StableDiffusionOnnxModelFormat: "olive" | "onnx"; /** * StableDiffusion1ModelFormat * @description An enumeration. @@ -6217,6 +6211,12 @@ export type components = { * @enum {string} */ StableDiffusion2ModelFormat: "checkpoint" | "diffusers"; + /** + * StableDiffusionOnnxModelFormat + * @description An enumeration. + * @enum {string} + */ + StableDiffusionOnnxModelFormat: "olive" | "onnx"; /** * StableDiffusionXLModelFormat * @description An enumeration. @@ -6254,7 +6254,7 @@ export type operations = { /** @description Successful Response */ 200: { content: { - "application/json": components["schemas"]["PaginatedResults_GraphExecutionState_"]; + "application/json": components["schemas"]["PaginatedDictResults"]; }; }; /** @description Validation Error */ @@ -6307,7 +6307,7 @@ export type operations = { /** @description Successful Response */ 200: { content: { - "application/json": components["schemas"]["GraphExecutionState"]; + "application/json": Record; }; }; /** @description Session not found */