Create batch manager

This commit is contained in:
Brandon Rising 2023-07-31 15:45:35 -04:00
parent bb681a8a11
commit 55b921818d
7 changed files with 193 additions and 46 deletions

View File

@ -29,6 +29,7 @@ from ..services.invoker import Invoker
from ..services.processor import DefaultInvocationProcessor from ..services.processor import DefaultInvocationProcessor
from ..services.sqlite import SqliteItemStorage from ..services.sqlite import SqliteItemStorage
from ..services.model_manager_service import ModelManagerService from ..services.model_manager_service import ModelManagerService
from ..services.batch_manager import BatchManager
from .events import FastAPIEventService from .events import FastAPIEventService
@ -115,11 +116,14 @@ class ApiDependencies:
) )
) )
batch_manager = BatchManager()
services = InvocationServices( services = InvocationServices(
model_manager=ModelManagerService(config, logger), model_manager=ModelManagerService(config, logger),
events=events, events=events,
latents=latents, latents=latents,
images=images, images=images,
batch_manager=batch_manager,
boards=boards, boards=boards,
board_images=board_images, board_images=board_images,
queue=MemoryInvocationQueue(), queue=MemoryInvocationQueue(),

View File

@ -15,6 +15,7 @@ from ...services.graph import (
GraphExecutionState, GraphExecutionState,
NodeAlreadyExecutedError, NodeAlreadyExecutedError,
) )
from ...services.batch_manager import Batch
from ...services.item_storage import PaginatedResults from ...services.item_storage import PaginatedResults
from ..dependencies import ApiDependencies from ..dependencies import ApiDependencies
@ -33,12 +34,24 @@ async def create_session(
graph: Optional[Graph] = Body(default=None, description="The graph to initialize the session with") graph: Optional[Graph] = Body(default=None, description="The graph to initialize the session with")
) -> GraphExecutionState: ) -> GraphExecutionState:
"""Creates a new session, optionally initializing it with an invocation graph""" """Creates a new session, optionally initializing it with an invocation graph"""
session = ApiDependencies.invoker.create_execution_state(graph)
return session
batch_indices = list()
if graph.batches: @session_router.post(
for batch in graph.batches: "/batch",
batch_indices.append(len(batch.data)-1) operation_id="create_batch",
session = ApiDependencies.invoker.create_execution_state(graph, batch_indices) responses={
200: {"model": GraphExecutionState},
400: {"description": "Invalid json"},
},
)
async def create_batch(
graph: Optional[Graph] = Body(default=None, description="The graph to initialize the session with"),
batches: list[Batch] = Body(description="Batch config to apply to the given graph")
) -> GraphExecutionState:
"""Creates a new session, optionally initializing it with an invocation graph"""
session = ApiDependencies.invoker.services.batch_manager.run_batch_process(batches, graph)
return session return session

View File

@ -0,0 +1,160 @@
import networkx as nx
import uuid
import copy
from abc import ABC, abstractmethod
from pydantic import BaseModel, Field
from fastapi_events.handlers.local import local_handler
from fastapi_events.typing import Event
from typing import (
Optional,
Union,
)
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
)
from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.graph import Graph, GraphExecutionState
from invokeai.app.services.invoker import Invoker
InvocationsUnion = Union[BaseInvocation.get_invocations()] # type: ignore
class Batch(BaseModel):
data: list[InvocationsUnion] = Field(description="Mapping of ")
node_id: str = Field(description="ID of the node to batch")
class BatchProcess(BaseModel):
batch_id: Optional[str] = Field(default_factory=uuid.uuid4().__str__, description="Identifier for this batch")
sessions: list[str] = Field(description="Tracker for which batch is currently being processed", default_factory=list)
batches: list[Batch] = Field(
description="List of batch configs to apply to this session",
default_factory=list,
)
batch_indices: list[int] = Field(description="Tracker for which batch is currently being processed", default_factory=list)
graph: Graph = Field(description="The graph being executed")
class BatchManagerBase(ABC):
@abstractmethod
def start(
self,
invoker: Invoker
):
pass
@abstractmethod
def run_batch_process(
self,
batches: list[Batch],
graph: Graph
) -> BatchProcess:
pass
@abstractmethod
def cancel_batch_process(
self,
batch_process_id: str
):
pass
class BatchManager(BatchManagerBase):
"""Responsible for managing currently running and scheduled batch jobs"""
__invoker: Invoker
__batches: list[BatchProcess]
def start(self, invoker) -> None:
# if we do want multithreading at some point, we could make this configurable
self.__invoker = invoker
self.__batches = list()
local_handler.register(
event_name=EventServiceBase.session_event, _func=self.on_event
)
async def on_event(self, event: Event):
event_name = event[1]["event"]
match event_name:
case "graph_execution_state_complete":
await self.process(event)
case "invocation_error":
await self.process(event)
return event
async def process(self, event: Event):
data = event[1]["data"]
batchTarget = None
for batch in self.__batches:
if data['graph_execution_state_id'] in batch.sessions:
batchTarget = batch
break
if batchTarget == None:
return
if sum(batchTarget.batch_indices) == 0:
self.__batches = [batch for batch in self.__batches if batch != batchTarget]
return
batchTarget.batch_indices = self._next_batch_index(batchTarget)
ges = self._next_batch_session(batchTarget)
batchTarget.sessions.append(ges.id)
self.__invoker.services.graph_execution_manager.set(ges)
self.__invoker.invoke(ges, invoke_all=True)
def _next_batch_session(self, batch_process: BatchProcess) -> GraphExecutionState:
graph = copy.deepcopy(batch_process.graph)
batches = batch_process.batches
g = graph.nx_graph_flat()
sorted_nodes = nx.topological_sort(g)
for npath in sorted_nodes:
node = graph.get_node(npath)
(index, batch) = next(((i,b) for i,b in enumerate(batches) if b.node_id in node.id), (None, None))
if batch:
batch_index = batch_process.batch_indices[index]
datum = batch.data[batch_index]
datum.id = node.id
graph.update_node(npath, datum)
return GraphExecutionState(graph=graph)
def _next_batch_index(self, batch_process: BatchProcess):
batch_indicies = batch_process.batch_indices.copy()
for index in range(len(batch_indicies)):
if batch_indicies[index] > 0:
batch_indicies[index] -= 1
break
return batch_indicies
def run_batch_process(
self,
batches: list[Batch],
graph: Graph
) -> BatchProcess:
batch_indices = list()
for batch in batches:
batch_indices.append(len(batch.data)-1)
batch_process = BatchProcess(
batches = batches,
batch_indices = batch_indices,
graph = graph,
)
ges = self._next_batch_session(batch_process)
batch_process.sessions.append(ges.id)
self.__batches.append(batch_process)
self.__invoker.services.graph_execution_manager.set(ges)
self.__invoker.invoke(ges, invoke_all=True)
return batch_process
def cancel_batch_process(
self,
batch_process_id: str
):
self.__batches = [batch for batch in self.__batches if batch.id != batch_process_id]

View File

@ -245,13 +245,6 @@ InvocationsUnion = Union[BaseInvocation.get_invocations()] # type: ignore
InvocationOutputsUnion = Union[BaseInvocationOutput.get_all_subclasses_tuple()] # type: ignore InvocationOutputsUnion = Union[BaseInvocationOutput.get_all_subclasses_tuple()] # type: ignore
class Batch(BaseModel):
batch_id: Optional[str] = Field(default_factory=uuid.uuid4().__str__, description="Identifier for this batch")
data: list[InvocationsUnion] = Field(description="Mapping of ")
node_id: str = Field(description="ID of the node to batch")
class Graph(BaseModel): class Graph(BaseModel):
id: str = Field(description="The id of this graph", default_factory=lambda: uuid.uuid4().__str__()) id: str = Field(description="The id of this graph", default_factory=lambda: uuid.uuid4().__str__())
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me # TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
@ -262,16 +255,13 @@ class Graph(BaseModel):
description="The connections between nodes and their fields in this graph", description="The connections between nodes and their fields in this graph",
default_factory=list, default_factory=list,
) )
batches: list[Batch] = Field(
description="List of batch configs to apply to this session",
default_factory=list,
)
def add_node(self, node: BaseInvocation) -> None: def add_node(self, node: BaseInvocation) -> None:
"""Adds a node to a graph """Adds a node to a graph
:raises NodeAlreadyInGraphError: the node is already present in the graph. :raises NodeAlreadyInGraphError: the node is already present in the graph.
""" """
if node.id in self.nodes: if node.id in self.nodes:
raise NodeAlreadyInGraphError() raise NodeAlreadyInGraphError()
@ -744,8 +734,6 @@ class GraphExecutionState(BaseModel):
default_factory=list, default_factory=list,
) )
batch_indices: list[int] = Field(description="Tracker for which batch is currently being processed", default_factory=list)
# The results of executed nodes # The results of executed nodes
results: dict[str, Annotated[InvocationOutputsUnion, Field(discriminator="type")]] = Field( results: dict[str, Annotated[InvocationOutputsUnion, Field(discriminator="type")]] = Field(
description="The results of node executions", default_factory=dict description="The results of node executions", default_factory=dict
@ -787,7 +775,6 @@ class GraphExecutionState(BaseModel):
# TODO: enable multiple nodes to execute simultaneously by tracking currently executing nodes # TODO: enable multiple nodes to execute simultaneously by tracking currently executing nodes
# possibly with a timeout? # possibly with a timeout?
self._apply_batch_config()
# If there are no prepared nodes, prepare some nodes # If there are no prepared nodes, prepare some nodes
next_node = self._get_next_node() next_node = self._get_next_node()
if next_node is None: if next_node is None:
@ -879,7 +866,7 @@ class GraphExecutionState(BaseModel):
new_node = copy.deepcopy(node) new_node = copy.deepcopy(node)
# Create the node id (use a random uuid) # Create the node id (use a random uuid)
new_node.id = str(f"{uuid.uuid4()}-{node.id}") new_node.id = str(uuid.uuid4())
# Set the iteration index for iteration invocations # Set the iteration index for iteration invocations
if isinstance(new_node, IterateInvocation): if isinstance(new_node, IterateInvocation):
@ -918,20 +905,6 @@ class GraphExecutionState(BaseModel):
iterators = [n for n in nx.ancestors(g, node_id) if isinstance(self.graph.get_node(n), IterateInvocation)] iterators = [n for n in nx.ancestors(g, node_id) if isinstance(self.graph.get_node(n), IterateInvocation)]
return iterators return iterators
def _apply_batch_config(self):
g = self.graph.nx_graph_flat()
sorted_nodes = nx.topological_sort(g)
batchable_nodes = [n for n in sorted_nodes if n not in self.executed]
for npath in batchable_nodes:
node = self.graph.get_node(npath)
(index, batch) = next(((i,b) for i,b in enumerate(self.graph.batches) if b.node_id in node.id), (None, None))
if batch:
batch_index = self.batch_indices[index]
datum = batch.data[batch_index]
datum.id = node.id
self.graph.update_node(npath, datum)
def _prepare(self) -> Optional[str]: def _prepare(self) -> Optional[str]:
# Get flattened source graph # Get flattened source graph
g = self.graph.nx_graph_flat() g = self.graph.nx_graph_flat()
@ -963,6 +936,7 @@ class GraphExecutionState(BaseModel):
), ),
None, None,
) )
if next_node_id == None: if next_node_id == None:
return None return None

View File

@ -4,6 +4,7 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from logging import Logger from logging import Logger
from invokeai.app.services.batch_manager import BatchManagerBase
from invokeai.app.services.board_images import BoardImagesServiceABC from invokeai.app.services.board_images import BoardImagesServiceABC
from invokeai.app.services.boards import BoardServiceABC from invokeai.app.services.boards import BoardServiceABC
from invokeai.app.services.images import ImageServiceABC from invokeai.app.services.images import ImageServiceABC
@ -21,6 +22,7 @@ class InvocationServices:
"""Services that can be used by invocations""" """Services that can be used by invocations"""
# TODO: Just forward-declared everything due to circular dependencies. Fix structure. # TODO: Just forward-declared everything due to circular dependencies. Fix structure.
batch_manager: "BatchManagerBase"
board_images: "BoardImagesServiceABC" board_images: "BoardImagesServiceABC"
boards: "BoardServiceABC" boards: "BoardServiceABC"
configuration: "InvokeAIAppConfig" configuration: "InvokeAIAppConfig"
@ -36,6 +38,7 @@ class InvocationServices:
def __init__( def __init__(
self, self,
batch_manager: "BatchManagerBase",
board_images: "BoardImagesServiceABC", board_images: "BoardImagesServiceABC",
boards: "BoardServiceABC", boards: "BoardServiceABC",
configuration: "InvokeAIAppConfig", configuration: "InvokeAIAppConfig",
@ -49,6 +52,7 @@ class InvocationServices:
processor: "InvocationProcessorABC", processor: "InvocationProcessorABC",
queue: "InvocationQueueABC", queue: "InvocationQueueABC",
): ):
self.batch_manager = batch_manager
self.board_images = board_images self.board_images = board_images
self.boards = boards self.boards = boards
self.boards = boards self.boards = boards

View File

@ -20,6 +20,7 @@ class Invoker:
def invoke(self, graph_execution_state: GraphExecutionState, invoke_all: bool = False) -> Optional[str]: def invoke(self, graph_execution_state: GraphExecutionState, invoke_all: bool = False) -> Optional[str]:
"""Determines the next node to invoke and enqueues it, preparing if needed. """Determines the next node to invoke and enqueues it, preparing if needed.
Returns the id of the queued node, or `None` if there are no nodes left to enqueue.""" Returns the id of the queued node, or `None` if there are no nodes left to enqueue."""
# Get the next invocation # Get the next invocation
invocation = graph_execution_state.next() invocation = graph_execution_state.next()
if not invocation: if not invocation:
@ -40,9 +41,9 @@ class Invoker:
return invocation.id return invocation.id
def create_execution_state(self, graph: Optional[Graph] = None, batch_indices: list[int] = list()) -> GraphExecutionState: def create_execution_state(self, graph: Optional[Graph] = None) -> GraphExecutionState:
"""Creates a new execution state for the given graph""" """Creates a new execution state for the given graph"""
new_state = GraphExecutionState(graph=Graph() if graph is None else graph, batch_indices=batch_indices) new_state = GraphExecutionState(graph=Graph() if graph is None else graph)
self.services.graph_execution_manager.set(new_state) self.services.graph_execution_manager.set(new_state)
return new_state return new_state

View File

@ -6,7 +6,6 @@ from ..invocations.baseinvocation import InvocationContext
from .invocation_queue import InvocationQueueItem from .invocation_queue import InvocationQueueItem
from .invoker import InvocationProcessorABC, Invoker from .invoker import InvocationProcessorABC, Invoker
from ..models.exceptions import CanceledException from ..models.exceptions import CanceledException
from .graph import GraphExecutionState
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
@ -71,6 +70,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
error=traceback.format_exc(), error=traceback.format_exc(),
) )
continue continue
# get the source node id to provide to clients (the prepared node id is not as useful) # get the source node id to provide to clients (the prepared node id is not as useful)
source_node_id = graph_execution_state.prepared_source_mapping[invocation.id] source_node_id = graph_execution_state.prepared_source_mapping[invocation.id]
@ -154,15 +154,6 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
error_type=e.__class__.__name__, error_type=e.__class__.__name__,
error=traceback.format_exc(), error=traceback.format_exc(),
) )
elif queue_item.invoke_all and sum(graph_execution_state.batch_indices) > 0:
batch_indicies = graph_execution_state.batch_indices.copy()
for index in range(len(batch_indicies)):
if batch_indicies[index] > 0:
batch_indicies[index] -= 1
break
new_ges = GraphExecutionState(graph=graph_execution_state.graph, batch_indices=batch_indicies)
self.__invoker.services.graph_execution_manager.set(new_ges)
self.__invoker.invoke(new_ges, invoke_all=True)
elif is_complete: elif is_complete:
self.__invoker.services.events.emit_graph_execution_complete(graph_execution_state.id) self.__invoker.services.events.emit_graph_execution_complete(graph_execution_state.id)