diff --git a/invokeai/app/cli_app.py b/invokeai/app/cli_app.py index 721760b222..9dc1429d92 100644 --- a/invokeai/app/cli_app.py +++ b/invokeai/app/cli_app.py @@ -102,6 +102,29 @@ def generate_matching_edges( return edges +class SessionError(Exception): + """Raised when a session error has occurred""" + pass + + +def invoke_all(context: CliContext): + """Runs all invocations in the specified session""" + context.invoker.invoke(context.session, invoke_all=True) + while not context.session.is_complete(): + # Wait some time + session = context.get_session() + time.sleep(0.1) + + # Print any errors + if context.session.has_error(): + for n in context.session.errors: + print( + f"Error in node {n} (source node {context.session.prepared_source_mapping[n]}): {session.errors[n]}" + ) + + raise SessionError() + + def invoke_cli(): args = Args() config = args.parse_args() @@ -134,7 +157,6 @@ def invoke_cli(): invoker = Invoker(services) session: GraphExecutionState = invoker.create_execution_state() - parser = get_command_parser() # Uncomment to print out previous sessions at startup @@ -151,8 +173,7 @@ def invoke_cli(): try: # Refresh the state of the session - session = invoker.services.graph_execution_manager.get(session.id) - history = list(get_graph_execution_history(session)) + history = list(get_graph_execution_history(context.session)) # Split the command for piping cmds = cmd_input.split("|") @@ -164,7 +185,7 @@ def invoke_cli(): raise InvalidArgs("Empty command") # Parse args to create invocation - args = vars(parser.parse_args(shlex.split(cmd.strip()))) + args = vars(context.parser.parse_args(shlex.split(cmd.strip()))) # Override defaults for field_name, field_default in context.defaults.items(): @@ -176,11 +197,11 @@ def invoke_cli(): command = CliCommand(command=args) # Run any CLI commands immediately - # TODO: this won't behave as expected if piping and using e.g. history, - # since invocations are gathered and then run together at the end. - # This is more efficient if the CLI is running against a distributed - # backend, so it's preferable not to change that behavior. if isinstance(command.command, BaseCommand): + # Invoke all current nodes to preserve operation order + invoke_all(context) + + # Run the command command.command.run(context) continue @@ -193,7 +214,7 @@ def invoke_cli(): from_node = ( next(filter(lambda n: n[0].id == from_id, new_invocations))[0] if current_id != start_id - else session.graph.get_node(from_id) + else context.session.graph.get_node(from_id) ) matching_edges = generate_matching_edges( from_node, command.command @@ -203,7 +224,7 @@ def invoke_cli(): # Parse provided links if "link_node" in args and args["link_node"]: for link in args["link_node"]: - link_node = session.graph.get_node(link) + link_node = context.session.graph.get_node(link) matching_edges = generate_matching_edges( link_node, command.command ) @@ -227,37 +248,24 @@ def invoke_cli(): current_id = current_id + 1 - # Command line was parsed successfully - # Add the invocations to the session - for invocation in new_invocations: - session.add_node(invocation[0]) - for edge in invocation[1]: + # Add the node to the session + context.session.add_node(command.command) + for edge in edges: print(edge) - session.add_edge(edge) + context.session.add_edge(edge) - # Execute all available invocations - invoker.invoke(session, invoke_all=True) - while not session.is_complete(): - # Wait some time - session = context.get_session() - time.sleep(0.1) - - # Print any errors - if session.has_error(): - for n in session.errors: - print( - f"Error in node {n} (source node {session.prepared_source_mapping[n]}): {session.errors[n]}" - ) - - # Start a new session - print("Creating a new session") - session = invoker.create_execution_state() - context.session = session + # Execute all remaining nodes + invoke_all(context) except InvalidArgs: print('Invalid command, use "help" to list commands') continue + except SessionError: + # Start a new session + print("Session error: creating a new session") + context.session = context.invoker.create_execution_state() + except ExitCli: break