Compare commits

...

2 Commits

Author SHA1 Message Date
f0a9a4fb88 feat(nodes): add ResultsServiceABC & SqliteResultsService
**Doesn't actually work bc of circular imports. Can't even test it.**

- add a base class for ResultsService and SQLite implementation
- use `graph_execution_manager` `on_changed` callback to keep `results` table in sync
2023-05-17 19:16:04 +10:00
34b50e11b6 feat(nodes): change ImageOutput type to image_output 2023-05-17 19:13:53 +10:00
4 changed files with 237 additions and 4 deletions

View File

@ -18,6 +18,7 @@ from ..services.invoker import Invoker
from ..services.processor import DefaultInvocationProcessor
from ..services.sqlite import SqliteItemStorage
from ..services.metadata import PngMetadataService
from ..services.results import SqliteResultsService
from .events import FastAPIEventService
@ -69,6 +70,14 @@ class ApiDependencies:
# TODO: build a file/path manager?
db_location = os.path.join(output_folder, "invokeai.db")
results = SqliteResultsService(filename=db_location)
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
filename=db_location, table_name="graph_executions"
)
graph_execution_manager.on_changed(results.handle_graph_execution_state_change)
services = InvocationServices(
model_manager=get_model_manager(config,logger),
events=events,
@ -76,13 +85,12 @@ class ApiDependencies:
latents=latents,
images=images,
metadata=metadata,
results=results,
queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](
filename=db_location, table_name="graphs"
),
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
filename=db_location, table_name="graph_executions"
),
graph_execution_manager=graph_execution_manager,
processor=DefaultInvocationProcessor(),
restoration=RestorationServices(config,logger),
)

View File

@ -31,7 +31,7 @@ class ImageOutput(BaseInvocationOutput):
"""Base class for invocations that output an image"""
# fmt: off
type: Literal["image"] = "image"
type: Literal["image_output"] = "image_output"
image: ImageField = Field(default=None, description="The output image")
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels")

View File

@ -10,6 +10,7 @@ from .image_storage import ImageStorageBase
from .restoration_services import RestorationServices
from .invocation_queue import InvocationQueueABC
from .item_storage import ItemStorageABC
from invokeai.app.services.results import ResultsServiceABC
class InvocationServices:
"""Services that can be used by invocations"""
@ -21,6 +22,7 @@ class InvocationServices:
queue: InvocationQueueABC
model_manager: ModelManager
restoration: RestorationServices
results: ResultsServiceABC
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
graph_library: ItemStorageABC["LibraryGraph"]
@ -36,6 +38,7 @@ class InvocationServices:
images: ImageStorageBase,
metadata: MetadataServiceBase,
queue: InvocationQueueABC,
results: ResultsServiceABC,
graph_library: ItemStorageABC["LibraryGraph"],
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
processor: "InvocationProcessorABC",
@ -48,6 +51,7 @@ class InvocationServices:
self.images = images
self.metadata = metadata
self.queue = queue
self.results = results
self.graph_library = graph_library
self.graph_execution_manager = graph_execution_manager
self.processor = processor

View File

@ -0,0 +1,221 @@
from __future__ import annotations
from abc import ABC, abstractmethod
import json
import sqlite3
from threading import Lock
from typing import Union
from pydantic import BaseModel, Field, parse_raw_as
from invokeai.app.models.image import ImageField
from invokeai.app.invocations.latent import LatentsField
from invokeai.app.services.graph import GraphExecutionState
from invokeai.app.services.item_storage import PaginatedResults
class Result(BaseModel):
id: str = Field(description="Result ID")
session_id: str = Field(description="Session ID")
node_id: str = Field(description="Node ID")
data: Union[LatentsField, ImageField] = Field(description="The result data")
class ResultWithSession(BaseModel):
result: Result = Field(description="The result")
session: GraphExecutionState = Field(description="The session")
class ResultsServiceABC(ABC):
@abstractmethod
def get(self, output_id: str) -> str:
pass
@abstractmethod
def list(
self, page: int = 0, per_page: int = 10
) -> PaginatedResults[ResultWithSession]:
pass
@abstractmethod
def search(
self, query: str, page: int = 0, per_page: int = 10
) -> PaginatedResults[ResultWithSession]:
pass
@abstractmethod
def handle_graph_execution_state_change(self, session: GraphExecutionState) -> None:
pass
class SqliteResultsService(ResultsServiceABC):
_filename: str
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: Lock
def __init__(self, filename: str):
super().__init__()
self._filename = filename
self._lock = Lock()
self._conn = sqlite3.connect(
self._filename, check_same_thread=False
) # TODO: figure out a better threading solution
self._cursor = self._conn.cursor()
self._create_table()
def _create_table(self):
try:
self._lock.acquire()
self._cursor.execute(
"""
CREATE TABLE IF NOT EXISTS results (
id TEXT PRIMARY KEY,
node_id TEXT,
session_id TEXT,
data TEXT
);
"""
)
self._cursor.execute(
"""CREATE UNIQUE INDEX IF NOT EXISTS result_id ON result(id);"""
)
finally:
self._lock.release()
def get(self, id: str) -> Union[ResultWithSession, None]:
try:
self._lock.acquire()
self._cursor.execute(
"""
SELECT results.data, graph_executions.state
FROM results
JOIN graph_executions ON results.session_id = graph_executions.id
WHERE results.id = ?
""",
(id,),
)
result_row = self._cursor.fetchone()
if result_row is None:
return None
result_raw, graph_execution_state_raw = result_row
result = parse_raw_as(Result, result_raw)
graph_execution_state = parse_raw_as(
GraphExecutionState, graph_execution_state_raw
)
finally:
self._lock.release()
if not result:
return None
return ResultWithSession(result=result, session=graph_execution_state)
def list(
self, page: int = 0, per_page: int = 10
) -> PaginatedResults[ResultWithSession]:
try:
self._lock.acquire()
self._cursor.execute(
"""
SELECT results.data, graph_executions.state
FROM results
JOIN graph_executions ON results.session_id = graph_executions.id
LIMIT ? OFFSET ?;
""",
(per_page, page * per_page),
)
result_rows = self._cursor.fetchall()
items = list(
map(
lambda r: ResultWithSession(
result=parse_raw_as(Result, r[0]),
session=parse_raw_as(GraphExecutionState, r[1]),
),
result_rows,
)
)
self._cursor.execute("""SELECT count(*) FROM results;""")
count = self._cursor.fetchone()[0]
finally:
self._lock.release()
pageCount = int(count / per_page) + 1
return PaginatedResults[ResultWithSession](
items=items, page=page, pages=pageCount, per_page=per_page, total=count
)
def search(
self, query: str, page: int = 0, per_page: int = 10
) -> PaginatedResults[ResultWithSession]:
try:
self._lock.acquire()
self._cursor.execute(
"""
SELECT results.data, graph_executions.state
FROM results
JOIN graph_executions ON results.session_id = graph_executions.id
WHERE item LIKE ?
LIMIT ? OFFSET ?;
""",
(f"%{query}%", per_page, page * per_page),
)
result_rows = self._cursor.fetchall()
items = list(
map(
lambda r: ResultWithSession(
result=parse_raw_as(Result, r[0]),
session=parse_raw_as(GraphExecutionState, r[1]),
),
result_rows,
)
)
self._cursor.execute(
f"""SELECT count(*) FROM results WHERE item LIKE ?;""",
(f"%{query}%",),
)
count = self._cursor.fetchone()[0]
finally:
self._lock.release()
pageCount = int(count / per_page) + 1
return PaginatedResults[ResultWithSession](
items=items, page=page, pages=pageCount, per_page=per_page, total=count
)
def handle_graph_execution_state_change(self, session: GraphExecutionState) -> None:
with self._conn as conn:
for node_id, result in session.results.items():
# We'll only process 'image_output' or 'latents_output'
if result["type"] not in ["image_output", "latents_output"]:
continue
# The id depends on the result type
if result["type"] == "image_output":
id = result["image"]["image_name"]
else: # 'latents_output'
id = result["latents"]["latents_name"]
# Stringify the entire result object for the data column
data = json.dumps(result)
# Insert the result into the results table, ignoring if it already exists
conn.execute(
"""
INSERT OR IGNORE INTO results (id, node_id, session_id, data)
VALUES (?, ?, ?, ?)
""",
(id, node_id, session.id, data),
)