mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
More testing
This commit is contained in:
parent
e26e4740b3
commit
e751f7d815
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
|
import threading
|
||||||
import os
|
import os
|
||||||
from invokeai.app.services.board_image_record_storage import (
|
from invokeai.app.services.board_image_record_storage import (
|
||||||
SqliteBoardImageRecordStorage,
|
SqliteBoardImageRecordStorage,
|
||||||
@ -65,6 +66,8 @@ class ApiDependencies:
|
|||||||
logger.info(f"Root directory = {str(config.root_path)}")
|
logger.info(f"Root directory = {str(config.root_path)}")
|
||||||
logger.debug(f"Internet connectivity is {config.internet_available}")
|
logger.debug(f"Internet connectivity is {config.internet_available}")
|
||||||
|
|
||||||
|
lock = threading.Lock()
|
||||||
|
|
||||||
events = FastAPIEventService(event_handler_id)
|
events = FastAPIEventService(event_handler_id)
|
||||||
|
|
||||||
output_folder = config.output_path
|
output_folder = config.output_path
|
||||||
@ -74,17 +77,17 @@ class ApiDependencies:
|
|||||||
db_location.parent.mkdir(parents=True, exist_ok=True)
|
db_location.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
||||||
filename=db_location, table_name="graph_executions"
|
filename=db_location, table_name="graph_executions", lock=lock
|
||||||
)
|
)
|
||||||
|
|
||||||
urls = LocalUrlService()
|
urls = LocalUrlService()
|
||||||
image_record_storage = SqliteImageRecordStorage(db_location)
|
image_record_storage = SqliteImageRecordStorage(db_location, lock=lock)
|
||||||
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
||||||
names = SimpleNameService()
|
names = SimpleNameService()
|
||||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
|
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
|
||||||
|
|
||||||
board_record_storage = SqliteBoardRecordStorage(db_location)
|
board_record_storage = SqliteBoardRecordStorage(db_location, lock=lock)
|
||||||
board_image_record_storage = SqliteBoardImageRecordStorage(db_location)
|
board_image_record_storage = SqliteBoardImageRecordStorage(db_location, lock=lock)
|
||||||
|
|
||||||
boards = BoardService(
|
boards = BoardService(
|
||||||
services=BoardServiceDependencies(
|
services=BoardServiceDependencies(
|
||||||
@ -118,7 +121,7 @@ class ApiDependencies:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
batch_manager_storage = SqliteBatchProcessStorage(db_location)
|
batch_manager_storage = SqliteBatchProcessStorage(db_location, lock=lock)
|
||||||
batch_manager = BatchManager(batch_manager_storage)
|
batch_manager = BatchManager(batch_manager_storage)
|
||||||
|
|
||||||
services = InvocationServices(
|
services = InvocationServices(
|
||||||
@ -130,7 +133,7 @@ class ApiDependencies:
|
|||||||
boards=boards,
|
boards=boards,
|
||||||
board_images=board_images,
|
board_images=board_images,
|
||||||
queue=MemoryInvocationQueue(),
|
queue=MemoryInvocationQueue(),
|
||||||
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),
|
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs", lock=lock),
|
||||||
graph_execution_manager=graph_execution_manager,
|
graph_execution_manager=graph_execution_manager,
|
||||||
processor=DefaultInvocationProcessor(),
|
processor=DefaultInvocationProcessor(),
|
||||||
configuration=config,
|
configuration=config,
|
||||||
|
@ -38,6 +38,7 @@ from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
|||||||
from invokeai.app.services.resource_name import SimpleNameService
|
from invokeai.app.services.resource_name import SimpleNameService
|
||||||
from invokeai.app.services.urls import LocalUrlService
|
from invokeai.app.services.urls import LocalUrlService
|
||||||
from invokeai.app.services.batch_manager import BatchManager
|
from invokeai.app.services.batch_manager import BatchManager
|
||||||
|
from invokeai.app.services.batch_manager_storage import SqliteBatchProcessStorage
|
||||||
from .services.default_graphs import default_text_to_image_graph_id, create_system_graphs
|
from .services.default_graphs import default_text_to_image_graph_id, create_system_graphs
|
||||||
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||||
|
|
||||||
@ -301,13 +302,16 @@ def invoke_cli():
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
batch_manager_storage = SqliteBatchProcessStorage(db_location)
|
||||||
|
batch_manager = BatchManager(batch_manager_storage)
|
||||||
|
|
||||||
services = InvocationServices(
|
services = InvocationServices(
|
||||||
model_manager=model_manager,
|
model_manager=model_manager,
|
||||||
events=events,
|
events=events,
|
||||||
latents=ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents")),
|
latents=ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents")),
|
||||||
images=images,
|
images=images,
|
||||||
boards=boards,
|
boards=boards,
|
||||||
batch_manager=BatchManager(),
|
batch_manager=batch_manager,
|
||||||
board_images=board_images,
|
board_images=board_images,
|
||||||
queue=MemoryInvocationQueue(),
|
queue=MemoryInvocationQueue(),
|
||||||
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),
|
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),
|
||||||
|
@ -12,6 +12,7 @@ from invokeai.app.services.graph import Graph, GraphExecutionState
|
|||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
from invokeai.app.services.batch_manager_storage import (
|
from invokeai.app.services.batch_manager_storage import (
|
||||||
BatchProcessStorageBase,
|
BatchProcessStorageBase,
|
||||||
|
BatchSessionNotFoundException,
|
||||||
Batch,
|
Batch,
|
||||||
BatchProcess,
|
BatchProcess,
|
||||||
BatchSession,
|
BatchSession,
|
||||||
@ -98,7 +99,10 @@ class BatchManager(BatchManagerBase):
|
|||||||
return GraphExecutionState(graph=graph)
|
return GraphExecutionState(graph=graph)
|
||||||
|
|
||||||
def run_batch_process(self, batch_id: str):
|
def run_batch_process(self, batch_id: str):
|
||||||
created_session = self.__batch_process_storage.get_created_session(batch_id)
|
try:
|
||||||
|
created_session = self.__batch_process_storage.get_created_session(batch_id)
|
||||||
|
except BatchSessionNotFoundException:
|
||||||
|
return
|
||||||
ges = self.__invoker.services.graph_execution_manager.get(created_session.session_id)
|
ges = self.__invoker.services.graph_execution_manager.get(created_session.session_id)
|
||||||
self.__invoker.invoke(ges, invoke_all=True)
|
self.__invoker.invoke(ges, invoke_all=True)
|
||||||
|
|
||||||
|
@ -15,9 +15,8 @@ import json
|
|||||||
from invokeai.app.invocations.baseinvocation import (
|
from invokeai.app.invocations.baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.graph import Graph, GraphExecutionState
|
from invokeai.app.services.graph import Graph
|
||||||
from invokeai.app.models.image import ImageField
|
from invokeai.app.models.image import ImageField
|
||||||
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, Extra, parse_raw_as
|
from pydantic import BaseModel, Field, Extra, parse_raw_as
|
||||||
|
|
||||||
@ -169,14 +168,14 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
|||||||
_cursor: sqlite3.Cursor
|
_cursor: sqlite3.Cursor
|
||||||
_lock: threading.Lock
|
_lock: threading.Lock
|
||||||
|
|
||||||
def __init__(self, filename: str) -> None:
|
def __init__(self, filename: str, lock: threading.Lock = threading.Lock()) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._filename = filename
|
self._filename = filename
|
||||||
self._conn = sqlite3.connect(filename, check_same_thread=False)
|
self._conn = sqlite3.connect(filename, check_same_thread=False)
|
||||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||||
self._conn.row_factory = sqlite3.Row
|
self._conn.row_factory = sqlite3.Row
|
||||||
self._cursor = self._conn.cursor()
|
self._cursor = self._conn.cursor()
|
||||||
self._lock = threading.Lock()
|
self._lock = lock
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
@ -231,7 +230,7 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
|||||||
CREATE TABLE IF NOT EXISTS batch_session (
|
CREATE TABLE IF NOT EXISTS batch_session (
|
||||||
batch_id TEXT NOT NULL,
|
batch_id TEXT NOT NULL,
|
||||||
session_id TEXT NOT NULL,
|
session_id TEXT NOT NULL,
|
||||||
state TEXT NOT NULL DEFAULT('created'),
|
state TEXT NOT NULL,
|
||||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||||
-- updated via trigger
|
-- updated via trigger
|
||||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||||
@ -267,7 +266,7 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
|||||||
ON batch_session FOR EACH ROW
|
ON batch_session FOR EACH ROW
|
||||||
BEGIN
|
BEGIN
|
||||||
UPDATE batch_session SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
UPDATE batch_session SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||||
WHERE batch_id = old.batch_id AND image_name = old.image_name;
|
WHERE batch_id = old.batch_id AND session_id = old.session_id;
|
||||||
END;
|
END;
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
@ -298,12 +297,13 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
|||||||
) -> BatchProcess:
|
) -> BatchProcess:
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
|
batches = [batch.json() for batch in batch_process.batches]
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
INSERT OR IGNORE INTO batch_process (batch_id, batches, graph)
|
INSERT OR IGNORE INTO batch_process (batch_id, batches, graph)
|
||||||
VALUES (?, ?, ?);
|
VALUES (?, ?, ?);
|
||||||
""",
|
""",
|
||||||
(batch_process.batch_id, json.dumps([batch.json() for batch in batch_process.batches]), batch_process.graph.json()),
|
(batch_process.batch_id, json.dumps(batches), batch_process.graph.json()),
|
||||||
)
|
)
|
||||||
self._conn.commit()
|
self._conn.commit()
|
||||||
except sqlite3.Error as e:
|
except sqlite3.Error as e:
|
||||||
@ -321,10 +321,11 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
|||||||
batch_id = session_dict.get("batch_id", "unknown")
|
batch_id = session_dict.get("batch_id", "unknown")
|
||||||
batches_raw = session_dict.get("batches", "unknown")
|
batches_raw = session_dict.get("batches", "unknown")
|
||||||
graph_raw = session_dict.get("graph", "unknown")
|
graph_raw = session_dict.get("graph", "unknown")
|
||||||
|
batches = json.loads(batches_raw)
|
||||||
|
batches = [parse_raw_as(Batch, batch) for batch in batches]
|
||||||
return BatchProcess(
|
return BatchProcess(
|
||||||
batch_id=batch_id,
|
batch_id=batch_id,
|
||||||
batches=[parse_raw_as(Batch, batch) for batch in json.loads(batches_raw)],
|
batches=batches,
|
||||||
graph=parse_raw_as(Graph, graph_raw),
|
graph=parse_raw_as(Graph, graph_raw),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -398,7 +399,7 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
|||||||
self._lock.release()
|
self._lock.release()
|
||||||
if result is None:
|
if result is None:
|
||||||
raise BatchSessionNotFoundException
|
raise BatchSessionNotFoundException
|
||||||
return BatchSession(**dict(result))
|
return self._deserialize_batch_session(dict(result))
|
||||||
|
|
||||||
def _deserialize_batch_session(self, session_dict: dict) -> BatchSession:
|
def _deserialize_batch_session(self, session_dict: dict) -> BatchSession:
|
||||||
"""Deserializes a batch session."""
|
"""Deserializes a batch session."""
|
||||||
|
@ -62,14 +62,14 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
|||||||
_cursor: sqlite3.Cursor
|
_cursor: sqlite3.Cursor
|
||||||
_lock: threading.Lock
|
_lock: threading.Lock
|
||||||
|
|
||||||
def __init__(self, filename: str) -> None:
|
def __init__(self, filename: str, lock: threading.Lock = threading.Lock()) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._filename = filename
|
self._filename = filename
|
||||||
self._conn = sqlite3.connect(filename, check_same_thread=False)
|
self._conn = sqlite3.connect(filename, check_same_thread=False)
|
||||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||||
self._conn.row_factory = sqlite3.Row
|
self._conn.row_factory = sqlite3.Row
|
||||||
self._cursor = self._conn.cursor()
|
self._cursor = self._conn.cursor()
|
||||||
self._lock = threading.Lock()
|
self._lock = lock
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
|
@ -95,14 +95,14 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
|||||||
_cursor: sqlite3.Cursor
|
_cursor: sqlite3.Cursor
|
||||||
_lock: threading.Lock
|
_lock: threading.Lock
|
||||||
|
|
||||||
def __init__(self, filename: str) -> None:
|
def __init__(self, filename: str, lock: threading.Lock = threading.Lock()) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._filename = filename
|
self._filename = filename
|
||||||
self._conn = sqlite3.connect(filename, check_same_thread=False)
|
self._conn = sqlite3.connect(filename, check_same_thread=False)
|
||||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||||
self._conn.row_factory = sqlite3.Row
|
self._conn.row_factory = sqlite3.Row
|
||||||
self._cursor = self._conn.cursor()
|
self._cursor = self._conn.cursor()
|
||||||
self._lock = threading.Lock()
|
self._lock = lock
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
|
@ -155,14 +155,14 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
_cursor: sqlite3.Cursor
|
_cursor: sqlite3.Cursor
|
||||||
_lock: threading.Lock
|
_lock: threading.Lock
|
||||||
|
|
||||||
def __init__(self, filename: str) -> None:
|
def __init__(self, filename: str, lock: threading.Lock = threading.Lock()) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._filename = filename
|
self._filename = filename
|
||||||
self._conn = sqlite3.connect(filename, check_same_thread=False)
|
self._conn = sqlite3.connect(filename, check_same_thread=False)
|
||||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||||
self._conn.row_factory = sqlite3.Row
|
self._conn.row_factory = sqlite3.Row
|
||||||
self._cursor = self._conn.cursor()
|
self._cursor = self._conn.cursor()
|
||||||
self._lock = threading.Lock()
|
self._lock = lock
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
|
@ -20,16 +20,16 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
_id_field: str
|
_id_field: str
|
||||||
_lock: Lock
|
_lock: Lock
|
||||||
|
|
||||||
def __init__(self, filename: str, table_name: str, id_field: str = "id"):
|
def __init__(self, filename: str, table_name: str, id_field: str = "id", lock: Lock = Lock()):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._filename = filename
|
self._filename = filename
|
||||||
self._table_name = table_name
|
self._table_name = table_name
|
||||||
self._id_field = id_field # TODO: validate that T has this field
|
self._id_field = id_field # TODO: validate that T has this field
|
||||||
self._lock = Lock()
|
self._lock = lock
|
||||||
self._conn = sqlite3.connect(
|
self._conn = sqlite3.connect(
|
||||||
self._filename, check_same_thread=False
|
self._filename, check_same_thread=False
|
||||||
) # TODO: figure out a better threading solution
|
) # TODO: figure out a better threading solution
|
||||||
self._conn.set_trace_callback(print)
|
self._conn.execute('pragma journal_mode=wal')
|
||||||
self._cursor = self._conn.cursor()
|
self._cursor = self._conn.cursor()
|
||||||
|
|
||||||
self._create_table()
|
self._create_table()
|
||||||
@ -55,21 +55,12 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
def set(self, item: T):
|
def set(self, item: T):
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
json = item.json()
|
|
||||||
print('-----------------locking db-----------------')
|
|
||||||
traceback.print_stack(limit=2)
|
|
||||||
|
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
f"""INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);""",
|
f"""INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);""",
|
||||||
(json,),
|
(item.json(),),
|
||||||
)
|
)
|
||||||
self._conn.commit()
|
self._conn.commit()
|
||||||
self._cursor.close()
|
|
||||||
self._cursor = self._conn.cursor()
|
|
||||||
print('-----------------unlocking db-----------------')
|
|
||||||
except Exception as e:
|
|
||||||
print("Exception!")
|
|
||||||
print(e)
|
|
||||||
finally:
|
finally:
|
||||||
self._lock.release()
|
self._lock.release()
|
||||||
self._on_changed(item)
|
self._on_changed(item)
|
||||||
@ -77,12 +68,8 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
def get(self, id: str) -> Optional[T]:
|
def get(self, id: str) -> Optional[T]:
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
print('-----------------locking db-----------------')
|
|
||||||
self._cursor.execute(f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),))
|
self._cursor.execute(f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),))
|
||||||
result = self._cursor.fetchone()
|
result = self._cursor.fetchone()
|
||||||
self._cursor.close()
|
|
||||||
self._cursor = self._conn.cursor()
|
|
||||||
print('-----------------unlocking db-----------------')
|
|
||||||
finally:
|
finally:
|
||||||
self._lock.release()
|
self._lock.release()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user