mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into bugfix/restore-pytorch-lightning
This commit is contained in:
commit
63b9ec4c5e
@ -102,6 +102,29 @@ def generate_matching_edges(
|
||||
return edges
|
||||
|
||||
|
||||
class SessionError(Exception):
|
||||
"""Raised when a session error has occurred"""
|
||||
pass
|
||||
|
||||
|
||||
def invoke_all(context: CliContext):
|
||||
"""Runs all invocations in the specified session"""
|
||||
context.invoker.invoke(context.session, invoke_all=True)
|
||||
while not context.session.is_complete():
|
||||
# Wait some time
|
||||
session = context.get_session()
|
||||
time.sleep(0.1)
|
||||
|
||||
# Print any errors
|
||||
if context.session.has_error():
|
||||
for n in context.session.errors:
|
||||
print(
|
||||
f"Error in node {n} (source node {context.session.prepared_source_mapping[n]}): {session.errors[n]}"
|
||||
)
|
||||
|
||||
raise SessionError()
|
||||
|
||||
|
||||
def invoke_cli():
|
||||
args = Args()
|
||||
config = args.parse_args()
|
||||
@ -134,7 +157,6 @@ def invoke_cli():
|
||||
|
||||
invoker = Invoker(services)
|
||||
session: GraphExecutionState = invoker.create_execution_state()
|
||||
|
||||
parser = get_command_parser()
|
||||
|
||||
# Uncomment to print out previous sessions at startup
|
||||
@ -151,8 +173,7 @@ def invoke_cli():
|
||||
|
||||
try:
|
||||
# Refresh the state of the session
|
||||
session = invoker.services.graph_execution_manager.get(session.id)
|
||||
history = list(get_graph_execution_history(session))
|
||||
history = list(get_graph_execution_history(context.session))
|
||||
|
||||
# Split the command for piping
|
||||
cmds = cmd_input.split("|")
|
||||
@ -164,7 +185,7 @@ def invoke_cli():
|
||||
raise InvalidArgs("Empty command")
|
||||
|
||||
# Parse args to create invocation
|
||||
args = vars(parser.parse_args(shlex.split(cmd.strip())))
|
||||
args = vars(context.parser.parse_args(shlex.split(cmd.strip())))
|
||||
|
||||
# Override defaults
|
||||
for field_name, field_default in context.defaults.items():
|
||||
@ -176,11 +197,11 @@ def invoke_cli():
|
||||
command = CliCommand(command=args)
|
||||
|
||||
# Run any CLI commands immediately
|
||||
# TODO: this won't behave as expected if piping and using e.g. history,
|
||||
# since invocations are gathered and then run together at the end.
|
||||
# This is more efficient if the CLI is running against a distributed
|
||||
# backend, so it's preferable not to change that behavior.
|
||||
if isinstance(command.command, BaseCommand):
|
||||
# Invoke all current nodes to preserve operation order
|
||||
invoke_all(context)
|
||||
|
||||
# Run the command
|
||||
command.command.run(context)
|
||||
continue
|
||||
|
||||
@ -193,7 +214,7 @@ def invoke_cli():
|
||||
from_node = (
|
||||
next(filter(lambda n: n[0].id == from_id, new_invocations))[0]
|
||||
if current_id != start_id
|
||||
else session.graph.get_node(from_id)
|
||||
else context.session.graph.get_node(from_id)
|
||||
)
|
||||
matching_edges = generate_matching_edges(
|
||||
from_node, command.command
|
||||
@ -203,7 +224,7 @@ def invoke_cli():
|
||||
# Parse provided links
|
||||
if "link_node" in args and args["link_node"]:
|
||||
for link in args["link_node"]:
|
||||
link_node = session.graph.get_node(link)
|
||||
link_node = context.session.graph.get_node(link)
|
||||
matching_edges = generate_matching_edges(
|
||||
link_node, command.command
|
||||
)
|
||||
@ -227,37 +248,24 @@ def invoke_cli():
|
||||
|
||||
current_id = current_id + 1
|
||||
|
||||
# Command line was parsed successfully
|
||||
# Add the invocations to the session
|
||||
for invocation in new_invocations:
|
||||
session.add_node(invocation[0])
|
||||
for edge in invocation[1]:
|
||||
# Add the node to the session
|
||||
context.session.add_node(command.command)
|
||||
for edge in edges:
|
||||
print(edge)
|
||||
session.add_edge(edge)
|
||||
context.session.add_edge(edge)
|
||||
|
||||
# Execute all available invocations
|
||||
invoker.invoke(session, invoke_all=True)
|
||||
while not session.is_complete():
|
||||
# Wait some time
|
||||
session = context.get_session()
|
||||
time.sleep(0.1)
|
||||
|
||||
# Print any errors
|
||||
if session.has_error():
|
||||
for n in session.errors:
|
||||
print(
|
||||
f"Error in node {n} (source node {session.prepared_source_mapping[n]}): {session.errors[n]}"
|
||||
)
|
||||
|
||||
# Start a new session
|
||||
print("Creating a new session")
|
||||
session = invoker.create_execution_state()
|
||||
context.session = session
|
||||
# Execute all remaining nodes
|
||||
invoke_all(context)
|
||||
|
||||
except InvalidArgs:
|
||||
print('Invalid command, use "help" to list commands')
|
||||
continue
|
||||
|
||||
except SessionError:
|
||||
# Start a new session
|
||||
print("Session error: creating a new session")
|
||||
context.session = context.invoker.create_execution_state()
|
||||
|
||||
except ExitCli:
|
||||
break
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user