mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
[cli] Update CLI to define commands as Pydantic objects
This commit is contained in:
parent
bdc7b8b75a
commit
ebc4b52f41
0
invokeai/app/cli/__init__.py
Normal file
0
invokeai/app/cli/__init__.py
Normal file
202
invokeai/app/cli/commands.py
Normal file
202
invokeai/app/cli/commands.py
Normal file
@ -0,0 +1,202 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import argparse
|
||||
from typing import Any, Callable, Iterable, Literal, get_args, get_origin, get_type_hints
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..invocations.image import ImageField
|
||||
from ..services.graph import GraphExecutionState
|
||||
from ..services.invoker import Invoker
|
||||
|
||||
|
||||
def add_parsers(
|
||||
subparsers,
|
||||
commands: list[type],
|
||||
command_field: str = "type",
|
||||
exclude_fields: list[str] = ["id", "type"],
|
||||
add_arguments: Callable[[argparse.ArgumentParser], None]|None = None
|
||||
):
|
||||
"""Adds parsers for each command to the subparsers"""
|
||||
|
||||
# Create subparsers for each command
|
||||
for command in commands:
|
||||
hints = get_type_hints(command)
|
||||
cmd_name = get_args(hints[command_field])[0]
|
||||
command_parser = subparsers.add_parser(cmd_name, help=command.__doc__)
|
||||
|
||||
if add_arguments is not None:
|
||||
add_arguments(command_parser)
|
||||
|
||||
# Convert all fields to arguments
|
||||
fields = command.__fields__ # type: ignore
|
||||
for name, field in fields.items():
|
||||
if name in exclude_fields:
|
||||
continue
|
||||
|
||||
if get_origin(field.type_) == Literal:
|
||||
allowed_values = get_args(field.type_)
|
||||
allowed_types = set()
|
||||
for val in allowed_values:
|
||||
allowed_types.add(type(val))
|
||||
allowed_types_list = list(allowed_types)
|
||||
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
|
||||
|
||||
command_parser.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field_type,
|
||||
default=field.default,
|
||||
choices=allowed_values,
|
||||
help=field.field_info.description,
|
||||
)
|
||||
else:
|
||||
command_parser.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field.type_,
|
||||
default=field.default,
|
||||
help=field.field_info.description,
|
||||
)
|
||||
|
||||
|
||||
class CliContext:
|
||||
invoker: Invoker
|
||||
session: GraphExecutionState
|
||||
parser: argparse.ArgumentParser
|
||||
defaults: dict[str, Any]
|
||||
|
||||
def __init__(self, invoker: Invoker, session: GraphExecutionState, parser: argparse.ArgumentParser):
|
||||
self.invoker = invoker
|
||||
self.session = session
|
||||
self.parser = parser
|
||||
self.defaults = dict()
|
||||
|
||||
def get_session(self):
|
||||
self.session = self.invoker.services.graph_execution_manager.get(self.session.id)
|
||||
return self.session
|
||||
|
||||
|
||||
class ExitCli(Exception):
|
||||
"""Exception to exit the CLI"""
|
||||
pass
|
||||
|
||||
|
||||
class BaseCommand(ABC, BaseModel):
|
||||
"""A CLI command"""
|
||||
|
||||
# All commands must include a type name like this:
|
||||
# type: Literal['your_command_name'] = 'your_command_name'
|
||||
|
||||
@classmethod
|
||||
def get_all_subclasses(cls):
|
||||
subclasses = []
|
||||
toprocess = [cls]
|
||||
while len(toprocess) > 0:
|
||||
next = toprocess.pop(0)
|
||||
next_subclasses = next.__subclasses__()
|
||||
subclasses.extend(next_subclasses)
|
||||
toprocess.extend(next_subclasses)
|
||||
return subclasses
|
||||
|
||||
@classmethod
|
||||
def get_commands(cls):
|
||||
return tuple(BaseCommand.get_all_subclasses())
|
||||
|
||||
@classmethod
|
||||
def get_commands_map(cls):
|
||||
# Get the type strings out of the literals and into a dictionary
|
||||
return dict(map(lambda t: (get_args(get_type_hints(t)['type'])[0], t),BaseCommand.get_all_subclasses()))
|
||||
|
||||
@abstractmethod
|
||||
def run(self, context: CliContext) -> None:
|
||||
"""Run the command. Raise ExitCli to exit."""
|
||||
pass
|
||||
|
||||
|
||||
class ExitCommand(BaseCommand):
|
||||
"""Exits the CLI"""
|
||||
type: Literal['exit'] = 'exit'
|
||||
|
||||
def run(self, context: CliContext) -> None:
|
||||
raise ExitCli()
|
||||
|
||||
|
||||
class HelpCommand(BaseCommand):
|
||||
"""Shows help"""
|
||||
type: Literal['help'] = 'help'
|
||||
|
||||
def run(self, context: CliContext) -> None:
|
||||
context.parser.print_help()
|
||||
|
||||
|
||||
def get_graph_execution_history(
|
||||
graph_execution_state: GraphExecutionState,
|
||||
) -> Iterable[str]:
|
||||
"""Gets the history of fully-executed invocations for a graph execution"""
|
||||
return (
|
||||
n
|
||||
for n in reversed(graph_execution_state.executed_history)
|
||||
if n in graph_execution_state.graph.nodes
|
||||
)
|
||||
|
||||
|
||||
def get_invocation_command(invocation) -> str:
|
||||
fields = invocation.__fields__.items()
|
||||
type_hints = get_type_hints(type(invocation))
|
||||
command = [invocation.type]
|
||||
for name, field in fields:
|
||||
if name in ["id", "type"]:
|
||||
continue
|
||||
|
||||
# TODO: add links
|
||||
|
||||
# Skip image fields when serializing command
|
||||
type_hint = type_hints.get(name) or None
|
||||
if type_hint is ImageField or ImageField in get_args(type_hint):
|
||||
continue
|
||||
|
||||
field_value = getattr(invocation, name)
|
||||
field_default = field.default
|
||||
if field_value != field_default:
|
||||
if type_hint is str or str in get_args(type_hint):
|
||||
command.append(f'--{name} "{field_value}"')
|
||||
else:
|
||||
command.append(f"--{name} {field_value}")
|
||||
|
||||
return " ".join(command)
|
||||
|
||||
|
||||
class HistoryCommand(BaseCommand):
|
||||
"""Shows the invocation history"""
|
||||
type: Literal['history'] = 'history'
|
||||
|
||||
# Inputs
|
||||
# fmt: off
|
||||
count: int = Field(default=5, gt=0, description="The number of history entries to show")
|
||||
# fmt: on
|
||||
|
||||
def run(self, context: CliContext) -> None:
|
||||
history = list(get_graph_execution_history(context.get_session()))
|
||||
for i in range(min(self.count, len(history))):
|
||||
entry_id = history[-1 - i]
|
||||
entry = context.get_session().graph.get_node(entry_id)
|
||||
print(f"{entry_id}: {get_invocation_command(entry)}")
|
||||
|
||||
|
||||
class SetDefaultCommand(BaseCommand):
|
||||
"""Sets a default value for a field"""
|
||||
type: Literal['default'] = 'default'
|
||||
|
||||
# Inputs
|
||||
# fmt: off
|
||||
field: str = Field(description="The field to set the default for")
|
||||
value: str = Field(description="The value to set the default to, or None to clear the default")
|
||||
# fmt: on
|
||||
|
||||
def run(self, context: CliContext) -> None:
|
||||
if self.value is None:
|
||||
if self.field in context.defaults:
|
||||
del context.defaults[self.field]
|
||||
else:
|
||||
context.defaults[self.field] = self.value
|
@ -5,13 +5,7 @@ import os
|
||||
import shlex
|
||||
import time
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Iterable,
|
||||
Literal,
|
||||
Union,
|
||||
get_args,
|
||||
get_origin,
|
||||
get_type_hints,
|
||||
)
|
||||
|
||||
@ -19,9 +13,9 @@ from pydantic import BaseModel
|
||||
from pydantic.fields import Field
|
||||
|
||||
from ..backend import Args
|
||||
from .cli.commands import BaseCommand, CliContext, ExitCli, add_parsers, get_graph_execution_history
|
||||
from .invocations import *
|
||||
from .invocations.baseinvocation import BaseInvocation
|
||||
from .invocations.image import ImageField
|
||||
from .services.events import EventServiceBase
|
||||
from .services.generate_initializer import get_generate
|
||||
from .services.graph import EdgeConnection, GraphExecutionState
|
||||
@ -33,15 +27,33 @@ from .services.processor import DefaultInvocationProcessor
|
||||
from .services.sqlite import SqliteItemStorage
|
||||
|
||||
|
||||
class InvocationCommand(BaseModel):
|
||||
invocation: Union[BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore
|
||||
class CliCommand(BaseModel):
|
||||
command: Union[BaseCommand.get_commands() + BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore
|
||||
|
||||
|
||||
class InvalidArgs(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def get_invocation_parser() -> argparse.ArgumentParser:
|
||||
def add_invocation_args(command_parser):
|
||||
# Add linking capability
|
||||
command_parser.add_argument(
|
||||
"--link",
|
||||
"-l",
|
||||
action="append",
|
||||
nargs=3,
|
||||
help="A link in the format 'dest_field source_node source_field'. source_node can be relative to history (e.g. -1)",
|
||||
)
|
||||
|
||||
command_parser.add_argument(
|
||||
"--link_node",
|
||||
"-ln",
|
||||
action="append",
|
||||
help="A link from all fields in the specified node. Node can be relative to history (e.g. -1)",
|
||||
)
|
||||
|
||||
|
||||
def get_command_parser() -> argparse.ArgumentParser:
|
||||
# Create invocation parser
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
@ -49,129 +61,19 @@ def get_invocation_parser() -> argparse.ArgumentParser:
|
||||
raise InvalidArgs
|
||||
|
||||
parser.exit = exit
|
||||
|
||||
subparsers = parser.add_subparsers(dest="type")
|
||||
invocation_parsers = dict()
|
||||
|
||||
# Add history parser
|
||||
history_parser = subparsers.add_parser(
|
||||
"history", help="Shows the invocation history"
|
||||
)
|
||||
history_parser.add_argument(
|
||||
"count",
|
||||
nargs="?",
|
||||
default=5,
|
||||
type=int,
|
||||
help="The number of history entries to show",
|
||||
)
|
||||
|
||||
# Add default parser
|
||||
default_parser = subparsers.add_parser(
|
||||
"default", help="Define a default value for all inputs with a specified name"
|
||||
)
|
||||
default_parser.add_argument("input", type=str, help="The input field")
|
||||
default_parser.add_argument("value", help="The default value")
|
||||
|
||||
default_parser = subparsers.add_parser(
|
||||
"reset_default", help="Resets a default value"
|
||||
)
|
||||
default_parser.add_argument("input", type=str, help="The input field")
|
||||
|
||||
# Create subparsers for each invocation
|
||||
invocations = BaseInvocation.get_all_subclasses()
|
||||
for invocation in invocations:
|
||||
hints = get_type_hints(invocation)
|
||||
cmd_name = get_args(hints["type"])[0]
|
||||
command_parser = subparsers.add_parser(cmd_name, help=invocation.__doc__)
|
||||
invocation_parsers[cmd_name] = command_parser
|
||||
add_parsers(subparsers, invocations, add_arguments=add_invocation_args)
|
||||
|
||||
# Add linking capability
|
||||
command_parser.add_argument(
|
||||
"--link",
|
||||
"-l",
|
||||
action="append",
|
||||
nargs=3,
|
||||
help="A link in the format 'dest_field source_node source_field'. source_node can be relative to history (e.g. -1)",
|
||||
)
|
||||
|
||||
command_parser.add_argument(
|
||||
"--link_node",
|
||||
"-ln",
|
||||
action="append",
|
||||
help="A link from all fields in the specified node. Node can be relative to history (e.g. -1)",
|
||||
)
|
||||
|
||||
# Convert all fields to arguments
|
||||
fields = invocation.__fields__
|
||||
for name, field in fields.items():
|
||||
if name in ["id", "type"]:
|
||||
continue
|
||||
|
||||
if get_origin(field.type_) == Literal:
|
||||
allowed_values = get_args(field.type_)
|
||||
allowed_types = set()
|
||||
for val in allowed_values:
|
||||
allowed_types.add(type(val))
|
||||
allowed_types_list = list(allowed_types)
|
||||
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
|
||||
|
||||
command_parser.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field_type,
|
||||
default=field.default,
|
||||
choices=allowed_values,
|
||||
help=field.field_info.description,
|
||||
)
|
||||
else:
|
||||
command_parser.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field.type_,
|
||||
default=field.default,
|
||||
help=field.field_info.description,
|
||||
)
|
||||
# Create subparsers for each command
|
||||
commands = BaseCommand.get_all_subclasses()
|
||||
add_parsers(subparsers, commands, exclude_fields=["type"])
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def get_invocation_command(invocation) -> str:
|
||||
fields = invocation.__fields__.items()
|
||||
type_hints = get_type_hints(type(invocation))
|
||||
command = [invocation.type]
|
||||
for name, field in fields:
|
||||
if name in ["id", "type"]:
|
||||
continue
|
||||
|
||||
# TODO: add links
|
||||
|
||||
# Skip image fields when serializing command
|
||||
type_hint = type_hints.get(name) or None
|
||||
if type_hint is ImageField or ImageField in get_args(type_hint):
|
||||
continue
|
||||
|
||||
field_value = getattr(invocation, name)
|
||||
field_default = field.default
|
||||
if field_value != field_default:
|
||||
if type_hint is str or str in get_args(type_hint):
|
||||
command.append(f'--{name} "{field_value}"')
|
||||
else:
|
||||
command.append(f"--{name} {field_value}")
|
||||
|
||||
return " ".join(command)
|
||||
|
||||
|
||||
def get_graph_execution_history(
|
||||
graph_execution_state: GraphExecutionState,
|
||||
) -> Iterable[str]:
|
||||
"""Gets the history of fully-executed invocations for a graph execution"""
|
||||
return (
|
||||
n
|
||||
for n in reversed(graph_execution_state.executed_history)
|
||||
if n in graph_execution_state.graph.nodes
|
||||
)
|
||||
|
||||
|
||||
def generate_matching_edges(
|
||||
a: BaseInvocation, b: BaseInvocation
|
||||
) -> list[tuple[EdgeConnection, EdgeConnection]]:
|
||||
@ -233,13 +135,12 @@ def invoke_cli():
|
||||
invoker = Invoker(services)
|
||||
session: GraphExecutionState = invoker.create_execution_state()
|
||||
|
||||
parser = get_invocation_parser()
|
||||
parser = get_command_parser()
|
||||
|
||||
# Uncomment to print out previous sessions at startup
|
||||
# print(services.session_manager.list())
|
||||
|
||||
# Defaults storage
|
||||
defaults: Dict[str, Any] = dict()
|
||||
context = CliContext(invoker, session, parser)
|
||||
|
||||
while True:
|
||||
try:
|
||||
@ -248,13 +149,6 @@ def invoke_cli():
|
||||
# Ctrl-c exits
|
||||
break
|
||||
|
||||
if cmd_input in ["exit", "q"]:
|
||||
break
|
||||
|
||||
if cmd_input in ["--help", "help", "h", "?"]:
|
||||
parser.print_help()
|
||||
continue
|
||||
|
||||
try:
|
||||
# Refresh the state of the session
|
||||
session = invoker.services.graph_execution_manager.get(session.id)
|
||||
@ -272,35 +166,23 @@ def invoke_cli():
|
||||
# Parse args to create invocation
|
||||
args = vars(parser.parse_args(shlex.split(cmd.strip())))
|
||||
|
||||
# Check for special commands
|
||||
# TODO: These might be better as Pydantic models, similar to the invocations
|
||||
if args["type"] == "history":
|
||||
history_count = args["count"] or 5
|
||||
for i in range(min(history_count, len(history))):
|
||||
entry_id = history[-1 - i]
|
||||
entry = session.graph.get_node(entry_id)
|
||||
print(f"{entry_id}: {get_invocation_command(entry.invocation)}")
|
||||
continue
|
||||
|
||||
if args["type"] == "reset_default":
|
||||
if args["input"] in defaults:
|
||||
del defaults[args["input"]]
|
||||
continue
|
||||
|
||||
if args["type"] == "default":
|
||||
field = args["input"]
|
||||
field_value = args["value"]
|
||||
defaults[field] = field_value
|
||||
continue
|
||||
|
||||
# Override defaults
|
||||
for field_name, field_default in defaults.items():
|
||||
for field_name, field_default in context.defaults.items():
|
||||
if field_name in args:
|
||||
args[field_name] = field_default
|
||||
|
||||
# Parse invocation
|
||||
args["id"] = current_id
|
||||
command = InvocationCommand(invocation=args)
|
||||
command = CliCommand(command=args)
|
||||
|
||||
# Run any CLI commands immediately
|
||||
# TODO: this won't behave as expected if piping and using e.g. history,
|
||||
# since invocations are gathered and then run together at the end.
|
||||
# This is more efficient if the CLI is running against a distributed
|
||||
# backend, so it's preferable not to change that behavior.
|
||||
if isinstance(command.command, BaseCommand):
|
||||
command.command.run(context)
|
||||
continue
|
||||
|
||||
# Pipe previous command output (if there was a previous command)
|
||||
edges = []
|
||||
@ -314,7 +196,7 @@ def invoke_cli():
|
||||
else session.graph.get_node(from_id)
|
||||
)
|
||||
matching_edges = generate_matching_edges(
|
||||
from_node, command.invocation
|
||||
from_node, command.command
|
||||
)
|
||||
edges.extend(matching_edges)
|
||||
|
||||
@ -323,22 +205,25 @@ def invoke_cli():
|
||||
for link in args["link_node"]:
|
||||
link_node = session.graph.get_node(link)
|
||||
matching_edges = generate_matching_edges(
|
||||
link_node, command.invocation
|
||||
link_node, command.command
|
||||
)
|
||||
matching_destinations = [e[1] for e in matching_edges]
|
||||
edges = [e for e in edges if e[1] not in matching_destinations]
|
||||
edges.extend(matching_edges)
|
||||
|
||||
if "link" in args and args["link"]:
|
||||
for link in args["link"]:
|
||||
edges = [e for e in edges if e[1].node_id != command.command.id and e[1].field != link[2]]
|
||||
edges.append(
|
||||
(
|
||||
EdgeConnection(node_id=link[1], field=link[0]),
|
||||
EdgeConnection(
|
||||
node_id=command.invocation.id, field=link[2]
|
||||
node_id=command.command.id, field=link[2]
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
new_invocations.append((command.invocation, edges))
|
||||
new_invocations.append((command.command, edges))
|
||||
|
||||
current_id = current_id + 1
|
||||
|
||||
@ -347,13 +232,14 @@ def invoke_cli():
|
||||
for invocation in new_invocations:
|
||||
session.add_node(invocation[0])
|
||||
for edge in invocation[1]:
|
||||
print(edge)
|
||||
session.add_edge(edge)
|
||||
|
||||
# Execute all available invocations
|
||||
invoker.invoke(session, invoke_all=True)
|
||||
while not session.is_complete():
|
||||
# Wait some time
|
||||
session = invoker.services.graph_execution_manager.get(session.id)
|
||||
session = context.get_session()
|
||||
time.sleep(0.1)
|
||||
|
||||
# Print any errors
|
||||
@ -366,11 +252,15 @@ def invoke_cli():
|
||||
# Start a new session
|
||||
print("Creating a new session")
|
||||
session = invoker.create_execution_state()
|
||||
context.session = session
|
||||
|
||||
except InvalidArgs:
|
||||
print('Invalid command, use "help" to list commands')
|
||||
continue
|
||||
|
||||
except ExitCli:
|
||||
break
|
||||
|
||||
except SystemExit:
|
||||
continue
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user