mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
1 Commits
release/ad
...
feat/laten
Author | SHA1 | Date | |
---|---|---|---|
b897ca18ce |
@ -1,6 +1,8 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import os
|
||||
from invokeai.app.models.image import ImageField
|
||||
from invokeai.app.services.outputs_sqlite import OutputsSqliteItemStorage
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from typing import types
|
||||
@ -76,6 +78,7 @@ class ApiDependencies:
|
||||
latents=latents,
|
||||
images=images,
|
||||
metadata=metadata,
|
||||
outputs=OutputsSqliteItemStorage(filename=db_location),
|
||||
queue=MemoryInvocationQueue(),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](
|
||||
filename=db_location, table_name="graphs"
|
||||
|
@ -106,6 +106,11 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
||||
context.services.images.save(
|
||||
image_type, image_name, generate_output.image, metadata
|
||||
)
|
||||
|
||||
context.services.outputs.set(image_name, context.graph_execution_state_id)
|
||||
|
||||
s = context.services.outputs.get(image_name)
|
||||
print(s)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
|
@ -1,6 +1,8 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||
|
||||
from typing import types
|
||||
|
||||
from invokeai.app.services.outputs_session_storage import OutputsSessionStorageABC
|
||||
from invokeai.app.services.metadata import MetadataServiceBase
|
||||
from invokeai.backend import ModelManager
|
||||
|
||||
@ -17,6 +19,7 @@ class InvocationServices:
|
||||
events: EventServiceBase
|
||||
latents: LatentsStorageBase
|
||||
images: ImageStorageBase
|
||||
outputs: OutputsSessionStorageABC
|
||||
metadata: MetadataServiceBase
|
||||
queue: InvocationQueueABC
|
||||
model_manager: ModelManager
|
||||
@ -34,6 +37,7 @@ class InvocationServices:
|
||||
logger: types.ModuleType,
|
||||
latents: LatentsStorageBase,
|
||||
images: ImageStorageBase,
|
||||
outputs: OutputsSessionStorageABC,
|
||||
metadata: MetadataServiceBase,
|
||||
queue: InvocationQueueABC,
|
||||
graph_library: ItemStorageABC["LibraryGraph"],
|
||||
@ -46,6 +50,7 @@ class InvocationServices:
|
||||
self.logger = logger
|
||||
self.latents = latents
|
||||
self.images = images
|
||||
self.outputs = outputs
|
||||
self.metadata = metadata
|
||||
self.queue = queue
|
||||
self.graph_library = graph_library
|
||||
|
59
invokeai/app/services/outputs_session_storage.py
Normal file
59
invokeai/app/services/outputs_session_storage.py
Normal file
@ -0,0 +1,59 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Generic, TypeVar
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.generics import GenericModel
|
||||
|
||||
class PaginatedStringResults(GenericModel):
|
||||
"""Paginated results"""
|
||||
#fmt: off
|
||||
items: list[str] = Field(description="Session IDs")
|
||||
page: int = Field(description="Current Page")
|
||||
pages: int = Field(description="Total number of pages")
|
||||
per_page: int = Field(description="Number of items per page")
|
||||
total: int = Field(description="Total number of items in result")
|
||||
#fmt: on
|
||||
|
||||
class OutputsSessionStorageABC(ABC):
|
||||
_on_changed_callbacks: list[Callable[[str], None]]
|
||||
_on_deleted_callbacks: list[Callable[[str], None]]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._on_changed_callbacks = list()
|
||||
self._on_deleted_callbacks = list()
|
||||
|
||||
"""Base item storage class"""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, output_id: str) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set(self, output_id: str, session_id: str) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list(self, page: int = 0, per_page: int = 10) -> PaginatedStringResults:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search(
|
||||
self, query: str, page: int = 0, per_page: int = 10
|
||||
) -> PaginatedStringResults:
|
||||
pass
|
||||
|
||||
def on_changed(self, on_changed: Callable[[str], None]) -> None:
|
||||
"""Register a callback for when an item is changed"""
|
||||
self._on_changed_callbacks.append(on_changed)
|
||||
|
||||
def on_deleted(self, on_deleted: Callable[[str], None]) -> None:
|
||||
"""Register a callback for when an item is deleted"""
|
||||
self._on_deleted_callbacks.append(on_deleted)
|
||||
|
||||
def _on_changed(self, foreign_key_value: str) -> None:
|
||||
for callback in self._on_changed_callbacks:
|
||||
callback(foreign_key_value)
|
||||
|
||||
def _on_deleted(self, item_id: str) -> None:
|
||||
for callback in self._on_deleted_callbacks:
|
||||
callback(item_id)
|
162
invokeai/app/services/outputs_sqlite.py
Normal file
162
invokeai/app/services/outputs_sqlite.py
Normal file
@ -0,0 +1,162 @@
|
||||
import json
|
||||
import sqlite3
|
||||
from threading import Lock
|
||||
from typing import Union
|
||||
|
||||
from invokeai.app.services.outputs_session_storage import (
|
||||
OutputsSessionStorageABC,
|
||||
PaginatedStringResults,
|
||||
)
|
||||
|
||||
sqlite_memory = ":memory:"
|
||||
|
||||
|
||||
class OutputsSqliteItemStorage(OutputsSessionStorageABC):
|
||||
_filename: str
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_lock: Lock
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
filename: str,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self._filename = filename
|
||||
self._lock = Lock()
|
||||
|
||||
self._conn = sqlite3.connect(
|
||||
self._filename, check_same_thread=False
|
||||
) # TODO: figure out a better threading solution
|
||||
self._cursor = self._conn.cursor()
|
||||
|
||||
self._create_table()
|
||||
|
||||
def _create_table(self):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
f"""CREATE TABLE IF NOT EXISTS outputs (
|
||||
id TEXT NOT NULL PRIMARY KEY,
|
||||
session_id TEXT NOT NULL
|
||||
);"""
|
||||
)
|
||||
self._cursor.execute(
|
||||
f"""CREATE UNIQUE INDEX IF NOT EXISTS outputs_id ON outputs(id);"""
|
||||
)
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def set(self, output_id: str, session_id: str):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
f"""INSERT OR REPLACE INTO outputs (id, session_id) VALUES (?, ?);""",
|
||||
(output_id, session_id),
|
||||
)
|
||||
self._conn.commit()
|
||||
finally:
|
||||
self._lock.release()
|
||||
self._on_changed(output_id)
|
||||
|
||||
def get(self, output_id: str) -> Union[str, None]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
f"""
|
||||
SELECT graph_executions.item session
|
||||
FROM graph_executions
|
||||
INNER JOIN outputs ON outputs.session_id = graph_executions.id
|
||||
WHERE outputs.id = ?;
|
||||
""",
|
||||
(output_id,),
|
||||
)
|
||||
result = self._cursor.fetchone()
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
if not result:
|
||||
return None
|
||||
|
||||
return result[0]
|
||||
|
||||
def delete(self, output_id: str):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
f"""DELETE FROM outputs WHERE id = ?;""", (str(id),)
|
||||
)
|
||||
self._conn.commit()
|
||||
finally:
|
||||
self._lock.release()
|
||||
self._on_deleted(output_id)
|
||||
|
||||
def list(self, page: int = 0, per_page: int = 10) -> PaginatedStringResults:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
f"""
|
||||
SELECT graph_executions.item session
|
||||
FROM graph_executions
|
||||
INNER JOIN outputs ON outputs.session_id = graph_executions.id
|
||||
LIMIT ? OFFSET ?;
|
||||
""",
|
||||
(per_page, page * per_page),
|
||||
)
|
||||
result = self._cursor.fetchall()
|
||||
|
||||
items = list(map(lambda r: r[0], result))
|
||||
|
||||
self._cursor.execute(
|
||||
f"""
|
||||
SELECT count(*)
|
||||
FROM graph_executions
|
||||
INNER JOIN outputs ON outputs.session_id = graph_executions.id;
|
||||
""")
|
||||
count = self._cursor.fetchone()[0]
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
pageCount = int(count / per_page) + 1
|
||||
|
||||
return PaginatedStringResults(
|
||||
items=items, page=page, pages=pageCount, per_page=per_page, total=count
|
||||
)
|
||||
|
||||
def search(
|
||||
self, query: str, page: int = 0, per_page: int = 10
|
||||
) -> PaginatedStringResults:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
f"""
|
||||
SELECT graph_executions.item session
|
||||
FROM graph_executions
|
||||
INNER JOIN outputs ON outputs.session_id = graph_executions.id
|
||||
WHERE outputs.id LIKE ? LIMIT ? OFFSET ?;
|
||||
""",
|
||||
(f"%{query}%", per_page, page * per_page),
|
||||
)
|
||||
result = self._cursor.fetchall()
|
||||
|
||||
items = list(map(lambda r: r[0], result))
|
||||
|
||||
self._cursor.execute(
|
||||
f"""
|
||||
SELECT count(*)
|
||||
FROM graph_executions
|
||||
INNER JOIN outputs ON outputs.session_id = graph_executions.id
|
||||
WHERE outputs.id LIKE ?;
|
||||
""",
|
||||
(f"%{query}%",),
|
||||
)
|
||||
count = self._cursor.fetchone()[0]
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
pageCount = int(count / per_page) + 1
|
||||
|
||||
return PaginatedStringResults(
|
||||
items=items, page=page, pages=pageCount, per_page=per_page, total=count
|
||||
)
|
Reference in New Issue
Block a user