diff --git a/invokeai/app/services/batch_manager.py b/invokeai/app/services/batch_manager.py index 790dd7bb48..6d4136c8f0 100644 --- a/invokeai/app/services/batch_manager.py +++ b/invokeai/app/services/batch_manager.py @@ -1,22 +1,24 @@ -import networkx as nx - from abc import ABC, abstractmethod from itertools import product -from pydantic import BaseModel, Field +from typing import Optional +from uuid import uuid4 + +import networkx as nx from fastapi_events.handlers.local import local_handler from fastapi_events.typing import Event +from pydantic import BaseModel, Field +from invokeai.app.services.batch_manager_storage import ( + Batch, + BatchProcess, + BatchProcessStorageBase, + BatchSession, + BatchSessionChanges, + BatchSessionNotFoundException, +) from invokeai.app.services.events import EventServiceBase from invokeai.app.services.graph import Graph, GraphExecutionState from invokeai.app.services.invoker import Invoker -from invokeai.app.services.batch_manager_storage import ( - BatchProcessStorageBase, - BatchSessionNotFoundException, - Batch, - BatchProcess, - BatchSession, - BatchSessionChanges, -) class BatchProcessResponse(BaseModel): @@ -107,7 +109,9 @@ class BatchManager(BatchManagerBase): if not batch_process.canceled: self.run_batch_process(batch_process.batch_id) - def _create_batch_session(self, batch_process: BatchProcess, batch_indices: tuple[int]) -> GraphExecutionState: + def _create_graph_execution_state( + self, batch_process: BatchProcess, batch_indices: tuple[int, ...] + ) -> GraphExecutionState: graph = batch_process.graph.copy(deep=True) batch = batch_process.batch g = graph.nx_graph_flat() @@ -129,12 +133,31 @@ class BatchManager(BatchManagerBase): def run_batch_process(self, batch_id: str) -> None: self.__batch_process_storage.start(batch_id) - try: - next_session = self.__batch_process_storage.get_next_session(batch_id) - except BatchSessionNotFoundException: - return batch_process = self.__batch_process_storage.get(batch_id) - ges = self._create_batch_session(batch_process=batch_process, batch_indices=tuple(next_session.batch_index)) + next_batch_index = self._get_batch_index_tuple(batch_process) + if next_batch_index is None: + # finished with current run + if batch_process.current_run >= (batch_process.batch.runs - 1): + # finished with all runs + return + batch_process.current_batch_index = 0 + batch_process.current_run += 1 + next_batch_index = self._get_batch_index_tuple(batch_process) + if next_batch_index is None: + # shouldn't happen; satisfy types + return + # remember to increment the batch index + batch_process.current_batch_index += 1 + self.__batch_process_storage.save(batch_process) + ges = self._create_graph_execution_state(batch_process=batch_process, batch_indices=next_batch_index) + next_session = self.__batch_process_storage.create_session( + BatchSession( + batch_id=batch_id, + session_id=str(uuid4()), + state="uninitialized", + batch_index=batch_process.current_batch_index, + ) + ) ges.id = next_session.session_id self.__invoker.services.graph_execution_manager.set(ges) self.__batch_process_storage.update_session_state( @@ -150,26 +173,11 @@ class BatchManager(BatchManagerBase): graph=graph, ) batch_process = self.__batch_process_storage.save(batch_process) - sessions = self._create_sessions(batch_process) return BatchProcessResponse( batch_id=batch_process.batch_id, - session_ids=[session.session_id for session in sessions], + session_ids=[], ) - def _create_sessions(self, batch_process: BatchProcess) -> list[BatchSession]: - batch_indices = list() - sessions_to_create: list[BatchSession] = list() - for batchdata in batch_process.batch.data: - batch_indices.append(list(range(len(batchdata[0].items)))) - all_batch_indices = product(*batch_indices) - for bi in all_batch_indices: - for _ in range(batch_process.batch.runs): - sessions_to_create.append(BatchSession(batch_id=batch_process.batch_id, batch_index=list(bi))) - if not sessions_to_create: - sessions_to_create.append(BatchSession(batch_id=batch_process.batch_id, batch_index=list(bi))) - created_sessions = self.__batch_process_storage.create_sessions(sessions_to_create) - return created_sessions - def get_sessions(self, batch_id: str) -> list[BatchSession]: return self.__batch_process_storage.get_sessions_by_batch_id(batch_id) @@ -204,3 +212,12 @@ class BatchManager(BatchManagerBase): def cancel_batch_process(self, batch_process_id: str) -> None: self.__batch_process_storage.cancel(batch_process_id) + + def _get_batch_index_tuple(self, batch_process: BatchProcess) -> Optional[tuple[int, ...]]: + batch_indices = list() + for batchdata in batch_process.batch.data: + batch_indices.append(list(range(len(batchdata[0].items)))) + try: + return list(product(*batch_indices))[batch_process.current_batch_index] + except IndexError: + return None diff --git a/invokeai/app/services/batch_manager_storage.py b/invokeai/app/services/batch_manager_storage.py index 870da454f1..71b788aecb 100644 --- a/invokeai/app/services/batch_manager_storage.py +++ b/invokeai/app/services/batch_manager_storage.py @@ -1,4 +1,3 @@ -import json import sqlite3 import threading import uuid @@ -85,19 +84,15 @@ class BatchSession(BaseModel): session_id: str = Field( default_factory=uuid_string, description="The Session to which this BatchSession is attached." ) - batch_index: list[int] = Field(description="The index of the batch to be run in this BatchSession.") + batch_index: int = Field(description="The index of this batch session in its parent batch process") state: BATCH_SESSION_STATE = Field(default="uninitialized", description="The state of this BatchSession") - @validator("batch_index", pre=True) - def parse_str_to_list(cls, v: Union[str, list[int]]): - if isinstance(v, str): - return json.loads(v) - return v - class BatchProcess(BaseModel): batch_id: str = Field(default_factory=uuid_string, description="Identifier for this batch.") batch: Batch = Field(description="The Batch to apply to this session.") + current_batch_index: int = Field(default=0, description="The last executed batch index") + current_run: int = Field(default=0, description="The current run of the batch") canceled: bool = Field(description="Whether or not to run sessions from this batch.", default=False) graph: Graph = Field(description="The graph into which batch data will be inserted before being executed.") @@ -279,8 +274,10 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase): """--sql CREATE TABLE IF NOT EXISTS batch_process ( batch_id TEXT NOT NULL PRIMARY KEY, - batches TEXT NOT NULL, + batch TEXT NOT NULL, graph TEXT NOT NULL, + current_batch_index NUMBER NOT NULL, + current_run NUMBER NOT NULL, canceled BOOLEAN NOT NULL DEFAULT(0), created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), -- Updated via trigger @@ -316,8 +313,8 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase): CREATE TABLE IF NOT EXISTS batch_session ( batch_id TEXT NOT NULL, session_id TEXT NOT NULL, - batch_index TEXT NOT NULL, state TEXT NOT NULL, + batch_index NUMBER NOT NULL, created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), -- updated via trigger updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), @@ -386,10 +383,16 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase): self._lock.acquire() self._cursor.execute( """--sql - INSERT OR IGNORE INTO batch_process (batch_id, batches, graph) - VALUES (?, ?, ?); + INSERT OR REPLACE INTO batch_process (batch_id, batch, graph, current_batch_index, current_run) + VALUES (?, ?, ?, ?, ?); """, - (batch_process.batch_id, batch_process.batch.json(), batch_process.graph.json()), + ( + batch_process.batch_id, + batch_process.batch.json(), + batch_process.graph.json(), + batch_process.current_batch_index, + batch_process.current_run, + ), ) self._conn.commit() except sqlite3.Error as e: @@ -405,13 +408,17 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase): # Retrieve all the values, setting "reasonable" defaults if they are not present. batch_id = session_dict.get("batch_id", "unknown") - batch_raw = session_dict.get("batches", "unknown") + batch_raw = session_dict.get("batch", "unknown") graph_raw = session_dict.get("graph", "unknown") + current_batch_index = session_dict.get("current_batch_index", 0) + current_run = session_dict.get("current_run", 0) canceled = session_dict.get("canceled", 0) return BatchProcess( batch_id=batch_id, batch=parse_raw_as(Batch, batch_raw), graph=parse_raw_as(Graph, graph_raw), + current_batch_index=current_batch_index, + current_run=current_run, canceled=canceled == 1, ) @@ -543,7 +550,7 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase): INSERT OR IGNORE INTO batch_session (batch_id, session_id, state, batch_index) VALUES (?, ?, ?, ?); """, - (session.batch_id, session.session_id, session.state, json.dumps(session.batch_index)), + (session.batch_id, session.session_id, session.state, session.batch_index), ) self._conn.commit() except sqlite3.Error as e: @@ -559,10 +566,7 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase): ) -> list[BatchSession]: try: self._lock.acquire() - session_data = [ - (session.batch_id, session.session_id, session.state, json.dumps(session.batch_index)) - for session in sessions - ] + session_data = [(session.batch_id, session.session_id, session.state) for session in sessions] self._cursor.executemany( """--sql INSERT OR IGNORE INTO batch_session (batch_id, session_id, state, batch_index)