[nodes] Add better error handling to processor and CLI (#2828)

* [nodes] Add better error handling to processor and CLI

* [nodes] Use more explicit name for marking node execution error

* [nodes] Update the processor call to error
This commit is contained in:
Kyle Schouviller 2023-02-27 10:01:07 -08:00 committed by GitHub
parent e0405031a7
commit b7d5a3e0b5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 102 additions and 21 deletions

View File

@ -27,7 +27,7 @@ from .services.events import EventServiceBase
class InvocationCommand(BaseModel): class InvocationCommand(BaseModel):
invocation: Union[BaseInvocation.get_invocations()] = Field(discriminator="type") invocation: Union[BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore
class InvalidArgs(Exception): class InvalidArgs(Exception):
@ -84,7 +84,7 @@ def get_invocation_parser() -> argparse.ArgumentParser:
for val in allowed_values: for val in allowed_values:
allowed_types.add(type(val)) allowed_types.add(type(val))
allowed_types_list = list(allowed_types) allowed_types_list = list(allowed_types)
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
command_parser.add_argument( command_parser.add_argument(
f"--{name}", f"--{name}",
@ -184,7 +184,7 @@ def invoke_cli():
) )
invoker = Invoker(services) invoker = Invoker(services)
session = invoker.create_execution_state() session: GraphExecutionState = invoker.create_execution_state()
parser = get_invocation_parser() parser = get_invocation_parser()
@ -219,6 +219,9 @@ def invoke_cli():
current_id = start_id current_id = start_id
new_invocations = list() new_invocations = list()
for cmd in cmds: for cmd in cmds:
if cmd is None or cmd.strip() == '':
raise InvalidArgs('Empty command')
# Parse args to create invocation # Parse args to create invocation
args = vars(parser.parse_args(shlex.split(cmd.strip()))) args = vars(parser.parse_args(shlex.split(cmd.strip())))
@ -289,6 +292,15 @@ def invoke_cli():
session = invoker.services.graph_execution_manager.get(session.id) session = invoker.services.graph_execution_manager.get(session.id)
time.sleep(0.1) time.sleep(0.1)
# Print any errors
if session.has_error():
for n in session.errors:
print(f'Error in node {n} (source node {session.prepared_source_mapping[n]}): {session.errors[n]}')
# Start a new session
print("Creating a new session")
session = invoker.create_execution_state()
except InvalidArgs: except InvalidArgs:
print('Invalid command, use "help" to list commands') print('Invalid command, use "help" to list commands')
continue continue

View File

@ -55,6 +55,21 @@ class EventServiceBase:
) )
) )
def emit_invocation_error(self,
graph_execution_state_id: str,
invocation_id: str,
error: str
) -> None:
"""Emitted when an invocation has completed"""
self.__emit_session_event(
event_name = 'invocation_error',
payload = dict(
graph_execution_state_id = graph_execution_state_id,
invocation_id = invocation_id,
error = error
)
)
def emit_invocation_started(self, def emit_invocation_started(self,
graph_execution_state_id: str, graph_execution_state_id: str,
invocation_id: str invocation_id: str

View File

@ -2,6 +2,7 @@
import copy import copy
import itertools import itertools
import traceback
from types import NoneType from types import NoneType
import uuid import uuid
import networkx as nx import networkx as nx
@ -153,8 +154,8 @@ class CollectInvocation(BaseInvocation):
return CollectInvocationOutput(collection = copy.copy(self.collection)) return CollectInvocationOutput(collection = copy.copy(self.collection))
InvocationsUnion = Union[BaseInvocation.get_invocations()] InvocationsUnion = Union[BaseInvocation.get_invocations()] # type: ignore
InvocationOutputsUnion = Union[BaseInvocationOutput.get_all_subclasses_tuple()] InvocationOutputsUnion = Union[BaseInvocationOutput.get_all_subclasses_tuple()] # type: ignore
class Graph(BaseModel): class Graph(BaseModel):
@ -486,11 +487,11 @@ class Graph(BaseModel):
type_tree.add_nodes_from(input_field_types) type_tree.add_nodes_from(input_field_types)
type_tree.add_edges_from([e for e in itertools.permutations(input_field_types, 2) if issubclass(e[1], e[0])]) type_tree.add_edges_from([e for e in itertools.permutations(input_field_types, 2) if issubclass(e[1], e[0])])
type_degrees = type_tree.in_degree(type_tree.nodes) type_degrees = type_tree.in_degree(type_tree.nodes)
if sum((t[1] == 0 for t in type_degrees)) != 1: if sum((t[1] == 0 for t in type_degrees)) != 1: # type: ignore
return False # There is more than one root type return False # There is more than one root type
# Get the input root type # Get the input root type
input_root_type = next(t[0] for t in type_degrees if t[1] == 0) input_root_type = next(t[0] for t in type_degrees if t[1] == 0) # type: ignore
# Verify that all outputs are lists # Verify that all outputs are lists
if not all((get_origin(f) == list for f in output_fields)): if not all((get_origin(f) == list for f in output_fields)):
@ -545,6 +546,9 @@ class GraphExecutionState(BaseModel):
# The results of executed nodes # The results of executed nodes
results: dict[str, Annotated[InvocationOutputsUnion, Field(discriminator="type")]] = Field(description="The results of node executions", default_factory=dict) results: dict[str, Annotated[InvocationOutputsUnion, Field(discriminator="type")]] = Field(description="The results of node executions", default_factory=dict)
# Errors raised when executing nodes
errors: dict[str, str] = Field(description="Errors raised when executing nodes", default_factory=dict)
# Map of prepared/executed nodes to their original nodes # Map of prepared/executed nodes to their original nodes
prepared_source_mapping: dict[str, str] = Field(description="The map of prepared nodes to original graph nodes", default_factory=dict) prepared_source_mapping: dict[str, str] = Field(description="The map of prepared nodes to original graph nodes", default_factory=dict)
@ -594,9 +598,17 @@ class GraphExecutionState(BaseModel):
self.executed.add(source_node) self.executed.add(source_node)
self.executed_history.append(source_node) self.executed_history.append(source_node)
def set_node_error(self, node_id: str, error: str):
"""Marks a node as errored"""
self.errors[node_id] = error
def is_complete(self) -> bool: def is_complete(self) -> bool:
"""Returns true if the graph is complete""" """Returns true if the graph is complete"""
return all((k in self.executed for k in self.graph.nodes)) return self.has_error() or all((k in self.executed for k in self.graph.nodes))
def has_error(self) -> bool:
"""Returns true if the graph has any errors"""
return len(self.errors) > 0
def _create_execution_node(self, node_path: str, iteration_node_map: list[tuple[str, str]]) -> list[str]: def _create_execution_node(self, node_path: str, iteration_node_map: list[tuple[str, str]]) -> list[str]:
"""Prepares an iteration node and connects all edges, returning the new node id""" """Prepares an iteration node and connects all edges, returning the new node id"""
@ -709,11 +721,11 @@ class GraphExecutionState(BaseModel):
# For every iterator, the parent must either not be a child of that iterator, or must match the prepared iteration for that iterator # For every iterator, the parent must either not be a child of that iterator, or must match the prepared iteration for that iterator
# TODO: Handle a node mapping to none # TODO: Handle a node mapping to none
eg = self.execution_graph.nx_graph_flat() eg = self.execution_graph.nx_graph_flat()
prepared_parent_mappings = [[(n,self._get_iteration_node(n, g, eg, it)) for n in next_node_parents] for it in iterator_node_prepared_combinations] prepared_parent_mappings = [[(n,self._get_iteration_node(n, g, eg, it)) for n in next_node_parents] for it in iterator_node_prepared_combinations] # type: ignore
# Create execution node for each iteration # Create execution node for each iteration
for iteration_mappings in prepared_parent_mappings: for iteration_mappings in prepared_parent_mappings:
create_results = self._create_execution_node(next_node_id, iteration_mappings) create_results = self._create_execution_node(next_node_id, iteration_mappings) # type: ignore
if create_results is not None: if create_results is not None:
new_node_ids.extend(create_results) new_node_ids.extend(create_results)

View File

@ -1,4 +1,5 @@
from threading import Event, Thread from threading import Event, Thread
import traceback
from ..invocations.baseinvocation import InvocationContext from ..invocations.baseinvocation import InvocationContext
from .invocation_queue import InvocationQueueItem from .invocation_queue import InvocationQueueItem
from .invoker import InvocationProcessorABC, Invoker from .invoker import InvocationProcessorABC, Invoker
@ -61,18 +62,34 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
result = outputs.dict() result = outputs.dict()
) )
# Queue any further commands if invoking all
is_complete = graph_execution_state.is_complete()
if queue_item.invoke_all and not is_complete:
self.__invoker.invoke(graph_execution_state, invoke_all = True)
elif is_complete:
self.__invoker.services.events.emit_graph_execution_complete(graph_execution_state.id)
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass
except Exception as e: except Exception as e:
# TODO: Log the error, mark the invocation as failed, and emit an event error = traceback.format_exc()
print(f'Error invoking {invocation.id}: {e}')
# Save error
graph_execution_state.set_node_error(invocation.id, error)
# Save the state changes
self.__invoker.services.graph_execution_manager.set(graph_execution_state)
# Send error event
self.__invoker.services.events.emit_invocation_error(
graph_execution_state_id = graph_execution_state.id,
invocation_id = invocation.id,
error = error
)
pass pass
# Queue any further commands if invoking all
is_complete = graph_execution_state.is_complete()
if queue_item.invoke_all and not is_complete:
self.__invoker.invoke(graph_execution_state, invoke_all = True)
elif is_complete:
self.__invoker.services.events.emit_graph_execution_complete(graph_execution_state.id)
except KeyboardInterrupt: except KeyboardInterrupt:
... # Log something? ... # Log something?

View File

@ -1,4 +1,4 @@
from .test_nodes import ImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation, PromptCollectionTestInvocation, TestEventService, create_edge, wait_until from .test_nodes import ErrorInvocation, ImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation, PromptCollectionTestInvocation, TestEventService, create_edge, wait_until
from ldm.invoke.app.services.processor import DefaultInvocationProcessor from ldm.invoke.app.services.processor import DefaultInvocationProcessor
from ldm.invoke.app.services.sqlite import SqliteItemStorage, sqlite_memory from ldm.invoke.app.services.sqlite import SqliteItemStorage, sqlite_memory
from ldm.invoke.app.services.invocation_queue import MemoryInvocationQueue from ldm.invoke.app.services.invocation_queue import MemoryInvocationQueue
@ -21,9 +21,9 @@ def simple_graph():
def mock_services() -> InvocationServices: def mock_services() -> InvocationServices:
# NOTE: none of these are actually called by the test invocations # NOTE: none of these are actually called by the test invocations
return InvocationServices( return InvocationServices(
generate = None, generate = None, # type: ignore
events = TestEventService(), events = TestEventService(),
images = None, images = None, # type: ignore
queue = MemoryInvocationQueue(), queue = MemoryInvocationQueue(),
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'), graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
processor = DefaultInvocationProcessor() processor = DefaultInvocationProcessor()
@ -79,3 +79,22 @@ def test_can_invoke_all(mock_invoker: Invoker, simple_graph):
g = mock_invoker.services.graph_execution_manager.get(g.id) g = mock_invoker.services.graph_execution_manager.get(g.id)
assert g.is_complete() assert g.is_complete()
def test_handles_errors(mock_invoker: Invoker):
g = mock_invoker.create_execution_state()
g.graph.add_node(ErrorInvocation(id = "1"))
mock_invoker.invoke(g, invoke_all=True)
def has_executed_all(g: GraphExecutionState):
g = mock_invoker.services.graph_execution_manager.get(g.id)
return g.is_complete()
wait_until(lambda: has_executed_all(g), timeout = 5, interval = 1)
mock_invoker.stop()
g = mock_invoker.services.graph_execution_manager.get(g.id)
assert g.has_error()
assert g.is_complete()
assert all((i in g.errors for i in g.source_prepared_mapping['1']))

View File

@ -32,6 +32,12 @@ class PromptTestInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> PromptTestInvocationOutput: def invoke(self, context: InvocationContext) -> PromptTestInvocationOutput:
return PromptTestInvocationOutput(prompt = self.prompt) return PromptTestInvocationOutput(prompt = self.prompt)
class ErrorInvocation(BaseInvocation):
type: Literal['test_error'] = 'test_error'
def invoke(self, context: InvocationContext) -> PromptTestInvocationOutput:
raise Exception("This invocation is supposed to fail")
class ImageTestInvocationOutput(BaseInvocationOutput): class ImageTestInvocationOutput(BaseInvocationOutput):
type: Literal['test_image_output'] = 'test_image_output' type: Literal['test_image_output'] = 'test_image_output'