Add runs field for running the same batch multiple times

This commit is contained in:
Brandon Rising 2023-08-18 13:41:07 -04:00
parent 99e03fe92e
commit 0282f46c71
2 changed files with 17 additions and 9 deletions

View File

@ -133,6 +133,7 @@ class BatchManager(BatchManagerBase):
batch_indices.append(list(range(len(batchdata[0].items)))) batch_indices.append(list(range(len(batchdata[0].items))))
all_batch_indices = product(*batch_indices) all_batch_indices = product(*batch_indices)
for bi in all_batch_indices: for bi in all_batch_indices:
for _ in range(batch_process.batch.runs):
ges = self._create_batch_session(batch_process, bi) ges = self._create_batch_session(batch_process, bi)
self.__invoker.services.graph_execution_manager.set(ges) self.__invoker.services.graph_execution_manager.set(ges)
batch_session = BatchSession(batch_id=batch_process.batch_id, session_id=ges.id, state="created") batch_session = BatchSession(batch_id=batch_process.batch_id, session_id=ges.id, state="created")

View File

@ -35,6 +35,13 @@ class Batch(BaseModel):
""" """
data: list[list[BatchData]] = Field(default_factory=list, description="The list of batch data collections.") data: list[list[BatchData]] = Field(default_factory=list, description="The list of batch data collections.")
runs: int = Field(default=1, description="Int stating how many times to iterate through all possible batch indices")
@validator("runs")
def validate_positive_runs(cls, r: int):
if r < 1:
raise ValueError("runs must be a positive integer")
return r
@validator("data") @validator("data")
def validate_len(cls, v: list[list[BatchData]]): def validate_len(cls, v: list[list[BatchData]]):