InvokeAI/invokeai/app/cli_app.py

276 lines
8.9 KiB
Python
Raw Normal View History

# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import argparse
import os
2023-03-03 06:02:00 +00:00
import shlex
import time
2023-03-03 06:02:00 +00:00
from typing import (
Union,
get_type_hints,
)
from pydantic import BaseModel
from pydantic.fields import Field
from ..backend import Args
from .cli.commands import BaseCommand, CliContext, ExitCli, add_parsers, get_graph_execution_history
2023-03-03 06:02:00 +00:00
from .invocations import *
from .invocations.baseinvocation import BaseInvocation
from .services.events import EventServiceBase
2023-03-11 16:32:57 +00:00
from .services.model_manager_initializer import get_model_manager
2023-03-11 22:00:00 +00:00
from .services.restoration_services import RestorationServices
2023-03-15 06:09:30 +00:00
from .services.graph import Edge, EdgeConnection, GraphExecutionState
from .services.image_storage import DiskImageStorage
from .services.invocation_queue import MemoryInvocationQueue
from .services.invocation_services import InvocationServices
2023-03-03 06:02:00 +00:00
from .services.invoker import Invoker
from .services.processor import DefaultInvocationProcessor
from .services.sqlite import SqliteItemStorage
class CliCommand(BaseModel):
command: Union[BaseCommand.get_commands() + BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore
class InvalidArgs(Exception):
pass
def add_invocation_args(command_parser):
# Add linking capability
command_parser.add_argument(
"--link",
"-l",
action="append",
nargs=3,
help="A link in the format 'dest_field source_node source_field'. source_node can be relative to history (e.g. -1)",
)
command_parser.add_argument(
"--link_node",
"-ln",
action="append",
help="A link from all fields in the specified node. Node can be relative to history (e.g. -1)",
)
def get_command_parser() -> argparse.ArgumentParser:
# Create invocation parser
parser = argparse.ArgumentParser()
2023-03-03 06:02:00 +00:00
def exit(*args, **kwargs):
raise InvalidArgs
2023-03-03 06:02:00 +00:00
parser.exit = exit
2023-03-03 06:02:00 +00:00
subparsers = parser.add_subparsers(dest="type")
# Create subparsers for each invocation
invocations = BaseInvocation.get_all_subclasses()
add_parsers(subparsers, invocations, add_arguments=add_invocation_args)
2023-03-03 06:02:00 +00:00
# Create subparsers for each command
commands = BaseCommand.get_all_subclasses()
add_parsers(subparsers, commands, exclude_fields=["type"])
return parser
2023-03-03 06:02:00 +00:00
def generate_matching_edges(
a: BaseInvocation, b: BaseInvocation
2023-03-15 06:09:30 +00:00
) -> list[Edge]:
"""Generates all possible edges between two invocations"""
atype = type(a)
btype = type(b)
aoutputtype = atype.get_output_type()
afields = get_type_hints(aoutputtype)
bfields = get_type_hints(btype)
matching_fields = set(afields.keys()).intersection(bfields.keys())
2023-03-03 06:02:00 +00:00
# Remove invalid fields
2023-03-03 06:02:00 +00:00
invalid_fields = set(["type", "id"])
matching_fields = matching_fields.difference(invalid_fields)
2023-03-03 06:02:00 +00:00
edges = [
2023-03-15 06:09:30 +00:00
Edge(
source=EdgeConnection(node_id=a.id, field=field),
destination=EdgeConnection(node_id=b.id, field=field)
2023-03-03 06:02:00 +00:00
)
for field in matching_fields
]
return edges
class SessionError(Exception):
"""Raised when a session error has occurred"""
pass
def invoke_all(context: CliContext):
"""Runs all invocations in the specified session"""
context.invoker.invoke(context.session, invoke_all=True)
2023-03-15 06:09:30 +00:00
while not context.get_session().is_complete():
# Wait some time
time.sleep(0.1)
# Print any errors
if context.session.has_error():
for n in context.session.errors:
print(
2023-03-15 06:09:30 +00:00
f"Error in node {n} (source node {context.session.prepared_source_mapping[n]}): {context.session.errors[n]}"
)
raise SessionError()
def invoke_cli():
2023-03-11 16:32:57 +00:00
config = Args()
config.parse_args()
model_manager = get_model_manager(config)
events = EventServiceBase()
2023-03-03 06:02:00 +00:00
output_folder = os.path.abspath(
os.path.join(os.path.dirname(__file__), "../../../outputs")
)
# TODO: build a file/path manager?
2023-03-03 06:02:00 +00:00
db_location = os.path.join(output_folder, "invokeai.db")
services = InvocationServices(
model_manager=model_manager,
2023-03-03 06:02:00 +00:00
events=events,
images=DiskImageStorage(output_folder),
queue=MemoryInvocationQueue(),
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
filename=db_location, table_name="graph_executions"
),
processor=DefaultInvocationProcessor(),
2023-03-11 22:00:00 +00:00
restoration=RestorationServices(config),
)
invoker = Invoker(services)
session: GraphExecutionState = invoker.create_execution_state()
parser = get_command_parser()
# Uncomment to print out previous sessions at startup
# print(services.session_manager.list())
context = CliContext(invoker, session, parser)
while True:
try:
cmd_input = input("> ")
except KeyboardInterrupt:
# Ctrl-c exits
break
try:
# Refresh the state of the session
history = list(get_graph_execution_history(context.session))
# Split the command for piping
2023-03-03 06:02:00 +00:00
cmds = cmd_input.split("|")
start_id = len(history)
current_id = start_id
new_invocations = list()
for cmd in cmds:
2023-03-03 06:02:00 +00:00
if cmd is None or cmd.strip() == "":
raise InvalidArgs("Empty command")
# Parse args to create invocation
args = vars(context.parser.parse_args(shlex.split(cmd.strip())))
# Override defaults
for field_name, field_default in context.defaults.items():
if field_name in args:
args[field_name] = field_default
# Parse invocation
2023-03-03 06:02:00 +00:00
args["id"] = current_id
command = CliCommand(command=args)
# Run any CLI commands immediately
if isinstance(command.command, BaseCommand):
# Invoke all current nodes to preserve operation order
invoke_all(context)
# Run the command
command.command.run(context)
continue
# Pipe previous command output (if there was a previous command)
2023-03-15 06:09:30 +00:00
edges: list[Edge] = list()
if len(history) > 0 or current_id != start_id:
2023-03-03 06:02:00 +00:00
from_id = (
history[0] if current_id == start_id else str(current_id - 1)
)
from_node = (
next(filter(lambda n: n[0].id == from_id, new_invocations))[0]
if current_id != start_id
else context.session.graph.get_node(from_id)
2023-03-03 06:02:00 +00:00
)
matching_edges = generate_matching_edges(
from_node, command.command
2023-03-03 06:02:00 +00:00
)
edges.extend(matching_edges)
2023-03-03 06:02:00 +00:00
# Parse provided links
2023-03-03 06:02:00 +00:00
if "link_node" in args and args["link_node"]:
for link in args["link_node"]:
link_node = context.session.graph.get_node(link)
2023-03-03 06:02:00 +00:00
matching_edges = generate_matching_edges(
link_node, command.command
2023-03-03 06:02:00 +00:00
)
2023-03-15 06:09:30 +00:00
matching_destinations = [e.destination for e in matching_edges]
edges = [e for e in edges if e.destination not in matching_destinations]
edges.extend(matching_edges)
2023-03-03 06:02:00 +00:00
if "link" in args and args["link"]:
for link in args["link"]:
2023-03-15 06:09:30 +00:00
edges = [e for e in edges if e.destination.node_id != command.command.id and e.destination.field != link[2]]
2023-03-03 06:02:00 +00:00
edges.append(
2023-03-15 06:09:30 +00:00
Edge(
source=EdgeConnection(node_id=link[1], field=link[0]),
destination=EdgeConnection(
node_id=command.command.id, field=link[2]
2023-03-15 06:09:30 +00:00
)
2023-03-03 06:02:00 +00:00
)
)
new_invocations.append((command.command, edges))
current_id = current_id + 1
# Add the node to the session
context.session.add_node(command.command)
for edge in edges:
print(edge)
context.session.add_edge(edge)
2023-03-03 06:02:00 +00:00
# Execute all remaining nodes
invoke_all(context)
except InvalidArgs:
print('Invalid command, use "help" to list commands')
continue
except SessionError:
# Start a new session
print("Session error: creating a new session")
context.session = context.invoker.create_execution_state()
except ExitCli:
break
except SystemExit:
continue
2023-03-03 06:02:00 +00:00
invoker.stop()
if __name__ == "__main__":
invoke_cli()