mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Switch sqlite clients to only use one connection
This commit is contained in:
parent
15e7ca1baa
commit
abf09fc8fa
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
|
import sqlite3
|
||||||
from invokeai.app.services.board_image_record_storage import (
|
from invokeai.app.services.board_image_record_storage import (
|
||||||
SqliteBoardImageRecordStorage,
|
SqliteBoardImageRecordStorage,
|
||||||
)
|
)
|
||||||
@ -74,18 +75,22 @@ class ApiDependencies:
|
|||||||
db_path.parent.mkdir(parents=True, exist_ok=True)
|
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
db_location = str(db_path)
|
db_location = str(db_path)
|
||||||
|
|
||||||
|
db_conn = sqlite3.connect(
|
||||||
|
db_location, check_same_thread=False
|
||||||
|
) # TODO: figure out a better threading solution
|
||||||
|
|
||||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
||||||
filename=db_location, table_name="graph_executions"
|
conn=db_conn, table_name="graph_executions"
|
||||||
)
|
)
|
||||||
|
|
||||||
urls = LocalUrlService()
|
urls = LocalUrlService()
|
||||||
image_record_storage = SqliteImageRecordStorage(db_location)
|
image_record_storage = SqliteImageRecordStorage(conn=db_conn)
|
||||||
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
||||||
names = SimpleNameService()
|
names = SimpleNameService()
|
||||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
|
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
|
||||||
|
|
||||||
board_record_storage = SqliteBoardRecordStorage(db_location)
|
board_record_storage = SqliteBoardRecordStorage(conn=db_conn)
|
||||||
board_image_record_storage = SqliteBoardImageRecordStorage(db_location)
|
board_image_record_storage = SqliteBoardImageRecordStorage(conn=db_conn)
|
||||||
|
|
||||||
boards = BoardService(
|
boards = BoardService(
|
||||||
services=BoardServiceDependencies(
|
services=BoardServiceDependencies(
|
||||||
@ -119,7 +124,7 @@ class ApiDependencies:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
batch_manager_storage = SqliteBatchProcessStorage(db_location)
|
batch_manager_storage = SqliteBatchProcessStorage(conn=db_conn)
|
||||||
batch_manager = BatchManager(batch_manager_storage)
|
batch_manager = BatchManager(batch_manager_storage)
|
||||||
|
|
||||||
services = InvocationServices(
|
services = InvocationServices(
|
||||||
@ -131,7 +136,7 @@ class ApiDependencies:
|
|||||||
boards=boards,
|
boards=boards,
|
||||||
board_images=board_images,
|
board_images=board_images,
|
||||||
queue=MemoryInvocationQueue(),
|
queue=MemoryInvocationQueue(),
|
||||||
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),
|
graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, table_name="graphs"),
|
||||||
graph_execution_manager=graph_execution_manager,
|
graph_execution_manager=graph_execution_manager,
|
||||||
processor=DefaultInvocationProcessor(),
|
processor=DefaultInvocationProcessor(),
|
||||||
configuration=config,
|
configuration=config,
|
||||||
|
@ -5,6 +5,7 @@ import re
|
|||||||
import shlex
|
import shlex
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
import sqlite3
|
||||||
from typing import Union, get_type_hints, Optional
|
from typing import Union, get_type_hints, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, ValidationError
|
from pydantic import BaseModel, ValidationError
|
||||||
@ -257,19 +258,23 @@ def invoke_cli():
|
|||||||
db_location = config.db_path
|
db_location = config.db_path
|
||||||
db_location.parent.mkdir(parents=True, exist_ok=True)
|
db_location.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
db_conn = sqlite3.connect(
|
||||||
|
db_location, check_same_thread=False
|
||||||
|
) # TODO: figure out a better threading solution
|
||||||
logger.info(f'InvokeAI database location is "{db_location}"')
|
logger.info(f'InvokeAI database location is "{db_location}"')
|
||||||
|
|
||||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
||||||
filename=db_location, table_name="graph_executions"
|
conn=db_conn, table_name="graph_executions"
|
||||||
)
|
)
|
||||||
|
|
||||||
urls = LocalUrlService()
|
urls = LocalUrlService()
|
||||||
image_record_storage = SqliteImageRecordStorage(db_location)
|
image_record_storage = SqliteImageRecordStorage(conn=db_conn)
|
||||||
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
||||||
names = SimpleNameService()
|
names = SimpleNameService()
|
||||||
|
|
||||||
board_record_storage = SqliteBoardRecordStorage(db_location)
|
board_record_storage = SqliteBoardRecordStorage(conn=db_conn)
|
||||||
board_image_record_storage = SqliteBoardImageRecordStorage(db_location)
|
board_image_record_storage = SqliteBoardImageRecordStorage(conn=db_conn)
|
||||||
|
|
||||||
boards = BoardService(
|
boards = BoardService(
|
||||||
services=BoardServiceDependencies(
|
services=BoardServiceDependencies(
|
||||||
@ -303,7 +308,7 @@ def invoke_cli():
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
batch_manager_storage = SqliteBatchProcessStorage(db_location)
|
batch_manager_storage = SqliteBatchProcessStorage(conn=db_conn)
|
||||||
batch_manager = BatchManager(batch_manager_storage)
|
batch_manager = BatchManager(batch_manager_storage)
|
||||||
|
|
||||||
services = InvocationServices(
|
services = InvocationServices(
|
||||||
@ -315,7 +320,7 @@ def invoke_cli():
|
|||||||
batch_manager=batch_manager,
|
batch_manager=batch_manager,
|
||||||
board_images=board_images,
|
board_images=board_images,
|
||||||
queue=MemoryInvocationQueue(),
|
queue=MemoryInvocationQueue(),
|
||||||
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),
|
graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, table_name="graphs"),
|
||||||
graph_execution_manager=graph_execution_manager,
|
graph_execution_manager=graph_execution_manager,
|
||||||
processor=DefaultInvocationProcessor(),
|
processor=DefaultInvocationProcessor(),
|
||||||
performance_statistics=InvocationStatsService(graph_execution_manager),
|
performance_statistics=InvocationStatsService(graph_execution_manager),
|
||||||
|
@ -177,15 +177,13 @@ class BatchProcessStorageBase(ABC):
|
|||||||
|
|
||||||
|
|
||||||
class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
||||||
_filename: str
|
|
||||||
_conn: sqlite3.Connection
|
_conn: sqlite3.Connection
|
||||||
_cursor: sqlite3.Cursor
|
_cursor: sqlite3.Cursor
|
||||||
_lock: threading.Lock
|
_lock: threading.Lock
|
||||||
|
|
||||||
def __init__(self, filename: str) -> None:
|
def __init__(self, conn: sqlite3.Connection) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._filename = filename
|
self._conn = conn
|
||||||
self._conn = sqlite3.connect(filename, check_same_thread=False)
|
|
||||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||||
self._conn.row_factory = sqlite3.Row
|
self._conn.row_factory = sqlite3.Row
|
||||||
self._cursor = self._conn.cursor()
|
self._cursor = self._conn.cursor()
|
||||||
|
@ -56,15 +56,13 @@ class BoardImageRecordStorageBase(ABC):
|
|||||||
|
|
||||||
|
|
||||||
class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||||
_filename: str
|
|
||||||
_conn: sqlite3.Connection
|
_conn: sqlite3.Connection
|
||||||
_cursor: sqlite3.Cursor
|
_cursor: sqlite3.Cursor
|
||||||
_lock: threading.Lock
|
_lock: threading.Lock
|
||||||
|
|
||||||
def __init__(self, filename: str) -> None:
|
def __init__(self, conn: sqlite3.Connection) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._filename = filename
|
self._conn = conn
|
||||||
self._conn = sqlite3.connect(filename, check_same_thread=False)
|
|
||||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||||
self._conn.row_factory = sqlite3.Row
|
self._conn.row_factory = sqlite3.Row
|
||||||
self._cursor = self._conn.cursor()
|
self._cursor = self._conn.cursor()
|
||||||
|
@ -90,15 +90,13 @@ class BoardRecordStorageBase(ABC):
|
|||||||
|
|
||||||
|
|
||||||
class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||||
_filename: str
|
|
||||||
_conn: sqlite3.Connection
|
_conn: sqlite3.Connection
|
||||||
_cursor: sqlite3.Cursor
|
_cursor: sqlite3.Cursor
|
||||||
_lock: threading.Lock
|
_lock: threading.Lock
|
||||||
|
|
||||||
def __init__(self, filename: str) -> None:
|
def __init__(self, conn: sqlite3.Connection) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._filename = filename
|
self._conn = conn
|
||||||
self._conn = sqlite3.connect(filename, check_same_thread=False)
|
|
||||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||||
self._conn.row_factory = sqlite3.Row
|
self._conn.row_factory = sqlite3.Row
|
||||||
self._cursor = self._conn.cursor()
|
self._cursor = self._conn.cursor()
|
||||||
|
@ -150,15 +150,13 @@ class ImageRecordStorageBase(ABC):
|
|||||||
|
|
||||||
|
|
||||||
class SqliteImageRecordStorage(ImageRecordStorageBase):
|
class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||||
_filename: str
|
|
||||||
_conn: sqlite3.Connection
|
_conn: sqlite3.Connection
|
||||||
_cursor: sqlite3.Cursor
|
_cursor: sqlite3.Cursor
|
||||||
_lock: threading.Lock
|
_lock: threading.Lock
|
||||||
|
|
||||||
def __init__(self, filename: str) -> None:
|
def __init__(self, conn: sqlite3.Connection) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._filename = filename
|
self._conn = conn
|
||||||
self._conn = sqlite3.connect(filename, check_same_thread=False)
|
|
||||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||||
self._conn.row_factory = sqlite3.Row
|
self._conn.row_factory = sqlite3.Row
|
||||||
self._cursor = self._conn.cursor()
|
self._cursor = self._conn.cursor()
|
||||||
|
@ -12,23 +12,19 @@ sqlite_memory = ":memory:"
|
|||||||
|
|
||||||
|
|
||||||
class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||||
_filename: str
|
|
||||||
_table_name: str
|
_table_name: str
|
||||||
_conn: sqlite3.Connection
|
_conn: sqlite3.Connection
|
||||||
_cursor: sqlite3.Cursor
|
_cursor: sqlite3.Cursor
|
||||||
_id_field: str
|
_id_field: str
|
||||||
_lock: Lock
|
_lock: Lock
|
||||||
|
|
||||||
def __init__(self, filename: str, table_name: str, id_field: str = "id"):
|
def __init__(self, conn: sqlite3.Connection, table_name: str, id_field: str = "id"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self._filename = filename
|
|
||||||
self._table_name = table_name
|
self._table_name = table_name
|
||||||
self._id_field = id_field # TODO: validate that T has this field
|
self._id_field = id_field # TODO: validate that T has this field
|
||||||
self._lock = Lock()
|
self._lock = Lock()
|
||||||
self._conn = sqlite3.connect(
|
self._conn = conn
|
||||||
self._filename, check_same_thread=False
|
|
||||||
) # TODO: figure out a better threading solution
|
|
||||||
self._cursor = self._conn.cursor()
|
self._cursor = self._conn.cursor()
|
||||||
|
|
||||||
self._create_table()
|
self._create_table()
|
||||||
|
Loading…
Reference in New Issue
Block a user