InvokeAI/invokeai/app/services/batch_manager.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

146 lines
5.4 KiB
Python
Raw Normal View History

2023-07-31 19:45:35 +00:00
import networkx as nx
import copy
from abc import ABC, abstractmethod
from itertools import product
2023-07-31 19:45:35 +00:00
from pydantic import BaseModel, Field
from fastapi_events.handlers.local import local_handler
from fastapi_events.typing import Event
from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.graph import Graph, GraphExecutionState
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.batch_manager_storage import (
BatchProcessStorageBase,
2023-08-10 18:09:00 +00:00
BatchSessionNotFoundException,
Batch,
BatchProcess,
BatchSession,
BatchSessionChanges,
)
2023-07-31 19:45:35 +00:00
2023-08-14 15:01:31 +00:00
2023-08-11 15:45:27 +00:00
class BatchProcessResponse(BaseModel):
batch_id: str = Field(description="ID for the batch")
session_ids: list[str] = Field(description="List of session IDs created for this batch")
2023-07-31 19:45:35 +00:00
class BatchManagerBase(ABC):
@abstractmethod
2023-08-01 20:41:40 +00:00
def start(self, invoker: Invoker):
2023-07-31 19:45:35 +00:00
pass
@abstractmethod
2023-08-11 15:45:27 +00:00
def create_batch_process(self, batches: list[Batch], graph: Graph) -> BatchProcessResponse:
pass
@abstractmethod
def run_batch_process(self, batch_id: str):
2023-07-31 19:45:35 +00:00
pass
@abstractmethod
2023-08-01 20:41:40 +00:00
def cancel_batch_process(self, batch_process_id: str):
2023-07-31 19:45:35 +00:00
pass
class BatchManager(BatchManagerBase):
"""Responsible for managing currently running and scheduled batch jobs"""
2023-08-01 20:41:40 +00:00
2023-07-31 19:45:35 +00:00
__invoker: Invoker
__batch_process_storage: BatchProcessStorageBase
def __init__(self, batch_process_storage: BatchProcessStorageBase) -> None:
super().__init__()
self.__batch_process_storage = batch_process_storage
2023-07-31 19:45:35 +00:00
def start(self, invoker: Invoker) -> None:
2023-07-31 19:45:35 +00:00
self.__invoker = invoker
2023-08-01 20:41:40 +00:00
local_handler.register(event_name=EventServiceBase.session_event, _func=self.on_event)
2023-07-31 19:45:35 +00:00
async def on_event(self, event: Event):
event_name = event[1]["event"]
match event_name:
case "graph_execution_state_complete":
await self.process(event, False)
2023-07-31 19:45:35 +00:00
case "invocation_error":
await self.process(event, True)
2023-07-31 19:45:35 +00:00
return event
2023-08-01 20:41:40 +00:00
async def process(self, event: Event, err: bool):
2023-07-31 19:45:35 +00:00
data = event[1]["data"]
batch_session = self.__batch_process_storage.get_session(data["graph_execution_state_id"])
if not batch_session:
2023-07-31 19:45:35 +00:00
return
2023-08-14 15:01:31 +00:00
updateSession = BatchSessionChanges(state="error" if err else "completed")
batch_session = self.__batch_process_storage.update_session_state(
batch_session.batch_id,
batch_session.session_id,
updateSession,
)
2023-08-11 19:52:49 +00:00
batch_process = self.__batch_process_storage.get(batch_session.batch_id)
if not batch_process.canceled:
self.run_batch_process(batch_process.batch_id)
2023-08-01 20:41:40 +00:00
def _create_batch_session(self, batch_process: BatchProcess, batch_indices: list[int]) -> GraphExecutionState:
2023-08-15 20:28:47 +00:00
graph = batch_process.graph.copy(deep=True)
2023-07-31 19:45:35 +00:00
batches = batch_process.batches
g = graph.nx_graph_flat()
sorted_nodes = nx.topological_sort(g)
for npath in sorted_nodes:
node = graph.get_node(npath)
2023-08-01 20:41:40 +00:00
(index, batch) = next(((i, b) for i, b in enumerate(batches) if b.node_id in node.id), (None, None))
2023-07-31 19:45:35 +00:00
if batch:
batch_index = batch_indices[index]
2023-07-31 19:45:35 +00:00
datum = batch.data[batch_index]
for key in datum:
node.__dict__[key] = datum[key]
graph.update_node(npath, node)
2023-07-31 19:45:35 +00:00
2023-08-01 20:41:40 +00:00
return GraphExecutionState(graph=graph)
2023-07-31 19:45:35 +00:00
def run_batch_process(self, batch_id: str):
2023-08-15 20:28:47 +00:00
self.__batch_process_storage.start(batch_id)
2023-08-14 15:01:31 +00:00
try:
2023-08-10 18:09:00 +00:00
created_session = self.__batch_process_storage.get_created_session(batch_id)
except BatchSessionNotFoundException:
return
ges = self.__invoker.services.graph_execution_manager.get(created_session.session_id)
self.__invoker.invoke(ges, invoke_all=True)
2023-08-14 15:01:31 +00:00
def _valid_batch_config(self, batch_process: BatchProcess) -> bool:
2023-08-11 15:45:27 +00:00
# TODO: Check that the node_ids in the batches are unique
# TODO: Validate data types are correct for each batch data
return True
2023-08-14 15:01:31 +00:00
2023-08-11 15:45:27 +00:00
def create_batch_process(self, batches: list[Batch], graph: Graph) -> BatchProcessResponse:
2023-07-31 19:45:35 +00:00
batch_process = BatchProcess(
2023-08-01 20:41:40 +00:00
batches=batches,
graph=graph,
2023-07-31 19:45:35 +00:00
)
if not self._valid_batch_config(batch_process):
return None
batch_process = self.__batch_process_storage.save(batch_process)
2023-08-11 15:45:27 +00:00
sessions = self._create_sessions(batch_process)
return BatchProcessResponse(
batch_id=batch_process.batch_id,
session_ids=[session.session_id for session in sessions],
)
2023-08-14 15:01:31 +00:00
2023-08-11 15:45:27 +00:00
def _create_sessions(self, batch_process: BatchProcess) -> list[BatchSession]:
batch_indices = list()
2023-08-11 15:45:27 +00:00
sessions = list()
for batch in batch_process.batches:
batch_indices.append(list(range(len(batch.data))))
all_batch_indices = product(*batch_indices)
for bi in all_batch_indices:
ges = self._create_batch_session(batch_process, bi)
self.__invoker.services.graph_execution_manager.set(ges)
2023-08-14 15:01:31 +00:00
batch_session = BatchSession(batch_id=batch_process.batch_id, session_id=ges.id, state="created")
2023-08-11 15:45:27 +00:00
sessions.append(self.__batch_process_storage.create_session(batch_session))
return sessions
2023-07-31 19:45:35 +00:00
2023-08-01 20:41:40 +00:00
def cancel_batch_process(self, batch_process_id: str):
2023-08-11 19:52:49 +00:00
self.__batch_process_storage.cancel(batch_process_id)