Compare commits

...

1 Commits

Author SHA1 Message Date
b897ca18ce feat(nodes): wip outputs_service 2023-05-17 00:42:27 +10:00
5 changed files with 234 additions and 0 deletions

View File

@ -1,6 +1,8 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import os
from invokeai.app.models.image import ImageField
from invokeai.app.services.outputs_sqlite import OutputsSqliteItemStorage
import invokeai.backend.util.logging as logger
from typing import types
@ -76,6 +78,7 @@ class ApiDependencies:
latents=latents,
images=images,
metadata=metadata,
outputs=OutputsSqliteItemStorage(filename=db_location),
queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](
filename=db_location, table_name="graphs"

View File

@ -106,6 +106,11 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
context.services.images.save(
image_type, image_name, generate_output.image, metadata
)
context.services.outputs.set(image_name, context.graph_execution_state_id)
s = context.services.outputs.get(image_name)
print(s)
return build_image_output(
image_type=image_type,
image_name=image_name,

View File

@ -1,6 +1,8 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
from typing import types
from invokeai.app.services.outputs_session_storage import OutputsSessionStorageABC
from invokeai.app.services.metadata import MetadataServiceBase
from invokeai.backend import ModelManager
@ -17,6 +19,7 @@ class InvocationServices:
events: EventServiceBase
latents: LatentsStorageBase
images: ImageStorageBase
outputs: OutputsSessionStorageABC
metadata: MetadataServiceBase
queue: InvocationQueueABC
model_manager: ModelManager
@ -34,6 +37,7 @@ class InvocationServices:
logger: types.ModuleType,
latents: LatentsStorageBase,
images: ImageStorageBase,
outputs: OutputsSessionStorageABC,
metadata: MetadataServiceBase,
queue: InvocationQueueABC,
graph_library: ItemStorageABC["LibraryGraph"],
@ -46,6 +50,7 @@ class InvocationServices:
self.logger = logger
self.latents = latents
self.images = images
self.outputs = outputs
self.metadata = metadata
self.queue = queue
self.graph_library = graph_library

View File

@ -0,0 +1,59 @@
from abc import ABC, abstractmethod
from typing import Callable, Generic, TypeVar
from pydantic import BaseModel, Field
from pydantic.generics import GenericModel
class PaginatedStringResults(GenericModel):
"""Paginated results"""
#fmt: off
items: list[str] = Field(description="Session IDs")
page: int = Field(description="Current Page")
pages: int = Field(description="Total number of pages")
per_page: int = Field(description="Number of items per page")
total: int = Field(description="Total number of items in result")
#fmt: on
class OutputsSessionStorageABC(ABC):
_on_changed_callbacks: list[Callable[[str], None]]
_on_deleted_callbacks: list[Callable[[str], None]]
def __init__(self) -> None:
self._on_changed_callbacks = list()
self._on_deleted_callbacks = list()
"""Base item storage class"""
@abstractmethod
def get(self, output_id: str) -> str:
pass
@abstractmethod
def set(self, output_id: str, session_id: str) -> None:
pass
@abstractmethod
def list(self, page: int = 0, per_page: int = 10) -> PaginatedStringResults:
pass
@abstractmethod
def search(
self, query: str, page: int = 0, per_page: int = 10
) -> PaginatedStringResults:
pass
def on_changed(self, on_changed: Callable[[str], None]) -> None:
"""Register a callback for when an item is changed"""
self._on_changed_callbacks.append(on_changed)
def on_deleted(self, on_deleted: Callable[[str], None]) -> None:
"""Register a callback for when an item is deleted"""
self._on_deleted_callbacks.append(on_deleted)
def _on_changed(self, foreign_key_value: str) -> None:
for callback in self._on_changed_callbacks:
callback(foreign_key_value)
def _on_deleted(self, item_id: str) -> None:
for callback in self._on_deleted_callbacks:
callback(item_id)

View File

@ -0,0 +1,162 @@
import json
import sqlite3
from threading import Lock
from typing import Union
from invokeai.app.services.outputs_session_storage import (
OutputsSessionStorageABC,
PaginatedStringResults,
)
sqlite_memory = ":memory:"
class OutputsSqliteItemStorage(OutputsSessionStorageABC):
_filename: str
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: Lock
def __init__(
self,
filename: str,
):
super().__init__()
self._filename = filename
self._lock = Lock()
self._conn = sqlite3.connect(
self._filename, check_same_thread=False
) # TODO: figure out a better threading solution
self._cursor = self._conn.cursor()
self._create_table()
def _create_table(self):
try:
self._lock.acquire()
self._cursor.execute(
f"""CREATE TABLE IF NOT EXISTS outputs (
id TEXT NOT NULL PRIMARY KEY,
session_id TEXT NOT NULL
);"""
)
self._cursor.execute(
f"""CREATE UNIQUE INDEX IF NOT EXISTS outputs_id ON outputs(id);"""
)
finally:
self._lock.release()
def set(self, output_id: str, session_id: str):
try:
self._lock.acquire()
self._cursor.execute(
f"""INSERT OR REPLACE INTO outputs (id, session_id) VALUES (?, ?);""",
(output_id, session_id),
)
self._conn.commit()
finally:
self._lock.release()
self._on_changed(output_id)
def get(self, output_id: str) -> Union[str, None]:
try:
self._lock.acquire()
self._cursor.execute(
f"""
SELECT graph_executions.item session
FROM graph_executions
INNER JOIN outputs ON outputs.session_id = graph_executions.id
WHERE outputs.id = ?;
""",
(output_id,),
)
result = self._cursor.fetchone()
finally:
self._lock.release()
if not result:
return None
return result[0]
def delete(self, output_id: str):
try:
self._lock.acquire()
self._cursor.execute(
f"""DELETE FROM outputs WHERE id = ?;""", (str(id),)
)
self._conn.commit()
finally:
self._lock.release()
self._on_deleted(output_id)
def list(self, page: int = 0, per_page: int = 10) -> PaginatedStringResults:
try:
self._lock.acquire()
self._cursor.execute(
f"""
SELECT graph_executions.item session
FROM graph_executions
INNER JOIN outputs ON outputs.session_id = graph_executions.id
LIMIT ? OFFSET ?;
""",
(per_page, page * per_page),
)
result = self._cursor.fetchall()
items = list(map(lambda r: r[0], result))
self._cursor.execute(
f"""
SELECT count(*)
FROM graph_executions
INNER JOIN outputs ON outputs.session_id = graph_executions.id;
""")
count = self._cursor.fetchone()[0]
finally:
self._lock.release()
pageCount = int(count / per_page) + 1
return PaginatedStringResults(
items=items, page=page, pages=pageCount, per_page=per_page, total=count
)
def search(
self, query: str, page: int = 0, per_page: int = 10
) -> PaginatedStringResults:
try:
self._lock.acquire()
self._cursor.execute(
f"""
SELECT graph_executions.item session
FROM graph_executions
INNER JOIN outputs ON outputs.session_id = graph_executions.id
WHERE outputs.id LIKE ? LIMIT ? OFFSET ?;
""",
(f"%{query}%", per_page, page * per_page),
)
result = self._cursor.fetchall()
items = list(map(lambda r: r[0], result))
self._cursor.execute(
f"""
SELECT count(*)
FROM graph_executions
INNER JOIN outputs ON outputs.session_id = graph_executions.id
WHERE outputs.id LIKE ?;
""",
(f"%{query}%",),
)
count = self._cursor.fetchone()[0]
finally:
self._lock.release()
pageCount = int(count / per_page) + 1
return PaginatedStringResults(
items=items, page=page, pages=pageCount, per_page=per_page, total=count
)