mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
161 lines
5.1 KiB
Python
161 lines
5.1 KiB
Python
|
|
||
|
import networkx as nx
|
||
|
import uuid
|
||
|
import copy
|
||
|
|
||
|
from abc import ABC, abstractmethod
|
||
|
from pydantic import BaseModel, Field
|
||
|
from fastapi_events.handlers.local import local_handler
|
||
|
from fastapi_events.typing import Event
|
||
|
from typing import (
|
||
|
Optional,
|
||
|
Union,
|
||
|
)
|
||
|
|
||
|
from invokeai.app.invocations.baseinvocation import (
|
||
|
BaseInvocation,
|
||
|
)
|
||
|
from invokeai.app.services.events import EventServiceBase
|
||
|
from invokeai.app.services.graph import Graph, GraphExecutionState
|
||
|
from invokeai.app.services.invoker import Invoker
|
||
|
|
||
|
|
||
|
InvocationsUnion = Union[BaseInvocation.get_invocations()] # type: ignore
|
||
|
class Batch(BaseModel):
|
||
|
data: list[InvocationsUnion] = Field(description="Mapping of ")
|
||
|
node_id: str = Field(description="ID of the node to batch")
|
||
|
|
||
|
|
||
|
class BatchProcess(BaseModel):
|
||
|
batch_id: Optional[str] = Field(default_factory=uuid.uuid4().__str__, description="Identifier for this batch")
|
||
|
sessions: list[str] = Field(description="Tracker for which batch is currently being processed", default_factory=list)
|
||
|
batches: list[Batch] = Field(
|
||
|
description="List of batch configs to apply to this session",
|
||
|
default_factory=list,
|
||
|
)
|
||
|
batch_indices: list[int] = Field(description="Tracker for which batch is currently being processed", default_factory=list)
|
||
|
graph: Graph = Field(description="The graph being executed")
|
||
|
|
||
|
|
||
|
class BatchManagerBase(ABC):
|
||
|
@abstractmethod
|
||
|
def start(
|
||
|
self,
|
||
|
invoker: Invoker
|
||
|
):
|
||
|
pass
|
||
|
|
||
|
@abstractmethod
|
||
|
def run_batch_process(
|
||
|
self,
|
||
|
batches: list[Batch],
|
||
|
graph: Graph
|
||
|
) -> BatchProcess:
|
||
|
pass
|
||
|
|
||
|
@abstractmethod
|
||
|
def cancel_batch_process(
|
||
|
self,
|
||
|
batch_process_id: str
|
||
|
):
|
||
|
pass
|
||
|
|
||
|
|
||
|
class BatchManager(BatchManagerBase):
|
||
|
"""Responsible for managing currently running and scheduled batch jobs"""
|
||
|
__invoker: Invoker
|
||
|
__batches: list[BatchProcess]
|
||
|
|
||
|
|
||
|
def start(self, invoker) -> None:
|
||
|
# if we do want multithreading at some point, we could make this configurable
|
||
|
self.__invoker = invoker
|
||
|
self.__batches = list()
|
||
|
local_handler.register(
|
||
|
event_name=EventServiceBase.session_event, _func=self.on_event
|
||
|
)
|
||
|
|
||
|
async def on_event(self, event: Event):
|
||
|
event_name = event[1]["event"]
|
||
|
|
||
|
match event_name:
|
||
|
case "graph_execution_state_complete":
|
||
|
await self.process(event)
|
||
|
case "invocation_error":
|
||
|
await self.process(event)
|
||
|
|
||
|
return event
|
||
|
|
||
|
async def process(self, event: Event):
|
||
|
data = event[1]["data"]
|
||
|
batchTarget = None
|
||
|
for batch in self.__batches:
|
||
|
if data['graph_execution_state_id'] in batch.sessions:
|
||
|
batchTarget = batch
|
||
|
break
|
||
|
|
||
|
if batchTarget == None:
|
||
|
return
|
||
|
|
||
|
if sum(batchTarget.batch_indices) == 0:
|
||
|
self.__batches = [batch for batch in self.__batches if batch != batchTarget]
|
||
|
return
|
||
|
|
||
|
batchTarget.batch_indices = self._next_batch_index(batchTarget)
|
||
|
ges = self._next_batch_session(batchTarget)
|
||
|
batchTarget.sessions.append(ges.id)
|
||
|
self.__invoker.services.graph_execution_manager.set(ges)
|
||
|
self.__invoker.invoke(ges, invoke_all=True)
|
||
|
|
||
|
def _next_batch_session(self, batch_process: BatchProcess) -> GraphExecutionState:
|
||
|
graph = copy.deepcopy(batch_process.graph)
|
||
|
batches = batch_process.batches
|
||
|
g = graph.nx_graph_flat()
|
||
|
sorted_nodes = nx.topological_sort(g)
|
||
|
for npath in sorted_nodes:
|
||
|
node = graph.get_node(npath)
|
||
|
(index, batch) = next(((i,b) for i,b in enumerate(batches) if b.node_id in node.id), (None, None))
|
||
|
if batch:
|
||
|
batch_index = batch_process.batch_indices[index]
|
||
|
datum = batch.data[batch_index]
|
||
|
datum.id = node.id
|
||
|
graph.update_node(npath, datum)
|
||
|
|
||
|
return GraphExecutionState(graph=graph)
|
||
|
|
||
|
|
||
|
def _next_batch_index(self, batch_process: BatchProcess):
|
||
|
batch_indicies = batch_process.batch_indices.copy()
|
||
|
for index in range(len(batch_indicies)):
|
||
|
if batch_indicies[index] > 0:
|
||
|
batch_indicies[index] -= 1
|
||
|
break
|
||
|
return batch_indicies
|
||
|
|
||
|
|
||
|
def run_batch_process(
|
||
|
self,
|
||
|
batches: list[Batch],
|
||
|
graph: Graph
|
||
|
) -> BatchProcess:
|
||
|
batch_indices = list()
|
||
|
for batch in batches:
|
||
|
batch_indices.append(len(batch.data)-1)
|
||
|
batch_process = BatchProcess(
|
||
|
batches = batches,
|
||
|
batch_indices = batch_indices,
|
||
|
graph = graph,
|
||
|
)
|
||
|
ges = self._next_batch_session(batch_process)
|
||
|
batch_process.sessions.append(ges.id)
|
||
|
self.__batches.append(batch_process)
|
||
|
self.__invoker.services.graph_execution_manager.set(ges)
|
||
|
self.__invoker.invoke(ges, invoke_all=True)
|
||
|
return batch_process
|
||
|
|
||
|
def cancel_batch_process(
|
||
|
self,
|
||
|
batch_process_id: str
|
||
|
):
|
||
|
self.__batches = [batch for batch in self.__batches if batch.id != batch_process_id]
|