From e751f7d815060a1a66fb79467b320eeb785d0a34 Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Thu, 10 Aug 2023 14:09:00 -0400 Subject: [PATCH] More testing --- invokeai/app/api/dependencies.py | 15 +++++++------ invokeai/app/cli_app.py | 6 +++++- invokeai/app/services/batch_manager.py | 6 +++++- .../app/services/batch_manager_storage.py | 21 ++++++++++--------- .../services/board_image_record_storage.py | 4 ++-- invokeai/app/services/board_record_storage.py | 4 ++-- invokeai/app/services/image_record_storage.py | 4 ++-- invokeai/app/services/sqlite.py | 21 ++++--------------- 8 files changed, 40 insertions(+), 41 deletions(-) diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 83fc8d9e11..d64610b717 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -2,6 +2,7 @@ from typing import Optional from logging import Logger +import threading import os from invokeai.app.services.board_image_record_storage import ( SqliteBoardImageRecordStorage, @@ -65,6 +66,8 @@ class ApiDependencies: logger.info(f"Root directory = {str(config.root_path)}") logger.debug(f"Internet connectivity is {config.internet_available}") + lock = threading.Lock() + events = FastAPIEventService(event_handler_id) output_folder = config.output_path @@ -74,17 +77,17 @@ class ApiDependencies: db_location.parent.mkdir(parents=True, exist_ok=True) graph_execution_manager = SqliteItemStorage[GraphExecutionState]( - filename=db_location, table_name="graph_executions" + filename=db_location, table_name="graph_executions", lock=lock ) urls = LocalUrlService() - image_record_storage = SqliteImageRecordStorage(db_location) + image_record_storage = SqliteImageRecordStorage(db_location, lock=lock) image_file_storage = DiskImageFileStorage(f"{output_folder}/images") names = SimpleNameService() latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents")) - board_record_storage = SqliteBoardRecordStorage(db_location) - board_image_record_storage = SqliteBoardImageRecordStorage(db_location) + board_record_storage = SqliteBoardRecordStorage(db_location, lock=lock) + board_image_record_storage = SqliteBoardImageRecordStorage(db_location, lock=lock) boards = BoardService( services=BoardServiceDependencies( @@ -118,7 +121,7 @@ class ApiDependencies: ) ) - batch_manager_storage = SqliteBatchProcessStorage(db_location) + batch_manager_storage = SqliteBatchProcessStorage(db_location, lock=lock) batch_manager = BatchManager(batch_manager_storage) services = InvocationServices( @@ -130,7 +133,7 @@ class ApiDependencies: boards=boards, board_images=board_images, queue=MemoryInvocationQueue(), - graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"), + graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs", lock=lock), graph_execution_manager=graph_execution_manager, processor=DefaultInvocationProcessor(), configuration=config, diff --git a/invokeai/app/cli_app.py b/invokeai/app/cli_app.py index 39b22e9dfb..92c255eb45 100644 --- a/invokeai/app/cli_app.py +++ b/invokeai/app/cli_app.py @@ -38,6 +38,7 @@ from invokeai.app.services.images import ImageService, ImageServiceDependencies from invokeai.app.services.resource_name import SimpleNameService from invokeai.app.services.urls import LocalUrlService from invokeai.app.services.batch_manager import BatchManager +from invokeai.app.services.batch_manager_storage import SqliteBatchProcessStorage from .services.default_graphs import default_text_to_image_graph_id, create_system_graphs from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage @@ -301,13 +302,16 @@ def invoke_cli(): ) ) + batch_manager_storage = SqliteBatchProcessStorage(db_location) + batch_manager = BatchManager(batch_manager_storage) + services = InvocationServices( model_manager=model_manager, events=events, latents=ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents")), images=images, boards=boards, - batch_manager=BatchManager(), + batch_manager=batch_manager, board_images=board_images, queue=MemoryInvocationQueue(), graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"), diff --git a/invokeai/app/services/batch_manager.py b/invokeai/app/services/batch_manager.py index eba5a8d676..a5c92bb1fc 100644 --- a/invokeai/app/services/batch_manager.py +++ b/invokeai/app/services/batch_manager.py @@ -12,6 +12,7 @@ from invokeai.app.services.graph import Graph, GraphExecutionState from invokeai.app.services.invoker import Invoker from invokeai.app.services.batch_manager_storage import ( BatchProcessStorageBase, + BatchSessionNotFoundException, Batch, BatchProcess, BatchSession, @@ -98,7 +99,10 @@ class BatchManager(BatchManagerBase): return GraphExecutionState(graph=graph) def run_batch_process(self, batch_id: str): - created_session = self.__batch_process_storage.get_created_session(batch_id) + try: + created_session = self.__batch_process_storage.get_created_session(batch_id) + except BatchSessionNotFoundException: + return ges = self.__invoker.services.graph_execution_manager.get(created_session.session_id) self.__invoker.invoke(ges, invoke_all=True) diff --git a/invokeai/app/services/batch_manager_storage.py b/invokeai/app/services/batch_manager_storage.py index 44fb5d928f..bc8804d21f 100644 --- a/invokeai/app/services/batch_manager_storage.py +++ b/invokeai/app/services/batch_manager_storage.py @@ -15,9 +15,8 @@ import json from invokeai.app.invocations.baseinvocation import ( BaseInvocation, ) -from invokeai.app.services.graph import Graph, GraphExecutionState +from invokeai.app.services.graph import Graph from invokeai.app.models.image import ImageField -from invokeai.app.services.image_record_storage import OffsetPaginatedResults from pydantic import BaseModel, Field, Extra, parse_raw_as @@ -169,14 +168,14 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase): _cursor: sqlite3.Cursor _lock: threading.Lock - def __init__(self, filename: str) -> None: + def __init__(self, filename: str, lock: threading.Lock = threading.Lock()) -> None: super().__init__() self._filename = filename self._conn = sqlite3.connect(filename, check_same_thread=False) # Enable row factory to get rows as dictionaries (must be done before making the cursor!) self._conn.row_factory = sqlite3.Row self._cursor = self._conn.cursor() - self._lock = threading.Lock() + self._lock = lock try: self._lock.acquire() @@ -231,7 +230,7 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase): CREATE TABLE IF NOT EXISTS batch_session ( batch_id TEXT NOT NULL, session_id TEXT NOT NULL, - state TEXT NOT NULL DEFAULT('created'), + state TEXT NOT NULL, created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), -- updated via trigger updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), @@ -267,7 +266,7 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase): ON batch_session FOR EACH ROW BEGIN UPDATE batch_session SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW') - WHERE batch_id = old.batch_id AND image_name = old.image_name; + WHERE batch_id = old.batch_id AND session_id = old.session_id; END; """ ) @@ -298,12 +297,13 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase): ) -> BatchProcess: try: self._lock.acquire() + batches = [batch.json() for batch in batch_process.batches] self._cursor.execute( """--sql INSERT OR IGNORE INTO batch_process (batch_id, batches, graph) VALUES (?, ?, ?); """, - (batch_process.batch_id, json.dumps([batch.json() for batch in batch_process.batches]), batch_process.graph.json()), + (batch_process.batch_id, json.dumps(batches), batch_process.graph.json()), ) self._conn.commit() except sqlite3.Error as e: @@ -321,10 +321,11 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase): batch_id = session_dict.get("batch_id", "unknown") batches_raw = session_dict.get("batches", "unknown") graph_raw = session_dict.get("graph", "unknown") - + batches = json.loads(batches_raw) + batches = [parse_raw_as(Batch, batch) for batch in batches] return BatchProcess( batch_id=batch_id, - batches=[parse_raw_as(Batch, batch) for batch in json.loads(batches_raw)], + batches=batches, graph=parse_raw_as(Graph, graph_raw), ) @@ -398,7 +399,7 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase): self._lock.release() if result is None: raise BatchSessionNotFoundException - return BatchSession(**dict(result)) + return self._deserialize_batch_session(dict(result)) def _deserialize_batch_session(self, session_dict: dict) -> BatchSession: """Deserializes a batch session.""" diff --git a/invokeai/app/services/board_image_record_storage.py b/invokeai/app/services/board_image_record_storage.py index f0007c8cef..c4c058422b 100644 --- a/invokeai/app/services/board_image_record_storage.py +++ b/invokeai/app/services/board_image_record_storage.py @@ -62,14 +62,14 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase): _cursor: sqlite3.Cursor _lock: threading.Lock - def __init__(self, filename: str) -> None: + def __init__(self, filename: str, lock: threading.Lock = threading.Lock()) -> None: super().__init__() self._filename = filename self._conn = sqlite3.connect(filename, check_same_thread=False) # Enable row factory to get rows as dictionaries (must be done before making the cursor!) self._conn.row_factory = sqlite3.Row self._cursor = self._conn.cursor() - self._lock = threading.Lock() + self._lock = lock try: self._lock.acquire() diff --git a/invokeai/app/services/board_record_storage.py b/invokeai/app/services/board_record_storage.py index 2fad7b0ab3..2626799836 100644 --- a/invokeai/app/services/board_record_storage.py +++ b/invokeai/app/services/board_record_storage.py @@ -95,14 +95,14 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase): _cursor: sqlite3.Cursor _lock: threading.Lock - def __init__(self, filename: str) -> None: + def __init__(self, filename: str, lock: threading.Lock = threading.Lock()) -> None: super().__init__() self._filename = filename self._conn = sqlite3.connect(filename, check_same_thread=False) # Enable row factory to get rows as dictionaries (must be done before making the cursor!) self._conn.row_factory = sqlite3.Row self._cursor = self._conn.cursor() - self._lock = threading.Lock() + self._lock = lock try: self._lock.acquire() diff --git a/invokeai/app/services/image_record_storage.py b/invokeai/app/services/image_record_storage.py index 8c274ab8f9..3dab08bcd7 100644 --- a/invokeai/app/services/image_record_storage.py +++ b/invokeai/app/services/image_record_storage.py @@ -155,14 +155,14 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): _cursor: sqlite3.Cursor _lock: threading.Lock - def __init__(self, filename: str) -> None: + def __init__(self, filename: str, lock: threading.Lock = threading.Lock()) -> None: super().__init__() self._filename = filename self._conn = sqlite3.connect(filename, check_same_thread=False) # Enable row factory to get rows as dictionaries (must be done before making the cursor!) self._conn.row_factory = sqlite3.Row self._cursor = self._conn.cursor() - self._lock = threading.Lock() + self._lock = lock try: self._lock.acquire() diff --git a/invokeai/app/services/sqlite.py b/invokeai/app/services/sqlite.py index 4f63ffb368..f7a4cadbfe 100644 --- a/invokeai/app/services/sqlite.py +++ b/invokeai/app/services/sqlite.py @@ -20,16 +20,16 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]): _id_field: str _lock: Lock - def __init__(self, filename: str, table_name: str, id_field: str = "id"): + def __init__(self, filename: str, table_name: str, id_field: str = "id", lock: Lock = Lock()): super().__init__() self._filename = filename self._table_name = table_name self._id_field = id_field # TODO: validate that T has this field - self._lock = Lock() + self._lock = lock self._conn = sqlite3.connect( self._filename, check_same_thread=False ) # TODO: figure out a better threading solution - self._conn.set_trace_callback(print) + self._conn.execute('pragma journal_mode=wal') self._cursor = self._conn.cursor() self._create_table() @@ -55,21 +55,12 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]): def set(self, item: T): try: self._lock.acquire() - json = item.json() - print('-----------------locking db-----------------') - traceback.print_stack(limit=2) self._cursor.execute( f"""INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);""", - (json,), + (item.json(),), ) self._conn.commit() - self._cursor.close() - self._cursor = self._conn.cursor() - print('-----------------unlocking db-----------------') - except Exception as e: - print("Exception!") - print(e) finally: self._lock.release() self._on_changed(item) @@ -77,12 +68,8 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]): def get(self, id: str) -> Optional[T]: try: self._lock.acquire() - print('-----------------locking db-----------------') self._cursor.execute(f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),)) result = self._cursor.fetchone() - self._cursor.close() - self._cursor = self._conn.cursor() - print('-----------------unlocking db-----------------') finally: self._lock.release()