Compare commits

...

2 Commits

Author SHA1 Message Date
f0a9a4fb88 feat(nodes): add ResultsServiceABC & SqliteResultsService
**Doesn't actually work bc of circular imports. Can't even test it.**

- add a base class for ResultsService and SQLite implementation
- use `graph_execution_manager` `on_changed` callback to keep `results` table in sync
2023-05-17 19:16:04 +10:00
34b50e11b6 feat(nodes): change ImageOutput type to image_output 2023-05-17 19:13:53 +10:00
4 changed files with 237 additions and 4 deletions

View File

@ -18,6 +18,7 @@ from ..services.invoker import Invoker
from ..services.processor import DefaultInvocationProcessor from ..services.processor import DefaultInvocationProcessor
from ..services.sqlite import SqliteItemStorage from ..services.sqlite import SqliteItemStorage
from ..services.metadata import PngMetadataService from ..services.metadata import PngMetadataService
from ..services.results import SqliteResultsService
from .events import FastAPIEventService from .events import FastAPIEventService
@ -69,6 +70,14 @@ class ApiDependencies:
# TODO: build a file/path manager? # TODO: build a file/path manager?
db_location = os.path.join(output_folder, "invokeai.db") db_location = os.path.join(output_folder, "invokeai.db")
results = SqliteResultsService(filename=db_location)
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
filename=db_location, table_name="graph_executions"
)
graph_execution_manager.on_changed(results.handle_graph_execution_state_change)
services = InvocationServices( services = InvocationServices(
model_manager=get_model_manager(config,logger), model_manager=get_model_manager(config,logger),
events=events, events=events,
@ -76,13 +85,12 @@ class ApiDependencies:
latents=latents, latents=latents,
images=images, images=images,
metadata=metadata, metadata=metadata,
results=results,
queue=MemoryInvocationQueue(), queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph]( graph_library=SqliteItemStorage[LibraryGraph](
filename=db_location, table_name="graphs" filename=db_location, table_name="graphs"
), ),
graph_execution_manager=SqliteItemStorage[GraphExecutionState]( graph_execution_manager=graph_execution_manager,
filename=db_location, table_name="graph_executions"
),
processor=DefaultInvocationProcessor(), processor=DefaultInvocationProcessor(),
restoration=RestorationServices(config,logger), restoration=RestorationServices(config,logger),
) )

View File

@ -31,7 +31,7 @@ class ImageOutput(BaseInvocationOutput):
"""Base class for invocations that output an image""" """Base class for invocations that output an image"""
# fmt: off # fmt: off
type: Literal["image"] = "image" type: Literal["image_output"] = "image_output"
image: ImageField = Field(default=None, description="The output image") image: ImageField = Field(default=None, description="The output image")
width: int = Field(description="The width of the image in pixels") width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels") height: int = Field(description="The height of the image in pixels")

View File

@ -10,6 +10,7 @@ from .image_storage import ImageStorageBase
from .restoration_services import RestorationServices from .restoration_services import RestorationServices
from .invocation_queue import InvocationQueueABC from .invocation_queue import InvocationQueueABC
from .item_storage import ItemStorageABC from .item_storage import ItemStorageABC
from invokeai.app.services.results import ResultsServiceABC
class InvocationServices: class InvocationServices:
"""Services that can be used by invocations""" """Services that can be used by invocations"""
@ -21,6 +22,7 @@ class InvocationServices:
queue: InvocationQueueABC queue: InvocationQueueABC
model_manager: ModelManager model_manager: ModelManager
restoration: RestorationServices restoration: RestorationServices
results: ResultsServiceABC
# NOTE: we must forward-declare any types that include invocations, since invocations can use services # NOTE: we must forward-declare any types that include invocations, since invocations can use services
graph_library: ItemStorageABC["LibraryGraph"] graph_library: ItemStorageABC["LibraryGraph"]
@ -36,6 +38,7 @@ class InvocationServices:
images: ImageStorageBase, images: ImageStorageBase,
metadata: MetadataServiceBase, metadata: MetadataServiceBase,
queue: InvocationQueueABC, queue: InvocationQueueABC,
results: ResultsServiceABC,
graph_library: ItemStorageABC["LibraryGraph"], graph_library: ItemStorageABC["LibraryGraph"],
graph_execution_manager: ItemStorageABC["GraphExecutionState"], graph_execution_manager: ItemStorageABC["GraphExecutionState"],
processor: "InvocationProcessorABC", processor: "InvocationProcessorABC",
@ -48,6 +51,7 @@ class InvocationServices:
self.images = images self.images = images
self.metadata = metadata self.metadata = metadata
self.queue = queue self.queue = queue
self.results = results
self.graph_library = graph_library self.graph_library = graph_library
self.graph_execution_manager = graph_execution_manager self.graph_execution_manager = graph_execution_manager
self.processor = processor self.processor = processor

View File

@ -0,0 +1,221 @@
from __future__ import annotations
from abc import ABC, abstractmethod
import json
import sqlite3
from threading import Lock
from typing import Union
from pydantic import BaseModel, Field, parse_raw_as
from invokeai.app.models.image import ImageField
from invokeai.app.invocations.latent import LatentsField
from invokeai.app.services.graph import GraphExecutionState
from invokeai.app.services.item_storage import PaginatedResults
class Result(BaseModel):
id: str = Field(description="Result ID")
session_id: str = Field(description="Session ID")
node_id: str = Field(description="Node ID")
data: Union[LatentsField, ImageField] = Field(description="The result data")
class ResultWithSession(BaseModel):
result: Result = Field(description="The result")
session: GraphExecutionState = Field(description="The session")
class ResultsServiceABC(ABC):
@abstractmethod
def get(self, output_id: str) -> str:
pass
@abstractmethod
def list(
self, page: int = 0, per_page: int = 10
) -> PaginatedResults[ResultWithSession]:
pass
@abstractmethod
def search(
self, query: str, page: int = 0, per_page: int = 10
) -> PaginatedResults[ResultWithSession]:
pass
@abstractmethod
def handle_graph_execution_state_change(self, session: GraphExecutionState) -> None:
pass
class SqliteResultsService(ResultsServiceABC):
_filename: str
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: Lock
def __init__(self, filename: str):
super().__init__()
self._filename = filename
self._lock = Lock()
self._conn = sqlite3.connect(
self._filename, check_same_thread=False
) # TODO: figure out a better threading solution
self._cursor = self._conn.cursor()
self._create_table()
def _create_table(self):
try:
self._lock.acquire()
self._cursor.execute(
"""
CREATE TABLE IF NOT EXISTS results (
id TEXT PRIMARY KEY,
node_id TEXT,
session_id TEXT,
data TEXT
);
"""
)
self._cursor.execute(
"""CREATE UNIQUE INDEX IF NOT EXISTS result_id ON result(id);"""
)
finally:
self._lock.release()
def get(self, id: str) -> Union[ResultWithSession, None]:
try:
self._lock.acquire()
self._cursor.execute(
"""
SELECT results.data, graph_executions.state
FROM results
JOIN graph_executions ON results.session_id = graph_executions.id
WHERE results.id = ?
""",
(id,),
)
result_row = self._cursor.fetchone()
if result_row is None:
return None
result_raw, graph_execution_state_raw = result_row
result = parse_raw_as(Result, result_raw)
graph_execution_state = parse_raw_as(
GraphExecutionState, graph_execution_state_raw
)
finally:
self._lock.release()
if not result:
return None
return ResultWithSession(result=result, session=graph_execution_state)
def list(
self, page: int = 0, per_page: int = 10
) -> PaginatedResults[ResultWithSession]:
try:
self._lock.acquire()
self._cursor.execute(
"""
SELECT results.data, graph_executions.state
FROM results
JOIN graph_executions ON results.session_id = graph_executions.id
LIMIT ? OFFSET ?;
""",
(per_page, page * per_page),
)
result_rows = self._cursor.fetchall()
items = list(
map(
lambda r: ResultWithSession(
result=parse_raw_as(Result, r[0]),
session=parse_raw_as(GraphExecutionState, r[1]),
),
result_rows,
)
)
self._cursor.execute("""SELECT count(*) FROM results;""")
count = self._cursor.fetchone()[0]
finally:
self._lock.release()
pageCount = int(count / per_page) + 1
return PaginatedResults[ResultWithSession](
items=items, page=page, pages=pageCount, per_page=per_page, total=count
)
def search(
self, query: str, page: int = 0, per_page: int = 10
) -> PaginatedResults[ResultWithSession]:
try:
self._lock.acquire()
self._cursor.execute(
"""
SELECT results.data, graph_executions.state
FROM results
JOIN graph_executions ON results.session_id = graph_executions.id
WHERE item LIKE ?
LIMIT ? OFFSET ?;
""",
(f"%{query}%", per_page, page * per_page),
)
result_rows = self._cursor.fetchall()
items = list(
map(
lambda r: ResultWithSession(
result=parse_raw_as(Result, r[0]),
session=parse_raw_as(GraphExecutionState, r[1]),
),
result_rows,
)
)
self._cursor.execute(
f"""SELECT count(*) FROM results WHERE item LIKE ?;""",
(f"%{query}%",),
)
count = self._cursor.fetchone()[0]
finally:
self._lock.release()
pageCount = int(count / per_page) + 1
return PaginatedResults[ResultWithSession](
items=items, page=page, pages=pageCount, per_page=per_page, total=count
)
def handle_graph_execution_state_change(self, session: GraphExecutionState) -> None:
with self._conn as conn:
for node_id, result in session.results.items():
# We'll only process 'image_output' or 'latents_output'
if result["type"] not in ["image_output", "latents_output"]:
continue
# The id depends on the result type
if result["type"] == "image_output":
id = result["image"]["image_name"]
else: # 'latents_output'
id = result["latents"]["latents_name"]
# Stringify the entire result object for the data column
data = json.dumps(result)
# Insert the result into the results table, ignoring if it already exists
conn.execute(
"""
INSERT OR IGNORE INTO results (id, node_id, session_id, data)
VALUES (?, ?, ?, ?)
""",
(id, node_id, session.id, data),
)