Merge branch 'main' into bugfix/restore-pytorch-lightning

This commit is contained in:
blessedcoolant 2023-03-09 16:57:14 +13:00 committed by GitHub
commit 63b9ec4c5e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -102,6 +102,29 @@ def generate_matching_edges(
return edges return edges
class SessionError(Exception):
"""Raised when a session error has occurred"""
pass
def invoke_all(context: CliContext):
"""Runs all invocations in the specified session"""
context.invoker.invoke(context.session, invoke_all=True)
while not context.session.is_complete():
# Wait some time
session = context.get_session()
time.sleep(0.1)
# Print any errors
if context.session.has_error():
for n in context.session.errors:
print(
f"Error in node {n} (source node {context.session.prepared_source_mapping[n]}): {session.errors[n]}"
)
raise SessionError()
def invoke_cli(): def invoke_cli():
args = Args() args = Args()
config = args.parse_args() config = args.parse_args()
@ -134,7 +157,6 @@ def invoke_cli():
invoker = Invoker(services) invoker = Invoker(services)
session: GraphExecutionState = invoker.create_execution_state() session: GraphExecutionState = invoker.create_execution_state()
parser = get_command_parser() parser = get_command_parser()
# Uncomment to print out previous sessions at startup # Uncomment to print out previous sessions at startup
@ -151,8 +173,7 @@ def invoke_cli():
try: try:
# Refresh the state of the session # Refresh the state of the session
session = invoker.services.graph_execution_manager.get(session.id) history = list(get_graph_execution_history(context.session))
history = list(get_graph_execution_history(session))
# Split the command for piping # Split the command for piping
cmds = cmd_input.split("|") cmds = cmd_input.split("|")
@ -164,7 +185,7 @@ def invoke_cli():
raise InvalidArgs("Empty command") raise InvalidArgs("Empty command")
# Parse args to create invocation # Parse args to create invocation
args = vars(parser.parse_args(shlex.split(cmd.strip()))) args = vars(context.parser.parse_args(shlex.split(cmd.strip())))
# Override defaults # Override defaults
for field_name, field_default in context.defaults.items(): for field_name, field_default in context.defaults.items():
@ -176,11 +197,11 @@ def invoke_cli():
command = CliCommand(command=args) command = CliCommand(command=args)
# Run any CLI commands immediately # Run any CLI commands immediately
# TODO: this won't behave as expected if piping and using e.g. history,
# since invocations are gathered and then run together at the end.
# This is more efficient if the CLI is running against a distributed
# backend, so it's preferable not to change that behavior.
if isinstance(command.command, BaseCommand): if isinstance(command.command, BaseCommand):
# Invoke all current nodes to preserve operation order
invoke_all(context)
# Run the command
command.command.run(context) command.command.run(context)
continue continue
@ -193,7 +214,7 @@ def invoke_cli():
from_node = ( from_node = (
next(filter(lambda n: n[0].id == from_id, new_invocations))[0] next(filter(lambda n: n[0].id == from_id, new_invocations))[0]
if current_id != start_id if current_id != start_id
else session.graph.get_node(from_id) else context.session.graph.get_node(from_id)
) )
matching_edges = generate_matching_edges( matching_edges = generate_matching_edges(
from_node, command.command from_node, command.command
@ -203,7 +224,7 @@ def invoke_cli():
# Parse provided links # Parse provided links
if "link_node" in args and args["link_node"]: if "link_node" in args and args["link_node"]:
for link in args["link_node"]: for link in args["link_node"]:
link_node = session.graph.get_node(link) link_node = context.session.graph.get_node(link)
matching_edges = generate_matching_edges( matching_edges = generate_matching_edges(
link_node, command.command link_node, command.command
) )
@ -227,37 +248,24 @@ def invoke_cli():
current_id = current_id + 1 current_id = current_id + 1
# Command line was parsed successfully # Add the node to the session
# Add the invocations to the session context.session.add_node(command.command)
for invocation in new_invocations: for edge in edges:
session.add_node(invocation[0])
for edge in invocation[1]:
print(edge) print(edge)
session.add_edge(edge) context.session.add_edge(edge)
# Execute all available invocations # Execute all remaining nodes
invoker.invoke(session, invoke_all=True) invoke_all(context)
while not session.is_complete():
# Wait some time
session = context.get_session()
time.sleep(0.1)
# Print any errors
if session.has_error():
for n in session.errors:
print(
f"Error in node {n} (source node {session.prepared_source_mapping[n]}): {session.errors[n]}"
)
# Start a new session
print("Creating a new session")
session = invoker.create_execution_state()
context.session = session
except InvalidArgs: except InvalidArgs:
print('Invalid command, use "help" to list commands') print('Invalid command, use "help" to list commands')
continue continue
except SessionError:
# Start a new session
print("Session error: creating a new session")
context.session = context.invoker.create_execution_state()
except ExitCli: except ExitCli:
break break