diff --git a/invokeai/app/services/batch_manager.py b/invokeai/app/services/batch_manager.py index 6d4136c8f0..790dd7bb48 100644 --- a/invokeai/app/services/batch_manager.py +++ b/invokeai/app/services/batch_manager.py @@ -1,24 +1,22 @@ +import networkx as nx + from abc import ABC, abstractmethod from itertools import product -from typing import Optional -from uuid import uuid4 - -import networkx as nx +from pydantic import BaseModel, Field from fastapi_events.handlers.local import local_handler from fastapi_events.typing import Event -from pydantic import BaseModel, Field -from invokeai.app.services.batch_manager_storage import ( - Batch, - BatchProcess, - BatchProcessStorageBase, - BatchSession, - BatchSessionChanges, - BatchSessionNotFoundException, -) from invokeai.app.services.events import EventServiceBase from invokeai.app.services.graph import Graph, GraphExecutionState from invokeai.app.services.invoker import Invoker +from invokeai.app.services.batch_manager_storage import ( + BatchProcessStorageBase, + BatchSessionNotFoundException, + Batch, + BatchProcess, + BatchSession, + BatchSessionChanges, +) class BatchProcessResponse(BaseModel): @@ -109,9 +107,7 @@ class BatchManager(BatchManagerBase): if not batch_process.canceled: self.run_batch_process(batch_process.batch_id) - def _create_graph_execution_state( - self, batch_process: BatchProcess, batch_indices: tuple[int, ...] - ) -> GraphExecutionState: + def _create_batch_session(self, batch_process: BatchProcess, batch_indices: tuple[int]) -> GraphExecutionState: graph = batch_process.graph.copy(deep=True) batch = batch_process.batch g = graph.nx_graph_flat() @@ -133,31 +129,12 @@ class BatchManager(BatchManagerBase): def run_batch_process(self, batch_id: str) -> None: self.__batch_process_storage.start(batch_id) + try: + next_session = self.__batch_process_storage.get_next_session(batch_id) + except BatchSessionNotFoundException: + return batch_process = self.__batch_process_storage.get(batch_id) - next_batch_index = self._get_batch_index_tuple(batch_process) - if next_batch_index is None: - # finished with current run - if batch_process.current_run >= (batch_process.batch.runs - 1): - # finished with all runs - return - batch_process.current_batch_index = 0 - batch_process.current_run += 1 - next_batch_index = self._get_batch_index_tuple(batch_process) - if next_batch_index is None: - # shouldn't happen; satisfy types - return - # remember to increment the batch index - batch_process.current_batch_index += 1 - self.__batch_process_storage.save(batch_process) - ges = self._create_graph_execution_state(batch_process=batch_process, batch_indices=next_batch_index) - next_session = self.__batch_process_storage.create_session( - BatchSession( - batch_id=batch_id, - session_id=str(uuid4()), - state="uninitialized", - batch_index=batch_process.current_batch_index, - ) - ) + ges = self._create_batch_session(batch_process=batch_process, batch_indices=tuple(next_session.batch_index)) ges.id = next_session.session_id self.__invoker.services.graph_execution_manager.set(ges) self.__batch_process_storage.update_session_state( @@ -173,11 +150,26 @@ class BatchManager(BatchManagerBase): graph=graph, ) batch_process = self.__batch_process_storage.save(batch_process) + sessions = self._create_sessions(batch_process) return BatchProcessResponse( batch_id=batch_process.batch_id, - session_ids=[], + session_ids=[session.session_id for session in sessions], ) + def _create_sessions(self, batch_process: BatchProcess) -> list[BatchSession]: + batch_indices = list() + sessions_to_create: list[BatchSession] = list() + for batchdata in batch_process.batch.data: + batch_indices.append(list(range(len(batchdata[0].items)))) + all_batch_indices = product(*batch_indices) + for bi in all_batch_indices: + for _ in range(batch_process.batch.runs): + sessions_to_create.append(BatchSession(batch_id=batch_process.batch_id, batch_index=list(bi))) + if not sessions_to_create: + sessions_to_create.append(BatchSession(batch_id=batch_process.batch_id, batch_index=list(bi))) + created_sessions = self.__batch_process_storage.create_sessions(sessions_to_create) + return created_sessions + def get_sessions(self, batch_id: str) -> list[BatchSession]: return self.__batch_process_storage.get_sessions_by_batch_id(batch_id) @@ -212,12 +204,3 @@ class BatchManager(BatchManagerBase): def cancel_batch_process(self, batch_process_id: str) -> None: self.__batch_process_storage.cancel(batch_process_id) - - def _get_batch_index_tuple(self, batch_process: BatchProcess) -> Optional[tuple[int, ...]]: - batch_indices = list() - for batchdata in batch_process.batch.data: - batch_indices.append(list(range(len(batchdata[0].items)))) - try: - return list(product(*batch_indices))[batch_process.current_batch_index] - except IndexError: - return None diff --git a/invokeai/app/services/batch_manager_storage.py b/invokeai/app/services/batch_manager_storage.py index 71b788aecb..870da454f1 100644 --- a/invokeai/app/services/batch_manager_storage.py +++ b/invokeai/app/services/batch_manager_storage.py @@ -1,3 +1,4 @@ +import json import sqlite3 import threading import uuid @@ -84,15 +85,19 @@ class BatchSession(BaseModel): session_id: str = Field( default_factory=uuid_string, description="The Session to which this BatchSession is attached." ) - batch_index: int = Field(description="The index of this batch session in its parent batch process") + batch_index: list[int] = Field(description="The index of the batch to be run in this BatchSession.") state: BATCH_SESSION_STATE = Field(default="uninitialized", description="The state of this BatchSession") + @validator("batch_index", pre=True) + def parse_str_to_list(cls, v: Union[str, list[int]]): + if isinstance(v, str): + return json.loads(v) + return v + class BatchProcess(BaseModel): batch_id: str = Field(default_factory=uuid_string, description="Identifier for this batch.") batch: Batch = Field(description="The Batch to apply to this session.") - current_batch_index: int = Field(default=0, description="The last executed batch index") - current_run: int = Field(default=0, description="The current run of the batch") canceled: bool = Field(description="Whether or not to run sessions from this batch.", default=False) graph: Graph = Field(description="The graph into which batch data will be inserted before being executed.") @@ -274,10 +279,8 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase): """--sql CREATE TABLE IF NOT EXISTS batch_process ( batch_id TEXT NOT NULL PRIMARY KEY, - batch TEXT NOT NULL, + batches TEXT NOT NULL, graph TEXT NOT NULL, - current_batch_index NUMBER NOT NULL, - current_run NUMBER NOT NULL, canceled BOOLEAN NOT NULL DEFAULT(0), created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), -- Updated via trigger @@ -313,8 +316,8 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase): CREATE TABLE IF NOT EXISTS batch_session ( batch_id TEXT NOT NULL, session_id TEXT NOT NULL, + batch_index TEXT NOT NULL, state TEXT NOT NULL, - batch_index NUMBER NOT NULL, created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), -- updated via trigger updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), @@ -383,16 +386,10 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase): self._lock.acquire() self._cursor.execute( """--sql - INSERT OR REPLACE INTO batch_process (batch_id, batch, graph, current_batch_index, current_run) - VALUES (?, ?, ?, ?, ?); + INSERT OR IGNORE INTO batch_process (batch_id, batches, graph) + VALUES (?, ?, ?); """, - ( - batch_process.batch_id, - batch_process.batch.json(), - batch_process.graph.json(), - batch_process.current_batch_index, - batch_process.current_run, - ), + (batch_process.batch_id, batch_process.batch.json(), batch_process.graph.json()), ) self._conn.commit() except sqlite3.Error as e: @@ -408,17 +405,13 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase): # Retrieve all the values, setting "reasonable" defaults if they are not present. batch_id = session_dict.get("batch_id", "unknown") - batch_raw = session_dict.get("batch", "unknown") + batch_raw = session_dict.get("batches", "unknown") graph_raw = session_dict.get("graph", "unknown") - current_batch_index = session_dict.get("current_batch_index", 0) - current_run = session_dict.get("current_run", 0) canceled = session_dict.get("canceled", 0) return BatchProcess( batch_id=batch_id, batch=parse_raw_as(Batch, batch_raw), graph=parse_raw_as(Graph, graph_raw), - current_batch_index=current_batch_index, - current_run=current_run, canceled=canceled == 1, ) @@ -550,7 +543,7 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase): INSERT OR IGNORE INTO batch_session (batch_id, session_id, state, batch_index) VALUES (?, ?, ?, ?); """, - (session.batch_id, session.session_id, session.state, session.batch_index), + (session.batch_id, session.session_id, session.state, json.dumps(session.batch_index)), ) self._conn.commit() except sqlite3.Error as e: @@ -566,7 +559,10 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase): ) -> list[BatchSession]: try: self._lock.acquire() - session_data = [(session.batch_id, session.session_id, session.state) for session in sessions] + session_data = [ + (session.batch_id, session.session_id, session.state, json.dumps(session.batch_index)) + for session in sessions + ] self._cursor.executemany( """--sql INSERT OR IGNORE INTO batch_session (batch_id, session_id, state, batch_index)