InvokeAI/invokeai/app/services/batch_manager.py

145 lines
4.9 KiB
Python
Raw Normal View History

2023-07-31 19:45:35 +00:00
import networkx as nx
import uuid
import copy
from abc import ABC, abstractmethod
from pydantic import BaseModel, Field
from fastapi_events.handlers.local import local_handler
from fastapi_events.typing import Event
from typing import (
Optional,
Union,
)
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
)
from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.graph import Graph, GraphExecutionState
from invokeai.app.services.invoker import Invoker
InvocationsUnion = Union[BaseInvocation.get_invocations()] # type: ignore
2023-08-01 20:41:40 +00:00
2023-07-31 19:45:35 +00:00
class Batch(BaseModel):
data: list[InvocationsUnion] = Field(description="Mapping of ")
node_id: str = Field(description="ID of the node to batch")
class BatchProcess(BaseModel):
batch_id: Optional[str] = Field(default_factory=uuid.uuid4().__str__, description="Identifier for this batch")
2023-08-01 20:41:40 +00:00
sessions: list[str] = Field(
description="Tracker for which batch is currently being processed", default_factory=list
)
2023-07-31 19:45:35 +00:00
batches: list[Batch] = Field(
description="List of batch configs to apply to this session",
default_factory=list,
)
2023-08-01 20:41:40 +00:00
batch_indices: list[int] = Field(
description="Tracker for which batch is currently being processed", default_factory=list
)
2023-07-31 19:45:35 +00:00
graph: Graph = Field(description="The graph being executed")
class BatchManagerBase(ABC):
@abstractmethod
2023-08-01 20:41:40 +00:00
def start(self, invoker: Invoker):
2023-07-31 19:45:35 +00:00
pass
@abstractmethod
2023-08-01 20:41:40 +00:00
def run_batch_process(self, batches: list[Batch], graph: Graph) -> BatchProcess:
2023-07-31 19:45:35 +00:00
pass
@abstractmethod
2023-08-01 20:41:40 +00:00
def cancel_batch_process(self, batch_process_id: str):
2023-07-31 19:45:35 +00:00
pass
class BatchManager(BatchManagerBase):
"""Responsible for managing currently running and scheduled batch jobs"""
2023-08-01 20:41:40 +00:00
2023-07-31 19:45:35 +00:00
__invoker: Invoker
__batches: list[BatchProcess]
def start(self, invoker) -> None:
# if we do want multithreading at some point, we could make this configurable
self.__invoker = invoker
self.__batches = list()
2023-08-01 20:41:40 +00:00
local_handler.register(event_name=EventServiceBase.session_event, _func=self.on_event)
2023-07-31 19:45:35 +00:00
async def on_event(self, event: Event):
event_name = event[1]["event"]
match event_name:
case "graph_execution_state_complete":
await self.process(event)
case "invocation_error":
await self.process(event)
return event
2023-08-01 20:41:40 +00:00
2023-07-31 19:45:35 +00:00
async def process(self, event: Event):
data = event[1]["data"]
batchTarget = None
for batch in self.__batches:
2023-08-01 20:41:40 +00:00
if data["graph_execution_state_id"] in batch.sessions:
2023-07-31 19:45:35 +00:00
batchTarget = batch
break
2023-08-01 20:41:40 +00:00
2023-07-31 19:45:35 +00:00
if batchTarget == None:
return
2023-08-01 20:41:40 +00:00
2023-07-31 19:45:35 +00:00
if sum(batchTarget.batch_indices) == 0:
self.__batches = [batch for batch in self.__batches if batch != batchTarget]
return
2023-08-01 20:41:40 +00:00
2023-07-31 19:45:35 +00:00
batchTarget.batch_indices = self._next_batch_index(batchTarget)
ges = self._next_batch_session(batchTarget)
batchTarget.sessions.append(ges.id)
self.__invoker.services.graph_execution_manager.set(ges)
self.__invoker.invoke(ges, invoke_all=True)
def _next_batch_session(self, batch_process: BatchProcess) -> GraphExecutionState:
graph = copy.deepcopy(batch_process.graph)
batches = batch_process.batches
g = graph.nx_graph_flat()
sorted_nodes = nx.topological_sort(g)
for npath in sorted_nodes:
node = graph.get_node(npath)
2023-08-01 20:41:40 +00:00
(index, batch) = next(((i, b) for i, b in enumerate(batches) if b.node_id in node.id), (None, None))
2023-07-31 19:45:35 +00:00
if batch:
batch_index = batch_process.batch_indices[index]
datum = batch.data[batch_index]
datum.id = node.id
graph.update_node(npath, datum)
2023-08-01 20:41:40 +00:00
return GraphExecutionState(graph=graph)
2023-07-31 19:45:35 +00:00
def _next_batch_index(self, batch_process: BatchProcess):
batch_indicies = batch_process.batch_indices.copy()
for index in range(len(batch_indicies)):
if batch_indicies[index] > 0:
batch_indicies[index] -= 1
break
return batch_indicies
2023-08-01 20:41:40 +00:00
def run_batch_process(self, batches: list[Batch], graph: Graph) -> BatchProcess:
2023-07-31 19:45:35 +00:00
batch_indices = list()
for batch in batches:
2023-08-01 20:41:40 +00:00
batch_indices.append(len(batch.data) - 1)
2023-07-31 19:45:35 +00:00
batch_process = BatchProcess(
2023-08-01 20:41:40 +00:00
batches=batches,
batch_indices=batch_indices,
graph=graph,
2023-07-31 19:45:35 +00:00
)
ges = self._next_batch_session(batch_process)
batch_process.sessions.append(ges.id)
self.__batches.append(batch_process)
self.__invoker.services.graph_execution_manager.set(ges)
self.__invoker.invoke(ges, invoke_all=True)
return batch_process
2023-08-01 20:41:40 +00:00
def cancel_batch_process(self, batch_process_id: str):
2023-07-31 19:45:35 +00:00
self.__batches = [batch for batch in self.__batches if batch.id != batch_process_id]