diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index df51ca8581..87c6f7eb66 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -2,6 +2,7 @@ from typing import Optional from logging import Logger +import sqlite3 from invokeai.app.services.board_image_record_storage import ( SqliteBoardImageRecordStorage, ) @@ -74,18 +75,22 @@ class ApiDependencies: db_path.parent.mkdir(parents=True, exist_ok=True) db_location = str(db_path) + db_conn = sqlite3.connect( + db_location, check_same_thread=False + ) # TODO: figure out a better threading solution + graph_execution_manager = SqliteItemStorage[GraphExecutionState]( - filename=db_location, table_name="graph_executions" + conn=db_conn, table_name="graph_executions" ) urls = LocalUrlService() - image_record_storage = SqliteImageRecordStorage(db_location) + image_record_storage = SqliteImageRecordStorage(conn=db_conn) image_file_storage = DiskImageFileStorage(f"{output_folder}/images") names = SimpleNameService() latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents")) - board_record_storage = SqliteBoardRecordStorage(db_location) - board_image_record_storage = SqliteBoardImageRecordStorage(db_location) + board_record_storage = SqliteBoardRecordStorage(conn=db_conn) + board_image_record_storage = SqliteBoardImageRecordStorage(conn=db_conn) boards = BoardService( services=BoardServiceDependencies( @@ -119,7 +124,7 @@ class ApiDependencies: ) ) - batch_manager_storage = SqliteBatchProcessStorage(db_location) + batch_manager_storage = SqliteBatchProcessStorage(conn=db_conn) batch_manager = BatchManager(batch_manager_storage) services = InvocationServices( @@ -131,7 +136,7 @@ class ApiDependencies: boards=boards, board_images=board_images, queue=MemoryInvocationQueue(), - graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"), + graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, table_name="graphs"), graph_execution_manager=graph_execution_manager, processor=DefaultInvocationProcessor(), configuration=config, diff --git a/invokeai/app/cli_app.py b/invokeai/app/cli_app.py index a0b681e2d4..63404b0b74 100644 --- a/invokeai/app/cli_app.py +++ b/invokeai/app/cli_app.py @@ -5,6 +5,7 @@ import re import shlex import sys import time +import sqlite3 from typing import Union, get_type_hints, Optional from pydantic import BaseModel, ValidationError @@ -257,19 +258,23 @@ def invoke_cli(): db_location = config.db_path db_location.parent.mkdir(parents=True, exist_ok=True) + + db_conn = sqlite3.connect( + db_location, check_same_thread=False + ) # TODO: figure out a better threading solution logger.info(f'InvokeAI database location is "{db_location}"') graph_execution_manager = SqliteItemStorage[GraphExecutionState]( - filename=db_location, table_name="graph_executions" + conn=db_conn, table_name="graph_executions" ) urls = LocalUrlService() - image_record_storage = SqliteImageRecordStorage(db_location) + image_record_storage = SqliteImageRecordStorage(conn=db_conn) image_file_storage = DiskImageFileStorage(f"{output_folder}/images") names = SimpleNameService() - board_record_storage = SqliteBoardRecordStorage(db_location) - board_image_record_storage = SqliteBoardImageRecordStorage(db_location) + board_record_storage = SqliteBoardRecordStorage(conn=db_conn) + board_image_record_storage = SqliteBoardImageRecordStorage(conn=db_conn) boards = BoardService( services=BoardServiceDependencies( @@ -303,7 +308,7 @@ def invoke_cli(): ) ) - batch_manager_storage = SqliteBatchProcessStorage(db_location) + batch_manager_storage = SqliteBatchProcessStorage(conn=db_conn) batch_manager = BatchManager(batch_manager_storage) services = InvocationServices( @@ -315,7 +320,7 @@ def invoke_cli(): batch_manager=batch_manager, board_images=board_images, queue=MemoryInvocationQueue(), - graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"), + graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, table_name="graphs"), graph_execution_manager=graph_execution_manager, processor=DefaultInvocationProcessor(), performance_statistics=InvocationStatsService(graph_execution_manager), diff --git a/invokeai/app/services/batch_manager_storage.py b/invokeai/app/services/batch_manager_storage.py index 7d16a73ebd..80411c1e5c 100644 --- a/invokeai/app/services/batch_manager_storage.py +++ b/invokeai/app/services/batch_manager_storage.py @@ -177,15 +177,13 @@ class BatchProcessStorageBase(ABC): class SqliteBatchProcessStorage(BatchProcessStorageBase): - _filename: str _conn: sqlite3.Connection _cursor: sqlite3.Cursor _lock: threading.Lock - def __init__(self, filename: str) -> None: + def __init__(self, conn: sqlite3.Connection) -> None: super().__init__() - self._filename = filename - self._conn = sqlite3.connect(filename, check_same_thread=False) + self._conn = conn # Enable row factory to get rows as dictionaries (must be done before making the cursor!) self._conn.row_factory = sqlite3.Row self._cursor = self._conn.cursor() diff --git a/invokeai/app/services/board_image_record_storage.py b/invokeai/app/services/board_image_record_storage.py index 03badf9866..3b97f48a62 100644 --- a/invokeai/app/services/board_image_record_storage.py +++ b/invokeai/app/services/board_image_record_storage.py @@ -56,15 +56,13 @@ class BoardImageRecordStorageBase(ABC): class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase): - _filename: str _conn: sqlite3.Connection _cursor: sqlite3.Cursor _lock: threading.Lock - def __init__(self, filename: str) -> None: + def __init__(self, conn: sqlite3.Connection) -> None: super().__init__() - self._filename = filename - self._conn = sqlite3.connect(filename, check_same_thread=False) + self._conn = conn # Enable row factory to get rows as dictionaries (must be done before making the cursor!) self._conn.row_factory = sqlite3.Row self._cursor = self._conn.cursor() diff --git a/invokeai/app/services/board_record_storage.py b/invokeai/app/services/board_record_storage.py index 2fad7b0ab3..b36a6bba05 100644 --- a/invokeai/app/services/board_record_storage.py +++ b/invokeai/app/services/board_record_storage.py @@ -90,15 +90,13 @@ class BoardRecordStorageBase(ABC): class SqliteBoardRecordStorage(BoardRecordStorageBase): - _filename: str _conn: sqlite3.Connection _cursor: sqlite3.Cursor _lock: threading.Lock - def __init__(self, filename: str) -> None: + def __init__(self, conn: sqlite3.Connection) -> None: super().__init__() - self._filename = filename - self._conn = sqlite3.connect(filename, check_same_thread=False) + self._conn = conn # Enable row factory to get rows as dictionaries (must be done before making the cursor!) self._conn.row_factory = sqlite3.Row self._cursor = self._conn.cursor() diff --git a/invokeai/app/services/image_record_storage.py b/invokeai/app/services/image_record_storage.py index 8c274ab8f9..c781911aa5 100644 --- a/invokeai/app/services/image_record_storage.py +++ b/invokeai/app/services/image_record_storage.py @@ -150,15 +150,13 @@ class ImageRecordStorageBase(ABC): class SqliteImageRecordStorage(ImageRecordStorageBase): - _filename: str _conn: sqlite3.Connection _cursor: sqlite3.Cursor _lock: threading.Lock - def __init__(self, filename: str) -> None: + def __init__(self, conn: sqlite3.Connection) -> None: super().__init__() - self._filename = filename - self._conn = sqlite3.connect(filename, check_same_thread=False) + self._conn = conn # Enable row factory to get rows as dictionaries (must be done before making the cursor!) self._conn.row_factory = sqlite3.Row self._cursor = self._conn.cursor() diff --git a/invokeai/app/services/sqlite.py b/invokeai/app/services/sqlite.py index 855f3f1939..76439b20c9 100644 --- a/invokeai/app/services/sqlite.py +++ b/invokeai/app/services/sqlite.py @@ -12,23 +12,19 @@ sqlite_memory = ":memory:" class SqliteItemStorage(ItemStorageABC, Generic[T]): - _filename: str _table_name: str _conn: sqlite3.Connection _cursor: sqlite3.Cursor _id_field: str _lock: Lock - def __init__(self, filename: str, table_name: str, id_field: str = "id"): + def __init__(self, conn: sqlite3.Connection, table_name: str, id_field: str = "id"): super().__init__() - self._filename = filename self._table_name = table_name self._id_field = id_field # TODO: validate that T has this field self._lock = Lock() - self._conn = sqlite3.connect( - self._filename, check_same_thread=False - ) # TODO: figure out a better threading solution + self._conn = conn self._cursor = self._conn.cursor() self._create_table()