diff --git a/invokeai/app/services/batch_manager.py b/invokeai/app/services/batch_manager.py index 956d5ff610..64d9ce8649 100644 --- a/invokeai/app/services/batch_manager.py +++ b/invokeai/app/services/batch_manager.py @@ -1,3 +1,4 @@ +from typing import Optional, Union import networkx as nx import copy @@ -27,7 +28,7 @@ class BatchProcessResponse(BaseModel): class BatchManagerBase(ABC): @abstractmethod - def start(self, invoker: Invoker): + def start(self, invoker: Invoker) -> None: pass @abstractmethod @@ -35,11 +36,11 @@ class BatchManagerBase(ABC): pass @abstractmethod - def run_batch_process(self, batch_id: str): + def run_batch_process(self, batch_id: str) -> None: pass @abstractmethod - def cancel_batch_process(self, batch_process_id: str): + def cancel_batch_process(self, batch_process_id: str) -> None: pass @@ -68,11 +69,11 @@ class BatchManager(BatchManagerBase): return event - async def process(self, event: Event, err: bool): + async def process(self, event: Event, err: bool) -> None: data = event[1]["data"] batch_session = self.__batch_process_storage.get_session(data["graph_execution_state_id"]) if not batch_session: - return + return None updateSession = BatchSessionChanges(state="error" if err else "completed") batch_session = self.__batch_process_storage.update_session_state( batch_session.batch_id, @@ -83,7 +84,7 @@ class BatchManager(BatchManagerBase): if not batch_process.canceled: self.run_batch_process(batch_process.batch_id) - def _create_batch_session(self, batch_process: BatchProcess, batch_indices: list[int]) -> GraphExecutionState: + def _create_batch_session(self, batch_process: BatchProcess, batch_indices: tuple[int]) -> GraphExecutionState: graph = batch_process.graph.copy(deep=True) batch = batch_process.batch g = graph.nx_graph_flat() @@ -103,7 +104,7 @@ class BatchManager(BatchManagerBase): return GraphExecutionState(graph=graph) - def run_batch_process(self, batch_id: str): + def run_batch_process(self, batch_id: str) -> None: self.__batch_process_storage.start(batch_id) try: created_session = self.__batch_process_storage.get_created_session(batch_id) @@ -137,5 +138,5 @@ class BatchManager(BatchManagerBase): sessions.append(self.__batch_process_storage.create_session(batch_session)) return sessions - def cancel_batch_process(self, batch_process_id: str): + def cancel_batch_process(self, batch_process_id: str) -> None: self.__batch_process_storage.cancel(batch_process_id)