mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
WIP running graphs as batches
This commit is contained in:
parent
4f9c728db0
commit
4bad96d9d6
@ -241,6 +241,13 @@ InvocationsUnion = Union[BaseInvocation.get_invocations()] # type: ignore
|
|||||||
InvocationOutputsUnion = Union[BaseInvocationOutput.get_all_subclasses_tuple()] # type: ignore
|
InvocationOutputsUnion = Union[BaseInvocationOutput.get_all_subclasses_tuple()] # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Batch(BaseModel):
|
||||||
|
batch_id: Optional[str] = Field(default_factory=uuid.uuid4().__str__, description="Identifier for this batch")
|
||||||
|
data: list[InvocationsUnion] = Field(description="Mapping of ")
|
||||||
|
node_id: str = Field(description="ID of the node to batch")
|
||||||
|
|
||||||
|
|
||||||
class Graph(BaseModel):
|
class Graph(BaseModel):
|
||||||
id: str = Field(description="The id of this graph", default_factory=lambda: uuid.uuid4().__str__())
|
id: str = Field(description="The id of this graph", default_factory=lambda: uuid.uuid4().__str__())
|
||||||
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
|
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
|
||||||
@ -251,13 +258,16 @@ class Graph(BaseModel):
|
|||||||
description="The connections between nodes and their fields in this graph",
|
description="The connections between nodes and their fields in this graph",
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
)
|
)
|
||||||
|
batches: list[Batch] = Field(
|
||||||
|
description="List of batch configs to apply to this session",
|
||||||
|
default_factory=list,
|
||||||
|
)
|
||||||
|
|
||||||
def add_node(self, node: BaseInvocation) -> None:
|
def add_node(self, node: BaseInvocation) -> None:
|
||||||
"""Adds a node to a graph
|
"""Adds a node to a graph
|
||||||
|
|
||||||
:raises NodeAlreadyInGraphError: the node is already present in the graph.
|
:raises NodeAlreadyInGraphError: the node is already present in the graph.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if node.id in self.nodes:
|
if node.id in self.nodes:
|
||||||
raise NodeAlreadyInGraphError()
|
raise NodeAlreadyInGraphError()
|
||||||
|
|
||||||
@ -793,6 +803,8 @@ class GraphExecutionState(BaseModel):
|
|||||||
# TODO: Store a reference to the graph instead of the actual graph?
|
# TODO: Store a reference to the graph instead of the actual graph?
|
||||||
graph: Graph = Field(description="The graph being executed")
|
graph: Graph = Field(description="The graph being executed")
|
||||||
|
|
||||||
|
batch_index: list[int] = Field(description="Tracker for which batch is currently being processed", default_factory=list)
|
||||||
|
|
||||||
# The graph of materialized nodes
|
# The graph of materialized nodes
|
||||||
execution_graph: Graph = Field(
|
execution_graph: Graph = Field(
|
||||||
description="The expanded graph of activated and executed nodes",
|
description="The expanded graph of activated and executed nodes",
|
||||||
@ -865,6 +877,13 @@ class GraphExecutionState(BaseModel):
|
|||||||
if next_node is not None:
|
if next_node is not None:
|
||||||
self._prepare_inputs(next_node)
|
self._prepare_inputs(next_node)
|
||||||
|
|
||||||
|
if sum(self.batch_index) != 0:
|
||||||
|
for index in self.batch_index:
|
||||||
|
if self.batch_index[index] > 0:
|
||||||
|
self.executed.clear()
|
||||||
|
self.batch_index[index] -= 1
|
||||||
|
return next(self)
|
||||||
|
|
||||||
# If next is still none, there's no next node, return None
|
# If next is still none, there's no next node, return None
|
||||||
return next_node
|
return next_node
|
||||||
|
|
||||||
@ -954,7 +973,7 @@ class GraphExecutionState(BaseModel):
|
|||||||
new_node = copy.deepcopy(node)
|
new_node = copy.deepcopy(node)
|
||||||
|
|
||||||
# Create the node id (use a random uuid)
|
# Create the node id (use a random uuid)
|
||||||
new_node.id = str(uuid.uuid4())
|
new_node.id = str(f"{uuid.uuid4()}-{node.id}")
|
||||||
|
|
||||||
# Set the iteration index for iteration invocations
|
# Set the iteration index for iteration invocations
|
||||||
if isinstance(new_node, IterateInvocation):
|
if isinstance(new_node, IterateInvocation):
|
||||||
|
@ -21,11 +21,17 @@ class Invoker:
|
|||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
"""Determines the next node to invoke and enqueues it, preparing if needed.
|
"""Determines the next node to invoke and enqueues it, preparing if needed.
|
||||||
Returns the id of the queued node, or `None` if there are no nodes left to enqueue."""
|
Returns the id of the queued node, or `None` if there are no nodes left to enqueue."""
|
||||||
|
|
||||||
# Get the next invocation
|
# Get the next invocation
|
||||||
invocation = graph_execution_state.next()
|
invocation = graph_execution_state.next()
|
||||||
if not invocation:
|
if not invocation:
|
||||||
return None
|
return None
|
||||||
|
(index, batch) = next(((i,b) for i,b in enumerate(graph_execution_state.graph.batches) if b.node_id in invocation.id), (None, None))
|
||||||
|
if batch:
|
||||||
|
# assert(isinstance(invocation.type, batch.node_type), f"Type mismatch between nodes and batch config on {invocation.id}")
|
||||||
|
batch_index = graph_execution_state.batch_index[index]
|
||||||
|
datum = batch.data[batch_index]
|
||||||
|
for param in datum.keys():
|
||||||
|
invocation[param] = datum[param]
|
||||||
|
|
||||||
# Save the execution state
|
# Save the execution state
|
||||||
self.services.graph_execution_manager.set(graph_execution_state)
|
self.services.graph_execution_manager.set(graph_execution_state)
|
||||||
@ -45,6 +51,11 @@ class Invoker:
|
|||||||
def create_execution_state(self, graph: Optional[Graph] = None) -> GraphExecutionState:
|
def create_execution_state(self, graph: Optional[Graph] = None) -> GraphExecutionState:
|
||||||
"""Creates a new execution state for the given graph"""
|
"""Creates a new execution state for the given graph"""
|
||||||
new_state = GraphExecutionState(graph=Graph() if graph is None else graph)
|
new_state = GraphExecutionState(graph=Graph() if graph is None else graph)
|
||||||
|
if graph.batches:
|
||||||
|
batch_index = list()
|
||||||
|
for batch in graph.batches:
|
||||||
|
batch_index.append(len(batch.data)-1)
|
||||||
|
new_state.batch_index = batch_index
|
||||||
self.services.graph_execution_manager.set(new_state)
|
self.services.graph_execution_manager.set(new_state)
|
||||||
return new_state
|
return new_state
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user