From b7d5a3e0b5ab6254dd5dd1a27020890e8d91806e Mon Sep 17 00:00:00 2001 From: Kyle Schouviller Date: Mon, 27 Feb 2023 10:01:07 -0800 Subject: [PATCH] [nodes] Add better error handling to processor and CLI (#2828) * [nodes] Add better error handling to processor and CLI * [nodes] Use more explicit name for marking node execution error * [nodes] Update the processor call to error --- ldm/invoke/app/cli_app.py | 18 ++++++++++++--- ldm/invoke/app/services/events.py | 15 +++++++++++++ ldm/invoke/app/services/graph.py | 26 ++++++++++++++++------ ldm/invoke/app/services/processor.py | 33 +++++++++++++++++++++------- tests/nodes/test_invoker.py | 25 ++++++++++++++++++--- tests/nodes/test_nodes.py | 6 +++++ 6 files changed, 102 insertions(+), 21 deletions(-) diff --git a/ldm/invoke/app/cli_app.py b/ldm/invoke/app/cli_app.py index 9081f3b083..51ccb9d41e 100644 --- a/ldm/invoke/app/cli_app.py +++ b/ldm/invoke/app/cli_app.py @@ -27,7 +27,7 @@ from .services.events import EventServiceBase class InvocationCommand(BaseModel): - invocation: Union[BaseInvocation.get_invocations()] = Field(discriminator="type") + invocation: Union[BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore class InvalidArgs(Exception): @@ -84,7 +84,7 @@ def get_invocation_parser() -> argparse.ArgumentParser: for val in allowed_values: allowed_types.add(type(val)) allowed_types_list = list(allowed_types) - field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] + field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore command_parser.add_argument( f"--{name}", @@ -184,7 +184,7 @@ def invoke_cli(): ) invoker = Invoker(services) - session = invoker.create_execution_state() + session: GraphExecutionState = invoker.create_execution_state() parser = get_invocation_parser() @@ -219,6 +219,9 @@ def invoke_cli(): current_id = start_id new_invocations = list() for cmd in cmds: + if cmd is None or cmd.strip() == '': + raise InvalidArgs('Empty command') + # Parse args to create invocation args = vars(parser.parse_args(shlex.split(cmd.strip()))) @@ -288,6 +291,15 @@ def invoke_cli(): # Wait some time session = invoker.services.graph_execution_manager.get(session.id) time.sleep(0.1) + + # Print any errors + if session.has_error(): + for n in session.errors: + print(f'Error in node {n} (source node {session.prepared_source_mapping[n]}): {session.errors[n]}') + + # Start a new session + print("Creating a new session") + session = invoker.create_execution_state() except InvalidArgs: print('Invalid command, use "help" to list commands') diff --git a/ldm/invoke/app/services/events.py b/ldm/invoke/app/services/events.py index 7b850b61ac..c7b6367d68 100644 --- a/ldm/invoke/app/services/events.py +++ b/ldm/invoke/app/services/events.py @@ -54,6 +54,21 @@ class EventServiceBase: result = result ) ) + + def emit_invocation_error(self, + graph_execution_state_id: str, + invocation_id: str, + error: str + ) -> None: + """Emitted when an invocation has completed""" + self.__emit_session_event( + event_name = 'invocation_error', + payload = dict( + graph_execution_state_id = graph_execution_state_id, + invocation_id = invocation_id, + error = error + ) + ) def emit_invocation_started(self, graph_execution_state_id: str, diff --git a/ldm/invoke/app/services/graph.py b/ldm/invoke/app/services/graph.py index 8d1583fc8b..059ebca2d4 100644 --- a/ldm/invoke/app/services/graph.py +++ b/ldm/invoke/app/services/graph.py @@ -2,6 +2,7 @@ import copy import itertools +import traceback from types import NoneType import uuid import networkx as nx @@ -153,8 +154,8 @@ class CollectInvocation(BaseInvocation): return CollectInvocationOutput(collection = copy.copy(self.collection)) -InvocationsUnion = Union[BaseInvocation.get_invocations()] -InvocationOutputsUnion = Union[BaseInvocationOutput.get_all_subclasses_tuple()] +InvocationsUnion = Union[BaseInvocation.get_invocations()] # type: ignore +InvocationOutputsUnion = Union[BaseInvocationOutput.get_all_subclasses_tuple()] # type: ignore class Graph(BaseModel): @@ -486,11 +487,11 @@ class Graph(BaseModel): type_tree.add_nodes_from(input_field_types) type_tree.add_edges_from([e for e in itertools.permutations(input_field_types, 2) if issubclass(e[1], e[0])]) type_degrees = type_tree.in_degree(type_tree.nodes) - if sum((t[1] == 0 for t in type_degrees)) != 1: + if sum((t[1] == 0 for t in type_degrees)) != 1: # type: ignore return False # There is more than one root type # Get the input root type - input_root_type = next(t[0] for t in type_degrees if t[1] == 0) + input_root_type = next(t[0] for t in type_degrees if t[1] == 0) # type: ignore # Verify that all outputs are lists if not all((get_origin(f) == list for f in output_fields)): @@ -545,6 +546,9 @@ class GraphExecutionState(BaseModel): # The results of executed nodes results: dict[str, Annotated[InvocationOutputsUnion, Field(discriminator="type")]] = Field(description="The results of node executions", default_factory=dict) + # Errors raised when executing nodes + errors: dict[str, str] = Field(description="Errors raised when executing nodes", default_factory=dict) + # Map of prepared/executed nodes to their original nodes prepared_source_mapping: dict[str, str] = Field(description="The map of prepared nodes to original graph nodes", default_factory=dict) @@ -593,10 +597,18 @@ class GraphExecutionState(BaseModel): if all([n in self.executed for n in prepared_nodes]): self.executed.add(source_node) self.executed_history.append(source_node) + + def set_node_error(self, node_id: str, error: str): + """Marks a node as errored""" + self.errors[node_id] = error def is_complete(self) -> bool: """Returns true if the graph is complete""" - return all((k in self.executed for k in self.graph.nodes)) + return self.has_error() or all((k in self.executed for k in self.graph.nodes)) + + def has_error(self) -> bool: + """Returns true if the graph has any errors""" + return len(self.errors) > 0 def _create_execution_node(self, node_path: str, iteration_node_map: list[tuple[str, str]]) -> list[str]: """Prepares an iteration node and connects all edges, returning the new node id""" @@ -709,11 +721,11 @@ class GraphExecutionState(BaseModel): # For every iterator, the parent must either not be a child of that iterator, or must match the prepared iteration for that iterator # TODO: Handle a node mapping to none eg = self.execution_graph.nx_graph_flat() - prepared_parent_mappings = [[(n,self._get_iteration_node(n, g, eg, it)) for n in next_node_parents] for it in iterator_node_prepared_combinations] + prepared_parent_mappings = [[(n,self._get_iteration_node(n, g, eg, it)) for n in next_node_parents] for it in iterator_node_prepared_combinations] # type: ignore # Create execution node for each iteration for iteration_mappings in prepared_parent_mappings: - create_results = self._create_execution_node(next_node_id, iteration_mappings) + create_results = self._create_execution_node(next_node_id, iteration_mappings) # type: ignore if create_results is not None: new_node_ids.extend(create_results) diff --git a/ldm/invoke/app/services/processor.py b/ldm/invoke/app/services/processor.py index 9ea4349bbf..1825d404e5 100644 --- a/ldm/invoke/app/services/processor.py +++ b/ldm/invoke/app/services/processor.py @@ -1,4 +1,5 @@ from threading import Event, Thread +import traceback from ..invocations.baseinvocation import InvocationContext from .invocation_queue import InvocationQueueItem from .invoker import InvocationProcessorABC, Invoker @@ -61,18 +62,34 @@ class DefaultInvocationProcessor(InvocationProcessorABC): result = outputs.dict() ) - # Queue any further commands if invoking all - is_complete = graph_execution_state.is_complete() - if queue_item.invoke_all and not is_complete: - self.__invoker.invoke(graph_execution_state, invoke_all = True) - elif is_complete: - self.__invoker.services.events.emit_graph_execution_complete(graph_execution_state.id) except KeyboardInterrupt: pass + except Exception as e: - # TODO: Log the error, mark the invocation as failed, and emit an event - print(f'Error invoking {invocation.id}: {e}') + error = traceback.format_exc() + + # Save error + graph_execution_state.set_node_error(invocation.id, error) + + # Save the state changes + self.__invoker.services.graph_execution_manager.set(graph_execution_state) + + # Send error event + self.__invoker.services.events.emit_invocation_error( + graph_execution_state_id = graph_execution_state.id, + invocation_id = invocation.id, + error = error + ) + pass + + # Queue any further commands if invoking all + is_complete = graph_execution_state.is_complete() + if queue_item.invoke_all and not is_complete: + self.__invoker.invoke(graph_execution_state, invoke_all = True) + elif is_complete: + self.__invoker.services.events.emit_graph_execution_complete(graph_execution_state.id) + except KeyboardInterrupt: ... # Log something? diff --git a/tests/nodes/test_invoker.py b/tests/nodes/test_invoker.py index e9109728d5..8ca2931841 100644 --- a/tests/nodes/test_invoker.py +++ b/tests/nodes/test_invoker.py @@ -1,4 +1,4 @@ -from .test_nodes import ImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation, PromptCollectionTestInvocation, TestEventService, create_edge, wait_until +from .test_nodes import ErrorInvocation, ImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation, PromptCollectionTestInvocation, TestEventService, create_edge, wait_until from ldm.invoke.app.services.processor import DefaultInvocationProcessor from ldm.invoke.app.services.sqlite import SqliteItemStorage, sqlite_memory from ldm.invoke.app.services.invocation_queue import MemoryInvocationQueue @@ -21,9 +21,9 @@ def simple_graph(): def mock_services() -> InvocationServices: # NOTE: none of these are actually called by the test invocations return InvocationServices( - generate = None, + generate = None, # type: ignore events = TestEventService(), - images = None, + images = None, # type: ignore queue = MemoryInvocationQueue(), graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'), processor = DefaultInvocationProcessor() @@ -79,3 +79,22 @@ def test_can_invoke_all(mock_invoker: Invoker, simple_graph): g = mock_invoker.services.graph_execution_manager.get(g.id) assert g.is_complete() + +def test_handles_errors(mock_invoker: Invoker): + g = mock_invoker.create_execution_state() + g.graph.add_node(ErrorInvocation(id = "1")) + + mock_invoker.invoke(g, invoke_all=True) + + def has_executed_all(g: GraphExecutionState): + g = mock_invoker.services.graph_execution_manager.get(g.id) + return g.is_complete() + + wait_until(lambda: has_executed_all(g), timeout = 5, interval = 1) + mock_invoker.stop() + + g = mock_invoker.services.graph_execution_manager.get(g.id) + assert g.has_error() + assert g.is_complete() + + assert all((i in g.errors for i in g.source_prepared_mapping['1'])) diff --git a/tests/nodes/test_nodes.py b/tests/nodes/test_nodes.py index fea2e75e95..e07dcb8594 100644 --- a/tests/nodes/test_nodes.py +++ b/tests/nodes/test_nodes.py @@ -32,6 +32,12 @@ class PromptTestInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> PromptTestInvocationOutput: return PromptTestInvocationOutput(prompt = self.prompt) +class ErrorInvocation(BaseInvocation): + type: Literal['test_error'] = 'test_error' + + def invoke(self, context: InvocationContext) -> PromptTestInvocationOutput: + raise Exception("This invocation is supposed to fail") + class ImageTestInvocationOutput(BaseInvocationOutput): type: Literal['test_image_output'] = 'test_image_output'