mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Testing sqlite issues with batch_manager
This commit is contained in:
parent
835d76af45
commit
e26e4740b3
@ -31,6 +31,7 @@ from ..services.processor import DefaultInvocationProcessor
|
||||
from ..services.sqlite import SqliteItemStorage
|
||||
from ..services.model_manager_service import ModelManagerService
|
||||
from ..services.batch_manager import BatchManager
|
||||
from ..services.batch_manager_storage import SqliteBatchProcessStorage
|
||||
from .events import FastAPIEventService
|
||||
|
||||
|
||||
@ -117,7 +118,8 @@ class ApiDependencies:
|
||||
)
|
||||
)
|
||||
|
||||
batch_manager = BatchManager()
|
||||
batch_manager_storage = SqliteBatchProcessStorage(db_location)
|
||||
batch_manager = BatchManager(batch_manager_storage)
|
||||
|
||||
services = InvocationServices(
|
||||
model_manager=ModelManagerService(config, logger),
|
||||
|
@ -51,8 +51,9 @@ async def create_batch(
|
||||
batches: list[Batch] = Body(description="Batch config to apply to the given graph"),
|
||||
) -> BatchProcess:
|
||||
"""Creates and starts a new new batch process"""
|
||||
session = ApiDependencies.invoker.services.batch_manager.run_batch_process(batches, graph)
|
||||
return session
|
||||
batch_id = ApiDependencies.invoker.services.batch_manager.create_batch_process(batches, graph)
|
||||
ApiDependencies.invoker.services.batch_manager.run_batch_process(batch_id)
|
||||
return {"batch_id":batch_id}
|
||||
|
||||
|
||||
@session_router.delete(
|
||||
|
@ -1,45 +1,22 @@
|
||||
import networkx as nx
|
||||
import uuid
|
||||
import copy
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from itertools import product
|
||||
from pydantic import BaseModel, Field
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.typing import Event
|
||||
from typing import (
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
)
|
||||
from invokeai.app.services.events import EventServiceBase
|
||||
from invokeai.app.services.graph import Graph, GraphExecutionState
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
|
||||
|
||||
InvocationsUnion = Union[BaseInvocation.get_invocations()] # type: ignore
|
||||
|
||||
|
||||
class Batch(BaseModel):
|
||||
data: list[InvocationsUnion] = Field(description="Mapping of ")
|
||||
node_id: str = Field(description="ID of the node to batch")
|
||||
|
||||
|
||||
class BatchProcess(BaseModel):
|
||||
batch_id: Optional[str] = Field(default_factory=uuid.uuid4().__str__, description="Identifier for this batch")
|
||||
sessions: list[str] = Field(
|
||||
description="Tracker for which batch is currently being processed", default_factory=list
|
||||
)
|
||||
batches: list[Batch] = Field(
|
||||
description="List of batch configs to apply to this session",
|
||||
default_factory=list,
|
||||
)
|
||||
batch_indices: list[int] = Field(
|
||||
description="Tracker for which batch is currently being processed", default_factory=list
|
||||
)
|
||||
graph: Graph = Field(description="The graph being executed")
|
||||
from invokeai.app.services.batch_manager_storage import (
|
||||
BatchProcessStorageBase,
|
||||
Batch,
|
||||
BatchProcess,
|
||||
BatchSession,
|
||||
BatchSessionChanges,
|
||||
)
|
||||
|
||||
|
||||
class BatchManagerBase(ABC):
|
||||
@ -48,7 +25,11 @@ class BatchManagerBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def run_batch_process(self, batches: list[Batch], graph: Graph) -> BatchProcess:
|
||||
def create_batch_process(self, batches: list[Batch], graph: Graph) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def run_batch_process(self, batch_id: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@ -61,8 +42,13 @@ class BatchManager(BatchManagerBase):
|
||||
|
||||
__invoker: Invoker
|
||||
__batches: list[BatchProcess]
|
||||
__batch_process_storage: BatchProcessStorageBase
|
||||
|
||||
def start(self, invoker) -> None:
|
||||
def __init__(self, batch_process_storage: BatchProcessStorageBase) -> None:
|
||||
super().__init__()
|
||||
self.__batch_process_storage = batch_process_storage
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
# if we do want multithreading at some point, we could make this configurable
|
||||
self.__invoker = invoker
|
||||
self.__batches = list()
|
||||
@ -73,34 +59,28 @@ class BatchManager(BatchManagerBase):
|
||||
|
||||
match event_name:
|
||||
case "graph_execution_state_complete":
|
||||
await self.process(event)
|
||||
await self.process(event, False)
|
||||
case "invocation_error":
|
||||
await self.process(event)
|
||||
await self.process(event, True)
|
||||
|
||||
return event
|
||||
|
||||
async def process(self, event: Event):
|
||||
async def process(self, event: Event, err: bool):
|
||||
data = event[1]["data"]
|
||||
batchTarget = None
|
||||
for batch in self.__batches:
|
||||
if data["graph_execution_state_id"] in batch.sessions:
|
||||
batchTarget = batch
|
||||
break
|
||||
|
||||
if batchTarget == None:
|
||||
batch_session = self.__batch_process_storage.get_session(data["graph_execution_state_id"])
|
||||
if not batch_session:
|
||||
return
|
||||
updateSession = BatchSessionChanges(
|
||||
state='error' if err else 'completed'
|
||||
)
|
||||
batch_session = self.__batch_process_storage.update_session_state(
|
||||
batch_session.batch_id,
|
||||
batch_session.session_id,
|
||||
updateSession,
|
||||
)
|
||||
self.run_batch_process(batch_session.batch_id)
|
||||
|
||||
if sum(batchTarget.batch_indices) == 0:
|
||||
self.__batches = [batch for batch in self.__batches if batch != batchTarget]
|
||||
return
|
||||
|
||||
batchTarget.batch_indices = self._next_batch_index(batchTarget)
|
||||
ges = self._next_batch_session(batchTarget)
|
||||
batchTarget.sessions.append(ges.id)
|
||||
self.__invoker.services.graph_execution_manager.set(ges)
|
||||
self.__invoker.invoke(ges, invoke_all=True)
|
||||
|
||||
def _next_batch_session(self, batch_process: BatchProcess) -> GraphExecutionState:
|
||||
def _create_batch_session(self, batch_process: BatchProcess, batch_indices: list[int]) -> GraphExecutionState:
|
||||
graph = copy.deepcopy(batch_process.graph)
|
||||
batches = batch_process.batches
|
||||
g = graph.nx_graph_flat()
|
||||
@ -109,36 +89,47 @@ class BatchManager(BatchManagerBase):
|
||||
node = graph.get_node(npath)
|
||||
(index, batch) = next(((i, b) for i, b in enumerate(batches) if b.node_id in node.id), (None, None))
|
||||
if batch:
|
||||
batch_index = batch_process.batch_indices[index]
|
||||
batch_index = batch_indices[index]
|
||||
datum = batch.data[batch_index]
|
||||
datum.id = node.id
|
||||
graph.update_node(npath, datum)
|
||||
for key in datum:
|
||||
node.__dict__[key] = datum[key]
|
||||
graph.update_node(npath, node)
|
||||
|
||||
return GraphExecutionState(graph=graph)
|
||||
|
||||
def _next_batch_index(self, batch_process: BatchProcess):
|
||||
batch_indicies = batch_process.batch_indices.copy()
|
||||
for index in range(len(batch_indicies)):
|
||||
if batch_indicies[index] > 0:
|
||||
batch_indicies[index] -= 1
|
||||
break
|
||||
return batch_indicies
|
||||
def run_batch_process(self, batch_id: str):
|
||||
created_session = self.__batch_process_storage.get_created_session(batch_id)
|
||||
ges = self.__invoker.services.graph_execution_manager.get(created_session.session_id)
|
||||
self.__invoker.invoke(ges, invoke_all=True)
|
||||
|
||||
def run_batch_process(self, batches: list[Batch], graph: Graph) -> BatchProcess:
|
||||
batch_indices = list()
|
||||
for batch in batches:
|
||||
batch_indices.append(len(batch.data) - 1)
|
||||
def _valid_batch_config(self, batch_process: BatchProcess) -> bool:
|
||||
return True
|
||||
|
||||
def create_batch_process(self, batches: list[Batch], graph: Graph) -> str:
|
||||
batch_process = BatchProcess(
|
||||
batches=batches,
|
||||
batch_indices=batch_indices,
|
||||
graph=graph,
|
||||
)
|
||||
ges = self._next_batch_session(batch_process)
|
||||
batch_process.sessions.append(ges.id)
|
||||
self.__batches.append(batch_process)
|
||||
if not self._valid_batch_config(batch_process):
|
||||
return None
|
||||
batch_process = self.__batch_process_storage.save(batch_process)
|
||||
self._create_sessions(batch_process)
|
||||
return batch_process.batch_id
|
||||
|
||||
def _create_sessions(self, batch_process: BatchProcess):
|
||||
batch_indices = list()
|
||||
for batch in batch_process.batches:
|
||||
batch_indices.append(list(range(len(batch.data))))
|
||||
all_batch_indices = product(*batch_indices)
|
||||
for bi in all_batch_indices:
|
||||
ges = self._create_batch_session(batch_process, bi)
|
||||
self.__invoker.services.graph_execution_manager.set(ges)
|
||||
self.__invoker.invoke(ges, invoke_all=True)
|
||||
return batch_process
|
||||
batch_session = BatchSession(
|
||||
batch_id=batch_process.batch_id,
|
||||
session_id=ges.id,
|
||||
state="created"
|
||||
)
|
||||
self.__batch_process_storage.create_session(batch_session)
|
||||
|
||||
def cancel_batch_process(self, batch_process_id: str):
|
||||
self.__batches = [batch for batch in self.__batches if batch.id != batch_process_id]
|
||||
|
500
invokeai/app/services/batch_manager_storage.py
Normal file
500
invokeai/app/services/batch_manager_storage.py
Normal file
@ -0,0 +1,500 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import cast
|
||||
import uuid
|
||||
import sqlite3
|
||||
import threading
|
||||
from typing import (
|
||||
Any,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
import json
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
)
|
||||
from invokeai.app.services.graph import Graph, GraphExecutionState
|
||||
from invokeai.app.models.image import ImageField
|
||||
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
||||
|
||||
from pydantic import BaseModel, Field, Extra, parse_raw_as
|
||||
|
||||
invocations = BaseInvocation.get_invocations()
|
||||
InvocationsUnion = Union[invocations] # type: ignore
|
||||
|
||||
BatchDataType = Union[str, int, float, ImageField]
|
||||
|
||||
class Batch(BaseModel):
|
||||
data: list[dict[str, BatchDataType]] = Field(description="Mapping of node field to data value")
|
||||
node_id: str = Field(description="ID of the node to batch")
|
||||
|
||||
|
||||
class BatchSession(BaseModel):
|
||||
batch_id: str = Field(description="Identifier for which batch this Index belongs to")
|
||||
session_id: str = Field(description="Session ID Created for this Batch Index")
|
||||
state: Literal["created", "completed", "inprogress", "error"] = Field(
|
||||
description="Is this session created, completed, in progress, or errored?"
|
||||
)
|
||||
|
||||
|
||||
class BatchProcess(BaseModel):
|
||||
batch_id: Optional[str] = Field(default_factory=uuid.uuid4().__str__, description="Identifier for this batch")
|
||||
batches: List[Batch] = Field(
|
||||
description="List of batch configs to apply to this session",
|
||||
default_factory=list,
|
||||
)
|
||||
graph: Graph = Field(description="The graph being executed")
|
||||
|
||||
|
||||
class BatchSessionChanges(BaseModel, extra=Extra.forbid):
|
||||
state: Literal["created", "completed", "inprogress", "error"] = Field(
|
||||
description="Is this session created, completed, in progress, or errored?"
|
||||
)
|
||||
|
||||
|
||||
class BatchProcessNotFoundException(Exception):
|
||||
"""Raised when an Batch Process record is not found."""
|
||||
|
||||
def __init__(self, message="BatchProcess record not found"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BatchProcessSaveException(Exception):
|
||||
"""Raised when an Batch Process record cannot be saved."""
|
||||
|
||||
def __init__(self, message="BatchProcess record not saved"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BatchProcessDeleteException(Exception):
|
||||
"""Raised when an Batch Process record cannot be deleted."""
|
||||
|
||||
def __init__(self, message="BatchProcess record not deleted"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BatchSessionNotFoundException(Exception):
|
||||
"""Raised when an Batch Session record is not found."""
|
||||
|
||||
def __init__(self, message="BatchSession record not found"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BatchSessionSaveException(Exception):
|
||||
"""Raised when an Batch Session record cannot be saved."""
|
||||
|
||||
def __init__(self, message="BatchSession record not saved"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BatchSessionDeleteException(Exception):
|
||||
"""Raised when an Batch Session record cannot be deleted."""
|
||||
|
||||
def __init__(self, message="BatchSession record not deleted"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BatchProcessStorageBase(ABC):
|
||||
"""Low-level service responsible for interfacing with the Batch Process record store."""
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, batch_id: str) -> None:
|
||||
"""Deletes a Batch Process record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(
|
||||
self,
|
||||
batch_process: BatchProcess,
|
||||
) -> BatchProcess:
|
||||
"""Saves a Batch Process record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get(
|
||||
self,
|
||||
batch_id: str,
|
||||
) -> BatchProcess:
|
||||
"""Gets a Batch Process record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create_session(
|
||||
self,
|
||||
session: BatchSession,
|
||||
) -> BatchSession:
|
||||
"""Creates a Batch Session attached to a Batch Process."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_session(
|
||||
self,
|
||||
session_id: str
|
||||
) -> BatchSession:
|
||||
"""Gets session by session_id"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_created_session(
|
||||
self,
|
||||
batch_id: str
|
||||
) -> BatchSession:
|
||||
"""Gets all created Batch Sessions for a given Batch Process id."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_created_sessions(
|
||||
self,
|
||||
batch_id: str
|
||||
) -> List[BatchSession]:
|
||||
"""Gets all created Batch Sessions for a given Batch Process id."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_session_state(
|
||||
self,
|
||||
batch_id: str,
|
||||
session_id: str,
|
||||
changes: BatchSessionChanges,
|
||||
) -> BatchSession:
|
||||
"""Updates the state of a Batch Session record."""
|
||||
pass
|
||||
|
||||
|
||||
class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
||||
_filename: str
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_lock: threading.Lock
|
||||
|
||||
def __init__(self, filename: str) -> None:
|
||||
super().__init__()
|
||||
self._filename = filename
|
||||
self._conn = sqlite3.connect(filename, check_same_thread=False)
|
||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
self._cursor = self._conn.cursor()
|
||||
self._lock = threading.Lock()
|
||||
|
||||
try:
|
||||
self._lock.acquire()
|
||||
# Enable foreign keys
|
||||
self._conn.execute("PRAGMA foreign_keys = ON;")
|
||||
self._create_tables()
|
||||
self._conn.commit()
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _create_tables(self) -> None:
|
||||
"""Creates the `batch_process` table and `batch_session` junction table."""
|
||||
|
||||
# Create the `batch_process` table.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS batch_process (
|
||||
batch_id TEXT NOT NULL PRIMARY KEY,
|
||||
batches TEXT NOT NULL,
|
||||
graph TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Soft delete, currently unused
|
||||
deleted_at DATETIME
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_batch_process_created_at ON batch_process (created_at);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add trigger for `updated_at`.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_batch_process_updated_at
|
||||
AFTER UPDATE
|
||||
ON batch_process FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE batch_process SET updated_at = current_timestamp
|
||||
WHERE batch_id = old.batch_id;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
# Create the `batch_session` junction table.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS batch_session (
|
||||
batch_id TEXT NOT NULL,
|
||||
session_id TEXT NOT NULL,
|
||||
state TEXT NOT NULL DEFAULT('created'),
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Soft delete, currently unused
|
||||
deleted_at DATETIME,
|
||||
-- enforce one-to-many relationship between batch_process and batch_session using PK
|
||||
-- (we can extend this to many-to-many later)
|
||||
PRIMARY KEY (batch_id,session_id),
|
||||
FOREIGN KEY (batch_id) REFERENCES batch_process (batch_id) ON DELETE CASCADE
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add index for batch id
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_batch_session_batch_id ON batch_session (batch_id);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add index for batch id, sorted by created_at
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_batch_session_batch_id_created_at ON batch_session (batch_id,created_at);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add trigger for `updated_at`.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_batch_session_updated_at
|
||||
AFTER UPDATE
|
||||
ON batch_session FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE batch_session SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE batch_id = old.batch_id AND image_name = old.image_name;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
def delete(self, batch_id: str) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM batch_process
|
||||
WHERE batch_id = ?;
|
||||
""",
|
||||
(batch_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchProcessDeleteException from e
|
||||
except Exception as e:
|
||||
self._conn.rollback()
|
||||
raise BatchProcessDeleteException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def save(
|
||||
self,
|
||||
batch_process: BatchProcess,
|
||||
) -> BatchProcess:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO batch_process (batch_id, batches, graph)
|
||||
VALUES (?, ?, ?);
|
||||
""",
|
||||
(batch_process.batch_id, json.dumps([batch.json() for batch in batch_process.batches]), batch_process.graph.json()),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchProcessSaveException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
return self.get(batch_process.batch_id)
|
||||
|
||||
def _deserialize_batch_process(self, session_dict: dict) -> BatchProcess:
|
||||
"""Deserializes a batch session."""
|
||||
|
||||
# Retrieve all the values, setting "reasonable" defaults if they are not present.
|
||||
|
||||
batch_id = session_dict.get("batch_id", "unknown")
|
||||
batches_raw = session_dict.get("batches", "unknown")
|
||||
graph_raw = session_dict.get("graph", "unknown")
|
||||
|
||||
return BatchProcess(
|
||||
batch_id=batch_id,
|
||||
batches=[parse_raw_as(Batch, batch) for batch in json.loads(batches_raw)],
|
||||
graph=parse_raw_as(Graph, graph_raw),
|
||||
)
|
||||
|
||||
def get(
|
||||
self,
|
||||
batch_id: str,
|
||||
) -> BatchProcess:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM batch_process
|
||||
WHERE batch_id = ?;
|
||||
""",
|
||||
(batch_id,)
|
||||
)
|
||||
|
||||
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchProcessNotFoundException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
if result is None:
|
||||
raise BatchProcessNotFoundException
|
||||
return self._deserialize_batch_process(dict(result))
|
||||
|
||||
def create_session(
|
||||
self,
|
||||
session: BatchSession,
|
||||
) -> BatchSession:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO batch_session (batch_id, session_id, state)
|
||||
VALUES (?, ?, ?);
|
||||
""",
|
||||
(session.batch_id, session.session_id, session.state),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchSessionSaveException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
return self.get_session(session.session_id)
|
||||
|
||||
|
||||
def get_session(
|
||||
self,
|
||||
session_id: str
|
||||
) -> BatchSession:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM batch_session
|
||||
WHERE session_id= ?;
|
||||
""",
|
||||
(session_id,),
|
||||
)
|
||||
|
||||
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchSessionNotFoundException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
if result is None:
|
||||
raise BatchSessionNotFoundException
|
||||
return BatchSession(**dict(result))
|
||||
|
||||
def _deserialize_batch_session(self, session_dict: dict) -> BatchSession:
|
||||
"""Deserializes a batch session."""
|
||||
|
||||
# Retrieve all the values, setting "reasonable" defaults if they are not present.
|
||||
|
||||
batch_id = session_dict.get("batch_id", "unknown")
|
||||
session_id = session_dict.get("session_id", "unknown")
|
||||
state = session_dict.get("state", "unknown")
|
||||
|
||||
return BatchSession(
|
||||
batch_id=batch_id,
|
||||
session_id=session_id,
|
||||
state=state,
|
||||
)
|
||||
|
||||
|
||||
def get_created_session(
|
||||
self,
|
||||
batch_id: str
|
||||
) -> BatchSession:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM batch_session
|
||||
WHERE batch_id = ? AND state = 'created';
|
||||
""",
|
||||
(batch_id,),
|
||||
)
|
||||
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchone())
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchSessionNotFoundException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
if result is None:
|
||||
raise BatchSessionNotFoundException
|
||||
session = self._deserialize_batch_session(dict(result))
|
||||
return session
|
||||
|
||||
|
||||
def get_created_sessions(
|
||||
self,
|
||||
batch_id: str
|
||||
) -> List[BatchSession]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM batch_session
|
||||
WHERE batch_id = ? AND state = created;
|
||||
""",
|
||||
(batch_id,),
|
||||
)
|
||||
|
||||
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchSessionNotFoundException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
if result is None:
|
||||
raise BatchSessionNotFoundException
|
||||
sessions = list(map(lambda r: self._deserialize_batch_session(dict(r)), result))
|
||||
return sessions
|
||||
|
||||
|
||||
def update_session_state(
|
||||
self,
|
||||
batch_id: str,
|
||||
session_id: str,
|
||||
changes: BatchSessionChanges,
|
||||
) -> BatchSession:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
# Change the state of a batch session
|
||||
if changes.state is not None:
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
UPDATE batch_session
|
||||
SET state = ?
|
||||
WHERE batch_id = ? AND session_id = ?;
|
||||
""",
|
||||
(changes.state, batch_id, session_id),
|
||||
)
|
||||
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchSessionSaveException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
return self.get_session(session_id)
|
@ -9,6 +9,7 @@ from .item_storage import ItemStorageABC, PaginatedResults
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
sqlite_memory = ":memory:"
|
||||
import traceback
|
||||
|
||||
|
||||
class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
@ -21,7 +22,6 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
|
||||
def __init__(self, filename: str, table_name: str, id_field: str = "id"):
|
||||
super().__init__()
|
||||
|
||||
self._filename = filename
|
||||
self._table_name = table_name
|
||||
self._id_field = id_field # TODO: validate that T has this field
|
||||
@ -29,6 +29,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
self._conn = sqlite3.connect(
|
||||
self._filename, check_same_thread=False
|
||||
) # TODO: figure out a better threading solution
|
||||
self._conn.set_trace_callback(print)
|
||||
self._cursor = self._conn.cursor()
|
||||
|
||||
self._create_table()
|
||||
@ -54,11 +55,21 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
def set(self, item: T):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
json = item.json()
|
||||
print('-----------------locking db-----------------')
|
||||
traceback.print_stack(limit=2)
|
||||
|
||||
self._cursor.execute(
|
||||
f"""INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);""",
|
||||
(item.json(),),
|
||||
(json,),
|
||||
)
|
||||
self._conn.commit()
|
||||
self._cursor.close()
|
||||
self._cursor = self._conn.cursor()
|
||||
print('-----------------unlocking db-----------------')
|
||||
except Exception as e:
|
||||
print("Exception!")
|
||||
print(e)
|
||||
finally:
|
||||
self._lock.release()
|
||||
self._on_changed(item)
|
||||
@ -66,8 +77,12 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
def get(self, id: str) -> Optional[T]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
print('-----------------locking db-----------------')
|
||||
self._cursor.execute(f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),))
|
||||
result = self._cursor.fetchone()
|
||||
self._cursor.close()
|
||||
self._cursor = self._conn.cursor()
|
||||
print('-----------------unlocking db-----------------')
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user