mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
2 Commits
fix/diffus
...
feat/nodes
Author | SHA1 | Date | |
---|---|---|---|
f0a9a4fb88 | |||
34b50e11b6 |
@ -18,6 +18,7 @@ from ..services.invoker import Invoker
|
||||
from ..services.processor import DefaultInvocationProcessor
|
||||
from ..services.sqlite import SqliteItemStorage
|
||||
from ..services.metadata import PngMetadataService
|
||||
from ..services.results import SqliteResultsService
|
||||
from .events import FastAPIEventService
|
||||
|
||||
|
||||
@ -69,6 +70,14 @@ class ApiDependencies:
|
||||
# TODO: build a file/path manager?
|
||||
db_location = os.path.join(output_folder, "invokeai.db")
|
||||
|
||||
results = SqliteResultsService(filename=db_location)
|
||||
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
||||
filename=db_location, table_name="graph_executions"
|
||||
)
|
||||
|
||||
graph_execution_manager.on_changed(results.handle_graph_execution_state_change)
|
||||
|
||||
services = InvocationServices(
|
||||
model_manager=get_model_manager(config,logger),
|
||||
events=events,
|
||||
@ -76,13 +85,12 @@ class ApiDependencies:
|
||||
latents=latents,
|
||||
images=images,
|
||||
metadata=metadata,
|
||||
results=results,
|
||||
queue=MemoryInvocationQueue(),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](
|
||||
filename=db_location, table_name="graphs"
|
||||
),
|
||||
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
|
||||
filename=db_location, table_name="graph_executions"
|
||||
),
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
processor=DefaultInvocationProcessor(),
|
||||
restoration=RestorationServices(config,logger),
|
||||
)
|
||||
|
@ -31,7 +31,7 @@ class ImageOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output an image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["image"] = "image"
|
||||
type: Literal["image_output"] = "image_output"
|
||||
image: ImageField = Field(default=None, description="The output image")
|
||||
width: int = Field(description="The width of the image in pixels")
|
||||
height: int = Field(description="The height of the image in pixels")
|
||||
|
@ -10,6 +10,7 @@ from .image_storage import ImageStorageBase
|
||||
from .restoration_services import RestorationServices
|
||||
from .invocation_queue import InvocationQueueABC
|
||||
from .item_storage import ItemStorageABC
|
||||
from invokeai.app.services.results import ResultsServiceABC
|
||||
|
||||
class InvocationServices:
|
||||
"""Services that can be used by invocations"""
|
||||
@ -21,6 +22,7 @@ class InvocationServices:
|
||||
queue: InvocationQueueABC
|
||||
model_manager: ModelManager
|
||||
restoration: RestorationServices
|
||||
results: ResultsServiceABC
|
||||
|
||||
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
|
||||
graph_library: ItemStorageABC["LibraryGraph"]
|
||||
@ -36,6 +38,7 @@ class InvocationServices:
|
||||
images: ImageStorageBase,
|
||||
metadata: MetadataServiceBase,
|
||||
queue: InvocationQueueABC,
|
||||
results: ResultsServiceABC,
|
||||
graph_library: ItemStorageABC["LibraryGraph"],
|
||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
||||
processor: "InvocationProcessorABC",
|
||||
@ -48,6 +51,7 @@ class InvocationServices:
|
||||
self.images = images
|
||||
self.metadata = metadata
|
||||
self.queue = queue
|
||||
self.results = results
|
||||
self.graph_library = graph_library
|
||||
self.graph_execution_manager = graph_execution_manager
|
||||
self.processor = processor
|
||||
|
221
invokeai/app/services/results.py
Normal file
221
invokeai/app/services/results.py
Normal file
@ -0,0 +1,221 @@
|
||||
from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
import json
|
||||
import sqlite3
|
||||
from threading import Lock
|
||||
from typing import Union
|
||||
|
||||
from pydantic import BaseModel, Field, parse_raw_as
|
||||
|
||||
from invokeai.app.models.image import ImageField
|
||||
from invokeai.app.invocations.latent import LatentsField
|
||||
from invokeai.app.services.graph import GraphExecutionState
|
||||
from invokeai.app.services.item_storage import PaginatedResults
|
||||
|
||||
|
||||
class Result(BaseModel):
|
||||
id: str = Field(description="Result ID")
|
||||
session_id: str = Field(description="Session ID")
|
||||
node_id: str = Field(description="Node ID")
|
||||
data: Union[LatentsField, ImageField] = Field(description="The result data")
|
||||
|
||||
|
||||
class ResultWithSession(BaseModel):
|
||||
result: Result = Field(description="The result")
|
||||
session: GraphExecutionState = Field(description="The session")
|
||||
|
||||
|
||||
class ResultsServiceABC(ABC):
|
||||
@abstractmethod
|
||||
def get(self, output_id: str) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list(
|
||||
self, page: int = 0, per_page: int = 10
|
||||
) -> PaginatedResults[ResultWithSession]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search(
|
||||
self, query: str, page: int = 0, per_page: int = 10
|
||||
) -> PaginatedResults[ResultWithSession]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def handle_graph_execution_state_change(self, session: GraphExecutionState) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class SqliteResultsService(ResultsServiceABC):
|
||||
_filename: str
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_lock: Lock
|
||||
|
||||
def __init__(self, filename: str):
|
||||
super().__init__()
|
||||
|
||||
self._filename = filename
|
||||
self._lock = Lock()
|
||||
|
||||
self._conn = sqlite3.connect(
|
||||
self._filename, check_same_thread=False
|
||||
) # TODO: figure out a better threading solution
|
||||
self._cursor = self._conn.cursor()
|
||||
|
||||
self._create_table()
|
||||
|
||||
def _create_table(self):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS results (
|
||||
id TEXT PRIMARY KEY,
|
||||
node_id TEXT,
|
||||
session_id TEXT,
|
||||
data TEXT
|
||||
);
|
||||
"""
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""CREATE UNIQUE INDEX IF NOT EXISTS result_id ON result(id);"""
|
||||
)
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def get(self, id: str) -> Union[ResultWithSession, None]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""
|
||||
SELECT results.data, graph_executions.state
|
||||
FROM results
|
||||
JOIN graph_executions ON results.session_id = graph_executions.id
|
||||
WHERE results.id = ?
|
||||
""",
|
||||
(id,),
|
||||
)
|
||||
|
||||
result_row = self._cursor.fetchone()
|
||||
|
||||
if result_row is None:
|
||||
return None
|
||||
|
||||
result_raw, graph_execution_state_raw = result_row
|
||||
result = parse_raw_as(Result, result_raw)
|
||||
graph_execution_state = parse_raw_as(
|
||||
GraphExecutionState, graph_execution_state_raw
|
||||
)
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
if not result:
|
||||
return None
|
||||
|
||||
return ResultWithSession(result=result, session=graph_execution_state)
|
||||
|
||||
def list(
|
||||
self, page: int = 0, per_page: int = 10
|
||||
) -> PaginatedResults[ResultWithSession]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""
|
||||
SELECT results.data, graph_executions.state
|
||||
FROM results
|
||||
JOIN graph_executions ON results.session_id = graph_executions.id
|
||||
LIMIT ? OFFSET ?;
|
||||
""",
|
||||
(per_page, page * per_page),
|
||||
)
|
||||
|
||||
result_rows = self._cursor.fetchall()
|
||||
|
||||
items = list(
|
||||
map(
|
||||
lambda r: ResultWithSession(
|
||||
result=parse_raw_as(Result, r[0]),
|
||||
session=parse_raw_as(GraphExecutionState, r[1]),
|
||||
),
|
||||
result_rows,
|
||||
)
|
||||
)
|
||||
|
||||
self._cursor.execute("""SELECT count(*) FROM results;""")
|
||||
count = self._cursor.fetchone()[0]
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
pageCount = int(count / per_page) + 1
|
||||
|
||||
return PaginatedResults[ResultWithSession](
|
||||
items=items, page=page, pages=pageCount, per_page=per_page, total=count
|
||||
)
|
||||
|
||||
def search(
|
||||
self, query: str, page: int = 0, per_page: int = 10
|
||||
) -> PaginatedResults[ResultWithSession]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""
|
||||
SELECT results.data, graph_executions.state
|
||||
FROM results
|
||||
JOIN graph_executions ON results.session_id = graph_executions.id
|
||||
WHERE item LIKE ?
|
||||
LIMIT ? OFFSET ?;
|
||||
""",
|
||||
(f"%{query}%", per_page, page * per_page),
|
||||
)
|
||||
|
||||
result_rows = self._cursor.fetchall()
|
||||
|
||||
items = list(
|
||||
map(
|
||||
lambda r: ResultWithSession(
|
||||
result=parse_raw_as(Result, r[0]),
|
||||
session=parse_raw_as(GraphExecutionState, r[1]),
|
||||
),
|
||||
result_rows,
|
||||
)
|
||||
)
|
||||
self._cursor.execute(
|
||||
f"""SELECT count(*) FROM results WHERE item LIKE ?;""",
|
||||
(f"%{query}%",),
|
||||
)
|
||||
count = self._cursor.fetchone()[0]
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
pageCount = int(count / per_page) + 1
|
||||
|
||||
return PaginatedResults[ResultWithSession](
|
||||
items=items, page=page, pages=pageCount, per_page=per_page, total=count
|
||||
)
|
||||
|
||||
def handle_graph_execution_state_change(self, session: GraphExecutionState) -> None:
|
||||
with self._conn as conn:
|
||||
for node_id, result in session.results.items():
|
||||
# We'll only process 'image_output' or 'latents_output'
|
||||
if result["type"] not in ["image_output", "latents_output"]:
|
||||
continue
|
||||
|
||||
# The id depends on the result type
|
||||
if result["type"] == "image_output":
|
||||
id = result["image"]["image_name"]
|
||||
else: # 'latents_output'
|
||||
id = result["latents"]["latents_name"]
|
||||
|
||||
# Stringify the entire result object for the data column
|
||||
data = json.dumps(result)
|
||||
|
||||
# Insert the result into the results table, ignoring if it already exists
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT OR IGNORE INTO results (id, node_id, session_id, data)
|
||||
VALUES (?, ?, ?, ?)
|
||||
""",
|
||||
(id, node_id, session.id, data),
|
||||
)
|
Reference in New Issue
Block a user