mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
WIP running graphs as batches
This commit is contained in:
parent
4f9c728db0
commit
4bad96d9d6
@ -241,6 +241,13 @@ InvocationsUnion = Union[BaseInvocation.get_invocations()] # type: ignore
|
||||
InvocationOutputsUnion = Union[BaseInvocationOutput.get_all_subclasses_tuple()] # type: ignore
|
||||
|
||||
|
||||
|
||||
class Batch(BaseModel):
|
||||
batch_id: Optional[str] = Field(default_factory=uuid.uuid4().__str__, description="Identifier for this batch")
|
||||
data: list[InvocationsUnion] = Field(description="Mapping of ")
|
||||
node_id: str = Field(description="ID of the node to batch")
|
||||
|
||||
|
||||
class Graph(BaseModel):
|
||||
id: str = Field(description="The id of this graph", default_factory=lambda: uuid.uuid4().__str__())
|
||||
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
|
||||
@ -251,13 +258,16 @@ class Graph(BaseModel):
|
||||
description="The connections between nodes and their fields in this graph",
|
||||
default_factory=list,
|
||||
)
|
||||
batches: list[Batch] = Field(
|
||||
description="List of batch configs to apply to this session",
|
||||
default_factory=list,
|
||||
)
|
||||
|
||||
def add_node(self, node: BaseInvocation) -> None:
|
||||
"""Adds a node to a graph
|
||||
|
||||
:raises NodeAlreadyInGraphError: the node is already present in the graph.
|
||||
"""
|
||||
|
||||
if node.id in self.nodes:
|
||||
raise NodeAlreadyInGraphError()
|
||||
|
||||
@ -793,6 +803,8 @@ class GraphExecutionState(BaseModel):
|
||||
# TODO: Store a reference to the graph instead of the actual graph?
|
||||
graph: Graph = Field(description="The graph being executed")
|
||||
|
||||
batch_index: list[int] = Field(description="Tracker for which batch is currently being processed", default_factory=list)
|
||||
|
||||
# The graph of materialized nodes
|
||||
execution_graph: Graph = Field(
|
||||
description="The expanded graph of activated and executed nodes",
|
||||
@ -865,6 +877,13 @@ class GraphExecutionState(BaseModel):
|
||||
if next_node is not None:
|
||||
self._prepare_inputs(next_node)
|
||||
|
||||
if sum(self.batch_index) != 0:
|
||||
for index in self.batch_index:
|
||||
if self.batch_index[index] > 0:
|
||||
self.executed.clear()
|
||||
self.batch_index[index] -= 1
|
||||
return next(self)
|
||||
|
||||
# If next is still none, there's no next node, return None
|
||||
return next_node
|
||||
|
||||
@ -954,7 +973,7 @@ class GraphExecutionState(BaseModel):
|
||||
new_node = copy.deepcopy(node)
|
||||
|
||||
# Create the node id (use a random uuid)
|
||||
new_node.id = str(uuid.uuid4())
|
||||
new_node.id = str(f"{uuid.uuid4()}-{node.id}")
|
||||
|
||||
# Set the iteration index for iteration invocations
|
||||
if isinstance(new_node, IterateInvocation):
|
||||
|
@ -21,11 +21,17 @@ class Invoker:
|
||||
) -> Optional[str]:
|
||||
"""Determines the next node to invoke and enqueues it, preparing if needed.
|
||||
Returns the id of the queued node, or `None` if there are no nodes left to enqueue."""
|
||||
|
||||
# Get the next invocation
|
||||
invocation = graph_execution_state.next()
|
||||
if not invocation:
|
||||
return None
|
||||
(index, batch) = next(((i,b) for i,b in enumerate(graph_execution_state.graph.batches) if b.node_id in invocation.id), (None, None))
|
||||
if batch:
|
||||
# assert(isinstance(invocation.type, batch.node_type), f"Type mismatch between nodes and batch config on {invocation.id}")
|
||||
batch_index = graph_execution_state.batch_index[index]
|
||||
datum = batch.data[batch_index]
|
||||
for param in datum.keys():
|
||||
invocation[param] = datum[param]
|
||||
|
||||
# Save the execution state
|
||||
self.services.graph_execution_manager.set(graph_execution_state)
|
||||
@ -45,6 +51,11 @@ class Invoker:
|
||||
def create_execution_state(self, graph: Optional[Graph] = None) -> GraphExecutionState:
|
||||
"""Creates a new execution state for the given graph"""
|
||||
new_state = GraphExecutionState(graph=Graph() if graph is None else graph)
|
||||
if graph.batches:
|
||||
batch_index = list()
|
||||
for batch in graph.batches:
|
||||
batch_index.append(len(batch.data)-1)
|
||||
new_state.batch_index = batch_index
|
||||
self.services.graph_execution_manager.set(new_state)
|
||||
return new_state
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user