mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
[nodes] Add better error handling to processor and CLI (#2828)
* [nodes] Add better error handling to processor and CLI * [nodes] Use more explicit name for marking node execution error * [nodes] Update the processor call to error
This commit is contained in:
parent
e0405031a7
commit
b7d5a3e0b5
@ -27,7 +27,7 @@ from .services.events import EventServiceBase
|
||||
|
||||
|
||||
class InvocationCommand(BaseModel):
|
||||
invocation: Union[BaseInvocation.get_invocations()] = Field(discriminator="type")
|
||||
invocation: Union[BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore
|
||||
|
||||
|
||||
class InvalidArgs(Exception):
|
||||
@ -84,7 +84,7 @@ def get_invocation_parser() -> argparse.ArgumentParser:
|
||||
for val in allowed_values:
|
||||
allowed_types.add(type(val))
|
||||
allowed_types_list = list(allowed_types)
|
||||
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list]
|
||||
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
|
||||
|
||||
command_parser.add_argument(
|
||||
f"--{name}",
|
||||
@ -184,7 +184,7 @@ def invoke_cli():
|
||||
)
|
||||
|
||||
invoker = Invoker(services)
|
||||
session = invoker.create_execution_state()
|
||||
session: GraphExecutionState = invoker.create_execution_state()
|
||||
|
||||
parser = get_invocation_parser()
|
||||
|
||||
@ -219,6 +219,9 @@ def invoke_cli():
|
||||
current_id = start_id
|
||||
new_invocations = list()
|
||||
for cmd in cmds:
|
||||
if cmd is None or cmd.strip() == '':
|
||||
raise InvalidArgs('Empty command')
|
||||
|
||||
# Parse args to create invocation
|
||||
args = vars(parser.parse_args(shlex.split(cmd.strip())))
|
||||
|
||||
@ -289,6 +292,15 @@ def invoke_cli():
|
||||
session = invoker.services.graph_execution_manager.get(session.id)
|
||||
time.sleep(0.1)
|
||||
|
||||
# Print any errors
|
||||
if session.has_error():
|
||||
for n in session.errors:
|
||||
print(f'Error in node {n} (source node {session.prepared_source_mapping[n]}): {session.errors[n]}')
|
||||
|
||||
# Start a new session
|
||||
print("Creating a new session")
|
||||
session = invoker.create_execution_state()
|
||||
|
||||
except InvalidArgs:
|
||||
print('Invalid command, use "help" to list commands')
|
||||
continue
|
||||
|
@ -55,6 +55,21 @@ class EventServiceBase:
|
||||
)
|
||||
)
|
||||
|
||||
def emit_invocation_error(self,
|
||||
graph_execution_state_id: str,
|
||||
invocation_id: str,
|
||||
error: str
|
||||
) -> None:
|
||||
"""Emitted when an invocation has completed"""
|
||||
self.__emit_session_event(
|
||||
event_name = 'invocation_error',
|
||||
payload = dict(
|
||||
graph_execution_state_id = graph_execution_state_id,
|
||||
invocation_id = invocation_id,
|
||||
error = error
|
||||
)
|
||||
)
|
||||
|
||||
def emit_invocation_started(self,
|
||||
graph_execution_state_id: str,
|
||||
invocation_id: str
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
import copy
|
||||
import itertools
|
||||
import traceback
|
||||
from types import NoneType
|
||||
import uuid
|
||||
import networkx as nx
|
||||
@ -153,8 +154,8 @@ class CollectInvocation(BaseInvocation):
|
||||
return CollectInvocationOutput(collection = copy.copy(self.collection))
|
||||
|
||||
|
||||
InvocationsUnion = Union[BaseInvocation.get_invocations()]
|
||||
InvocationOutputsUnion = Union[BaseInvocationOutput.get_all_subclasses_tuple()]
|
||||
InvocationsUnion = Union[BaseInvocation.get_invocations()] # type: ignore
|
||||
InvocationOutputsUnion = Union[BaseInvocationOutput.get_all_subclasses_tuple()] # type: ignore
|
||||
|
||||
|
||||
class Graph(BaseModel):
|
||||
@ -486,11 +487,11 @@ class Graph(BaseModel):
|
||||
type_tree.add_nodes_from(input_field_types)
|
||||
type_tree.add_edges_from([e for e in itertools.permutations(input_field_types, 2) if issubclass(e[1], e[0])])
|
||||
type_degrees = type_tree.in_degree(type_tree.nodes)
|
||||
if sum((t[1] == 0 for t in type_degrees)) != 1:
|
||||
if sum((t[1] == 0 for t in type_degrees)) != 1: # type: ignore
|
||||
return False # There is more than one root type
|
||||
|
||||
# Get the input root type
|
||||
input_root_type = next(t[0] for t in type_degrees if t[1] == 0)
|
||||
input_root_type = next(t[0] for t in type_degrees if t[1] == 0) # type: ignore
|
||||
|
||||
# Verify that all outputs are lists
|
||||
if not all((get_origin(f) == list for f in output_fields)):
|
||||
@ -545,6 +546,9 @@ class GraphExecutionState(BaseModel):
|
||||
# The results of executed nodes
|
||||
results: dict[str, Annotated[InvocationOutputsUnion, Field(discriminator="type")]] = Field(description="The results of node executions", default_factory=dict)
|
||||
|
||||
# Errors raised when executing nodes
|
||||
errors: dict[str, str] = Field(description="Errors raised when executing nodes", default_factory=dict)
|
||||
|
||||
# Map of prepared/executed nodes to their original nodes
|
||||
prepared_source_mapping: dict[str, str] = Field(description="The map of prepared nodes to original graph nodes", default_factory=dict)
|
||||
|
||||
@ -594,9 +598,17 @@ class GraphExecutionState(BaseModel):
|
||||
self.executed.add(source_node)
|
||||
self.executed_history.append(source_node)
|
||||
|
||||
def set_node_error(self, node_id: str, error: str):
|
||||
"""Marks a node as errored"""
|
||||
self.errors[node_id] = error
|
||||
|
||||
def is_complete(self) -> bool:
|
||||
"""Returns true if the graph is complete"""
|
||||
return all((k in self.executed for k in self.graph.nodes))
|
||||
return self.has_error() or all((k in self.executed for k in self.graph.nodes))
|
||||
|
||||
def has_error(self) -> bool:
|
||||
"""Returns true if the graph has any errors"""
|
||||
return len(self.errors) > 0
|
||||
|
||||
def _create_execution_node(self, node_path: str, iteration_node_map: list[tuple[str, str]]) -> list[str]:
|
||||
"""Prepares an iteration node and connects all edges, returning the new node id"""
|
||||
@ -709,11 +721,11 @@ class GraphExecutionState(BaseModel):
|
||||
# For every iterator, the parent must either not be a child of that iterator, or must match the prepared iteration for that iterator
|
||||
# TODO: Handle a node mapping to none
|
||||
eg = self.execution_graph.nx_graph_flat()
|
||||
prepared_parent_mappings = [[(n,self._get_iteration_node(n, g, eg, it)) for n in next_node_parents] for it in iterator_node_prepared_combinations]
|
||||
prepared_parent_mappings = [[(n,self._get_iteration_node(n, g, eg, it)) for n in next_node_parents] for it in iterator_node_prepared_combinations] # type: ignore
|
||||
|
||||
# Create execution node for each iteration
|
||||
for iteration_mappings in prepared_parent_mappings:
|
||||
create_results = self._create_execution_node(next_node_id, iteration_mappings)
|
||||
create_results = self._create_execution_node(next_node_id, iteration_mappings) # type: ignore
|
||||
if create_results is not None:
|
||||
new_node_ids.extend(create_results)
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
from threading import Event, Thread
|
||||
import traceback
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
from .invocation_queue import InvocationQueueItem
|
||||
from .invoker import InvocationProcessorABC, Invoker
|
||||
@ -61,18 +62,34 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
result = outputs.dict()
|
||||
)
|
||||
|
||||
# Queue any further commands if invoking all
|
||||
is_complete = graph_execution_state.is_complete()
|
||||
if queue_item.invoke_all and not is_complete:
|
||||
self.__invoker.invoke(graph_execution_state, invoke_all = True)
|
||||
elif is_complete:
|
||||
self.__invoker.services.events.emit_graph_execution_complete(graph_execution_state.id)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
# TODO: Log the error, mark the invocation as failed, and emit an event
|
||||
print(f'Error invoking {invocation.id}: {e}')
|
||||
error = traceback.format_exc()
|
||||
|
||||
# Save error
|
||||
graph_execution_state.set_node_error(invocation.id, error)
|
||||
|
||||
# Save the state changes
|
||||
self.__invoker.services.graph_execution_manager.set(graph_execution_state)
|
||||
|
||||
# Send error event
|
||||
self.__invoker.services.events.emit_invocation_error(
|
||||
graph_execution_state_id = graph_execution_state.id,
|
||||
invocation_id = invocation.id,
|
||||
error = error
|
||||
)
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# Queue any further commands if invoking all
|
||||
is_complete = graph_execution_state.is_complete()
|
||||
if queue_item.invoke_all and not is_complete:
|
||||
self.__invoker.invoke(graph_execution_state, invoke_all = True)
|
||||
elif is_complete:
|
||||
self.__invoker.services.events.emit_graph_execution_complete(graph_execution_state.id)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
... # Log something?
|
||||
|
@ -1,4 +1,4 @@
|
||||
from .test_nodes import ImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation, PromptCollectionTestInvocation, TestEventService, create_edge, wait_until
|
||||
from .test_nodes import ErrorInvocation, ImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation, PromptCollectionTestInvocation, TestEventService, create_edge, wait_until
|
||||
from ldm.invoke.app.services.processor import DefaultInvocationProcessor
|
||||
from ldm.invoke.app.services.sqlite import SqliteItemStorage, sqlite_memory
|
||||
from ldm.invoke.app.services.invocation_queue import MemoryInvocationQueue
|
||||
@ -21,9 +21,9 @@ def simple_graph():
|
||||
def mock_services() -> InvocationServices:
|
||||
# NOTE: none of these are actually called by the test invocations
|
||||
return InvocationServices(
|
||||
generate = None,
|
||||
generate = None, # type: ignore
|
||||
events = TestEventService(),
|
||||
images = None,
|
||||
images = None, # type: ignore
|
||||
queue = MemoryInvocationQueue(),
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
|
||||
processor = DefaultInvocationProcessor()
|
||||
@ -79,3 +79,22 @@ def test_can_invoke_all(mock_invoker: Invoker, simple_graph):
|
||||
|
||||
g = mock_invoker.services.graph_execution_manager.get(g.id)
|
||||
assert g.is_complete()
|
||||
|
||||
def test_handles_errors(mock_invoker: Invoker):
|
||||
g = mock_invoker.create_execution_state()
|
||||
g.graph.add_node(ErrorInvocation(id = "1"))
|
||||
|
||||
mock_invoker.invoke(g, invoke_all=True)
|
||||
|
||||
def has_executed_all(g: GraphExecutionState):
|
||||
g = mock_invoker.services.graph_execution_manager.get(g.id)
|
||||
return g.is_complete()
|
||||
|
||||
wait_until(lambda: has_executed_all(g), timeout = 5, interval = 1)
|
||||
mock_invoker.stop()
|
||||
|
||||
g = mock_invoker.services.graph_execution_manager.get(g.id)
|
||||
assert g.has_error()
|
||||
assert g.is_complete()
|
||||
|
||||
assert all((i in g.errors for i in g.source_prepared_mapping['1']))
|
||||
|
@ -32,6 +32,12 @@ class PromptTestInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> PromptTestInvocationOutput:
|
||||
return PromptTestInvocationOutput(prompt = self.prompt)
|
||||
|
||||
class ErrorInvocation(BaseInvocation):
|
||||
type: Literal['test_error'] = 'test_error'
|
||||
|
||||
def invoke(self, context: InvocationContext) -> PromptTestInvocationOutput:
|
||||
raise Exception("This invocation is supposed to fail")
|
||||
|
||||
class ImageTestInvocationOutput(BaseInvocationOutput):
|
||||
type: Literal['test_image_output'] = 'test_image_output'
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user