Switch sqlite clients to only use one connection

This commit is contained in:
Brandon Rising 2023-08-15 21:46:24 -04:00
parent 15e7ca1baa
commit abf09fc8fa
7 changed files with 32 additions and 34 deletions

View File

@ -2,6 +2,7 @@
from typing import Optional from typing import Optional
from logging import Logger from logging import Logger
import sqlite3
from invokeai.app.services.board_image_record_storage import ( from invokeai.app.services.board_image_record_storage import (
SqliteBoardImageRecordStorage, SqliteBoardImageRecordStorage,
) )
@ -74,18 +75,22 @@ class ApiDependencies:
db_path.parent.mkdir(parents=True, exist_ok=True) db_path.parent.mkdir(parents=True, exist_ok=True)
db_location = str(db_path) db_location = str(db_path)
db_conn = sqlite3.connect(
db_location, check_same_thread=False
) # TODO: figure out a better threading solution
graph_execution_manager = SqliteItemStorage[GraphExecutionState]( graph_execution_manager = SqliteItemStorage[GraphExecutionState](
filename=db_location, table_name="graph_executions" conn=db_conn, table_name="graph_executions"
) )
urls = LocalUrlService() urls = LocalUrlService()
image_record_storage = SqliteImageRecordStorage(db_location) image_record_storage = SqliteImageRecordStorage(conn=db_conn)
image_file_storage = DiskImageFileStorage(f"{output_folder}/images") image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
names = SimpleNameService() names = SimpleNameService()
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents")) latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
board_record_storage = SqliteBoardRecordStorage(db_location) board_record_storage = SqliteBoardRecordStorage(conn=db_conn)
board_image_record_storage = SqliteBoardImageRecordStorage(db_location) board_image_record_storage = SqliteBoardImageRecordStorage(conn=db_conn)
boards = BoardService( boards = BoardService(
services=BoardServiceDependencies( services=BoardServiceDependencies(
@ -119,7 +124,7 @@ class ApiDependencies:
) )
) )
batch_manager_storage = SqliteBatchProcessStorage(db_location) batch_manager_storage = SqliteBatchProcessStorage(conn=db_conn)
batch_manager = BatchManager(batch_manager_storage) batch_manager = BatchManager(batch_manager_storage)
services = InvocationServices( services = InvocationServices(
@ -131,7 +136,7 @@ class ApiDependencies:
boards=boards, boards=boards,
board_images=board_images, board_images=board_images,
queue=MemoryInvocationQueue(), queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"), graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, table_name="graphs"),
graph_execution_manager=graph_execution_manager, graph_execution_manager=graph_execution_manager,
processor=DefaultInvocationProcessor(), processor=DefaultInvocationProcessor(),
configuration=config, configuration=config,

View File

@ -5,6 +5,7 @@ import re
import shlex import shlex
import sys import sys
import time import time
import sqlite3
from typing import Union, get_type_hints, Optional from typing import Union, get_type_hints, Optional
from pydantic import BaseModel, ValidationError from pydantic import BaseModel, ValidationError
@ -257,19 +258,23 @@ def invoke_cli():
db_location = config.db_path db_location = config.db_path
db_location.parent.mkdir(parents=True, exist_ok=True) db_location.parent.mkdir(parents=True, exist_ok=True)
db_conn = sqlite3.connect(
db_location, check_same_thread=False
) # TODO: figure out a better threading solution
logger.info(f'InvokeAI database location is "{db_location}"') logger.info(f'InvokeAI database location is "{db_location}"')
graph_execution_manager = SqliteItemStorage[GraphExecutionState]( graph_execution_manager = SqliteItemStorage[GraphExecutionState](
filename=db_location, table_name="graph_executions" conn=db_conn, table_name="graph_executions"
) )
urls = LocalUrlService() urls = LocalUrlService()
image_record_storage = SqliteImageRecordStorage(db_location) image_record_storage = SqliteImageRecordStorage(conn=db_conn)
image_file_storage = DiskImageFileStorage(f"{output_folder}/images") image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
names = SimpleNameService() names = SimpleNameService()
board_record_storage = SqliteBoardRecordStorage(db_location) board_record_storage = SqliteBoardRecordStorage(conn=db_conn)
board_image_record_storage = SqliteBoardImageRecordStorage(db_location) board_image_record_storage = SqliteBoardImageRecordStorage(conn=db_conn)
boards = BoardService( boards = BoardService(
services=BoardServiceDependencies( services=BoardServiceDependencies(
@ -303,7 +308,7 @@ def invoke_cli():
) )
) )
batch_manager_storage = SqliteBatchProcessStorage(db_location) batch_manager_storage = SqliteBatchProcessStorage(conn=db_conn)
batch_manager = BatchManager(batch_manager_storage) batch_manager = BatchManager(batch_manager_storage)
services = InvocationServices( services = InvocationServices(
@ -315,7 +320,7 @@ def invoke_cli():
batch_manager=batch_manager, batch_manager=batch_manager,
board_images=board_images, board_images=board_images,
queue=MemoryInvocationQueue(), queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"), graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, table_name="graphs"),
graph_execution_manager=graph_execution_manager, graph_execution_manager=graph_execution_manager,
processor=DefaultInvocationProcessor(), processor=DefaultInvocationProcessor(),
performance_statistics=InvocationStatsService(graph_execution_manager), performance_statistics=InvocationStatsService(graph_execution_manager),

View File

@ -177,15 +177,13 @@ class BatchProcessStorageBase(ABC):
class SqliteBatchProcessStorage(BatchProcessStorageBase): class SqliteBatchProcessStorage(BatchProcessStorageBase):
_filename: str
_conn: sqlite3.Connection _conn: sqlite3.Connection
_cursor: sqlite3.Cursor _cursor: sqlite3.Cursor
_lock: threading.Lock _lock: threading.Lock
def __init__(self, filename: str) -> None: def __init__(self, conn: sqlite3.Connection) -> None:
super().__init__() super().__init__()
self._filename = filename self._conn = conn
self._conn = sqlite3.connect(filename, check_same_thread=False)
# Enable row factory to get rows as dictionaries (must be done before making the cursor!) # Enable row factory to get rows as dictionaries (must be done before making the cursor!)
self._conn.row_factory = sqlite3.Row self._conn.row_factory = sqlite3.Row
self._cursor = self._conn.cursor() self._cursor = self._conn.cursor()

View File

@ -56,15 +56,13 @@ class BoardImageRecordStorageBase(ABC):
class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase): class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
_filename: str
_conn: sqlite3.Connection _conn: sqlite3.Connection
_cursor: sqlite3.Cursor _cursor: sqlite3.Cursor
_lock: threading.Lock _lock: threading.Lock
def __init__(self, filename: str) -> None: def __init__(self, conn: sqlite3.Connection) -> None:
super().__init__() super().__init__()
self._filename = filename self._conn = conn
self._conn = sqlite3.connect(filename, check_same_thread=False)
# Enable row factory to get rows as dictionaries (must be done before making the cursor!) # Enable row factory to get rows as dictionaries (must be done before making the cursor!)
self._conn.row_factory = sqlite3.Row self._conn.row_factory = sqlite3.Row
self._cursor = self._conn.cursor() self._cursor = self._conn.cursor()

View File

@ -90,15 +90,13 @@ class BoardRecordStorageBase(ABC):
class SqliteBoardRecordStorage(BoardRecordStorageBase): class SqliteBoardRecordStorage(BoardRecordStorageBase):
_filename: str
_conn: sqlite3.Connection _conn: sqlite3.Connection
_cursor: sqlite3.Cursor _cursor: sqlite3.Cursor
_lock: threading.Lock _lock: threading.Lock
def __init__(self, filename: str) -> None: def __init__(self, conn: sqlite3.Connection) -> None:
super().__init__() super().__init__()
self._filename = filename self._conn = conn
self._conn = sqlite3.connect(filename, check_same_thread=False)
# Enable row factory to get rows as dictionaries (must be done before making the cursor!) # Enable row factory to get rows as dictionaries (must be done before making the cursor!)
self._conn.row_factory = sqlite3.Row self._conn.row_factory = sqlite3.Row
self._cursor = self._conn.cursor() self._cursor = self._conn.cursor()

View File

@ -150,15 +150,13 @@ class ImageRecordStorageBase(ABC):
class SqliteImageRecordStorage(ImageRecordStorageBase): class SqliteImageRecordStorage(ImageRecordStorageBase):
_filename: str
_conn: sqlite3.Connection _conn: sqlite3.Connection
_cursor: sqlite3.Cursor _cursor: sqlite3.Cursor
_lock: threading.Lock _lock: threading.Lock
def __init__(self, filename: str) -> None: def __init__(self, conn: sqlite3.Connection) -> None:
super().__init__() super().__init__()
self._filename = filename self._conn = conn
self._conn = sqlite3.connect(filename, check_same_thread=False)
# Enable row factory to get rows as dictionaries (must be done before making the cursor!) # Enable row factory to get rows as dictionaries (must be done before making the cursor!)
self._conn.row_factory = sqlite3.Row self._conn.row_factory = sqlite3.Row
self._cursor = self._conn.cursor() self._cursor = self._conn.cursor()

View File

@ -12,23 +12,19 @@ sqlite_memory = ":memory:"
class SqliteItemStorage(ItemStorageABC, Generic[T]): class SqliteItemStorage(ItemStorageABC, Generic[T]):
_filename: str
_table_name: str _table_name: str
_conn: sqlite3.Connection _conn: sqlite3.Connection
_cursor: sqlite3.Cursor _cursor: sqlite3.Cursor
_id_field: str _id_field: str
_lock: Lock _lock: Lock
def __init__(self, filename: str, table_name: str, id_field: str = "id"): def __init__(self, conn: sqlite3.Connection, table_name: str, id_field: str = "id"):
super().__init__() super().__init__()
self._filename = filename
self._table_name = table_name self._table_name = table_name
self._id_field = id_field # TODO: validate that T has this field self._id_field = id_field # TODO: validate that T has this field
self._lock = Lock() self._lock = Lock()
self._conn = sqlite3.connect( self._conn = conn
self._filename, check_same_thread=False
) # TODO: figure out a better threading solution
self._cursor = self._conn.cursor() self._cursor = self._conn.cursor()
self._create_table() self._create_table()