Switch sqlite clients to only use one connection

This commit is contained in:
Brandon Rising 2023-08-15 21:46:24 -04:00
parent 15e7ca1baa
commit abf09fc8fa
7 changed files with 32 additions and 34 deletions

View File

@ -2,6 +2,7 @@
from typing import Optional
from logging import Logger
import sqlite3
from invokeai.app.services.board_image_record_storage import (
SqliteBoardImageRecordStorage,
)
@ -74,18 +75,22 @@ class ApiDependencies:
db_path.parent.mkdir(parents=True, exist_ok=True)
db_location = str(db_path)
db_conn = sqlite3.connect(
db_location, check_same_thread=False
) # TODO: figure out a better threading solution
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
filename=db_location, table_name="graph_executions"
conn=db_conn, table_name="graph_executions"
)
urls = LocalUrlService()
image_record_storage = SqliteImageRecordStorage(db_location)
image_record_storage = SqliteImageRecordStorage(conn=db_conn)
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
names = SimpleNameService()
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
board_record_storage = SqliteBoardRecordStorage(db_location)
board_image_record_storage = SqliteBoardImageRecordStorage(db_location)
board_record_storage = SqliteBoardRecordStorage(conn=db_conn)
board_image_record_storage = SqliteBoardImageRecordStorage(conn=db_conn)
boards = BoardService(
services=BoardServiceDependencies(
@ -119,7 +124,7 @@ class ApiDependencies:
)
)
batch_manager_storage = SqliteBatchProcessStorage(db_location)
batch_manager_storage = SqliteBatchProcessStorage(conn=db_conn)
batch_manager = BatchManager(batch_manager_storage)
services = InvocationServices(
@ -131,7 +136,7 @@ class ApiDependencies:
boards=boards,
board_images=board_images,
queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),
graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, table_name="graphs"),
graph_execution_manager=graph_execution_manager,
processor=DefaultInvocationProcessor(),
configuration=config,

View File

@ -5,6 +5,7 @@ import re
import shlex
import sys
import time
import sqlite3
from typing import Union, get_type_hints, Optional
from pydantic import BaseModel, ValidationError
@ -257,19 +258,23 @@ def invoke_cli():
db_location = config.db_path
db_location.parent.mkdir(parents=True, exist_ok=True)
db_conn = sqlite3.connect(
db_location, check_same_thread=False
) # TODO: figure out a better threading solution
logger.info(f'InvokeAI database location is "{db_location}"')
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
filename=db_location, table_name="graph_executions"
conn=db_conn, table_name="graph_executions"
)
urls = LocalUrlService()
image_record_storage = SqliteImageRecordStorage(db_location)
image_record_storage = SqliteImageRecordStorage(conn=db_conn)
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
names = SimpleNameService()
board_record_storage = SqliteBoardRecordStorage(db_location)
board_image_record_storage = SqliteBoardImageRecordStorage(db_location)
board_record_storage = SqliteBoardRecordStorage(conn=db_conn)
board_image_record_storage = SqliteBoardImageRecordStorage(conn=db_conn)
boards = BoardService(
services=BoardServiceDependencies(
@ -303,7 +308,7 @@ def invoke_cli():
)
)
batch_manager_storage = SqliteBatchProcessStorage(db_location)
batch_manager_storage = SqliteBatchProcessStorage(conn=db_conn)
batch_manager = BatchManager(batch_manager_storage)
services = InvocationServices(
@ -315,7 +320,7 @@ def invoke_cli():
batch_manager=batch_manager,
board_images=board_images,
queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),
graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, table_name="graphs"),
graph_execution_manager=graph_execution_manager,
processor=DefaultInvocationProcessor(),
performance_statistics=InvocationStatsService(graph_execution_manager),

View File

@ -177,15 +177,13 @@ class BatchProcessStorageBase(ABC):
class SqliteBatchProcessStorage(BatchProcessStorageBase):
_filename: str
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: threading.Lock
def __init__(self, filename: str) -> None:
def __init__(self, conn: sqlite3.Connection) -> None:
super().__init__()
self._filename = filename
self._conn = sqlite3.connect(filename, check_same_thread=False)
self._conn = conn
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
self._conn.row_factory = sqlite3.Row
self._cursor = self._conn.cursor()

View File

@ -56,15 +56,13 @@ class BoardImageRecordStorageBase(ABC):
class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
_filename: str
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: threading.Lock
def __init__(self, filename: str) -> None:
def __init__(self, conn: sqlite3.Connection) -> None:
super().__init__()
self._filename = filename
self._conn = sqlite3.connect(filename, check_same_thread=False)
self._conn = conn
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
self._conn.row_factory = sqlite3.Row
self._cursor = self._conn.cursor()

View File

@ -90,15 +90,13 @@ class BoardRecordStorageBase(ABC):
class SqliteBoardRecordStorage(BoardRecordStorageBase):
_filename: str
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: threading.Lock
def __init__(self, filename: str) -> None:
def __init__(self, conn: sqlite3.Connection) -> None:
super().__init__()
self._filename = filename
self._conn = sqlite3.connect(filename, check_same_thread=False)
self._conn = conn
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
self._conn.row_factory = sqlite3.Row
self._cursor = self._conn.cursor()

View File

@ -150,15 +150,13 @@ class ImageRecordStorageBase(ABC):
class SqliteImageRecordStorage(ImageRecordStorageBase):
_filename: str
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: threading.Lock
def __init__(self, filename: str) -> None:
def __init__(self, conn: sqlite3.Connection) -> None:
super().__init__()
self._filename = filename
self._conn = sqlite3.connect(filename, check_same_thread=False)
self._conn = conn
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
self._conn.row_factory = sqlite3.Row
self._cursor = self._conn.cursor()

View File

@ -12,23 +12,19 @@ sqlite_memory = ":memory:"
class SqliteItemStorage(ItemStorageABC, Generic[T]):
_filename: str
_table_name: str
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_id_field: str
_lock: Lock
def __init__(self, filename: str, table_name: str, id_field: str = "id"):
def __init__(self, conn: sqlite3.Connection, table_name: str, id_field: str = "id"):
super().__init__()
self._filename = filename
self._table_name = table_name
self._id_field = id_field # TODO: validate that T has this field
self._lock = Lock()
self._conn = sqlite3.connect(
self._filename, check_same_thread=False
) # TODO: figure out a better threading solution
self._conn = conn
self._cursor = self._conn.cursor()
self._create_table()