mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
[nodes] Add better error handling to processor and CLI (#2828)
* [nodes] Add better error handling to processor and CLI * [nodes] Use more explicit name for marking node execution error * [nodes] Update the processor call to error
This commit is contained in:
parent
e0405031a7
commit
b7d5a3e0b5
@ -27,7 +27,7 @@ from .services.events import EventServiceBase
|
|||||||
|
|
||||||
|
|
||||||
class InvocationCommand(BaseModel):
|
class InvocationCommand(BaseModel):
|
||||||
invocation: Union[BaseInvocation.get_invocations()] = Field(discriminator="type")
|
invocation: Union[BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class InvalidArgs(Exception):
|
class InvalidArgs(Exception):
|
||||||
@ -84,7 +84,7 @@ def get_invocation_parser() -> argparse.ArgumentParser:
|
|||||||
for val in allowed_values:
|
for val in allowed_values:
|
||||||
allowed_types.add(type(val))
|
allowed_types.add(type(val))
|
||||||
allowed_types_list = list(allowed_types)
|
allowed_types_list = list(allowed_types)
|
||||||
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list]
|
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
|
||||||
|
|
||||||
command_parser.add_argument(
|
command_parser.add_argument(
|
||||||
f"--{name}",
|
f"--{name}",
|
||||||
@ -184,7 +184,7 @@ def invoke_cli():
|
|||||||
)
|
)
|
||||||
|
|
||||||
invoker = Invoker(services)
|
invoker = Invoker(services)
|
||||||
session = invoker.create_execution_state()
|
session: GraphExecutionState = invoker.create_execution_state()
|
||||||
|
|
||||||
parser = get_invocation_parser()
|
parser = get_invocation_parser()
|
||||||
|
|
||||||
@ -219,6 +219,9 @@ def invoke_cli():
|
|||||||
current_id = start_id
|
current_id = start_id
|
||||||
new_invocations = list()
|
new_invocations = list()
|
||||||
for cmd in cmds:
|
for cmd in cmds:
|
||||||
|
if cmd is None or cmd.strip() == '':
|
||||||
|
raise InvalidArgs('Empty command')
|
||||||
|
|
||||||
# Parse args to create invocation
|
# Parse args to create invocation
|
||||||
args = vars(parser.parse_args(shlex.split(cmd.strip())))
|
args = vars(parser.parse_args(shlex.split(cmd.strip())))
|
||||||
|
|
||||||
@ -289,6 +292,15 @@ def invoke_cli():
|
|||||||
session = invoker.services.graph_execution_manager.get(session.id)
|
session = invoker.services.graph_execution_manager.get(session.id)
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
# Print any errors
|
||||||
|
if session.has_error():
|
||||||
|
for n in session.errors:
|
||||||
|
print(f'Error in node {n} (source node {session.prepared_source_mapping[n]}): {session.errors[n]}')
|
||||||
|
|
||||||
|
# Start a new session
|
||||||
|
print("Creating a new session")
|
||||||
|
session = invoker.create_execution_state()
|
||||||
|
|
||||||
except InvalidArgs:
|
except InvalidArgs:
|
||||||
print('Invalid command, use "help" to list commands')
|
print('Invalid command, use "help" to list commands')
|
||||||
continue
|
continue
|
||||||
|
@ -55,6 +55,21 @@ class EventServiceBase:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def emit_invocation_error(self,
|
||||||
|
graph_execution_state_id: str,
|
||||||
|
invocation_id: str,
|
||||||
|
error: str
|
||||||
|
) -> None:
|
||||||
|
"""Emitted when an invocation has completed"""
|
||||||
|
self.__emit_session_event(
|
||||||
|
event_name = 'invocation_error',
|
||||||
|
payload = dict(
|
||||||
|
graph_execution_state_id = graph_execution_state_id,
|
||||||
|
invocation_id = invocation_id,
|
||||||
|
error = error
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def emit_invocation_started(self,
|
def emit_invocation_started(self,
|
||||||
graph_execution_state_id: str,
|
graph_execution_state_id: str,
|
||||||
invocation_id: str
|
invocation_id: str
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import copy
|
import copy
|
||||||
import itertools
|
import itertools
|
||||||
|
import traceback
|
||||||
from types import NoneType
|
from types import NoneType
|
||||||
import uuid
|
import uuid
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
@ -153,8 +154,8 @@ class CollectInvocation(BaseInvocation):
|
|||||||
return CollectInvocationOutput(collection = copy.copy(self.collection))
|
return CollectInvocationOutput(collection = copy.copy(self.collection))
|
||||||
|
|
||||||
|
|
||||||
InvocationsUnion = Union[BaseInvocation.get_invocations()]
|
InvocationsUnion = Union[BaseInvocation.get_invocations()] # type: ignore
|
||||||
InvocationOutputsUnion = Union[BaseInvocationOutput.get_all_subclasses_tuple()]
|
InvocationOutputsUnion = Union[BaseInvocationOutput.get_all_subclasses_tuple()] # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class Graph(BaseModel):
|
class Graph(BaseModel):
|
||||||
@ -486,11 +487,11 @@ class Graph(BaseModel):
|
|||||||
type_tree.add_nodes_from(input_field_types)
|
type_tree.add_nodes_from(input_field_types)
|
||||||
type_tree.add_edges_from([e for e in itertools.permutations(input_field_types, 2) if issubclass(e[1], e[0])])
|
type_tree.add_edges_from([e for e in itertools.permutations(input_field_types, 2) if issubclass(e[1], e[0])])
|
||||||
type_degrees = type_tree.in_degree(type_tree.nodes)
|
type_degrees = type_tree.in_degree(type_tree.nodes)
|
||||||
if sum((t[1] == 0 for t in type_degrees)) != 1:
|
if sum((t[1] == 0 for t in type_degrees)) != 1: # type: ignore
|
||||||
return False # There is more than one root type
|
return False # There is more than one root type
|
||||||
|
|
||||||
# Get the input root type
|
# Get the input root type
|
||||||
input_root_type = next(t[0] for t in type_degrees if t[1] == 0)
|
input_root_type = next(t[0] for t in type_degrees if t[1] == 0) # type: ignore
|
||||||
|
|
||||||
# Verify that all outputs are lists
|
# Verify that all outputs are lists
|
||||||
if not all((get_origin(f) == list for f in output_fields)):
|
if not all((get_origin(f) == list for f in output_fields)):
|
||||||
@ -545,6 +546,9 @@ class GraphExecutionState(BaseModel):
|
|||||||
# The results of executed nodes
|
# The results of executed nodes
|
||||||
results: dict[str, Annotated[InvocationOutputsUnion, Field(discriminator="type")]] = Field(description="The results of node executions", default_factory=dict)
|
results: dict[str, Annotated[InvocationOutputsUnion, Field(discriminator="type")]] = Field(description="The results of node executions", default_factory=dict)
|
||||||
|
|
||||||
|
# Errors raised when executing nodes
|
||||||
|
errors: dict[str, str] = Field(description="Errors raised when executing nodes", default_factory=dict)
|
||||||
|
|
||||||
# Map of prepared/executed nodes to their original nodes
|
# Map of prepared/executed nodes to their original nodes
|
||||||
prepared_source_mapping: dict[str, str] = Field(description="The map of prepared nodes to original graph nodes", default_factory=dict)
|
prepared_source_mapping: dict[str, str] = Field(description="The map of prepared nodes to original graph nodes", default_factory=dict)
|
||||||
|
|
||||||
@ -594,9 +598,17 @@ class GraphExecutionState(BaseModel):
|
|||||||
self.executed.add(source_node)
|
self.executed.add(source_node)
|
||||||
self.executed_history.append(source_node)
|
self.executed_history.append(source_node)
|
||||||
|
|
||||||
|
def set_node_error(self, node_id: str, error: str):
|
||||||
|
"""Marks a node as errored"""
|
||||||
|
self.errors[node_id] = error
|
||||||
|
|
||||||
def is_complete(self) -> bool:
|
def is_complete(self) -> bool:
|
||||||
"""Returns true if the graph is complete"""
|
"""Returns true if the graph is complete"""
|
||||||
return all((k in self.executed for k in self.graph.nodes))
|
return self.has_error() or all((k in self.executed for k in self.graph.nodes))
|
||||||
|
|
||||||
|
def has_error(self) -> bool:
|
||||||
|
"""Returns true if the graph has any errors"""
|
||||||
|
return len(self.errors) > 0
|
||||||
|
|
||||||
def _create_execution_node(self, node_path: str, iteration_node_map: list[tuple[str, str]]) -> list[str]:
|
def _create_execution_node(self, node_path: str, iteration_node_map: list[tuple[str, str]]) -> list[str]:
|
||||||
"""Prepares an iteration node and connects all edges, returning the new node id"""
|
"""Prepares an iteration node and connects all edges, returning the new node id"""
|
||||||
@ -709,11 +721,11 @@ class GraphExecutionState(BaseModel):
|
|||||||
# For every iterator, the parent must either not be a child of that iterator, or must match the prepared iteration for that iterator
|
# For every iterator, the parent must either not be a child of that iterator, or must match the prepared iteration for that iterator
|
||||||
# TODO: Handle a node mapping to none
|
# TODO: Handle a node mapping to none
|
||||||
eg = self.execution_graph.nx_graph_flat()
|
eg = self.execution_graph.nx_graph_flat()
|
||||||
prepared_parent_mappings = [[(n,self._get_iteration_node(n, g, eg, it)) for n in next_node_parents] for it in iterator_node_prepared_combinations]
|
prepared_parent_mappings = [[(n,self._get_iteration_node(n, g, eg, it)) for n in next_node_parents] for it in iterator_node_prepared_combinations] # type: ignore
|
||||||
|
|
||||||
# Create execution node for each iteration
|
# Create execution node for each iteration
|
||||||
for iteration_mappings in prepared_parent_mappings:
|
for iteration_mappings in prepared_parent_mappings:
|
||||||
create_results = self._create_execution_node(next_node_id, iteration_mappings)
|
create_results = self._create_execution_node(next_node_id, iteration_mappings) # type: ignore
|
||||||
if create_results is not None:
|
if create_results is not None:
|
||||||
new_node_ids.extend(create_results)
|
new_node_ids.extend(create_results)
|
||||||
|
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from threading import Event, Thread
|
from threading import Event, Thread
|
||||||
|
import traceback
|
||||||
from ..invocations.baseinvocation import InvocationContext
|
from ..invocations.baseinvocation import InvocationContext
|
||||||
from .invocation_queue import InvocationQueueItem
|
from .invocation_queue import InvocationQueueItem
|
||||||
from .invoker import InvocationProcessorABC, Invoker
|
from .invoker import InvocationProcessorABC, Invoker
|
||||||
@ -61,18 +62,34 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
result = outputs.dict()
|
result = outputs.dict()
|
||||||
)
|
)
|
||||||
|
|
||||||
# Queue any further commands if invoking all
|
|
||||||
is_complete = graph_execution_state.is_complete()
|
|
||||||
if queue_item.invoke_all and not is_complete:
|
|
||||||
self.__invoker.invoke(graph_execution_state, invoke_all = True)
|
|
||||||
elif is_complete:
|
|
||||||
self.__invoker.services.events.emit_graph_execution_complete(graph_execution_state.id)
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# TODO: Log the error, mark the invocation as failed, and emit an event
|
error = traceback.format_exc()
|
||||||
print(f'Error invoking {invocation.id}: {e}')
|
|
||||||
|
# Save error
|
||||||
|
graph_execution_state.set_node_error(invocation.id, error)
|
||||||
|
|
||||||
|
# Save the state changes
|
||||||
|
self.__invoker.services.graph_execution_manager.set(graph_execution_state)
|
||||||
|
|
||||||
|
# Send error event
|
||||||
|
self.__invoker.services.events.emit_invocation_error(
|
||||||
|
graph_execution_state_id = graph_execution_state.id,
|
||||||
|
invocation_id = invocation.id,
|
||||||
|
error = error
|
||||||
|
)
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# Queue any further commands if invoking all
|
||||||
|
is_complete = graph_execution_state.is_complete()
|
||||||
|
if queue_item.invoke_all and not is_complete:
|
||||||
|
self.__invoker.invoke(graph_execution_state, invoke_all = True)
|
||||||
|
elif is_complete:
|
||||||
|
self.__invoker.services.events.emit_graph_execution_complete(graph_execution_state.id)
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
... # Log something?
|
... # Log something?
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from .test_nodes import ImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation, PromptCollectionTestInvocation, TestEventService, create_edge, wait_until
|
from .test_nodes import ErrorInvocation, ImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation, PromptCollectionTestInvocation, TestEventService, create_edge, wait_until
|
||||||
from ldm.invoke.app.services.processor import DefaultInvocationProcessor
|
from ldm.invoke.app.services.processor import DefaultInvocationProcessor
|
||||||
from ldm.invoke.app.services.sqlite import SqliteItemStorage, sqlite_memory
|
from ldm.invoke.app.services.sqlite import SqliteItemStorage, sqlite_memory
|
||||||
from ldm.invoke.app.services.invocation_queue import MemoryInvocationQueue
|
from ldm.invoke.app.services.invocation_queue import MemoryInvocationQueue
|
||||||
@ -21,9 +21,9 @@ def simple_graph():
|
|||||||
def mock_services() -> InvocationServices:
|
def mock_services() -> InvocationServices:
|
||||||
# NOTE: none of these are actually called by the test invocations
|
# NOTE: none of these are actually called by the test invocations
|
||||||
return InvocationServices(
|
return InvocationServices(
|
||||||
generate = None,
|
generate = None, # type: ignore
|
||||||
events = TestEventService(),
|
events = TestEventService(),
|
||||||
images = None,
|
images = None, # type: ignore
|
||||||
queue = MemoryInvocationQueue(),
|
queue = MemoryInvocationQueue(),
|
||||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
|
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
|
||||||
processor = DefaultInvocationProcessor()
|
processor = DefaultInvocationProcessor()
|
||||||
@ -79,3 +79,22 @@ def test_can_invoke_all(mock_invoker: Invoker, simple_graph):
|
|||||||
|
|
||||||
g = mock_invoker.services.graph_execution_manager.get(g.id)
|
g = mock_invoker.services.graph_execution_manager.get(g.id)
|
||||||
assert g.is_complete()
|
assert g.is_complete()
|
||||||
|
|
||||||
|
def test_handles_errors(mock_invoker: Invoker):
|
||||||
|
g = mock_invoker.create_execution_state()
|
||||||
|
g.graph.add_node(ErrorInvocation(id = "1"))
|
||||||
|
|
||||||
|
mock_invoker.invoke(g, invoke_all=True)
|
||||||
|
|
||||||
|
def has_executed_all(g: GraphExecutionState):
|
||||||
|
g = mock_invoker.services.graph_execution_manager.get(g.id)
|
||||||
|
return g.is_complete()
|
||||||
|
|
||||||
|
wait_until(lambda: has_executed_all(g), timeout = 5, interval = 1)
|
||||||
|
mock_invoker.stop()
|
||||||
|
|
||||||
|
g = mock_invoker.services.graph_execution_manager.get(g.id)
|
||||||
|
assert g.has_error()
|
||||||
|
assert g.is_complete()
|
||||||
|
|
||||||
|
assert all((i in g.errors for i in g.source_prepared_mapping['1']))
|
||||||
|
@ -32,6 +32,12 @@ class PromptTestInvocation(BaseInvocation):
|
|||||||
def invoke(self, context: InvocationContext) -> PromptTestInvocationOutput:
|
def invoke(self, context: InvocationContext) -> PromptTestInvocationOutput:
|
||||||
return PromptTestInvocationOutput(prompt = self.prompt)
|
return PromptTestInvocationOutput(prompt = self.prompt)
|
||||||
|
|
||||||
|
class ErrorInvocation(BaseInvocation):
|
||||||
|
type: Literal['test_error'] = 'test_error'
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> PromptTestInvocationOutput:
|
||||||
|
raise Exception("This invocation is supposed to fail")
|
||||||
|
|
||||||
class ImageTestInvocationOutput(BaseInvocationOutput):
|
class ImageTestInvocationOutput(BaseInvocationOutput):
|
||||||
type: Literal['test_image_output'] = 'test_image_output'
|
type: Literal['test_image_output'] = 'test_image_output'
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user