From 55b921818de2711b4a7e742d1d6b40e6ba26c7c2 Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Mon, 31 Jul 2023 15:45:35 -0400 Subject: [PATCH] Create batch manager --- invokeai/app/api/dependencies.py | 4 + invokeai/app/api/routers/sessions.py | 23 ++- invokeai/app/services/batch_manager.py | 160 +++++++++++++++++++ invokeai/app/services/graph.py | 32 +--- invokeai/app/services/invocation_services.py | 4 + invokeai/app/services/invoker.py | 5 +- invokeai/app/services/processor.py | 11 +- 7 files changed, 193 insertions(+), 46 deletions(-) create mode 100644 invokeai/app/services/batch_manager.py diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index a186daedf5..75a0cc55e9 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -29,6 +29,7 @@ from ..services.invoker import Invoker from ..services.processor import DefaultInvocationProcessor from ..services.sqlite import SqliteItemStorage from ..services.model_manager_service import ModelManagerService +from ..services.batch_manager import BatchManager from .events import FastAPIEventService @@ -115,11 +116,14 @@ class ApiDependencies: ) ) + batch_manager = BatchManager() + services = InvocationServices( model_manager=ModelManagerService(config, logger), events=events, latents=latents, images=images, + batch_manager=batch_manager, boards=boards, board_images=board_images, queue=MemoryInvocationQueue(), diff --git a/invokeai/app/api/routers/sessions.py b/invokeai/app/api/routers/sessions.py index fba2e6d596..6b72d6b83b 100644 --- a/invokeai/app/api/routers/sessions.py +++ b/invokeai/app/api/routers/sessions.py @@ -15,6 +15,7 @@ from ...services.graph import ( GraphExecutionState, NodeAlreadyExecutedError, ) +from ...services.batch_manager import Batch from ...services.item_storage import PaginatedResults from ..dependencies import ApiDependencies @@ -33,12 +34,24 @@ async def create_session( graph: Optional[Graph] = Body(default=None, description="The graph to initialize the session with") ) -> GraphExecutionState: """Creates a new session, optionally initializing it with an invocation graph""" + session = ApiDependencies.invoker.create_execution_state(graph) + return session - batch_indices = list() - if graph.batches: - for batch in graph.batches: - batch_indices.append(len(batch.data)-1) - session = ApiDependencies.invoker.create_execution_state(graph, batch_indices) + +@session_router.post( + "/batch", + operation_id="create_batch", + responses={ + 200: {"model": GraphExecutionState}, + 400: {"description": "Invalid json"}, + }, +) +async def create_batch( + graph: Optional[Graph] = Body(default=None, description="The graph to initialize the session with"), + batches: list[Batch] = Body(description="Batch config to apply to the given graph") +) -> GraphExecutionState: + """Creates a new session, optionally initializing it with an invocation graph""" + session = ApiDependencies.invoker.services.batch_manager.run_batch_process(batches, graph) return session diff --git a/invokeai/app/services/batch_manager.py b/invokeai/app/services/batch_manager.py new file mode 100644 index 0000000000..2607964a27 --- /dev/null +++ b/invokeai/app/services/batch_manager.py @@ -0,0 +1,160 @@ + +import networkx as nx +import uuid +import copy + +from abc import ABC, abstractmethod +from pydantic import BaseModel, Field +from fastapi_events.handlers.local import local_handler +from fastapi_events.typing import Event +from typing import ( + Optional, + Union, +) + +from invokeai.app.invocations.baseinvocation import ( + BaseInvocation, +) +from invokeai.app.services.events import EventServiceBase +from invokeai.app.services.graph import Graph, GraphExecutionState +from invokeai.app.services.invoker import Invoker + + +InvocationsUnion = Union[BaseInvocation.get_invocations()] # type: ignore +class Batch(BaseModel): + data: list[InvocationsUnion] = Field(description="Mapping of ") + node_id: str = Field(description="ID of the node to batch") + + +class BatchProcess(BaseModel): + batch_id: Optional[str] = Field(default_factory=uuid.uuid4().__str__, description="Identifier for this batch") + sessions: list[str] = Field(description="Tracker for which batch is currently being processed", default_factory=list) + batches: list[Batch] = Field( + description="List of batch configs to apply to this session", + default_factory=list, + ) + batch_indices: list[int] = Field(description="Tracker for which batch is currently being processed", default_factory=list) + graph: Graph = Field(description="The graph being executed") + + +class BatchManagerBase(ABC): + @abstractmethod + def start( + self, + invoker: Invoker + ): + pass + + @abstractmethod + def run_batch_process( + self, + batches: list[Batch], + graph: Graph + ) -> BatchProcess: + pass + + @abstractmethod + def cancel_batch_process( + self, + batch_process_id: str + ): + pass + + +class BatchManager(BatchManagerBase): + """Responsible for managing currently running and scheduled batch jobs""" + __invoker: Invoker + __batches: list[BatchProcess] + + + def start(self, invoker) -> None: + # if we do want multithreading at some point, we could make this configurable + self.__invoker = invoker + self.__batches = list() + local_handler.register( + event_name=EventServiceBase.session_event, _func=self.on_event + ) + + async def on_event(self, event: Event): + event_name = event[1]["event"] + + match event_name: + case "graph_execution_state_complete": + await self.process(event) + case "invocation_error": + await self.process(event) + + return event + + async def process(self, event: Event): + data = event[1]["data"] + batchTarget = None + for batch in self.__batches: + if data['graph_execution_state_id'] in batch.sessions: + batchTarget = batch + break + + if batchTarget == None: + return + + if sum(batchTarget.batch_indices) == 0: + self.__batches = [batch for batch in self.__batches if batch != batchTarget] + return + + batchTarget.batch_indices = self._next_batch_index(batchTarget) + ges = self._next_batch_session(batchTarget) + batchTarget.sessions.append(ges.id) + self.__invoker.services.graph_execution_manager.set(ges) + self.__invoker.invoke(ges, invoke_all=True) + + def _next_batch_session(self, batch_process: BatchProcess) -> GraphExecutionState: + graph = copy.deepcopy(batch_process.graph) + batches = batch_process.batches + g = graph.nx_graph_flat() + sorted_nodes = nx.topological_sort(g) + for npath in sorted_nodes: + node = graph.get_node(npath) + (index, batch) = next(((i,b) for i,b in enumerate(batches) if b.node_id in node.id), (None, None)) + if batch: + batch_index = batch_process.batch_indices[index] + datum = batch.data[batch_index] + datum.id = node.id + graph.update_node(npath, datum) + + return GraphExecutionState(graph=graph) + + + def _next_batch_index(self, batch_process: BatchProcess): + batch_indicies = batch_process.batch_indices.copy() + for index in range(len(batch_indicies)): + if batch_indicies[index] > 0: + batch_indicies[index] -= 1 + break + return batch_indicies + + + def run_batch_process( + self, + batches: list[Batch], + graph: Graph + ) -> BatchProcess: + batch_indices = list() + for batch in batches: + batch_indices.append(len(batch.data)-1) + batch_process = BatchProcess( + batches = batches, + batch_indices = batch_indices, + graph = graph, + ) + ges = self._next_batch_session(batch_process) + batch_process.sessions.append(ges.id) + self.__batches.append(batch_process) + self.__invoker.services.graph_execution_manager.set(ges) + self.__invoker.invoke(ges, invoke_all=True) + return batch_process + + def cancel_batch_process( + self, + batch_process_id: str + ): + self.__batches = [batch for batch in self.__batches if batch.id != batch_process_id] diff --git a/invokeai/app/services/graph.py b/invokeai/app/services/graph.py index aa23f39756..d7f021df14 100644 --- a/invokeai/app/services/graph.py +++ b/invokeai/app/services/graph.py @@ -245,13 +245,6 @@ InvocationsUnion = Union[BaseInvocation.get_invocations()] # type: ignore InvocationOutputsUnion = Union[BaseInvocationOutput.get_all_subclasses_tuple()] # type: ignore - -class Batch(BaseModel): - batch_id: Optional[str] = Field(default_factory=uuid.uuid4().__str__, description="Identifier for this batch") - data: list[InvocationsUnion] = Field(description="Mapping of ") - node_id: str = Field(description="ID of the node to batch") - - class Graph(BaseModel): id: str = Field(description="The id of this graph", default_factory=lambda: uuid.uuid4().__str__()) # TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me @@ -262,16 +255,13 @@ class Graph(BaseModel): description="The connections between nodes and their fields in this graph", default_factory=list, ) - batches: list[Batch] = Field( - description="List of batch configs to apply to this session", - default_factory=list, - ) def add_node(self, node: BaseInvocation) -> None: """Adds a node to a graph :raises NodeAlreadyInGraphError: the node is already present in the graph. """ + if node.id in self.nodes: raise NodeAlreadyInGraphError() @@ -744,8 +734,6 @@ class GraphExecutionState(BaseModel): default_factory=list, ) - batch_indices: list[int] = Field(description="Tracker for which batch is currently being processed", default_factory=list) - # The results of executed nodes results: dict[str, Annotated[InvocationOutputsUnion, Field(discriminator="type")]] = Field( description="The results of node executions", default_factory=dict @@ -787,7 +775,6 @@ class GraphExecutionState(BaseModel): # TODO: enable multiple nodes to execute simultaneously by tracking currently executing nodes # possibly with a timeout? - self._apply_batch_config() # If there are no prepared nodes, prepare some nodes next_node = self._get_next_node() if next_node is None: @@ -879,7 +866,7 @@ class GraphExecutionState(BaseModel): new_node = copy.deepcopy(node) # Create the node id (use a random uuid) - new_node.id = str(f"{uuid.uuid4()}-{node.id}") + new_node.id = str(uuid.uuid4()) # Set the iteration index for iteration invocations if isinstance(new_node, IterateInvocation): @@ -918,20 +905,6 @@ class GraphExecutionState(BaseModel): iterators = [n for n in nx.ancestors(g, node_id) if isinstance(self.graph.get_node(n), IterateInvocation)] return iterators - def _apply_batch_config(self): - g = self.graph.nx_graph_flat() - sorted_nodes = nx.topological_sort(g) - batchable_nodes = [n for n in sorted_nodes if n not in self.executed] - for npath in batchable_nodes: - node = self.graph.get_node(npath) - (index, batch) = next(((i,b) for i,b in enumerate(self.graph.batches) if b.node_id in node.id), (None, None)) - if batch: - batch_index = self.batch_indices[index] - datum = batch.data[batch_index] - datum.id = node.id - self.graph.update_node(npath, datum) - - def _prepare(self) -> Optional[str]: # Get flattened source graph g = self.graph.nx_graph_flat() @@ -963,6 +936,7 @@ class GraphExecutionState(BaseModel): ), None, ) + if next_node_id == None: return None diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py index 8af17c7643..839df11469 100644 --- a/invokeai/app/services/invocation_services.py +++ b/invokeai/app/services/invocation_services.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from logging import Logger + from invokeai.app.services.batch_manager import BatchManagerBase from invokeai.app.services.board_images import BoardImagesServiceABC from invokeai.app.services.boards import BoardServiceABC from invokeai.app.services.images import ImageServiceABC @@ -21,6 +22,7 @@ class InvocationServices: """Services that can be used by invocations""" # TODO: Just forward-declared everything due to circular dependencies. Fix structure. + batch_manager: "BatchManagerBase" board_images: "BoardImagesServiceABC" boards: "BoardServiceABC" configuration: "InvokeAIAppConfig" @@ -36,6 +38,7 @@ class InvocationServices: def __init__( self, + batch_manager: "BatchManagerBase", board_images: "BoardImagesServiceABC", boards: "BoardServiceABC", configuration: "InvokeAIAppConfig", @@ -49,6 +52,7 @@ class InvocationServices: processor: "InvocationProcessorABC", queue: "InvocationQueueABC", ): + self.batch_manager = batch_manager self.board_images = board_images self.boards = boards self.boards = boards diff --git a/invokeai/app/services/invoker.py b/invokeai/app/services/invoker.py index d9ca7e6410..1a7b0de27e 100644 --- a/invokeai/app/services/invoker.py +++ b/invokeai/app/services/invoker.py @@ -20,6 +20,7 @@ class Invoker: def invoke(self, graph_execution_state: GraphExecutionState, invoke_all: bool = False) -> Optional[str]: """Determines the next node to invoke and enqueues it, preparing if needed. Returns the id of the queued node, or `None` if there are no nodes left to enqueue.""" + # Get the next invocation invocation = graph_execution_state.next() if not invocation: @@ -40,9 +41,9 @@ class Invoker: return invocation.id - def create_execution_state(self, graph: Optional[Graph] = None, batch_indices: list[int] = list()) -> GraphExecutionState: + def create_execution_state(self, graph: Optional[Graph] = None) -> GraphExecutionState: """Creates a new execution state for the given graph""" - new_state = GraphExecutionState(graph=Graph() if graph is None else graph, batch_indices=batch_indices) + new_state = GraphExecutionState(graph=Graph() if graph is None else graph) self.services.graph_execution_manager.set(new_state) return new_state diff --git a/invokeai/app/services/processor.py b/invokeai/app/services/processor.py index 116bf94f15..50fe217e05 100644 --- a/invokeai/app/services/processor.py +++ b/invokeai/app/services/processor.py @@ -6,7 +6,6 @@ from ..invocations.baseinvocation import InvocationContext from .invocation_queue import InvocationQueueItem from .invoker import InvocationProcessorABC, Invoker from ..models.exceptions import CanceledException -from .graph import GraphExecutionState import invokeai.backend.util.logging as logger @@ -71,6 +70,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC): error=traceback.format_exc(), ) continue + # get the source node id to provide to clients (the prepared node id is not as useful) source_node_id = graph_execution_state.prepared_source_mapping[invocation.id] @@ -154,15 +154,6 @@ class DefaultInvocationProcessor(InvocationProcessorABC): error_type=e.__class__.__name__, error=traceback.format_exc(), ) - elif queue_item.invoke_all and sum(graph_execution_state.batch_indices) > 0: - batch_indicies = graph_execution_state.batch_indices.copy() - for index in range(len(batch_indicies)): - if batch_indicies[index] > 0: - batch_indicies[index] -= 1 - break - new_ges = GraphExecutionState(graph=graph_execution_state.graph, batch_indices=batch_indicies) - self.__invoker.services.graph_execution_manager.set(new_ges) - self.__invoker.invoke(new_ges, invoke_all=True) elif is_complete: self.__invoker.services.events.emit_graph_execution_complete(graph_execution_state.id)