diff --git a/invokeai/app/services/batch_manager.py b/invokeai/app/services/batch_manager.py index a2f8c30c42..016f1a8ba6 100644 --- a/invokeai/app/services/batch_manager.py +++ b/invokeai/app/services/batch_manager.py @@ -133,15 +133,16 @@ class BatchManager(BatchManagerBase): batch_indices.append(list(range(len(batchdata[0].items)))) all_batch_indices = product(*batch_indices) for bi in all_batch_indices: - ges = self._create_batch_session(batch_process, bi) - self.__invoker.services.graph_execution_manager.set(ges) - batch_session = BatchSession(batch_id=batch_process.batch_id, session_id=ges.id, state="created") - sessions.append(self.__batch_process_storage.create_session(batch_session)) - if not sessions: - ges = GraphExecutionState(graph=batch_process.graph) - self.__invoker.services.graph_execution_manager.set(ges) - batch_session = BatchSession(batch_id=batch_process.batch_id, session_id=ges.id, state="created") - sessions.append(self.__batch_process_storage.create_session(batch_session)) + for _ in range(batch_process.batch.runs): + ges = self._create_batch_session(batch_process, bi) + self.__invoker.services.graph_execution_manager.set(ges) + batch_session = BatchSession(batch_id=batch_process.batch_id, session_id=ges.id, state="created") + sessions.append(self.__batch_process_storage.create_session(batch_session)) + if not sessions: + ges = GraphExecutionState(graph=batch_process.graph) + self.__invoker.services.graph_execution_manager.set(ges) + batch_session = BatchSession(batch_id=batch_process.batch_id, session_id=ges.id, state="created") + sessions.append(self.__batch_process_storage.create_session(batch_session)) return sessions def cancel_batch_process(self, batch_process_id: str) -> None: diff --git a/invokeai/app/services/batch_manager_storage.py b/invokeai/app/services/batch_manager_storage.py index 471836de9c..1e0f0916be 100644 --- a/invokeai/app/services/batch_manager_storage.py +++ b/invokeai/app/services/batch_manager_storage.py @@ -35,6 +35,13 @@ class Batch(BaseModel): """ data: list[list[BatchData]] = Field(default_factory=list, description="The list of batch data collections.") + runs: int = Field(default=1, description="Int stating how many times to iterate through all possible batch indices") + + @validator("runs") + def validate_positive_runs(cls, r: int): + if r < 1: + raise ValueError("runs must be a positive integer") + return r @validator("data") def validate_len(cls, v: list[list[BatchData]]):