feat(batches): defer ges creation until execution

This improves the overall responsiveness of the system substantially, but does make each iteration *slightly* slower, distributing the up-front cost across the batch.

Two main changes:

1. Create BatchSessions immediately, but do not create a whole graph execution state until the batch is executed.
BatchSessions are created with a `session_id` that does not exist in sessions database.
The default state is changed to `"uninitialized"` to better represent this.

Results: Time to create 5000 batches reduced from over 30s to 2.5s

2. Use `executemany()` to retrieve lists of created sessions.
Results: time to create 5000 batches reduced from 2.5s to under 0.5s

Other changes:

- set BatchSession state to `"in_progress"` just before `invoke()` is called
- rename a few methods to accomodate the new behaviour
- remove unused `BatchProcessStorage.get_created_sessions()` method
This commit is contained in:
psychedelicious 2023-08-21 21:44:33 +10:00
parent 50816432dc
commit 88ae19a768
2 changed files with 98 additions and 47 deletions

View File

@ -94,7 +94,7 @@ class BatchManager(BatchManagerBase):
async def _process(self, event: Event, err: bool) -> None:
data = event[1]["data"]
try:
batch_session = self.__batch_process_storage.get_session(data["graph_execution_state_id"])
batch_session = self.__batch_process_storage.get_session_by_session_id(data["graph_execution_state_id"])
except BatchSessionNotFoundException:
return None
changes = BatchSessionChanges(state="error" if err else "completed")
@ -130,10 +130,18 @@ class BatchManager(BatchManagerBase):
def run_batch_process(self, batch_id: str) -> None:
self.__batch_process_storage.start(batch_id)
try:
created_session = self.__batch_process_storage.get_created_session(batch_id)
next_session = self.__batch_process_storage.get_next_session(batch_id)
except BatchSessionNotFoundException:
return
ges = self.__invoker.services.graph_execution_manager.get(created_session.session_id)
batch_process = self.__batch_process_storage.get(batch_id)
ges = self._create_batch_session(batch_process=batch_process, batch_indices=tuple(next_session.batch_index))
ges.id = next_session.session_id
self.__invoker.services.graph_execution_manager.set(ges)
self.__batch_process_storage.update_session_state(
batch_id=next_session.batch_id,
session_id=next_session.session_id,
changes=BatchSessionChanges(state="in_progress"),
)
self.__invoker.invoke(ges, invoke_all=True)
def create_batch_process(self, batch: Batch, graph: Graph) -> BatchProcessResponse:
@ -150,25 +158,20 @@ class BatchManager(BatchManagerBase):
def _create_sessions(self, batch_process: BatchProcess) -> list[BatchSession]:
batch_indices = list()
sessions = list()
sessions_to_create: list[BatchSession] = list()
for batchdata in batch_process.batch.data:
batch_indices.append(list(range(len(batchdata[0].items))))
all_batch_indices = product(*batch_indices)
for bi in all_batch_indices:
for _ in range(batch_process.batch.runs):
ges = self._create_batch_session(batch_process, bi)
self.__invoker.services.graph_execution_manager.set(ges)
batch_session = BatchSession(batch_id=batch_process.batch_id, session_id=ges.id, state="created")
sessions.append(self.__batch_process_storage.create_session(batch_session))
if not sessions:
ges = GraphExecutionState(graph=batch_process.graph)
self.__invoker.services.graph_execution_manager.set(ges)
batch_session = BatchSession(batch_id=batch_process.batch_id, session_id=ges.id, state="created")
sessions.append(self.__batch_process_storage.create_session(batch_session))
return sessions
sessions_to_create.append(BatchSession(batch_id=batch_process.batch_id, batch_index=list(bi)))
if not sessions_to_create:
sessions_to_create.append(BatchSession(batch_id=batch_process.batch_id, batch_index=list(bi)))
created_sessions = self.__batch_process_storage.create_sessions(sessions_to_create)
return created_sessions
def get_sessions(self, batch_id: str) -> list[BatchSession]:
return self.__batch_process_storage.get_sessions(batch_id)
return self.__batch_process_storage.get_sessions_by_batch_id(batch_id)
def get_batch(self, batch_id: str) -> BatchProcess:
return self.__batch_process_storage.get(batch_id)
@ -177,7 +180,7 @@ class BatchManager(BatchManagerBase):
bps = self.__batch_process_storage.get_all()
res = list()
for bp in bps:
sessions = self.__batch_process_storage.get_sessions(bp.batch_id)
sessions = self.__batch_process_storage.get_sessions_by_batch_id(bp.batch_id)
res.append(
BatchProcessResponse(
batch_id=bp.batch_id,
@ -190,7 +193,7 @@ class BatchManager(BatchManagerBase):
bps = self.__batch_process_storage.get_incomplete()
res = list()
for bp in bps:
sessions = self.__batch_process_storage.get_sessions(bp.batch_id)
sessions = self.__batch_process_storage.get_sessions_by_batch_id(bp.batch_id)
res.append(
BatchProcessResponse(
batch_id=bp.batch_id,

View File

@ -1,3 +1,4 @@
import json
import sqlite3
import threading
import uuid
@ -71,17 +72,29 @@ class Batch(BaseModel):
return v
class BatchSession(BaseModel):
batch_id: str = Field(description="The Batch to which this BatchSession is attached.")
session_id: str = Field(description="The Session to which this BatchSession is attached.")
state: Literal["created", "completed", "inprogress", "error"] = Field(description="The state of this BatchSession")
def uuid_string():
res = uuid.uuid4()
return str(res)
BATCH_SESSION_STATE = Literal["uninitialized", "in_progress", "completed", "error"]
class BatchSession(BaseModel):
batch_id: str = Field(defaultdescription="The Batch to which this BatchSession is attached.")
session_id: str = Field(
default_factory=uuid_string, description="The Session to which this BatchSession is attached."
)
batch_index: list[int] = Field(description="The index of the batch to be run in this BatchSession.")
state: BATCH_SESSION_STATE = Field(default="uninitialized", description="The state of this BatchSession")
@validator("batch_index", pre=True)
def parse_str_to_list(cls, v: Union[str, list[int]]):
if isinstance(v, str):
return json.loads(v)
return v
class BatchProcess(BaseModel):
batch_id: str = Field(default_factory=uuid_string, description="Identifier for this batch.")
batch: Batch = Field(description="The Batch to apply to this session.")
@ -90,7 +103,7 @@ class BatchProcess(BaseModel):
class BatchSessionChanges(BaseModel, extra=Extra.forbid):
state: Literal["created", "completed", "inprogress", "error"] = Field(description="The state of this BatchSession")
state: BATCH_SESSION_STATE = Field(description="The state of this BatchSession")
class BatchProcessNotFoundException(Exception):
@ -198,23 +211,31 @@ class BatchProcessStorageBase(ABC):
pass
@abstractmethod
def get_session(self, session_id: str) -> BatchSession:
def create_sessions(
self,
sessions: list[BatchSession],
) -> list[BatchSession]:
"""Creates many BatchSessions attached to a BatchProcess."""
pass
@abstractmethod
def get_session_by_session_id(self, session_id: str) -> BatchSession:
"""Gets a BatchSession by session_id"""
pass
@abstractmethod
def get_sessions(self, batch_id: str) -> List[BatchSession]:
def get_sessions_by_batch_id(self, batch_id: str) -> List[BatchSession]:
"""Gets all BatchSession's for a given BatchProcess id."""
pass
@abstractmethod
def get_created_session(self, batch_id: str) -> BatchSession:
"""Gets the latest BatchSession with state `created`, for a given BatchProcess id."""
def get_sessions_by_session_ids(self, session_ids: list[str]) -> List[BatchSession]:
"""Gets all BatchSession's for a given list of session ids."""
pass
@abstractmethod
def get_created_sessions(self, batch_id: str) -> List[BatchSession]:
"""Gets all BatchSession's with state `created`, for a given BatchProcess id."""
def get_next_session(self, batch_id: str) -> BatchSession:
"""Gets the next BatchSession with state `uninitialized`, for a given BatchProcess id."""
pass
@abstractmethod
@ -295,6 +316,7 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
CREATE TABLE IF NOT EXISTS batch_session (
batch_id TEXT NOT NULL,
session_id TEXT NOT NULL,
batch_index TEXT NOT NULL,
state TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- updated via trigger
@ -453,7 +475,7 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
(
SELECT batch_id
FROM batch_session bs
WHERE state = 'created'
WHERE state IN ('uninitialized', 'in_progress')
);
"""
)
@ -518,10 +540,10 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
self._lock.acquire()
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO batch_session (batch_id, session_id, state)
VALUES (?, ?, ?);
INSERT OR IGNORE INTO batch_session (batch_id, session_id, state, batch_index)
VALUES (?, ?, ?, ?);
""",
(session.batch_id, session.session_id, session.state),
(session.batch_id, session.session_id, session.state, json.dumps(session.batch_index)),
)
self._conn.commit()
except sqlite3.Error as e:
@ -529,9 +551,34 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
raise BatchSessionSaveException from e
finally:
self._lock.release()
return self.get_session(session.session_id)
return self.get_session_by_session_id(session.session_id)
def get_session(self, session_id: str) -> BatchSession:
def create_sessions(
self,
sessions: list[BatchSession],
) -> list[BatchSession]:
try:
self._lock.acquire()
session_data = [
(session.batch_id, session.session_id, session.state, json.dumps(session.batch_index))
for session in sessions
]
self._cursor.executemany(
"""--sql
INSERT OR IGNORE INTO batch_session (batch_id, session_id, state, batch_index)
VALUES (?, ?, ?, ?);
""",
session_data,
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BatchSessionSaveException from e
finally:
self._lock.release()
return self.get_sessions_by_session_ids([session.session_id for session in sessions])
def get_session_by_session_id(self, session_id: str) -> BatchSession:
try:
self._lock.acquire()
self._cursor.execute(
@ -558,14 +605,14 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
return BatchSession.parse_obj(session_dict)
def get_created_session(self, batch_id: str) -> BatchSession:
def get_next_session(self, batch_id: str) -> BatchSession:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT *
FROM batch_session
WHERE batch_id = ? AND state = 'created';
WHERE batch_id = ? AND state = 'uninitialized';
""",
(batch_id,),
)
@ -581,14 +628,14 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
session = self._deserialize_batch_session(dict(result))
return session
def get_created_sessions(self, batch_id: str) -> List[BatchSession]:
def get_sessions_by_batch_id(self, batch_id: str) -> List[BatchSession]:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT *
FROM batch_session
WHERE batch_id = ? AND state = created;
WHERE batch_id = ?;
""",
(batch_id,),
)
@ -604,16 +651,17 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
sessions = list(map(lambda r: self._deserialize_batch_session(dict(r)), result))
return sessions
def get_sessions(self, batch_id: str) -> List[BatchSession]:
def get_sessions_by_session_ids(self, session_ids: list[str]) -> List[BatchSession]:
try:
self._lock.acquire()
placeholders = ",".join("?" * len(session_ids))
self._cursor.execute(
"""--sql
SELECT *
FROM batch_session
WHERE batch_id = ?;
f"""--sql
SELECT * FROM batch_session
WHERE session_id
IN ({placeholders})
""",
(batch_id,),
tuple(session_ids),
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
@ -652,4 +700,4 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
raise BatchSessionSaveException from e
finally:
self._lock.release()
return self.get_session(session_id)
return self.get_session_by_session_id(session_id)