[cli] Update CLI to define commands as Pydantic objects

This commit is contained in:
Kyle Schouviller 2023-03-04 14:46:02 -08:00
parent bdc7b8b75a
commit ebc4b52f41
3 changed files with 254 additions and 162 deletions

View File

View File

@ -0,0 +1,202 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from abc import ABC, abstractmethod
import argparse
from typing import Any, Callable, Iterable, Literal, get_args, get_origin, get_type_hints
from pydantic import BaseModel, Field
from ..invocations.image import ImageField
from ..services.graph import GraphExecutionState
from ..services.invoker import Invoker
def add_parsers(
subparsers,
commands: list[type],
command_field: str = "type",
exclude_fields: list[str] = ["id", "type"],
add_arguments: Callable[[argparse.ArgumentParser], None]|None = None
):
"""Adds parsers for each command to the subparsers"""
# Create subparsers for each command
for command in commands:
hints = get_type_hints(command)
cmd_name = get_args(hints[command_field])[0]
command_parser = subparsers.add_parser(cmd_name, help=command.__doc__)
if add_arguments is not None:
add_arguments(command_parser)
# Convert all fields to arguments
fields = command.__fields__ # type: ignore
for name, field in fields.items():
if name in exclude_fields:
continue
if get_origin(field.type_) == Literal:
allowed_values = get_args(field.type_)
allowed_types = set()
for val in allowed_values:
allowed_types.add(type(val))
allowed_types_list = list(allowed_types)
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
command_parser.add_argument(
f"--{name}",
dest=name,
type=field_type,
default=field.default,
choices=allowed_values,
help=field.field_info.description,
)
else:
command_parser.add_argument(
f"--{name}",
dest=name,
type=field.type_,
default=field.default,
help=field.field_info.description,
)
class CliContext:
invoker: Invoker
session: GraphExecutionState
parser: argparse.ArgumentParser
defaults: dict[str, Any]
def __init__(self, invoker: Invoker, session: GraphExecutionState, parser: argparse.ArgumentParser):
self.invoker = invoker
self.session = session
self.parser = parser
self.defaults = dict()
def get_session(self):
self.session = self.invoker.services.graph_execution_manager.get(self.session.id)
return self.session
class ExitCli(Exception):
"""Exception to exit the CLI"""
pass
class BaseCommand(ABC, BaseModel):
"""A CLI command"""
# All commands must include a type name like this:
# type: Literal['your_command_name'] = 'your_command_name'
@classmethod
def get_all_subclasses(cls):
subclasses = []
toprocess = [cls]
while len(toprocess) > 0:
next = toprocess.pop(0)
next_subclasses = next.__subclasses__()
subclasses.extend(next_subclasses)
toprocess.extend(next_subclasses)
return subclasses
@classmethod
def get_commands(cls):
return tuple(BaseCommand.get_all_subclasses())
@classmethod
def get_commands_map(cls):
# Get the type strings out of the literals and into a dictionary
return dict(map(lambda t: (get_args(get_type_hints(t)['type'])[0], t),BaseCommand.get_all_subclasses()))
@abstractmethod
def run(self, context: CliContext) -> None:
"""Run the command. Raise ExitCli to exit."""
pass
class ExitCommand(BaseCommand):
"""Exits the CLI"""
type: Literal['exit'] = 'exit'
def run(self, context: CliContext) -> None:
raise ExitCli()
class HelpCommand(BaseCommand):
"""Shows help"""
type: Literal['help'] = 'help'
def run(self, context: CliContext) -> None:
context.parser.print_help()
def get_graph_execution_history(
graph_execution_state: GraphExecutionState,
) -> Iterable[str]:
"""Gets the history of fully-executed invocations for a graph execution"""
return (
n
for n in reversed(graph_execution_state.executed_history)
if n in graph_execution_state.graph.nodes
)
def get_invocation_command(invocation) -> str:
fields = invocation.__fields__.items()
type_hints = get_type_hints(type(invocation))
command = [invocation.type]
for name, field in fields:
if name in ["id", "type"]:
continue
# TODO: add links
# Skip image fields when serializing command
type_hint = type_hints.get(name) or None
if type_hint is ImageField or ImageField in get_args(type_hint):
continue
field_value = getattr(invocation, name)
field_default = field.default
if field_value != field_default:
if type_hint is str or str in get_args(type_hint):
command.append(f'--{name} "{field_value}"')
else:
command.append(f"--{name} {field_value}")
return " ".join(command)
class HistoryCommand(BaseCommand):
"""Shows the invocation history"""
type: Literal['history'] = 'history'
# Inputs
# fmt: off
count: int = Field(default=5, gt=0, description="The number of history entries to show")
# fmt: on
def run(self, context: CliContext) -> None:
history = list(get_graph_execution_history(context.get_session()))
for i in range(min(self.count, len(history))):
entry_id = history[-1 - i]
entry = context.get_session().graph.get_node(entry_id)
print(f"{entry_id}: {get_invocation_command(entry)}")
class SetDefaultCommand(BaseCommand):
"""Sets a default value for a field"""
type: Literal['default'] = 'default'
# Inputs
# fmt: off
field: str = Field(description="The field to set the default for")
value: str = Field(description="The value to set the default to, or None to clear the default")
# fmt: on
def run(self, context: CliContext) -> None:
if self.value is None:
if self.field in context.defaults:
del context.defaults[self.field]
else:
context.defaults[self.field] = self.value

View File

@ -5,13 +5,7 @@ import os
import shlex
import time
from typing import (
Any,
Dict,
Iterable,
Literal,
Union,
get_args,
get_origin,
get_type_hints,
)
@ -19,9 +13,9 @@ from pydantic import BaseModel
from pydantic.fields import Field
from ..backend import Args
from .cli.commands import BaseCommand, CliContext, ExitCli, add_parsers, get_graph_execution_history
from .invocations import *
from .invocations.baseinvocation import BaseInvocation
from .invocations.image import ImageField
from .services.events import EventServiceBase
from .services.generate_initializer import get_generate
from .services.graph import EdgeConnection, GraphExecutionState
@ -33,15 +27,33 @@ from .services.processor import DefaultInvocationProcessor
from .services.sqlite import SqliteItemStorage
class InvocationCommand(BaseModel):
invocation: Union[BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore
class CliCommand(BaseModel):
command: Union[BaseCommand.get_commands() + BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore
class InvalidArgs(Exception):
pass
def get_invocation_parser() -> argparse.ArgumentParser:
def add_invocation_args(command_parser):
# Add linking capability
command_parser.add_argument(
"--link",
"-l",
action="append",
nargs=3,
help="A link in the format 'dest_field source_node source_field'. source_node can be relative to history (e.g. -1)",
)
command_parser.add_argument(
"--link_node",
"-ln",
action="append",
help="A link from all fields in the specified node. Node can be relative to history (e.g. -1)",
)
def get_command_parser() -> argparse.ArgumentParser:
# Create invocation parser
parser = argparse.ArgumentParser()
@ -49,129 +61,19 @@ def get_invocation_parser() -> argparse.ArgumentParser:
raise InvalidArgs
parser.exit = exit
subparsers = parser.add_subparsers(dest="type")
invocation_parsers = dict()
# Add history parser
history_parser = subparsers.add_parser(
"history", help="Shows the invocation history"
)
history_parser.add_argument(
"count",
nargs="?",
default=5,
type=int,
help="The number of history entries to show",
)
# Add default parser
default_parser = subparsers.add_parser(
"default", help="Define a default value for all inputs with a specified name"
)
default_parser.add_argument("input", type=str, help="The input field")
default_parser.add_argument("value", help="The default value")
default_parser = subparsers.add_parser(
"reset_default", help="Resets a default value"
)
default_parser.add_argument("input", type=str, help="The input field")
# Create subparsers for each invocation
invocations = BaseInvocation.get_all_subclasses()
for invocation in invocations:
hints = get_type_hints(invocation)
cmd_name = get_args(hints["type"])[0]
command_parser = subparsers.add_parser(cmd_name, help=invocation.__doc__)
invocation_parsers[cmd_name] = command_parser
add_parsers(subparsers, invocations, add_arguments=add_invocation_args)
# Add linking capability
command_parser.add_argument(
"--link",
"-l",
action="append",
nargs=3,
help="A link in the format 'dest_field source_node source_field'. source_node can be relative to history (e.g. -1)",
)
command_parser.add_argument(
"--link_node",
"-ln",
action="append",
help="A link from all fields in the specified node. Node can be relative to history (e.g. -1)",
)
# Convert all fields to arguments
fields = invocation.__fields__
for name, field in fields.items():
if name in ["id", "type"]:
continue
if get_origin(field.type_) == Literal:
allowed_values = get_args(field.type_)
allowed_types = set()
for val in allowed_values:
allowed_types.add(type(val))
allowed_types_list = list(allowed_types)
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
command_parser.add_argument(
f"--{name}",
dest=name,
type=field_type,
default=field.default,
choices=allowed_values,
help=field.field_info.description,
)
else:
command_parser.add_argument(
f"--{name}",
dest=name,
type=field.type_,
default=field.default,
help=field.field_info.description,
)
# Create subparsers for each command
commands = BaseCommand.get_all_subclasses()
add_parsers(subparsers, commands, exclude_fields=["type"])
return parser
def get_invocation_command(invocation) -> str:
fields = invocation.__fields__.items()
type_hints = get_type_hints(type(invocation))
command = [invocation.type]
for name, field in fields:
if name in ["id", "type"]:
continue
# TODO: add links
# Skip image fields when serializing command
type_hint = type_hints.get(name) or None
if type_hint is ImageField or ImageField in get_args(type_hint):
continue
field_value = getattr(invocation, name)
field_default = field.default
if field_value != field_default:
if type_hint is str or str in get_args(type_hint):
command.append(f'--{name} "{field_value}"')
else:
command.append(f"--{name} {field_value}")
return " ".join(command)
def get_graph_execution_history(
graph_execution_state: GraphExecutionState,
) -> Iterable[str]:
"""Gets the history of fully-executed invocations for a graph execution"""
return (
n
for n in reversed(graph_execution_state.executed_history)
if n in graph_execution_state.graph.nodes
)
def generate_matching_edges(
a: BaseInvocation, b: BaseInvocation
) -> list[tuple[EdgeConnection, EdgeConnection]]:
@ -233,13 +135,12 @@ def invoke_cli():
invoker = Invoker(services)
session: GraphExecutionState = invoker.create_execution_state()
parser = get_invocation_parser()
parser = get_command_parser()
# Uncomment to print out previous sessions at startup
# print(services.session_manager.list())
# Defaults storage
defaults: Dict[str, Any] = dict()
context = CliContext(invoker, session, parser)
while True:
try:
@ -248,13 +149,6 @@ def invoke_cli():
# Ctrl-c exits
break
if cmd_input in ["exit", "q"]:
break
if cmd_input in ["--help", "help", "h", "?"]:
parser.print_help()
continue
try:
# Refresh the state of the session
session = invoker.services.graph_execution_manager.get(session.id)
@ -272,35 +166,23 @@ def invoke_cli():
# Parse args to create invocation
args = vars(parser.parse_args(shlex.split(cmd.strip())))
# Check for special commands
# TODO: These might be better as Pydantic models, similar to the invocations
if args["type"] == "history":
history_count = args["count"] or 5
for i in range(min(history_count, len(history))):
entry_id = history[-1 - i]
entry = session.graph.get_node(entry_id)
print(f"{entry_id}: {get_invocation_command(entry.invocation)}")
continue
if args["type"] == "reset_default":
if args["input"] in defaults:
del defaults[args["input"]]
continue
if args["type"] == "default":
field = args["input"]
field_value = args["value"]
defaults[field] = field_value
continue
# Override defaults
for field_name, field_default in defaults.items():
for field_name, field_default in context.defaults.items():
if field_name in args:
args[field_name] = field_default
# Parse invocation
args["id"] = current_id
command = InvocationCommand(invocation=args)
command = CliCommand(command=args)
# Run any CLI commands immediately
# TODO: this won't behave as expected if piping and using e.g. history,
# since invocations are gathered and then run together at the end.
# This is more efficient if the CLI is running against a distributed
# backend, so it's preferable not to change that behavior.
if isinstance(command.command, BaseCommand):
command.command.run(context)
continue
# Pipe previous command output (if there was a previous command)
edges = []
@ -314,7 +196,7 @@ def invoke_cli():
else session.graph.get_node(from_id)
)
matching_edges = generate_matching_edges(
from_node, command.invocation
from_node, command.command
)
edges.extend(matching_edges)
@ -323,22 +205,25 @@ def invoke_cli():
for link in args["link_node"]:
link_node = session.graph.get_node(link)
matching_edges = generate_matching_edges(
link_node, command.invocation
link_node, command.command
)
matching_destinations = [e[1] for e in matching_edges]
edges = [e for e in edges if e[1] not in matching_destinations]
edges.extend(matching_edges)
if "link" in args and args["link"]:
for link in args["link"]:
edges = [e for e in edges if e[1].node_id != command.command.id and e[1].field != link[2]]
edges.append(
(
EdgeConnection(node_id=link[1], field=link[0]),
EdgeConnection(
node_id=command.invocation.id, field=link[2]
node_id=command.command.id, field=link[2]
),
)
)
new_invocations.append((command.invocation, edges))
new_invocations.append((command.command, edges))
current_id = current_id + 1
@ -347,13 +232,14 @@ def invoke_cli():
for invocation in new_invocations:
session.add_node(invocation[0])
for edge in invocation[1]:
print(edge)
session.add_edge(edge)
# Execute all available invocations
invoker.invoke(session, invoke_all=True)
while not session.is_complete():
# Wait some time
session = invoker.services.graph_execution_manager.get(session.id)
session = context.get_session()
time.sleep(0.1)
# Print any errors
@ -366,11 +252,15 @@ def invoke_cli():
# Start a new session
print("Creating a new session")
session = invoker.create_execution_state()
context.session = session
except InvalidArgs:
print('Invalid command, use "help" to list commands')
continue
except ExitCli:
break
except SystemExit:
continue