mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add runs field for running the same batch multiple times
This commit is contained in:
parent
99e03fe92e
commit
0282f46c71
@ -133,15 +133,16 @@ class BatchManager(BatchManagerBase):
|
||||
batch_indices.append(list(range(len(batchdata[0].items))))
|
||||
all_batch_indices = product(*batch_indices)
|
||||
for bi in all_batch_indices:
|
||||
ges = self._create_batch_session(batch_process, bi)
|
||||
self.__invoker.services.graph_execution_manager.set(ges)
|
||||
batch_session = BatchSession(batch_id=batch_process.batch_id, session_id=ges.id, state="created")
|
||||
sessions.append(self.__batch_process_storage.create_session(batch_session))
|
||||
if not sessions:
|
||||
ges = GraphExecutionState(graph=batch_process.graph)
|
||||
self.__invoker.services.graph_execution_manager.set(ges)
|
||||
batch_session = BatchSession(batch_id=batch_process.batch_id, session_id=ges.id, state="created")
|
||||
sessions.append(self.__batch_process_storage.create_session(batch_session))
|
||||
for _ in range(batch_process.batch.runs):
|
||||
ges = self._create_batch_session(batch_process, bi)
|
||||
self.__invoker.services.graph_execution_manager.set(ges)
|
||||
batch_session = BatchSession(batch_id=batch_process.batch_id, session_id=ges.id, state="created")
|
||||
sessions.append(self.__batch_process_storage.create_session(batch_session))
|
||||
if not sessions:
|
||||
ges = GraphExecutionState(graph=batch_process.graph)
|
||||
self.__invoker.services.graph_execution_manager.set(ges)
|
||||
batch_session = BatchSession(batch_id=batch_process.batch_id, session_id=ges.id, state="created")
|
||||
sessions.append(self.__batch_process_storage.create_session(batch_session))
|
||||
return sessions
|
||||
|
||||
def cancel_batch_process(self, batch_process_id: str) -> None:
|
||||
|
@ -35,6 +35,13 @@ class Batch(BaseModel):
|
||||
"""
|
||||
|
||||
data: list[list[BatchData]] = Field(default_factory=list, description="The list of batch data collections.")
|
||||
runs: int = Field(default=1, description="Int stating how many times to iterate through all possible batch indices")
|
||||
|
||||
@validator("runs")
|
||||
def validate_positive_runs(cls, r: int):
|
||||
if r < 1:
|
||||
raise ValueError("runs must be a positive integer")
|
||||
return r
|
||||
|
||||
@validator("data")
|
||||
def validate_len(cls, v: list[list[BatchData]]):
|
||||
|
Loading…
Reference in New Issue
Block a user