mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Create batch manager
This commit is contained in:
parent
bb681a8a11
commit
55b921818d
@ -29,6 +29,7 @@ from ..services.invoker import Invoker
|
|||||||
from ..services.processor import DefaultInvocationProcessor
|
from ..services.processor import DefaultInvocationProcessor
|
||||||
from ..services.sqlite import SqliteItemStorage
|
from ..services.sqlite import SqliteItemStorage
|
||||||
from ..services.model_manager_service import ModelManagerService
|
from ..services.model_manager_service import ModelManagerService
|
||||||
|
from ..services.batch_manager import BatchManager
|
||||||
from .events import FastAPIEventService
|
from .events import FastAPIEventService
|
||||||
|
|
||||||
|
|
||||||
@ -115,11 +116,14 @@ class ApiDependencies:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
batch_manager = BatchManager()
|
||||||
|
|
||||||
services = InvocationServices(
|
services = InvocationServices(
|
||||||
model_manager=ModelManagerService(config, logger),
|
model_manager=ModelManagerService(config, logger),
|
||||||
events=events,
|
events=events,
|
||||||
latents=latents,
|
latents=latents,
|
||||||
images=images,
|
images=images,
|
||||||
|
batch_manager=batch_manager,
|
||||||
boards=boards,
|
boards=boards,
|
||||||
board_images=board_images,
|
board_images=board_images,
|
||||||
queue=MemoryInvocationQueue(),
|
queue=MemoryInvocationQueue(),
|
||||||
|
@ -15,6 +15,7 @@ from ...services.graph import (
|
|||||||
GraphExecutionState,
|
GraphExecutionState,
|
||||||
NodeAlreadyExecutedError,
|
NodeAlreadyExecutedError,
|
||||||
)
|
)
|
||||||
|
from ...services.batch_manager import Batch
|
||||||
from ...services.item_storage import PaginatedResults
|
from ...services.item_storage import PaginatedResults
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
|
|
||||||
@ -33,12 +34,24 @@ async def create_session(
|
|||||||
graph: Optional[Graph] = Body(default=None, description="The graph to initialize the session with")
|
graph: Optional[Graph] = Body(default=None, description="The graph to initialize the session with")
|
||||||
) -> GraphExecutionState:
|
) -> GraphExecutionState:
|
||||||
"""Creates a new session, optionally initializing it with an invocation graph"""
|
"""Creates a new session, optionally initializing it with an invocation graph"""
|
||||||
|
session = ApiDependencies.invoker.create_execution_state(graph)
|
||||||
|
return session
|
||||||
|
|
||||||
batch_indices = list()
|
|
||||||
if graph.batches:
|
@session_router.post(
|
||||||
for batch in graph.batches:
|
"/batch",
|
||||||
batch_indices.append(len(batch.data)-1)
|
operation_id="create_batch",
|
||||||
session = ApiDependencies.invoker.create_execution_state(graph, batch_indices)
|
responses={
|
||||||
|
200: {"model": GraphExecutionState},
|
||||||
|
400: {"description": "Invalid json"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def create_batch(
|
||||||
|
graph: Optional[Graph] = Body(default=None, description="The graph to initialize the session with"),
|
||||||
|
batches: list[Batch] = Body(description="Batch config to apply to the given graph")
|
||||||
|
) -> GraphExecutionState:
|
||||||
|
"""Creates a new session, optionally initializing it with an invocation graph"""
|
||||||
|
session = ApiDependencies.invoker.services.batch_manager.run_batch_process(batches, graph)
|
||||||
return session
|
return session
|
||||||
|
|
||||||
|
|
||||||
|
160
invokeai/app/services/batch_manager.py
Normal file
160
invokeai/app/services/batch_manager.py
Normal file
@ -0,0 +1,160 @@
|
|||||||
|
|
||||||
|
import networkx as nx
|
||||||
|
import uuid
|
||||||
|
import copy
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from fastapi_events.handlers.local import local_handler
|
||||||
|
from fastapi_events.typing import Event
|
||||||
|
from typing import (
|
||||||
|
Optional,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
from invokeai.app.invocations.baseinvocation import (
|
||||||
|
BaseInvocation,
|
||||||
|
)
|
||||||
|
from invokeai.app.services.events import EventServiceBase
|
||||||
|
from invokeai.app.services.graph import Graph, GraphExecutionState
|
||||||
|
from invokeai.app.services.invoker import Invoker
|
||||||
|
|
||||||
|
|
||||||
|
InvocationsUnion = Union[BaseInvocation.get_invocations()] # type: ignore
|
||||||
|
class Batch(BaseModel):
|
||||||
|
data: list[InvocationsUnion] = Field(description="Mapping of ")
|
||||||
|
node_id: str = Field(description="ID of the node to batch")
|
||||||
|
|
||||||
|
|
||||||
|
class BatchProcess(BaseModel):
|
||||||
|
batch_id: Optional[str] = Field(default_factory=uuid.uuid4().__str__, description="Identifier for this batch")
|
||||||
|
sessions: list[str] = Field(description="Tracker for which batch is currently being processed", default_factory=list)
|
||||||
|
batches: list[Batch] = Field(
|
||||||
|
description="List of batch configs to apply to this session",
|
||||||
|
default_factory=list,
|
||||||
|
)
|
||||||
|
batch_indices: list[int] = Field(description="Tracker for which batch is currently being processed", default_factory=list)
|
||||||
|
graph: Graph = Field(description="The graph being executed")
|
||||||
|
|
||||||
|
|
||||||
|
class BatchManagerBase(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def start(
|
||||||
|
self,
|
||||||
|
invoker: Invoker
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def run_batch_process(
|
||||||
|
self,
|
||||||
|
batches: list[Batch],
|
||||||
|
graph: Graph
|
||||||
|
) -> BatchProcess:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def cancel_batch_process(
|
||||||
|
self,
|
||||||
|
batch_process_id: str
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BatchManager(BatchManagerBase):
|
||||||
|
"""Responsible for managing currently running and scheduled batch jobs"""
|
||||||
|
__invoker: Invoker
|
||||||
|
__batches: list[BatchProcess]
|
||||||
|
|
||||||
|
|
||||||
|
def start(self, invoker) -> None:
|
||||||
|
# if we do want multithreading at some point, we could make this configurable
|
||||||
|
self.__invoker = invoker
|
||||||
|
self.__batches = list()
|
||||||
|
local_handler.register(
|
||||||
|
event_name=EventServiceBase.session_event, _func=self.on_event
|
||||||
|
)
|
||||||
|
|
||||||
|
async def on_event(self, event: Event):
|
||||||
|
event_name = event[1]["event"]
|
||||||
|
|
||||||
|
match event_name:
|
||||||
|
case "graph_execution_state_complete":
|
||||||
|
await self.process(event)
|
||||||
|
case "invocation_error":
|
||||||
|
await self.process(event)
|
||||||
|
|
||||||
|
return event
|
||||||
|
|
||||||
|
async def process(self, event: Event):
|
||||||
|
data = event[1]["data"]
|
||||||
|
batchTarget = None
|
||||||
|
for batch in self.__batches:
|
||||||
|
if data['graph_execution_state_id'] in batch.sessions:
|
||||||
|
batchTarget = batch
|
||||||
|
break
|
||||||
|
|
||||||
|
if batchTarget == None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if sum(batchTarget.batch_indices) == 0:
|
||||||
|
self.__batches = [batch for batch in self.__batches if batch != batchTarget]
|
||||||
|
return
|
||||||
|
|
||||||
|
batchTarget.batch_indices = self._next_batch_index(batchTarget)
|
||||||
|
ges = self._next_batch_session(batchTarget)
|
||||||
|
batchTarget.sessions.append(ges.id)
|
||||||
|
self.__invoker.services.graph_execution_manager.set(ges)
|
||||||
|
self.__invoker.invoke(ges, invoke_all=True)
|
||||||
|
|
||||||
|
def _next_batch_session(self, batch_process: BatchProcess) -> GraphExecutionState:
|
||||||
|
graph = copy.deepcopy(batch_process.graph)
|
||||||
|
batches = batch_process.batches
|
||||||
|
g = graph.nx_graph_flat()
|
||||||
|
sorted_nodes = nx.topological_sort(g)
|
||||||
|
for npath in sorted_nodes:
|
||||||
|
node = graph.get_node(npath)
|
||||||
|
(index, batch) = next(((i,b) for i,b in enumerate(batches) if b.node_id in node.id), (None, None))
|
||||||
|
if batch:
|
||||||
|
batch_index = batch_process.batch_indices[index]
|
||||||
|
datum = batch.data[batch_index]
|
||||||
|
datum.id = node.id
|
||||||
|
graph.update_node(npath, datum)
|
||||||
|
|
||||||
|
return GraphExecutionState(graph=graph)
|
||||||
|
|
||||||
|
|
||||||
|
def _next_batch_index(self, batch_process: BatchProcess):
|
||||||
|
batch_indicies = batch_process.batch_indices.copy()
|
||||||
|
for index in range(len(batch_indicies)):
|
||||||
|
if batch_indicies[index] > 0:
|
||||||
|
batch_indicies[index] -= 1
|
||||||
|
break
|
||||||
|
return batch_indicies
|
||||||
|
|
||||||
|
|
||||||
|
def run_batch_process(
|
||||||
|
self,
|
||||||
|
batches: list[Batch],
|
||||||
|
graph: Graph
|
||||||
|
) -> BatchProcess:
|
||||||
|
batch_indices = list()
|
||||||
|
for batch in batches:
|
||||||
|
batch_indices.append(len(batch.data)-1)
|
||||||
|
batch_process = BatchProcess(
|
||||||
|
batches = batches,
|
||||||
|
batch_indices = batch_indices,
|
||||||
|
graph = graph,
|
||||||
|
)
|
||||||
|
ges = self._next_batch_session(batch_process)
|
||||||
|
batch_process.sessions.append(ges.id)
|
||||||
|
self.__batches.append(batch_process)
|
||||||
|
self.__invoker.services.graph_execution_manager.set(ges)
|
||||||
|
self.__invoker.invoke(ges, invoke_all=True)
|
||||||
|
return batch_process
|
||||||
|
|
||||||
|
def cancel_batch_process(
|
||||||
|
self,
|
||||||
|
batch_process_id: str
|
||||||
|
):
|
||||||
|
self.__batches = [batch for batch in self.__batches if batch.id != batch_process_id]
|
@ -245,13 +245,6 @@ InvocationsUnion = Union[BaseInvocation.get_invocations()] # type: ignore
|
|||||||
InvocationOutputsUnion = Union[BaseInvocationOutput.get_all_subclasses_tuple()] # type: ignore
|
InvocationOutputsUnion = Union[BaseInvocationOutput.get_all_subclasses_tuple()] # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Batch(BaseModel):
|
|
||||||
batch_id: Optional[str] = Field(default_factory=uuid.uuid4().__str__, description="Identifier for this batch")
|
|
||||||
data: list[InvocationsUnion] = Field(description="Mapping of ")
|
|
||||||
node_id: str = Field(description="ID of the node to batch")
|
|
||||||
|
|
||||||
|
|
||||||
class Graph(BaseModel):
|
class Graph(BaseModel):
|
||||||
id: str = Field(description="The id of this graph", default_factory=lambda: uuid.uuid4().__str__())
|
id: str = Field(description="The id of this graph", default_factory=lambda: uuid.uuid4().__str__())
|
||||||
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
|
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
|
||||||
@ -262,16 +255,13 @@ class Graph(BaseModel):
|
|||||||
description="The connections between nodes and their fields in this graph",
|
description="The connections between nodes and their fields in this graph",
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
)
|
)
|
||||||
batches: list[Batch] = Field(
|
|
||||||
description="List of batch configs to apply to this session",
|
|
||||||
default_factory=list,
|
|
||||||
)
|
|
||||||
|
|
||||||
def add_node(self, node: BaseInvocation) -> None:
|
def add_node(self, node: BaseInvocation) -> None:
|
||||||
"""Adds a node to a graph
|
"""Adds a node to a graph
|
||||||
|
|
||||||
:raises NodeAlreadyInGraphError: the node is already present in the graph.
|
:raises NodeAlreadyInGraphError: the node is already present in the graph.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if node.id in self.nodes:
|
if node.id in self.nodes:
|
||||||
raise NodeAlreadyInGraphError()
|
raise NodeAlreadyInGraphError()
|
||||||
|
|
||||||
@ -744,8 +734,6 @@ class GraphExecutionState(BaseModel):
|
|||||||
default_factory=list,
|
default_factory=list,
|
||||||
)
|
)
|
||||||
|
|
||||||
batch_indices: list[int] = Field(description="Tracker for which batch is currently being processed", default_factory=list)
|
|
||||||
|
|
||||||
# The results of executed nodes
|
# The results of executed nodes
|
||||||
results: dict[str, Annotated[InvocationOutputsUnion, Field(discriminator="type")]] = Field(
|
results: dict[str, Annotated[InvocationOutputsUnion, Field(discriminator="type")]] = Field(
|
||||||
description="The results of node executions", default_factory=dict
|
description="The results of node executions", default_factory=dict
|
||||||
@ -787,7 +775,6 @@ class GraphExecutionState(BaseModel):
|
|||||||
# TODO: enable multiple nodes to execute simultaneously by tracking currently executing nodes
|
# TODO: enable multiple nodes to execute simultaneously by tracking currently executing nodes
|
||||||
# possibly with a timeout?
|
# possibly with a timeout?
|
||||||
|
|
||||||
self._apply_batch_config()
|
|
||||||
# If there are no prepared nodes, prepare some nodes
|
# If there are no prepared nodes, prepare some nodes
|
||||||
next_node = self._get_next_node()
|
next_node = self._get_next_node()
|
||||||
if next_node is None:
|
if next_node is None:
|
||||||
@ -879,7 +866,7 @@ class GraphExecutionState(BaseModel):
|
|||||||
new_node = copy.deepcopy(node)
|
new_node = copy.deepcopy(node)
|
||||||
|
|
||||||
# Create the node id (use a random uuid)
|
# Create the node id (use a random uuid)
|
||||||
new_node.id = str(f"{uuid.uuid4()}-{node.id}")
|
new_node.id = str(uuid.uuid4())
|
||||||
|
|
||||||
# Set the iteration index for iteration invocations
|
# Set the iteration index for iteration invocations
|
||||||
if isinstance(new_node, IterateInvocation):
|
if isinstance(new_node, IterateInvocation):
|
||||||
@ -918,20 +905,6 @@ class GraphExecutionState(BaseModel):
|
|||||||
iterators = [n for n in nx.ancestors(g, node_id) if isinstance(self.graph.get_node(n), IterateInvocation)]
|
iterators = [n for n in nx.ancestors(g, node_id) if isinstance(self.graph.get_node(n), IterateInvocation)]
|
||||||
return iterators
|
return iterators
|
||||||
|
|
||||||
def _apply_batch_config(self):
|
|
||||||
g = self.graph.nx_graph_flat()
|
|
||||||
sorted_nodes = nx.topological_sort(g)
|
|
||||||
batchable_nodes = [n for n in sorted_nodes if n not in self.executed]
|
|
||||||
for npath in batchable_nodes:
|
|
||||||
node = self.graph.get_node(npath)
|
|
||||||
(index, batch) = next(((i,b) for i,b in enumerate(self.graph.batches) if b.node_id in node.id), (None, None))
|
|
||||||
if batch:
|
|
||||||
batch_index = self.batch_indices[index]
|
|
||||||
datum = batch.data[batch_index]
|
|
||||||
datum.id = node.id
|
|
||||||
self.graph.update_node(npath, datum)
|
|
||||||
|
|
||||||
|
|
||||||
def _prepare(self) -> Optional[str]:
|
def _prepare(self) -> Optional[str]:
|
||||||
# Get flattened source graph
|
# Get flattened source graph
|
||||||
g = self.graph.nx_graph_flat()
|
g = self.graph.nx_graph_flat()
|
||||||
@ -963,6 +936,7 @@ class GraphExecutionState(BaseModel):
|
|||||||
),
|
),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
if next_node_id == None:
|
if next_node_id == None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -4,6 +4,7 @@ from typing import TYPE_CHECKING
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
|
from invokeai.app.services.batch_manager import BatchManagerBase
|
||||||
from invokeai.app.services.board_images import BoardImagesServiceABC
|
from invokeai.app.services.board_images import BoardImagesServiceABC
|
||||||
from invokeai.app.services.boards import BoardServiceABC
|
from invokeai.app.services.boards import BoardServiceABC
|
||||||
from invokeai.app.services.images import ImageServiceABC
|
from invokeai.app.services.images import ImageServiceABC
|
||||||
@ -21,6 +22,7 @@ class InvocationServices:
|
|||||||
"""Services that can be used by invocations"""
|
"""Services that can be used by invocations"""
|
||||||
|
|
||||||
# TODO: Just forward-declared everything due to circular dependencies. Fix structure.
|
# TODO: Just forward-declared everything due to circular dependencies. Fix structure.
|
||||||
|
batch_manager: "BatchManagerBase"
|
||||||
board_images: "BoardImagesServiceABC"
|
board_images: "BoardImagesServiceABC"
|
||||||
boards: "BoardServiceABC"
|
boards: "BoardServiceABC"
|
||||||
configuration: "InvokeAIAppConfig"
|
configuration: "InvokeAIAppConfig"
|
||||||
@ -36,6 +38,7 @@ class InvocationServices:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
batch_manager: "BatchManagerBase",
|
||||||
board_images: "BoardImagesServiceABC",
|
board_images: "BoardImagesServiceABC",
|
||||||
boards: "BoardServiceABC",
|
boards: "BoardServiceABC",
|
||||||
configuration: "InvokeAIAppConfig",
|
configuration: "InvokeAIAppConfig",
|
||||||
@ -49,6 +52,7 @@ class InvocationServices:
|
|||||||
processor: "InvocationProcessorABC",
|
processor: "InvocationProcessorABC",
|
||||||
queue: "InvocationQueueABC",
|
queue: "InvocationQueueABC",
|
||||||
):
|
):
|
||||||
|
self.batch_manager = batch_manager
|
||||||
self.board_images = board_images
|
self.board_images = board_images
|
||||||
self.boards = boards
|
self.boards = boards
|
||||||
self.boards = boards
|
self.boards = boards
|
||||||
|
@ -20,6 +20,7 @@ class Invoker:
|
|||||||
def invoke(self, graph_execution_state: GraphExecutionState, invoke_all: bool = False) -> Optional[str]:
|
def invoke(self, graph_execution_state: GraphExecutionState, invoke_all: bool = False) -> Optional[str]:
|
||||||
"""Determines the next node to invoke and enqueues it, preparing if needed.
|
"""Determines the next node to invoke and enqueues it, preparing if needed.
|
||||||
Returns the id of the queued node, or `None` if there are no nodes left to enqueue."""
|
Returns the id of the queued node, or `None` if there are no nodes left to enqueue."""
|
||||||
|
|
||||||
# Get the next invocation
|
# Get the next invocation
|
||||||
invocation = graph_execution_state.next()
|
invocation = graph_execution_state.next()
|
||||||
if not invocation:
|
if not invocation:
|
||||||
@ -40,9 +41,9 @@ class Invoker:
|
|||||||
|
|
||||||
return invocation.id
|
return invocation.id
|
||||||
|
|
||||||
def create_execution_state(self, graph: Optional[Graph] = None, batch_indices: list[int] = list()) -> GraphExecutionState:
|
def create_execution_state(self, graph: Optional[Graph] = None) -> GraphExecutionState:
|
||||||
"""Creates a new execution state for the given graph"""
|
"""Creates a new execution state for the given graph"""
|
||||||
new_state = GraphExecutionState(graph=Graph() if graph is None else graph, batch_indices=batch_indices)
|
new_state = GraphExecutionState(graph=Graph() if graph is None else graph)
|
||||||
self.services.graph_execution_manager.set(new_state)
|
self.services.graph_execution_manager.set(new_state)
|
||||||
return new_state
|
return new_state
|
||||||
|
|
||||||
|
@ -6,7 +6,6 @@ from ..invocations.baseinvocation import InvocationContext
|
|||||||
from .invocation_queue import InvocationQueueItem
|
from .invocation_queue import InvocationQueueItem
|
||||||
from .invoker import InvocationProcessorABC, Invoker
|
from .invoker import InvocationProcessorABC, Invoker
|
||||||
from ..models.exceptions import CanceledException
|
from ..models.exceptions import CanceledException
|
||||||
from .graph import GraphExecutionState
|
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
|
|
||||||
@ -71,6 +70,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
error=traceback.format_exc(),
|
error=traceback.format_exc(),
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# get the source node id to provide to clients (the prepared node id is not as useful)
|
# get the source node id to provide to clients (the prepared node id is not as useful)
|
||||||
source_node_id = graph_execution_state.prepared_source_mapping[invocation.id]
|
source_node_id = graph_execution_state.prepared_source_mapping[invocation.id]
|
||||||
|
|
||||||
@ -154,15 +154,6 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
error_type=e.__class__.__name__,
|
error_type=e.__class__.__name__,
|
||||||
error=traceback.format_exc(),
|
error=traceback.format_exc(),
|
||||||
)
|
)
|
||||||
elif queue_item.invoke_all and sum(graph_execution_state.batch_indices) > 0:
|
|
||||||
batch_indicies = graph_execution_state.batch_indices.copy()
|
|
||||||
for index in range(len(batch_indicies)):
|
|
||||||
if batch_indicies[index] > 0:
|
|
||||||
batch_indicies[index] -= 1
|
|
||||||
break
|
|
||||||
new_ges = GraphExecutionState(graph=graph_execution_state.graph, batch_indices=batch_indicies)
|
|
||||||
self.__invoker.services.graph_execution_manager.set(new_ges)
|
|
||||||
self.__invoker.invoke(new_ges, invoke_all=True)
|
|
||||||
elif is_complete:
|
elif is_complete:
|
||||||
self.__invoker.services.events.emit_graph_execution_complete(graph_execution_state.id)
|
self.__invoker.services.events.emit_graph_execution_complete(graph_execution_state.id)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user