InvokeAI/invokeai/app/cli_app.py

382 lines
12 KiB
Python
Raw Normal View History

# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import argparse
import os
2023-03-03 06:02:00 +00:00
import shlex
import time
2023-03-03 06:02:00 +00:00
from typing import (
Any,
Dict,
Iterable,
Literal,
Union,
get_args,
get_origin,
get_type_hints,
)
from pydantic import BaseModel
from pydantic.fields import Field
from ..backend import Args
2023-03-03 06:02:00 +00:00
from .invocations import *
from .invocations.baseinvocation import BaseInvocation
from .invocations.image import ImageField
2023-03-03 06:02:00 +00:00
from .services.events import EventServiceBase
from .services.generate_initializer import get_generate
2023-03-03 06:02:00 +00:00
from .services.graph import EdgeConnection, GraphExecutionState
from .services.image_storage import DiskImageStorage
from .services.invocation_queue import MemoryInvocationQueue
from .services.invocation_services import InvocationServices
2023-03-03 06:02:00 +00:00
from .services.invoker import Invoker
from .services.processor import DefaultInvocationProcessor
from .services.sqlite import SqliteItemStorage
class InvocationCommand(BaseModel):
2023-03-03 06:02:00 +00:00
invocation: Union[BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore
class InvalidArgs(Exception):
pass
def get_invocation_parser() -> argparse.ArgumentParser:
# Create invocation parser
parser = argparse.ArgumentParser()
2023-03-03 06:02:00 +00:00
def exit(*args, **kwargs):
raise InvalidArgs
2023-03-03 06:02:00 +00:00
parser.exit = exit
2023-03-03 06:02:00 +00:00
subparsers = parser.add_subparsers(dest="type")
invocation_parsers = dict()
# Add history parser
2023-03-03 06:02:00 +00:00
history_parser = subparsers.add_parser(
"history", help="Shows the invocation history"
)
history_parser.add_argument(
"count",
nargs="?",
default=5,
type=int,
help="The number of history entries to show",
)
# Add default parser
2023-03-03 06:02:00 +00:00
default_parser = subparsers.add_parser(
"default", help="Define a default value for all inputs with a specified name"
)
default_parser.add_argument("input", type=str, help="The input field")
default_parser.add_argument("value", help="The default value")
default_parser = subparsers.add_parser(
"reset_default", help="Resets a default value"
)
default_parser.add_argument("input", type=str, help="The input field")
# Create subparsers for each invocation
invocations = BaseInvocation.get_all_subclasses()
for invocation in invocations:
hints = get_type_hints(invocation)
2023-03-03 06:02:00 +00:00
cmd_name = get_args(hints["type"])[0]
command_parser = subparsers.add_parser(cmd_name, help=invocation.__doc__)
invocation_parsers[cmd_name] = command_parser
# Add linking capability
2023-03-03 06:02:00 +00:00
command_parser.add_argument(
"--link",
"-l",
action="append",
nargs=3,
help="A link in the format 'dest_field source_node source_field'. source_node can be relative to history (e.g. -1)",
)
command_parser.add_argument(
"--link_node",
"-ln",
action="append",
help="A link from all fields in the specified node. Node can be relative to history (e.g. -1)",
)
# Convert all fields to arguments
fields = invocation.__fields__
for name, field in fields.items():
2023-03-03 06:02:00 +00:00
if name in ["id", "type"]:
continue
2023-03-03 06:02:00 +00:00
if get_origin(field.type_) == Literal:
allowed_values = get_args(field.type_)
allowed_types = set()
for val in allowed_values:
allowed_types.add(type(val))
allowed_types_list = list(allowed_types)
2023-03-03 06:02:00 +00:00
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
command_parser.add_argument(
f"--{name}",
dest=name,
type=field_type,
default=field.default,
2023-03-03 06:02:00 +00:00
choices=allowed_values,
help=field.field_info.description,
)
else:
command_parser.add_argument(
f"--{name}",
dest=name,
type=field.type_,
2023-03-03 06:02:00 +00:00
default=field.default,
help=field.field_info.description,
)
return parser
def get_invocation_command(invocation) -> str:
fields = invocation.__fields__.items()
type_hints = get_type_hints(type(invocation))
command = [invocation.type]
2023-03-03 06:02:00 +00:00
for name, field in fields:
if name in ["id", "type"]:
continue
# TODO: add links
# Skip image fields when serializing command
type_hint = type_hints.get(name) or None
if type_hint is ImageField or ImageField in get_args(type_hint):
continue
field_value = getattr(invocation, name)
field_default = field.default
if field_value != field_default:
if type_hint is str or str in get_args(type_hint):
command.append(f'--{name} "{field_value}"')
else:
2023-03-03 06:02:00 +00:00
command.append(f"--{name} {field_value}")
return " ".join(command)
2023-03-03 06:02:00 +00:00
def get_graph_execution_history(
graph_execution_state: GraphExecutionState,
) -> Iterable[str]:
"""Gets the history of fully-executed invocations for a graph execution"""
2023-03-03 06:02:00 +00:00
return (
n
for n in reversed(graph_execution_state.executed_history)
if n in graph_execution_state.graph.nodes
)
2023-03-03 06:02:00 +00:00
def generate_matching_edges(
a: BaseInvocation, b: BaseInvocation
) -> list[tuple[EdgeConnection, EdgeConnection]]:
"""Generates all possible edges between two invocations"""
atype = type(a)
btype = type(b)
aoutputtype = atype.get_output_type()
afields = get_type_hints(aoutputtype)
bfields = get_type_hints(btype)
matching_fields = set(afields.keys()).intersection(bfields.keys())
2023-03-03 06:02:00 +00:00
# Remove invalid fields
2023-03-03 06:02:00 +00:00
invalid_fields = set(["type", "id"])
matching_fields = matching_fields.difference(invalid_fields)
2023-03-03 06:02:00 +00:00
edges = [
(
EdgeConnection(node_id=a.id, field=field),
EdgeConnection(node_id=b.id, field=field),
)
for field in matching_fields
]
return edges
def invoke_cli():
args = Args()
config = args.parse_args()
generate = get_generate(args, config)
# NOTE: load model on first use, uncomment to load at startup
# TODO: Make this a config option?
2023-03-03 06:02:00 +00:00
# generate.load_model()
events = EventServiceBase()
2023-03-03 06:02:00 +00:00
output_folder = os.path.abspath(
os.path.join(os.path.dirname(__file__), "../../../outputs")
)
# TODO: build a file/path manager?
2023-03-03 06:02:00 +00:00
db_location = os.path.join(output_folder, "invokeai.db")
services = InvocationServices(
2023-03-03 06:02:00 +00:00
generate=generate,
events=events,
images=DiskImageStorage(output_folder),
queue=MemoryInvocationQueue(),
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
filename=db_location, table_name="graph_executions"
),
processor=DefaultInvocationProcessor(),
)
invoker = Invoker(services)
session: GraphExecutionState = invoker.create_execution_state()
2023-03-03 06:02:00 +00:00
parser = get_invocation_parser()
# Uncomment to print out previous sessions at startup
# print(services.session_manager.list())
# Defaults storage
defaults: Dict[str, Any] = dict()
while True:
try:
cmd_input = input("> ")
except KeyboardInterrupt:
# Ctrl-c exits
break
2023-03-03 06:02:00 +00:00
if cmd_input in ["exit", "q"]:
break
2023-03-03 06:02:00 +00:00
if cmd_input in ["--help", "help", "h", "?"]:
parser.print_help()
continue
try:
# Refresh the state of the session
session = invoker.services.graph_execution_manager.get(session.id)
history = list(get_graph_execution_history(session))
# Split the command for piping
2023-03-03 06:02:00 +00:00
cmds = cmd_input.split("|")
start_id = len(history)
current_id = start_id
new_invocations = list()
for cmd in cmds:
2023-03-03 06:02:00 +00:00
if cmd is None or cmd.strip() == "":
raise InvalidArgs("Empty command")
# Parse args to create invocation
args = vars(parser.parse_args(shlex.split(cmd.strip())))
# Check for special commands
# TODO: These might be better as Pydantic models, similar to the invocations
2023-03-03 06:02:00 +00:00
if args["type"] == "history":
history_count = args["count"] or 5
for i in range(min(history_count, len(history))):
entry_id = history[-1 - i]
entry = session.graph.get_node(entry_id)
2023-03-03 06:02:00 +00:00
print(f"{entry_id}: {get_invocation_command(entry.invocation)}")
continue
2023-03-03 06:02:00 +00:00
if args["type"] == "reset_default":
if args["input"] in defaults:
del defaults[args["input"]]
continue
2023-03-03 06:02:00 +00:00
if args["type"] == "default":
field = args["input"]
field_value = args["value"]
defaults[field] = field_value
continue
# Override defaults
2023-03-03 06:02:00 +00:00
for field_name, field_default in defaults.items():
if field_name in args:
args[field_name] = field_default
# Parse invocation
2023-03-03 06:02:00 +00:00
args["id"] = current_id
command = InvocationCommand(invocation=args)
# Pipe previous command output (if there was a previous command)
edges = []
if len(history) > 0 or current_id != start_id:
2023-03-03 06:02:00 +00:00
from_id = (
history[0] if current_id == start_id else str(current_id - 1)
)
from_node = (
next(filter(lambda n: n[0].id == from_id, new_invocations))[0]
if current_id != start_id
else session.graph.get_node(from_id)
)
matching_edges = generate_matching_edges(
from_node, command.invocation
)
edges.extend(matching_edges)
2023-03-03 06:02:00 +00:00
# Parse provided links
2023-03-03 06:02:00 +00:00
if "link_node" in args and args["link_node"]:
for link in args["link_node"]:
link_node = session.graph.get_node(link)
2023-03-03 06:02:00 +00:00
matching_edges = generate_matching_edges(
link_node, command.invocation
)
edges.extend(matching_edges)
2023-03-03 06:02:00 +00:00
if "link" in args and args["link"]:
for link in args["link"]:
edges.append(
(
EdgeConnection(node_id=link[1], field=link[0]),
EdgeConnection(
node_id=command.invocation.id, field=link[2]
),
)
)
new_invocations.append((command.invocation, edges))
current_id = current_id + 1
# Command line was parsed successfully
# Add the invocations to the session
for invocation in new_invocations:
session.add_node(invocation[0])
for edge in invocation[1]:
session.add_edge(edge)
# Execute all available invocations
2023-03-03 06:02:00 +00:00
invoker.invoke(session, invoke_all=True)
while not session.is_complete():
# Wait some time
session = invoker.services.graph_execution_manager.get(session.id)
time.sleep(0.1)
2023-03-03 06:02:00 +00:00
# Print any errors
if session.has_error():
for n in session.errors:
2023-03-03 06:02:00 +00:00
print(
f"Error in node {n} (source node {session.prepared_source_mapping[n]}): {session.errors[n]}"
)
# Start a new session
print("Creating a new session")
session = invoker.create_execution_state()
except InvalidArgs:
print('Invalid command, use "help" to list commands')
continue
except SystemExit:
continue
2023-03-03 06:02:00 +00:00
invoker.stop()
if __name__ == "__main__":
invoke_cli()