mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(backend): fix typings in batch_manager.py
- `batch_indicies` is `tuple[int]` not `list[int]` - explicit `None` return values
This commit is contained in:
parent
8cf9bd47b2
commit
5bec64d65b
@ -1,3 +1,4 @@
|
||||
from typing import Optional, Union
|
||||
import networkx as nx
|
||||
import copy
|
||||
|
||||
@ -27,7 +28,7 @@ class BatchProcessResponse(BaseModel):
|
||||
|
||||
class BatchManagerBase(ABC):
|
||||
@abstractmethod
|
||||
def start(self, invoker: Invoker):
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@ -35,11 +36,11 @@ class BatchManagerBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def run_batch_process(self, batch_id: str):
|
||||
def run_batch_process(self, batch_id: str) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_batch_process(self, batch_process_id: str):
|
||||
def cancel_batch_process(self, batch_process_id: str) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@ -68,11 +69,11 @@ class BatchManager(BatchManagerBase):
|
||||
|
||||
return event
|
||||
|
||||
async def process(self, event: Event, err: bool):
|
||||
async def process(self, event: Event, err: bool) -> None:
|
||||
data = event[1]["data"]
|
||||
batch_session = self.__batch_process_storage.get_session(data["graph_execution_state_id"])
|
||||
if not batch_session:
|
||||
return
|
||||
return None
|
||||
updateSession = BatchSessionChanges(state="error" if err else "completed")
|
||||
batch_session = self.__batch_process_storage.update_session_state(
|
||||
batch_session.batch_id,
|
||||
@ -83,7 +84,7 @@ class BatchManager(BatchManagerBase):
|
||||
if not batch_process.canceled:
|
||||
self.run_batch_process(batch_process.batch_id)
|
||||
|
||||
def _create_batch_session(self, batch_process: BatchProcess, batch_indices: list[int]) -> GraphExecutionState:
|
||||
def _create_batch_session(self, batch_process: BatchProcess, batch_indices: tuple[int]) -> GraphExecutionState:
|
||||
graph = batch_process.graph.copy(deep=True)
|
||||
batch = batch_process.batch
|
||||
g = graph.nx_graph_flat()
|
||||
@ -103,7 +104,7 @@ class BatchManager(BatchManagerBase):
|
||||
|
||||
return GraphExecutionState(graph=graph)
|
||||
|
||||
def run_batch_process(self, batch_id: str):
|
||||
def run_batch_process(self, batch_id: str) -> None:
|
||||
self.__batch_process_storage.start(batch_id)
|
||||
try:
|
||||
created_session = self.__batch_process_storage.get_created_session(batch_id)
|
||||
@ -137,5 +138,5 @@ class BatchManager(BatchManagerBase):
|
||||
sessions.append(self.__batch_process_storage.create_session(batch_session))
|
||||
return sessions
|
||||
|
||||
def cancel_batch_process(self, batch_process_id: str):
|
||||
def cancel_batch_process(self, batch_process_id: str) -> None:
|
||||
self.__batch_process_storage.cancel(batch_process_id)
|
||||
|
Loading…
Reference in New Issue
Block a user