Go back to 1 lock per table

This commit is contained in:
Brandon Rising 2023-08-10 14:26:22 -04:00
parent e751f7d815
commit 280ac15da2
6 changed files with 21 additions and 18 deletions

View File

@ -77,17 +77,17 @@ class ApiDependencies:
db_location.parent.mkdir(parents=True, exist_ok=True) db_location.parent.mkdir(parents=True, exist_ok=True)
graph_execution_manager = SqliteItemStorage[GraphExecutionState]( graph_execution_manager = SqliteItemStorage[GraphExecutionState](
filename=db_location, table_name="graph_executions", lock=lock filename=db_location, table_name="graph_executions"
) )
urls = LocalUrlService() urls = LocalUrlService()
image_record_storage = SqliteImageRecordStorage(db_location, lock=lock) image_record_storage = SqliteImageRecordStorage(db_location)
image_file_storage = DiskImageFileStorage(f"{output_folder}/images") image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
names = SimpleNameService() names = SimpleNameService()
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents")) latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
board_record_storage = SqliteBoardRecordStorage(db_location, lock=lock) board_record_storage = SqliteBoardRecordStorage(db_location)
board_image_record_storage = SqliteBoardImageRecordStorage(db_location, lock=lock) board_image_record_storage = SqliteBoardImageRecordStorage(db_location)
boards = BoardService( boards = BoardService(
services=BoardServiceDependencies( services=BoardServiceDependencies(
@ -121,7 +121,7 @@ class ApiDependencies:
) )
) )
batch_manager_storage = SqliteBatchProcessStorage(db_location, lock=lock) batch_manager_storage = SqliteBatchProcessStorage(db_location)
batch_manager = BatchManager(batch_manager_storage) batch_manager = BatchManager(batch_manager_storage)
services = InvocationServices( services = InvocationServices(
@ -133,7 +133,7 @@ class ApiDependencies:
boards=boards, boards=boards,
board_images=board_images, board_images=board_images,
queue=MemoryInvocationQueue(), queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs", lock=lock), graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),
graph_execution_manager=graph_execution_manager, graph_execution_manager=graph_execution_manager,
processor=DefaultInvocationProcessor(), processor=DefaultInvocationProcessor(),
configuration=config, configuration=config,

View File

@ -38,8 +38,12 @@ class BatchSession(BaseModel):
) )
def uuid_string():
res = uuid.uuid4()
return str(res)
class BatchProcess(BaseModel): class BatchProcess(BaseModel):
batch_id: Optional[str] = Field(default_factory=uuid.uuid4().__str__, description="Identifier for this batch") batch_id: Optional[str] = Field(default_factory=uuid_string, description="Identifier for this batch")
batches: List[Batch] = Field( batches: List[Batch] = Field(
description="List of batch configs to apply to this session", description="List of batch configs to apply to this session",
default_factory=list, default_factory=list,
@ -168,14 +172,14 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
_cursor: sqlite3.Cursor _cursor: sqlite3.Cursor
_lock: threading.Lock _lock: threading.Lock
def __init__(self, filename: str, lock: threading.Lock = threading.Lock()) -> None: def __init__(self, filename: str) -> None:
super().__init__() super().__init__()
self._filename = filename self._filename = filename
self._conn = sqlite3.connect(filename, check_same_thread=False) self._conn = sqlite3.connect(filename, check_same_thread=False)
# Enable row factory to get rows as dictionaries (must be done before making the cursor!) # Enable row factory to get rows as dictionaries (must be done before making the cursor!)
self._conn.row_factory = sqlite3.Row self._conn.row_factory = sqlite3.Row
self._cursor = self._conn.cursor() self._cursor = self._conn.cursor()
self._lock = lock self._lock = threading.Lock()
try: try:
self._lock.acquire() self._lock.acquire()

View File

@ -62,14 +62,14 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
_cursor: sqlite3.Cursor _cursor: sqlite3.Cursor
_lock: threading.Lock _lock: threading.Lock
def __init__(self, filename: str, lock: threading.Lock = threading.Lock()) -> None: def __init__(self, filename: str) -> None:
super().__init__() super().__init__()
self._filename = filename self._filename = filename
self._conn = sqlite3.connect(filename, check_same_thread=False) self._conn = sqlite3.connect(filename, check_same_thread=False)
# Enable row factory to get rows as dictionaries (must be done before making the cursor!) # Enable row factory to get rows as dictionaries (must be done before making the cursor!)
self._conn.row_factory = sqlite3.Row self._conn.row_factory = sqlite3.Row
self._cursor = self._conn.cursor() self._cursor = self._conn.cursor()
self._lock = lock self._lock = threading.Lock()
try: try:
self._lock.acquire() self._lock.acquire()

View File

@ -95,14 +95,14 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
_cursor: sqlite3.Cursor _cursor: sqlite3.Cursor
_lock: threading.Lock _lock: threading.Lock
def __init__(self, filename: str, lock: threading.Lock = threading.Lock()) -> None: def __init__(self, filename: str) -> None:
super().__init__() super().__init__()
self._filename = filename self._filename = filename
self._conn = sqlite3.connect(filename, check_same_thread=False) self._conn = sqlite3.connect(filename, check_same_thread=False)
# Enable row factory to get rows as dictionaries (must be done before making the cursor!) # Enable row factory to get rows as dictionaries (must be done before making the cursor!)
self._conn.row_factory = sqlite3.Row self._conn.row_factory = sqlite3.Row
self._cursor = self._conn.cursor() self._cursor = self._conn.cursor()
self._lock = lock self._lock = threading.Lock()
try: try:
self._lock.acquire() self._lock.acquire()

View File

@ -155,14 +155,14 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
_cursor: sqlite3.Cursor _cursor: sqlite3.Cursor
_lock: threading.Lock _lock: threading.Lock
def __init__(self, filename: str, lock: threading.Lock = threading.Lock()) -> None: def __init__(self, filename: str) -> None:
super().__init__() super().__init__()
self._filename = filename self._filename = filename
self._conn = sqlite3.connect(filename, check_same_thread=False) self._conn = sqlite3.connect(filename, check_same_thread=False)
# Enable row factory to get rows as dictionaries (must be done before making the cursor!) # Enable row factory to get rows as dictionaries (must be done before making the cursor!)
self._conn.row_factory = sqlite3.Row self._conn.row_factory = sqlite3.Row
self._cursor = self._conn.cursor() self._cursor = self._conn.cursor()
self._lock = lock self._lock = threading.Lock()
try: try:
self._lock.acquire() self._lock.acquire()

View File

@ -9,7 +9,6 @@ from .item_storage import ItemStorageABC, PaginatedResults
T = TypeVar("T", bound=BaseModel) T = TypeVar("T", bound=BaseModel)
sqlite_memory = ":memory:" sqlite_memory = ":memory:"
import traceback
class SqliteItemStorage(ItemStorageABC, Generic[T]): class SqliteItemStorage(ItemStorageABC, Generic[T]):
@ -20,12 +19,12 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
_id_field: str _id_field: str
_lock: Lock _lock: Lock
def __init__(self, filename: str, table_name: str, id_field: str = "id", lock: Lock = Lock()): def __init__(self, filename: str, table_name: str, id_field: str = "id"):
super().__init__() super().__init__()
self._filename = filename self._filename = filename
self._table_name = table_name self._table_name = table_name
self._id_field = id_field # TODO: validate that T has this field self._id_field = id_field # TODO: validate that T has this field
self._lock = lock self._lock = Lock()
self._conn = sqlite3.connect( self._conn = sqlite3.connect(
self._filename, check_same_thread=False self._filename, check_same_thread=False
) # TODO: figure out a better threading solution ) # TODO: figure out a better threading solution