mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat: add itemstorage retrieval as dict (skip pydantic)
- Create `PaginatedDictResults`, a non-type-safe, non-generic version of `PaginatedResults`. (because `PaginatedResults` is a `pydantic.GenericlModel`, it requires a pydantic model as the generic type and is not suitable) - Add `ItemStorageABC` and `SqliteItemStorage` methods to return the dict representation of items - The existing methods are unchanged in what they output, but now they use the `dict` methods to retrieve items before parsing them - `ImagesService` and some metadata stuff is updated to use the appropriate methods - Sessions router updated to use the dict versions - Client types regenerated
This commit is contained in:
parent
832335998f
commit
0022e4d95d
@ -1,5 +1,6 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
|
import json
|
||||||
from typing import Annotated, List, Optional, Union
|
from typing import Annotated, List, Optional, Union
|
||||||
|
|
||||||
from fastapi import Body, HTTPException, Path, Query, Response
|
from fastapi import Body, HTTPException, Path, Query, Response
|
||||||
@ -8,14 +9,8 @@ from pydantic.fields import Field
|
|||||||
|
|
||||||
from ...invocations import *
|
from ...invocations import *
|
||||||
from ...invocations.baseinvocation import BaseInvocation
|
from ...invocations.baseinvocation import BaseInvocation
|
||||||
from ...services.graph import (
|
from ...services.graph import Edge, EdgeConnection, Graph, GraphExecutionState, NodeAlreadyExecutedError
|
||||||
Edge,
|
from ...services.item_storage import PaginatedDictResults, PaginatedResults
|
||||||
EdgeConnection,
|
|
||||||
Graph,
|
|
||||||
GraphExecutionState,
|
|
||||||
NodeAlreadyExecutedError,
|
|
||||||
)
|
|
||||||
from ...services.item_storage import PaginatedResults
|
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
|
|
||||||
session_router = APIRouter(prefix="/v1/sessions", tags=["sessions"])
|
session_router = APIRouter(prefix="/v1/sessions", tags=["sessions"])
|
||||||
@ -40,18 +35,18 @@ async def create_session(
|
|||||||
@session_router.get(
|
@session_router.get(
|
||||||
"/",
|
"/",
|
||||||
operation_id="list_sessions",
|
operation_id="list_sessions",
|
||||||
responses={200: {"model": PaginatedResults[GraphExecutionState]}},
|
responses={200: {"model": PaginatedDictResults}},
|
||||||
)
|
)
|
||||||
async def list_sessions(
|
async def list_sessions(
|
||||||
page: int = Query(default=0, description="The page of results to get"),
|
page: int = Query(default=0, description="The page of results to get"),
|
||||||
per_page: int = Query(default=10, description="The number of results per page"),
|
per_page: int = Query(default=10, description="The number of results per page"),
|
||||||
query: str = Query(default="", description="The query string to search for"),
|
query: str = Query(default="", description="The query string to search for"),
|
||||||
) -> PaginatedResults[GraphExecutionState]:
|
) -> PaginatedDictResults:
|
||||||
"""Gets a list of sessions, optionally searching"""
|
"""Gets a list of sessions, optionally searching"""
|
||||||
if query == "":
|
if query == "":
|
||||||
result = ApiDependencies.invoker.services.graph_execution_manager.list(page, per_page)
|
result = ApiDependencies.invoker.services.graph_execution_manager.list_as_dict(page, per_page)
|
||||||
else:
|
else:
|
||||||
result = ApiDependencies.invoker.services.graph_execution_manager.search(query, page, per_page)
|
result = ApiDependencies.invoker.services.graph_execution_manager.search_as_dict(query, page, per_page)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@ -59,19 +54,18 @@ async def list_sessions(
|
|||||||
"/{session_id}",
|
"/{session_id}",
|
||||||
operation_id="get_session",
|
operation_id="get_session",
|
||||||
responses={
|
responses={
|
||||||
200: {"model": GraphExecutionState},
|
200: {"model": dict},
|
||||||
404: {"description": "Session not found"},
|
404: {"description": "Session not found"},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
async def get_session(
|
async def get_session(
|
||||||
session_id: str = Path(description="The id of the session to get"),
|
session_id: str = Path(description="The id of the session to get"),
|
||||||
) -> GraphExecutionState:
|
) -> dict:
|
||||||
"""Gets a session"""
|
"""Gets a session"""
|
||||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
session_dict = ApiDependencies.invoker.services.graph_execution_manager.get_as_dict(session_id)
|
||||||
if session is None:
|
if session_dict is None:
|
||||||
raise HTTPException(status_code=404)
|
raise HTTPException(status_code=404)
|
||||||
else:
|
return session_dict
|
||||||
return session
|
|
||||||
|
|
||||||
|
|
||||||
@session_router.post(
|
@session_router.post(
|
||||||
|
@ -35,7 +35,7 @@ from invokeai.app.services.models.image_record import (
|
|||||||
)
|
)
|
||||||
from invokeai.app.services.resource_name import NameServiceBase
|
from invokeai.app.services.resource_name import NameServiceBase
|
||||||
from invokeai.app.services.urls import UrlServiceBase
|
from invokeai.app.services.urls import UrlServiceBase
|
||||||
from invokeai.app.util.metadata import get_metadata_graph_from_raw_session
|
from invokeai.app.util.metadata import get_metadata_graph_from_session_dict
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from invokeai.app.services.graph import GraphExecutionState
|
from invokeai.app.services.graph import GraphExecutionState
|
||||||
@ -190,10 +190,10 @@ class ImageService(ImageServiceABC):
|
|||||||
graph = None
|
graph = None
|
||||||
|
|
||||||
if session_id is not None:
|
if session_id is not None:
|
||||||
session_raw = self._services.graph_execution_manager.get_raw(session_id)
|
session_raw = self._services.graph_execution_manager.get_as_dict(session_id)
|
||||||
if session_raw is not None:
|
if session_raw is not None:
|
||||||
try:
|
try:
|
||||||
graph = get_metadata_graph_from_raw_session(session_raw)
|
graph = get_metadata_graph_from_session_dict(session_raw)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._services.logger.warn(f"Failed to parse session graph: {e}")
|
self._services.logger.warn(f"Failed to parse session graph: {e}")
|
||||||
graph = None
|
graph = None
|
||||||
@ -294,12 +294,12 @@ class ImageService(ImageServiceABC):
|
|||||||
if not image_record.session_id:
|
if not image_record.session_id:
|
||||||
return ImageMetadata(metadata=metadata)
|
return ImageMetadata(metadata=metadata)
|
||||||
|
|
||||||
session_raw = self._services.graph_execution_manager.get_raw(image_record.session_id)
|
session_raw = self._services.graph_execution_manager.get_as_dict(image_record.session_id)
|
||||||
graph = None
|
graph = None
|
||||||
|
|
||||||
if session_raw:
|
if session_raw:
|
||||||
try:
|
try:
|
||||||
graph = get_metadata_graph_from_raw_session(session_raw)
|
graph = get_metadata_graph_from_session_dict(session_raw)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._services.logger.warn(f"Failed to parse session graph: {e}")
|
self._services.logger.warn(f"Failed to parse session graph: {e}")
|
||||||
graph = None
|
graph = None
|
||||||
|
@ -19,6 +19,18 @@ class PaginatedResults(GenericModel, Generic[T]):
|
|||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
class PaginatedDictResults(BaseModel):
|
||||||
|
"""Paginated raw results (dict)"""
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
items: list[dict] = Field(description="Items")
|
||||||
|
page: int = Field(description="Current Page")
|
||||||
|
pages: int = Field(description="Total number of pages")
|
||||||
|
per_page: int = Field(description="Number of items per page")
|
||||||
|
total: int = Field(description="Total number of items in result")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
class ItemStorageABC(ABC, Generic[T]):
|
class ItemStorageABC(ABC, Generic[T]):
|
||||||
_on_changed_callbacks: list[Callable[[T], None]]
|
_on_changed_callbacks: list[Callable[[T], None]]
|
||||||
_on_deleted_callbacks: list[Callable[[str], None]]
|
_on_deleted_callbacks: list[Callable[[str], None]]
|
||||||
@ -35,8 +47,8 @@ class ItemStorageABC(ABC, Generic[T]):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_raw(self, item_id: str) -> Optional[str]:
|
def get_as_dict(self, item_id: str) -> Optional[dict]:
|
||||||
"""Gets the raw item as a string, skipping Pydantic parsing"""
|
"""Gets the item as a dict, skipping Pydantic parsing"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -49,10 +61,19 @@ class ItemStorageABC(ABC, Generic[T]):
|
|||||||
"""Gets a paginated list of items"""
|
"""Gets a paginated list of items"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def list_as_dict(self, page: int = 0, per_page: int = 10) -> PaginatedDictResults:
|
||||||
|
"""Gets a paginated list of items, skipping Pydantic parsing"""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
|
def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def search_as_dict(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedDictResults:
|
||||||
|
pass
|
||||||
|
|
||||||
def on_changed(self, on_changed: Callable[[T], None]) -> None:
|
def on_changed(self, on_changed: Callable[[T], None]) -> None:
|
||||||
"""Register a callback for when an item is changed"""
|
"""Register a callback for when an item is changed"""
|
||||||
self._on_changed_callbacks.append(on_changed)
|
self._on_changed_callbacks.append(on_changed)
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
|
import json
|
||||||
import sqlite3
|
import sqlite3
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
from typing import Generic, Optional, TypeVar, get_args
|
from typing import Generic, Optional, TypeVar, get_args
|
||||||
|
|
||||||
from pydantic import BaseModel, parse_raw_as
|
from pydantic import BaseModel, parse_obj_as, parse_raw_as
|
||||||
|
|
||||||
from .item_storage import ItemStorageABC, PaginatedResults
|
from .item_storage import ItemStorageABC, PaginatedDictResults, PaginatedResults
|
||||||
|
|
||||||
T = TypeVar("T", bound=BaseModel)
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
|
||||||
@ -47,9 +48,9 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
finally:
|
finally:
|
||||||
self._lock.release()
|
self._lock.release()
|
||||||
|
|
||||||
def _parse_item(self, item: str) -> T:
|
def _parse_item_from_dict(self, item: dict) -> T:
|
||||||
item_type = get_args(self.__orig_class__)[0]
|
item_type = get_args(self.__orig_class__)[0]
|
||||||
parsed = parse_raw_as(item_type, item)
|
parsed = parse_obj_as(item_type, item)
|
||||||
return parsed
|
return parsed
|
||||||
|
|
||||||
def set(self, item: T):
|
def set(self, item: T):
|
||||||
@ -64,31 +65,25 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
self._lock.release()
|
self._lock.release()
|
||||||
self._on_changed(item)
|
self._on_changed(item)
|
||||||
|
|
||||||
|
def get_as_dict(self, id: str) -> Optional[dict]:
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
self._cursor.execute(f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),))
|
||||||
|
result = self._cursor.fetchone()
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
|
||||||
|
if not result:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return json.loads(result[0])
|
||||||
|
|
||||||
def get(self, id: str) -> Optional[T]:
|
def get(self, id: str) -> Optional[T]:
|
||||||
try:
|
item = self.get_as_dict(id)
|
||||||
self._lock.acquire()
|
if not item:
|
||||||
self._cursor.execute(f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),))
|
|
||||||
result = self._cursor.fetchone()
|
|
||||||
finally:
|
|
||||||
self._lock.release()
|
|
||||||
|
|
||||||
if not result:
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return self._parse_item(result[0])
|
return self._parse_item_from_dict(item)
|
||||||
|
|
||||||
def get_raw(self, id: str) -> Optional[str]:
|
|
||||||
try:
|
|
||||||
self._lock.acquire()
|
|
||||||
self._cursor.execute(f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),))
|
|
||||||
result = self._cursor.fetchone()
|
|
||||||
finally:
|
|
||||||
self._lock.release()
|
|
||||||
|
|
||||||
if not result:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return result[0]
|
|
||||||
|
|
||||||
def delete(self, id: str):
|
def delete(self, id: str):
|
||||||
try:
|
try:
|
||||||
@ -99,7 +94,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
self._lock.release()
|
self._lock.release()
|
||||||
self._on_deleted(id)
|
self._on_deleted(id)
|
||||||
|
|
||||||
def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
|
def list_as_dict(self, page: int = 0, per_page: int = 10) -> PaginatedDictResults:
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
@ -108,7 +103,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
)
|
)
|
||||||
result = self._cursor.fetchall()
|
result = self._cursor.fetchall()
|
||||||
|
|
||||||
items = list(map(lambda r: self._parse_item(r[0]), result))
|
items = list(map(lambda r: json.loads(r[0]), result))
|
||||||
|
|
||||||
self._cursor.execute(f"""SELECT count(*) FROM {self._table_name};""")
|
self._cursor.execute(f"""SELECT count(*) FROM {self._table_name};""")
|
||||||
count = self._cursor.fetchone()[0]
|
count = self._cursor.fetchone()[0]
|
||||||
@ -117,9 +112,20 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
|
|
||||||
pageCount = int(count / per_page) + 1
|
pageCount = int(count / per_page) + 1
|
||||||
|
|
||||||
return PaginatedResults[T](items=items, page=page, pages=pageCount, per_page=per_page, total=count)
|
return PaginatedDictResults(items=items, page=page, pages=pageCount, per_page=per_page, total=count)
|
||||||
|
|
||||||
def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
|
def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
|
||||||
|
paginated_raw_results = self.list_as_dict(page, per_page)
|
||||||
|
items = list(map(lambda r: self._parse_item_from_dict(r), paginated_raw_results.items))
|
||||||
|
return PaginatedResults[T](
|
||||||
|
items=items,
|
||||||
|
page=page,
|
||||||
|
pages=paginated_raw_results.pages,
|
||||||
|
per_page=per_page,
|
||||||
|
total=paginated_raw_results.total,
|
||||||
|
)
|
||||||
|
|
||||||
|
def search_as_dict(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedDictResults:
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
@ -128,7 +134,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
)
|
)
|
||||||
result = self._cursor.fetchall()
|
result = self._cursor.fetchall()
|
||||||
|
|
||||||
items = list(map(lambda r: self._parse_item(r[0]), result))
|
items = list(map(lambda r: json.loads(r[0]), result))
|
||||||
|
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
f"""SELECT count(*) FROM {self._table_name} WHERE item LIKE ?;""",
|
f"""SELECT count(*) FROM {self._table_name} WHERE item LIKE ?;""",
|
||||||
@ -140,4 +146,15 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
|
|
||||||
pageCount = int(count / per_page) + 1
|
pageCount = int(count / per_page) + 1
|
||||||
|
|
||||||
return PaginatedResults[T](items=items, page=page, pages=pageCount, per_page=per_page, total=count)
|
return PaginatedDictResults(items=items, page=page, pages=pageCount, per_page=per_page, total=count)
|
||||||
|
|
||||||
|
def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
|
||||||
|
paginated_raw_results = self.search_as_dict(query, page, per_page)
|
||||||
|
items = list(map(lambda r: self._parse_item_from_dict(r), paginated_raw_results.items))
|
||||||
|
return PaginatedResults[T](
|
||||||
|
items=items,
|
||||||
|
page=page,
|
||||||
|
pages=paginated_raw_results.pages,
|
||||||
|
per_page=per_page,
|
||||||
|
total=paginated_raw_results.total,
|
||||||
|
)
|
||||||
|
@ -6,7 +6,7 @@ from pydantic import ValidationError
|
|||||||
from invokeai.app.services.graph import Edge
|
from invokeai.app.services.graph import Edge
|
||||||
|
|
||||||
|
|
||||||
def get_metadata_graph_from_raw_session(session_raw: str) -> Optional[dict]:
|
def get_metadata_graph_from_session_dict(session_dict: dict) -> Optional[dict]:
|
||||||
"""
|
"""
|
||||||
Parses raw session string, returning a dict of the graph.
|
Parses raw session string, returning a dict of the graph.
|
||||||
|
|
||||||
@ -17,7 +17,7 @@ def get_metadata_graph_from_raw_session(session_raw: str) -> Optional[dict]:
|
|||||||
Any validation failure will return None.
|
Any validation failure will return None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
graph = json.loads(session_raw).get("graph", None)
|
graph = session_dict.get("graph", None)
|
||||||
|
|
||||||
# sanity check make sure the graph is at least reasonably shaped
|
# sanity check make sure the graph is at least reasonably shaped
|
||||||
if (
|
if (
|
||||||
|
@ -4755,15 +4755,15 @@ export type components = {
|
|||||||
image_resolution?: number;
|
image_resolution?: number;
|
||||||
};
|
};
|
||||||
/**
|
/**
|
||||||
* PaginatedResults[GraphExecutionState]
|
* PaginatedDictResults
|
||||||
* @description Paginated results
|
* @description Paginated raw results (dict)
|
||||||
*/
|
*/
|
||||||
PaginatedResults_GraphExecutionState_: {
|
PaginatedDictResults: {
|
||||||
/**
|
/**
|
||||||
* Items
|
* Items
|
||||||
* @description Items
|
* @description Items
|
||||||
*/
|
*/
|
||||||
items: (components["schemas"]["GraphExecutionState"])[];
|
items: (Record<string, never>)[];
|
||||||
/**
|
/**
|
||||||
* Page
|
* Page
|
||||||
* @description Current Page
|
* @description Current Page
|
||||||
@ -6193,12 +6193,6 @@ export type components = {
|
|||||||
ui_hidden: boolean;
|
ui_hidden: boolean;
|
||||||
ui_type?: components["schemas"]["UIType"];
|
ui_type?: components["schemas"]["UIType"];
|
||||||
};
|
};
|
||||||
/**
|
|
||||||
* StableDiffusionOnnxModelFormat
|
|
||||||
* @description An enumeration.
|
|
||||||
* @enum {string}
|
|
||||||
*/
|
|
||||||
StableDiffusionOnnxModelFormat: "olive" | "onnx";
|
|
||||||
/**
|
/**
|
||||||
* StableDiffusion1ModelFormat
|
* StableDiffusion1ModelFormat
|
||||||
* @description An enumeration.
|
* @description An enumeration.
|
||||||
@ -6217,6 +6211,12 @@ export type components = {
|
|||||||
* @enum {string}
|
* @enum {string}
|
||||||
*/
|
*/
|
||||||
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
|
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
|
||||||
|
/**
|
||||||
|
* StableDiffusionOnnxModelFormat
|
||||||
|
* @description An enumeration.
|
||||||
|
* @enum {string}
|
||||||
|
*/
|
||||||
|
StableDiffusionOnnxModelFormat: "olive" | "onnx";
|
||||||
/**
|
/**
|
||||||
* StableDiffusionXLModelFormat
|
* StableDiffusionXLModelFormat
|
||||||
* @description An enumeration.
|
* @description An enumeration.
|
||||||
@ -6254,7 +6254,7 @@ export type operations = {
|
|||||||
/** @description Successful Response */
|
/** @description Successful Response */
|
||||||
200: {
|
200: {
|
||||||
content: {
|
content: {
|
||||||
"application/json": components["schemas"]["PaginatedResults_GraphExecutionState_"];
|
"application/json": components["schemas"]["PaginatedDictResults"];
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
/** @description Validation Error */
|
/** @description Validation Error */
|
||||||
@ -6307,7 +6307,7 @@ export type operations = {
|
|||||||
/** @description Successful Response */
|
/** @description Successful Response */
|
||||||
200: {
|
200: {
|
||||||
content: {
|
content: {
|
||||||
"application/json": components["schemas"]["GraphExecutionState"];
|
"application/json": Record<string, never>;
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
/** @description Session not found */
|
/** @description Session not found */
|
||||||
|
Loading…
Reference in New Issue
Block a user