feat(batches): defer ges creation until execution

This improves the overall responsiveness of the system substantially, but does make each iteration *slightly* slower, distributing the up-front cost across the batch.

Two main changes:

1. Create BatchSessions immediately, but do not create a whole graph execution state until the batch is executed.
BatchSessions are created with a `session_id` that does not exist in sessions database.
The default state is changed to `"uninitialized"` to better represent this.

Results: Time to create 5000 batches reduced from over 30s to 2.5s

2. Use `executemany()` to retrieve lists of created sessions.
Results: time to create 5000 batches reduced from 2.5s to under 0.5s

Other changes:

- set BatchSession state to `"in_progress"` just before `invoke()` is called
- rename a few methods to accomodate the new behaviour
- remove unused `BatchProcessStorage.get_created_sessions()` method
This commit is contained in:
psychedelicious
2023-08-21 21:44:33 +10:00
parent 50816432dc
commit 88ae19a768
2 changed files with 98 additions and 47 deletions

View File

@ -1,3 +1,4 @@
import json
import sqlite3
import threading
import uuid
@ -71,17 +72,29 @@ class Batch(BaseModel):
return v
class BatchSession(BaseModel):
batch_id: str = Field(description="The Batch to which this BatchSession is attached.")
session_id: str = Field(description="The Session to which this BatchSession is attached.")
state: Literal["created", "completed", "inprogress", "error"] = Field(description="The state of this BatchSession")
def uuid_string():
res = uuid.uuid4()
return str(res)
BATCH_SESSION_STATE = Literal["uninitialized", "in_progress", "completed", "error"]
class BatchSession(BaseModel):
batch_id: str = Field(defaultdescription="The Batch to which this BatchSession is attached.")
session_id: str = Field(
default_factory=uuid_string, description="The Session to which this BatchSession is attached."
)
batch_index: list[int] = Field(description="The index of the batch to be run in this BatchSession.")
state: BATCH_SESSION_STATE = Field(default="uninitialized", description="The state of this BatchSession")
@validator("batch_index", pre=True)
def parse_str_to_list(cls, v: Union[str, list[int]]):
if isinstance(v, str):
return json.loads(v)
return v
class BatchProcess(BaseModel):
batch_id: str = Field(default_factory=uuid_string, description="Identifier for this batch.")
batch: Batch = Field(description="The Batch to apply to this session.")
@ -90,7 +103,7 @@ class BatchProcess(BaseModel):
class BatchSessionChanges(BaseModel, extra=Extra.forbid):
state: Literal["created", "completed", "inprogress", "error"] = Field(description="The state of this BatchSession")
state: BATCH_SESSION_STATE = Field(description="The state of this BatchSession")
class BatchProcessNotFoundException(Exception):
@ -198,23 +211,31 @@ class BatchProcessStorageBase(ABC):
pass
@abstractmethod
def get_session(self, session_id: str) -> BatchSession:
def create_sessions(
self,
sessions: list[BatchSession],
) -> list[BatchSession]:
"""Creates many BatchSessions attached to a BatchProcess."""
pass
@abstractmethod
def get_session_by_session_id(self, session_id: str) -> BatchSession:
"""Gets a BatchSession by session_id"""
pass
@abstractmethod
def get_sessions(self, batch_id: str) -> List[BatchSession]:
def get_sessions_by_batch_id(self, batch_id: str) -> List[BatchSession]:
"""Gets all BatchSession's for a given BatchProcess id."""
pass
@abstractmethod
def get_created_session(self, batch_id: str) -> BatchSession:
"""Gets the latest BatchSession with state `created`, for a given BatchProcess id."""
def get_sessions_by_session_ids(self, session_ids: list[str]) -> List[BatchSession]:
"""Gets all BatchSession's for a given list of session ids."""
pass
@abstractmethod
def get_created_sessions(self, batch_id: str) -> List[BatchSession]:
"""Gets all BatchSession's with state `created`, for a given BatchProcess id."""
def get_next_session(self, batch_id: str) -> BatchSession:
"""Gets the next BatchSession with state `uninitialized`, for a given BatchProcess id."""
pass
@abstractmethod
@ -295,6 +316,7 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
CREATE TABLE IF NOT EXISTS batch_session (
batch_id TEXT NOT NULL,
session_id TEXT NOT NULL,
batch_index TEXT NOT NULL,
state TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- updated via trigger
@ -453,7 +475,7 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
(
SELECT batch_id
FROM batch_session bs
WHERE state = 'created'
WHERE state IN ('uninitialized', 'in_progress')
);
"""
)
@ -518,10 +540,10 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
self._lock.acquire()
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO batch_session (batch_id, session_id, state)
VALUES (?, ?, ?);
INSERT OR IGNORE INTO batch_session (batch_id, session_id, state, batch_index)
VALUES (?, ?, ?, ?);
""",
(session.batch_id, session.session_id, session.state),
(session.batch_id, session.session_id, session.state, json.dumps(session.batch_index)),
)
self._conn.commit()
except sqlite3.Error as e:
@ -529,9 +551,34 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
raise BatchSessionSaveException from e
finally:
self._lock.release()
return self.get_session(session.session_id)
return self.get_session_by_session_id(session.session_id)
def get_session(self, session_id: str) -> BatchSession:
def create_sessions(
self,
sessions: list[BatchSession],
) -> list[BatchSession]:
try:
self._lock.acquire()
session_data = [
(session.batch_id, session.session_id, session.state, json.dumps(session.batch_index))
for session in sessions
]
self._cursor.executemany(
"""--sql
INSERT OR IGNORE INTO batch_session (batch_id, session_id, state, batch_index)
VALUES (?, ?, ?, ?);
""",
session_data,
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BatchSessionSaveException from e
finally:
self._lock.release()
return self.get_sessions_by_session_ids([session.session_id for session in sessions])
def get_session_by_session_id(self, session_id: str) -> BatchSession:
try:
self._lock.acquire()
self._cursor.execute(
@ -558,14 +605,14 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
return BatchSession.parse_obj(session_dict)
def get_created_session(self, batch_id: str) -> BatchSession:
def get_next_session(self, batch_id: str) -> BatchSession:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT *
FROM batch_session
WHERE batch_id = ? AND state = 'created';
WHERE batch_id = ? AND state = 'uninitialized';
""",
(batch_id,),
)
@ -581,14 +628,14 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
session = self._deserialize_batch_session(dict(result))
return session
def get_created_sessions(self, batch_id: str) -> List[BatchSession]:
def get_sessions_by_batch_id(self, batch_id: str) -> List[BatchSession]:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT *
FROM batch_session
WHERE batch_id = ? AND state = created;
WHERE batch_id = ?;
""",
(batch_id,),
)
@ -604,16 +651,17 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
sessions = list(map(lambda r: self._deserialize_batch_session(dict(r)), result))
return sessions
def get_sessions(self, batch_id: str) -> List[BatchSession]:
def get_sessions_by_session_ids(self, session_ids: list[str]) -> List[BatchSession]:
try:
self._lock.acquire()
placeholders = ",".join("?" * len(session_ids))
self._cursor.execute(
"""--sql
SELECT *
FROM batch_session
WHERE batch_id = ?;
f"""--sql
SELECT * FROM batch_session
WHERE session_id
IN ({placeholders})
""",
(batch_id,),
tuple(session_ids),
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
@ -652,4 +700,4 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
raise BatchSessionSaveException from e
finally:
self._lock.release()
return self.get_session(session_id)
return self.get_session_by_session_id(session_id)