mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
2 Commits
fix/diffus
...
feat/nodes
Author | SHA1 | Date | |
---|---|---|---|
f0a9a4fb88 | |||
34b50e11b6 |
@ -18,6 +18,7 @@ from ..services.invoker import Invoker
|
|||||||
from ..services.processor import DefaultInvocationProcessor
|
from ..services.processor import DefaultInvocationProcessor
|
||||||
from ..services.sqlite import SqliteItemStorage
|
from ..services.sqlite import SqliteItemStorage
|
||||||
from ..services.metadata import PngMetadataService
|
from ..services.metadata import PngMetadataService
|
||||||
|
from ..services.results import SqliteResultsService
|
||||||
from .events import FastAPIEventService
|
from .events import FastAPIEventService
|
||||||
|
|
||||||
|
|
||||||
@ -69,6 +70,14 @@ class ApiDependencies:
|
|||||||
# TODO: build a file/path manager?
|
# TODO: build a file/path manager?
|
||||||
db_location = os.path.join(output_folder, "invokeai.db")
|
db_location = os.path.join(output_folder, "invokeai.db")
|
||||||
|
|
||||||
|
results = SqliteResultsService(filename=db_location)
|
||||||
|
|
||||||
|
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
||||||
|
filename=db_location, table_name="graph_executions"
|
||||||
|
)
|
||||||
|
|
||||||
|
graph_execution_manager.on_changed(results.handle_graph_execution_state_change)
|
||||||
|
|
||||||
services = InvocationServices(
|
services = InvocationServices(
|
||||||
model_manager=get_model_manager(config,logger),
|
model_manager=get_model_manager(config,logger),
|
||||||
events=events,
|
events=events,
|
||||||
@ -76,13 +85,12 @@ class ApiDependencies:
|
|||||||
latents=latents,
|
latents=latents,
|
||||||
images=images,
|
images=images,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
|
results=results,
|
||||||
queue=MemoryInvocationQueue(),
|
queue=MemoryInvocationQueue(),
|
||||||
graph_library=SqliteItemStorage[LibraryGraph](
|
graph_library=SqliteItemStorage[LibraryGraph](
|
||||||
filename=db_location, table_name="graphs"
|
filename=db_location, table_name="graphs"
|
||||||
),
|
),
|
||||||
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
|
graph_execution_manager=graph_execution_manager,
|
||||||
filename=db_location, table_name="graph_executions"
|
|
||||||
),
|
|
||||||
processor=DefaultInvocationProcessor(),
|
processor=DefaultInvocationProcessor(),
|
||||||
restoration=RestorationServices(config,logger),
|
restoration=RestorationServices(config,logger),
|
||||||
)
|
)
|
||||||
|
@ -31,7 +31,7 @@ class ImageOutput(BaseInvocationOutput):
|
|||||||
"""Base class for invocations that output an image"""
|
"""Base class for invocations that output an image"""
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["image"] = "image"
|
type: Literal["image_output"] = "image_output"
|
||||||
image: ImageField = Field(default=None, description="The output image")
|
image: ImageField = Field(default=None, description="The output image")
|
||||||
width: int = Field(description="The width of the image in pixels")
|
width: int = Field(description="The width of the image in pixels")
|
||||||
height: int = Field(description="The height of the image in pixels")
|
height: int = Field(description="The height of the image in pixels")
|
||||||
|
@ -10,6 +10,7 @@ from .image_storage import ImageStorageBase
|
|||||||
from .restoration_services import RestorationServices
|
from .restoration_services import RestorationServices
|
||||||
from .invocation_queue import InvocationQueueABC
|
from .invocation_queue import InvocationQueueABC
|
||||||
from .item_storage import ItemStorageABC
|
from .item_storage import ItemStorageABC
|
||||||
|
from invokeai.app.services.results import ResultsServiceABC
|
||||||
|
|
||||||
class InvocationServices:
|
class InvocationServices:
|
||||||
"""Services that can be used by invocations"""
|
"""Services that can be used by invocations"""
|
||||||
@ -21,6 +22,7 @@ class InvocationServices:
|
|||||||
queue: InvocationQueueABC
|
queue: InvocationQueueABC
|
||||||
model_manager: ModelManager
|
model_manager: ModelManager
|
||||||
restoration: RestorationServices
|
restoration: RestorationServices
|
||||||
|
results: ResultsServiceABC
|
||||||
|
|
||||||
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
|
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
|
||||||
graph_library: ItemStorageABC["LibraryGraph"]
|
graph_library: ItemStorageABC["LibraryGraph"]
|
||||||
@ -36,6 +38,7 @@ class InvocationServices:
|
|||||||
images: ImageStorageBase,
|
images: ImageStorageBase,
|
||||||
metadata: MetadataServiceBase,
|
metadata: MetadataServiceBase,
|
||||||
queue: InvocationQueueABC,
|
queue: InvocationQueueABC,
|
||||||
|
results: ResultsServiceABC,
|
||||||
graph_library: ItemStorageABC["LibraryGraph"],
|
graph_library: ItemStorageABC["LibraryGraph"],
|
||||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
||||||
processor: "InvocationProcessorABC",
|
processor: "InvocationProcessorABC",
|
||||||
@ -48,6 +51,7 @@ class InvocationServices:
|
|||||||
self.images = images
|
self.images = images
|
||||||
self.metadata = metadata
|
self.metadata = metadata
|
||||||
self.queue = queue
|
self.queue = queue
|
||||||
|
self.results = results
|
||||||
self.graph_library = graph_library
|
self.graph_library = graph_library
|
||||||
self.graph_execution_manager = graph_execution_manager
|
self.graph_execution_manager = graph_execution_manager
|
||||||
self.processor = processor
|
self.processor = processor
|
||||||
|
221
invokeai/app/services/results.py
Normal file
221
invokeai/app/services/results.py
Normal file
@ -0,0 +1,221 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
import json
|
||||||
|
import sqlite3
|
||||||
|
from threading import Lock
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, parse_raw_as
|
||||||
|
|
||||||
|
from invokeai.app.models.image import ImageField
|
||||||
|
from invokeai.app.invocations.latent import LatentsField
|
||||||
|
from invokeai.app.services.graph import GraphExecutionState
|
||||||
|
from invokeai.app.services.item_storage import PaginatedResults
|
||||||
|
|
||||||
|
|
||||||
|
class Result(BaseModel):
|
||||||
|
id: str = Field(description="Result ID")
|
||||||
|
session_id: str = Field(description="Session ID")
|
||||||
|
node_id: str = Field(description="Node ID")
|
||||||
|
data: Union[LatentsField, ImageField] = Field(description="The result data")
|
||||||
|
|
||||||
|
|
||||||
|
class ResultWithSession(BaseModel):
|
||||||
|
result: Result = Field(description="The result")
|
||||||
|
session: GraphExecutionState = Field(description="The session")
|
||||||
|
|
||||||
|
|
||||||
|
class ResultsServiceABC(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def get(self, output_id: str) -> str:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def list(
|
||||||
|
self, page: int = 0, per_page: int = 10
|
||||||
|
) -> PaginatedResults[ResultWithSession]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def search(
|
||||||
|
self, query: str, page: int = 0, per_page: int = 10
|
||||||
|
) -> PaginatedResults[ResultWithSession]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def handle_graph_execution_state_change(self, session: GraphExecutionState) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SqliteResultsService(ResultsServiceABC):
|
||||||
|
_filename: str
|
||||||
|
_conn: sqlite3.Connection
|
||||||
|
_cursor: sqlite3.Cursor
|
||||||
|
_lock: Lock
|
||||||
|
|
||||||
|
def __init__(self, filename: str):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self._filename = filename
|
||||||
|
self._lock = Lock()
|
||||||
|
|
||||||
|
self._conn = sqlite3.connect(
|
||||||
|
self._filename, check_same_thread=False
|
||||||
|
) # TODO: figure out a better threading solution
|
||||||
|
self._cursor = self._conn.cursor()
|
||||||
|
|
||||||
|
self._create_table()
|
||||||
|
|
||||||
|
def _create_table(self):
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
self._cursor.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS results (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
node_id TEXT,
|
||||||
|
session_id TEXT,
|
||||||
|
data TEXT
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
self._cursor.execute(
|
||||||
|
"""CREATE UNIQUE INDEX IF NOT EXISTS result_id ON result(id);"""
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
|
||||||
|
def get(self, id: str) -> Union[ResultWithSession, None]:
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
self._cursor.execute(
|
||||||
|
"""
|
||||||
|
SELECT results.data, graph_executions.state
|
||||||
|
FROM results
|
||||||
|
JOIN graph_executions ON results.session_id = graph_executions.id
|
||||||
|
WHERE results.id = ?
|
||||||
|
""",
|
||||||
|
(id,),
|
||||||
|
)
|
||||||
|
|
||||||
|
result_row = self._cursor.fetchone()
|
||||||
|
|
||||||
|
if result_row is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
result_raw, graph_execution_state_raw = result_row
|
||||||
|
result = parse_raw_as(Result, result_raw)
|
||||||
|
graph_execution_state = parse_raw_as(
|
||||||
|
GraphExecutionState, graph_execution_state_raw
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
|
||||||
|
if not result:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return ResultWithSession(result=result, session=graph_execution_state)
|
||||||
|
|
||||||
|
def list(
|
||||||
|
self, page: int = 0, per_page: int = 10
|
||||||
|
) -> PaginatedResults[ResultWithSession]:
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
self._cursor.execute(
|
||||||
|
"""
|
||||||
|
SELECT results.data, graph_executions.state
|
||||||
|
FROM results
|
||||||
|
JOIN graph_executions ON results.session_id = graph_executions.id
|
||||||
|
LIMIT ? OFFSET ?;
|
||||||
|
""",
|
||||||
|
(per_page, page * per_page),
|
||||||
|
)
|
||||||
|
|
||||||
|
result_rows = self._cursor.fetchall()
|
||||||
|
|
||||||
|
items = list(
|
||||||
|
map(
|
||||||
|
lambda r: ResultWithSession(
|
||||||
|
result=parse_raw_as(Result, r[0]),
|
||||||
|
session=parse_raw_as(GraphExecutionState, r[1]),
|
||||||
|
),
|
||||||
|
result_rows,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self._cursor.execute("""SELECT count(*) FROM results;""")
|
||||||
|
count = self._cursor.fetchone()[0]
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
|
||||||
|
pageCount = int(count / per_page) + 1
|
||||||
|
|
||||||
|
return PaginatedResults[ResultWithSession](
|
||||||
|
items=items, page=page, pages=pageCount, per_page=per_page, total=count
|
||||||
|
)
|
||||||
|
|
||||||
|
def search(
|
||||||
|
self, query: str, page: int = 0, per_page: int = 10
|
||||||
|
) -> PaginatedResults[ResultWithSession]:
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
self._cursor.execute(
|
||||||
|
"""
|
||||||
|
SELECT results.data, graph_executions.state
|
||||||
|
FROM results
|
||||||
|
JOIN graph_executions ON results.session_id = graph_executions.id
|
||||||
|
WHERE item LIKE ?
|
||||||
|
LIMIT ? OFFSET ?;
|
||||||
|
""",
|
||||||
|
(f"%{query}%", per_page, page * per_page),
|
||||||
|
)
|
||||||
|
|
||||||
|
result_rows = self._cursor.fetchall()
|
||||||
|
|
||||||
|
items = list(
|
||||||
|
map(
|
||||||
|
lambda r: ResultWithSession(
|
||||||
|
result=parse_raw_as(Result, r[0]),
|
||||||
|
session=parse_raw_as(GraphExecutionState, r[1]),
|
||||||
|
),
|
||||||
|
result_rows,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self._cursor.execute(
|
||||||
|
f"""SELECT count(*) FROM results WHERE item LIKE ?;""",
|
||||||
|
(f"%{query}%",),
|
||||||
|
)
|
||||||
|
count = self._cursor.fetchone()[0]
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
|
||||||
|
pageCount = int(count / per_page) + 1
|
||||||
|
|
||||||
|
return PaginatedResults[ResultWithSession](
|
||||||
|
items=items, page=page, pages=pageCount, per_page=per_page, total=count
|
||||||
|
)
|
||||||
|
|
||||||
|
def handle_graph_execution_state_change(self, session: GraphExecutionState) -> None:
|
||||||
|
with self._conn as conn:
|
||||||
|
for node_id, result in session.results.items():
|
||||||
|
# We'll only process 'image_output' or 'latents_output'
|
||||||
|
if result["type"] not in ["image_output", "latents_output"]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# The id depends on the result type
|
||||||
|
if result["type"] == "image_output":
|
||||||
|
id = result["image"]["image_name"]
|
||||||
|
else: # 'latents_output'
|
||||||
|
id = result["latents"]["latents_name"]
|
||||||
|
|
||||||
|
# Stringify the entire result object for the data column
|
||||||
|
data = json.dumps(result)
|
||||||
|
|
||||||
|
# Insert the result into the results table, ignoring if it already exists
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT OR IGNORE INTO results (id, node_id, session_id, data)
|
||||||
|
VALUES (?, ?, ?, ?)
|
||||||
|
""",
|
||||||
|
(id, node_id, session.id, data),
|
||||||
|
)
|
Reference in New Issue
Block a user