Add runs field for running the same batch multiple times

This commit is contained in:
Brandon Rising 2023-08-18 13:41:07 -04:00
parent 99e03fe92e
commit 0282f46c71
2 changed files with 17 additions and 9 deletions

View File

@ -133,15 +133,16 @@ class BatchManager(BatchManagerBase):
batch_indices.append(list(range(len(batchdata[0].items))))
all_batch_indices = product(*batch_indices)
for bi in all_batch_indices:
ges = self._create_batch_session(batch_process, bi)
self.__invoker.services.graph_execution_manager.set(ges)
batch_session = BatchSession(batch_id=batch_process.batch_id, session_id=ges.id, state="created")
sessions.append(self.__batch_process_storage.create_session(batch_session))
if not sessions:
ges = GraphExecutionState(graph=batch_process.graph)
self.__invoker.services.graph_execution_manager.set(ges)
batch_session = BatchSession(batch_id=batch_process.batch_id, session_id=ges.id, state="created")
sessions.append(self.__batch_process_storage.create_session(batch_session))
for _ in range(batch_process.batch.runs):
ges = self._create_batch_session(batch_process, bi)
self.__invoker.services.graph_execution_manager.set(ges)
batch_session = BatchSession(batch_id=batch_process.batch_id, session_id=ges.id, state="created")
sessions.append(self.__batch_process_storage.create_session(batch_session))
if not sessions:
ges = GraphExecutionState(graph=batch_process.graph)
self.__invoker.services.graph_execution_manager.set(ges)
batch_session = BatchSession(batch_id=batch_process.batch_id, session_id=ges.id, state="created")
sessions.append(self.__batch_process_storage.create_session(batch_session))
return sessions
def cancel_batch_process(self, batch_process_id: str) -> None:

View File

@ -35,6 +35,13 @@ class Batch(BaseModel):
"""
data: list[list[BatchData]] = Field(default_factory=list, description="The list of batch data collections.")
runs: int = Field(default=1, description="Int stating how many times to iterate through all possible batch indices")
@validator("runs")
def validate_positive_runs(cls, r: int):
if r < 1:
raise ValueError("runs must be a positive integer")
return r
@validator("data")
def validate_len(cls, v: list[list[BatchData]]):