diff --git a/invokeai/app/services/batch_manager.py b/invokeai/app/services/batch_manager.py index 9f7927687c..790dd7bb48 100644 --- a/invokeai/app/services/batch_manager.py +++ b/invokeai/app/services/batch_manager.py @@ -94,7 +94,7 @@ class BatchManager(BatchManagerBase): async def _process(self, event: Event, err: bool) -> None: data = event[1]["data"] try: - batch_session = self.__batch_process_storage.get_session(data["graph_execution_state_id"]) + batch_session = self.__batch_process_storage.get_session_by_session_id(data["graph_execution_state_id"]) except BatchSessionNotFoundException: return None changes = BatchSessionChanges(state="error" if err else "completed") @@ -130,10 +130,18 @@ class BatchManager(BatchManagerBase): def run_batch_process(self, batch_id: str) -> None: self.__batch_process_storage.start(batch_id) try: - created_session = self.__batch_process_storage.get_created_session(batch_id) + next_session = self.__batch_process_storage.get_next_session(batch_id) except BatchSessionNotFoundException: return - ges = self.__invoker.services.graph_execution_manager.get(created_session.session_id) + batch_process = self.__batch_process_storage.get(batch_id) + ges = self._create_batch_session(batch_process=batch_process, batch_indices=tuple(next_session.batch_index)) + ges.id = next_session.session_id + self.__invoker.services.graph_execution_manager.set(ges) + self.__batch_process_storage.update_session_state( + batch_id=next_session.batch_id, + session_id=next_session.session_id, + changes=BatchSessionChanges(state="in_progress"), + ) self.__invoker.invoke(ges, invoke_all=True) def create_batch_process(self, batch: Batch, graph: Graph) -> BatchProcessResponse: @@ -150,25 +158,20 @@ class BatchManager(BatchManagerBase): def _create_sessions(self, batch_process: BatchProcess) -> list[BatchSession]: batch_indices = list() - sessions = list() + sessions_to_create: list[BatchSession] = list() for batchdata in batch_process.batch.data: batch_indices.append(list(range(len(batchdata[0].items)))) all_batch_indices = product(*batch_indices) for bi in all_batch_indices: for _ in range(batch_process.batch.runs): - ges = self._create_batch_session(batch_process, bi) - self.__invoker.services.graph_execution_manager.set(ges) - batch_session = BatchSession(batch_id=batch_process.batch_id, session_id=ges.id, state="created") - sessions.append(self.__batch_process_storage.create_session(batch_session)) - if not sessions: - ges = GraphExecutionState(graph=batch_process.graph) - self.__invoker.services.graph_execution_manager.set(ges) - batch_session = BatchSession(batch_id=batch_process.batch_id, session_id=ges.id, state="created") - sessions.append(self.__batch_process_storage.create_session(batch_session)) - return sessions + sessions_to_create.append(BatchSession(batch_id=batch_process.batch_id, batch_index=list(bi))) + if not sessions_to_create: + sessions_to_create.append(BatchSession(batch_id=batch_process.batch_id, batch_index=list(bi))) + created_sessions = self.__batch_process_storage.create_sessions(sessions_to_create) + return created_sessions def get_sessions(self, batch_id: str) -> list[BatchSession]: - return self.__batch_process_storage.get_sessions(batch_id) + return self.__batch_process_storage.get_sessions_by_batch_id(batch_id) def get_batch(self, batch_id: str) -> BatchProcess: return self.__batch_process_storage.get(batch_id) @@ -177,7 +180,7 @@ class BatchManager(BatchManagerBase): bps = self.__batch_process_storage.get_all() res = list() for bp in bps: - sessions = self.__batch_process_storage.get_sessions(bp.batch_id) + sessions = self.__batch_process_storage.get_sessions_by_batch_id(bp.batch_id) res.append( BatchProcessResponse( batch_id=bp.batch_id, @@ -190,7 +193,7 @@ class BatchManager(BatchManagerBase): bps = self.__batch_process_storage.get_incomplete() res = list() for bp in bps: - sessions = self.__batch_process_storage.get_sessions(bp.batch_id) + sessions = self.__batch_process_storage.get_sessions_by_batch_id(bp.batch_id) res.append( BatchProcessResponse( batch_id=bp.batch_id, diff --git a/invokeai/app/services/batch_manager_storage.py b/invokeai/app/services/batch_manager_storage.py index d23b1446f2..870da454f1 100644 --- a/invokeai/app/services/batch_manager_storage.py +++ b/invokeai/app/services/batch_manager_storage.py @@ -1,3 +1,4 @@ +import json import sqlite3 import threading import uuid @@ -71,17 +72,29 @@ class Batch(BaseModel): return v -class BatchSession(BaseModel): - batch_id: str = Field(description="The Batch to which this BatchSession is attached.") - session_id: str = Field(description="The Session to which this BatchSession is attached.") - state: Literal["created", "completed", "inprogress", "error"] = Field(description="The state of this BatchSession") - - def uuid_string(): res = uuid.uuid4() return str(res) +BATCH_SESSION_STATE = Literal["uninitialized", "in_progress", "completed", "error"] + + +class BatchSession(BaseModel): + batch_id: str = Field(defaultdescription="The Batch to which this BatchSession is attached.") + session_id: str = Field( + default_factory=uuid_string, description="The Session to which this BatchSession is attached." + ) + batch_index: list[int] = Field(description="The index of the batch to be run in this BatchSession.") + state: BATCH_SESSION_STATE = Field(default="uninitialized", description="The state of this BatchSession") + + @validator("batch_index", pre=True) + def parse_str_to_list(cls, v: Union[str, list[int]]): + if isinstance(v, str): + return json.loads(v) + return v + + class BatchProcess(BaseModel): batch_id: str = Field(default_factory=uuid_string, description="Identifier for this batch.") batch: Batch = Field(description="The Batch to apply to this session.") @@ -90,7 +103,7 @@ class BatchProcess(BaseModel): class BatchSessionChanges(BaseModel, extra=Extra.forbid): - state: Literal["created", "completed", "inprogress", "error"] = Field(description="The state of this BatchSession") + state: BATCH_SESSION_STATE = Field(description="The state of this BatchSession") class BatchProcessNotFoundException(Exception): @@ -198,23 +211,31 @@ class BatchProcessStorageBase(ABC): pass @abstractmethod - def get_session(self, session_id: str) -> BatchSession: + def create_sessions( + self, + sessions: list[BatchSession], + ) -> list[BatchSession]: + """Creates many BatchSessions attached to a BatchProcess.""" + pass + + @abstractmethod + def get_session_by_session_id(self, session_id: str) -> BatchSession: """Gets a BatchSession by session_id""" pass @abstractmethod - def get_sessions(self, batch_id: str) -> List[BatchSession]: + def get_sessions_by_batch_id(self, batch_id: str) -> List[BatchSession]: """Gets all BatchSession's for a given BatchProcess id.""" pass @abstractmethod - def get_created_session(self, batch_id: str) -> BatchSession: - """Gets the latest BatchSession with state `created`, for a given BatchProcess id.""" + def get_sessions_by_session_ids(self, session_ids: list[str]) -> List[BatchSession]: + """Gets all BatchSession's for a given list of session ids.""" pass @abstractmethod - def get_created_sessions(self, batch_id: str) -> List[BatchSession]: - """Gets all BatchSession's with state `created`, for a given BatchProcess id.""" + def get_next_session(self, batch_id: str) -> BatchSession: + """Gets the next BatchSession with state `uninitialized`, for a given BatchProcess id.""" pass @abstractmethod @@ -295,6 +316,7 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase): CREATE TABLE IF NOT EXISTS batch_session ( batch_id TEXT NOT NULL, session_id TEXT NOT NULL, + batch_index TEXT NOT NULL, state TEXT NOT NULL, created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), -- updated via trigger @@ -453,7 +475,7 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase): ( SELECT batch_id FROM batch_session bs - WHERE state = 'created' + WHERE state IN ('uninitialized', 'in_progress') ); """ ) @@ -518,10 +540,10 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase): self._lock.acquire() self._cursor.execute( """--sql - INSERT OR IGNORE INTO batch_session (batch_id, session_id, state) - VALUES (?, ?, ?); + INSERT OR IGNORE INTO batch_session (batch_id, session_id, state, batch_index) + VALUES (?, ?, ?, ?); """, - (session.batch_id, session.session_id, session.state), + (session.batch_id, session.session_id, session.state, json.dumps(session.batch_index)), ) self._conn.commit() except sqlite3.Error as e: @@ -529,9 +551,34 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase): raise BatchSessionSaveException from e finally: self._lock.release() - return self.get_session(session.session_id) + return self.get_session_by_session_id(session.session_id) - def get_session(self, session_id: str) -> BatchSession: + def create_sessions( + self, + sessions: list[BatchSession], + ) -> list[BatchSession]: + try: + self._lock.acquire() + session_data = [ + (session.batch_id, session.session_id, session.state, json.dumps(session.batch_index)) + for session in sessions + ] + self._cursor.executemany( + """--sql + INSERT OR IGNORE INTO batch_session (batch_id, session_id, state, batch_index) + VALUES (?, ?, ?, ?); + """, + session_data, + ) + self._conn.commit() + except sqlite3.Error as e: + self._conn.rollback() + raise BatchSessionSaveException from e + finally: + self._lock.release() + return self.get_sessions_by_session_ids([session.session_id for session in sessions]) + + def get_session_by_session_id(self, session_id: str) -> BatchSession: try: self._lock.acquire() self._cursor.execute( @@ -558,14 +605,14 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase): return BatchSession.parse_obj(session_dict) - def get_created_session(self, batch_id: str) -> BatchSession: + def get_next_session(self, batch_id: str) -> BatchSession: try: self._lock.acquire() self._cursor.execute( """--sql SELECT * FROM batch_session - WHERE batch_id = ? AND state = 'created'; + WHERE batch_id = ? AND state = 'uninitialized'; """, (batch_id,), ) @@ -581,14 +628,14 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase): session = self._deserialize_batch_session(dict(result)) return session - def get_created_sessions(self, batch_id: str) -> List[BatchSession]: + def get_sessions_by_batch_id(self, batch_id: str) -> List[BatchSession]: try: self._lock.acquire() self._cursor.execute( """--sql SELECT * FROM batch_session - WHERE batch_id = ? AND state = created; + WHERE batch_id = ?; """, (batch_id,), ) @@ -604,16 +651,17 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase): sessions = list(map(lambda r: self._deserialize_batch_session(dict(r)), result)) return sessions - def get_sessions(self, batch_id: str) -> List[BatchSession]: + def get_sessions_by_session_ids(self, session_ids: list[str]) -> List[BatchSession]: try: self._lock.acquire() + placeholders = ",".join("?" * len(session_ids)) self._cursor.execute( - """--sql - SELECT * - FROM batch_session - WHERE batch_id = ?; + f"""--sql + SELECT * FROM batch_session + WHERE session_id + IN ({placeholders}) """, - (batch_id,), + tuple(session_ids), ) result = cast(list[sqlite3.Row], self._cursor.fetchall()) @@ -652,4 +700,4 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase): raise BatchSessionSaveException from e finally: self._lock.release() - return self.get_session(session_id) + return self.get_session_by_session_id(session_id)