[cli] Update CLI to define commands as Pydantic objects

This commit is contained in:
Kyle Schouviller 2023-03-04 14:46:02 -08:00
parent bdc7b8b75a
commit ebc4b52f41
3 changed files with 254 additions and 162 deletions

View File

View File

@ -0,0 +1,202 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from abc import ABC, abstractmethod
import argparse
from typing import Any, Callable, Iterable, Literal, get_args, get_origin, get_type_hints
from pydantic import BaseModel, Field
from ..invocations.image import ImageField
from ..services.graph import GraphExecutionState
from ..services.invoker import Invoker
def add_parsers(
subparsers,
commands: list[type],
command_field: str = "type",
exclude_fields: list[str] = ["id", "type"],
add_arguments: Callable[[argparse.ArgumentParser], None]|None = None
):
"""Adds parsers for each command to the subparsers"""
# Create subparsers for each command
for command in commands:
hints = get_type_hints(command)
cmd_name = get_args(hints[command_field])[0]
command_parser = subparsers.add_parser(cmd_name, help=command.__doc__)
if add_arguments is not None:
add_arguments(command_parser)
# Convert all fields to arguments
fields = command.__fields__ # type: ignore
for name, field in fields.items():
if name in exclude_fields:
continue
if get_origin(field.type_) == Literal:
allowed_values = get_args(field.type_)
allowed_types = set()
for val in allowed_values:
allowed_types.add(type(val))
allowed_types_list = list(allowed_types)
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
command_parser.add_argument(
f"--{name}",
dest=name,
type=field_type,
default=field.default,
choices=allowed_values,
help=field.field_info.description,
)
else:
command_parser.add_argument(
f"--{name}",
dest=name,
type=field.type_,
default=field.default,
help=field.field_info.description,
)
class CliContext:
invoker: Invoker
session: GraphExecutionState
parser: argparse.ArgumentParser
defaults: dict[str, Any]
def __init__(self, invoker: Invoker, session: GraphExecutionState, parser: argparse.ArgumentParser):
self.invoker = invoker
self.session = session
self.parser = parser
self.defaults = dict()
def get_session(self):
self.session = self.invoker.services.graph_execution_manager.get(self.session.id)
return self.session
class ExitCli(Exception):
"""Exception to exit the CLI"""
pass
class BaseCommand(ABC, BaseModel):
"""A CLI command"""
# All commands must include a type name like this:
# type: Literal['your_command_name'] = 'your_command_name'
@classmethod
def get_all_subclasses(cls):
subclasses = []
toprocess = [cls]
while len(toprocess) > 0:
next = toprocess.pop(0)
next_subclasses = next.__subclasses__()
subclasses.extend(next_subclasses)
toprocess.extend(next_subclasses)
return subclasses
@classmethod
def get_commands(cls):
return tuple(BaseCommand.get_all_subclasses())
@classmethod
def get_commands_map(cls):
# Get the type strings out of the literals and into a dictionary
return dict(map(lambda t: (get_args(get_type_hints(t)['type'])[0], t),BaseCommand.get_all_subclasses()))
@abstractmethod
def run(self, context: CliContext) -> None:
"""Run the command. Raise ExitCli to exit."""
pass
class ExitCommand(BaseCommand):
"""Exits the CLI"""
type: Literal['exit'] = 'exit'
def run(self, context: CliContext) -> None:
raise ExitCli()
class HelpCommand(BaseCommand):
"""Shows help"""
type: Literal['help'] = 'help'
def run(self, context: CliContext) -> None:
context.parser.print_help()
def get_graph_execution_history(
graph_execution_state: GraphExecutionState,
) -> Iterable[str]:
"""Gets the history of fully-executed invocations for a graph execution"""
return (
n
for n in reversed(graph_execution_state.executed_history)
if n in graph_execution_state.graph.nodes
)
def get_invocation_command(invocation) -> str:
fields = invocation.__fields__.items()
type_hints = get_type_hints(type(invocation))
command = [invocation.type]
for name, field in fields:
if name in ["id", "type"]:
continue
# TODO: add links
# Skip image fields when serializing command
type_hint = type_hints.get(name) or None
if type_hint is ImageField or ImageField in get_args(type_hint):
continue
field_value = getattr(invocation, name)
field_default = field.default
if field_value != field_default:
if type_hint is str or str in get_args(type_hint):
command.append(f'--{name} "{field_value}"')
else:
command.append(f"--{name} {field_value}")
return " ".join(command)
class HistoryCommand(BaseCommand):
"""Shows the invocation history"""
type: Literal['history'] = 'history'
# Inputs
# fmt: off
count: int = Field(default=5, gt=0, description="The number of history entries to show")
# fmt: on
def run(self, context: CliContext) -> None:
history = list(get_graph_execution_history(context.get_session()))
for i in range(min(self.count, len(history))):
entry_id = history[-1 - i]
entry = context.get_session().graph.get_node(entry_id)
print(f"{entry_id}: {get_invocation_command(entry)}")
class SetDefaultCommand(BaseCommand):
"""Sets a default value for a field"""
type: Literal['default'] = 'default'
# Inputs
# fmt: off
field: str = Field(description="The field to set the default for")
value: str = Field(description="The value to set the default to, or None to clear the default")
# fmt: on
def run(self, context: CliContext) -> None:
if self.value is None:
if self.field in context.defaults:
del context.defaults[self.field]
else:
context.defaults[self.field] = self.value

View File

@ -5,13 +5,7 @@ import os
import shlex import shlex
import time import time
from typing import ( from typing import (
Any,
Dict,
Iterable,
Literal,
Union, Union,
get_args,
get_origin,
get_type_hints, get_type_hints,
) )
@ -19,9 +13,9 @@ from pydantic import BaseModel
from pydantic.fields import Field from pydantic.fields import Field
from ..backend import Args from ..backend import Args
from .cli.commands import BaseCommand, CliContext, ExitCli, add_parsers, get_graph_execution_history
from .invocations import * from .invocations import *
from .invocations.baseinvocation import BaseInvocation from .invocations.baseinvocation import BaseInvocation
from .invocations.image import ImageField
from .services.events import EventServiceBase from .services.events import EventServiceBase
from .services.generate_initializer import get_generate from .services.generate_initializer import get_generate
from .services.graph import EdgeConnection, GraphExecutionState from .services.graph import EdgeConnection, GraphExecutionState
@ -33,58 +27,15 @@ from .services.processor import DefaultInvocationProcessor
from .services.sqlite import SqliteItemStorage from .services.sqlite import SqliteItemStorage
class InvocationCommand(BaseModel): class CliCommand(BaseModel):
invocation: Union[BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore command: Union[BaseCommand.get_commands() + BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore
class InvalidArgs(Exception): class InvalidArgs(Exception):
pass pass
def get_invocation_parser() -> argparse.ArgumentParser: def add_invocation_args(command_parser):
# Create invocation parser
parser = argparse.ArgumentParser()
def exit(*args, **kwargs):
raise InvalidArgs
parser.exit = exit
subparsers = parser.add_subparsers(dest="type")
invocation_parsers = dict()
# Add history parser
history_parser = subparsers.add_parser(
"history", help="Shows the invocation history"
)
history_parser.add_argument(
"count",
nargs="?",
default=5,
type=int,
help="The number of history entries to show",
)
# Add default parser
default_parser = subparsers.add_parser(
"default", help="Define a default value for all inputs with a specified name"
)
default_parser.add_argument("input", type=str, help="The input field")
default_parser.add_argument("value", help="The default value")
default_parser = subparsers.add_parser(
"reset_default", help="Resets a default value"
)
default_parser.add_argument("input", type=str, help="The input field")
# Create subparsers for each invocation
invocations = BaseInvocation.get_all_subclasses()
for invocation in invocations:
hints = get_type_hints(invocation)
cmd_name = get_args(hints["type"])[0]
command_parser = subparsers.add_parser(cmd_name, help=invocation.__doc__)
invocation_parsers[cmd_name] = command_parser
# Add linking capability # Add linking capability
command_parser.add_argument( command_parser.add_argument(
"--link", "--link",
@ -101,77 +52,28 @@ def get_invocation_parser() -> argparse.ArgumentParser:
help="A link from all fields in the specified node. Node can be relative to history (e.g. -1)", help="A link from all fields in the specified node. Node can be relative to history (e.g. -1)",
) )
# Convert all fields to arguments
fields = invocation.__fields__
for name, field in fields.items():
if name in ["id", "type"]:
continue
if get_origin(field.type_) == Literal: def get_command_parser() -> argparse.ArgumentParser:
allowed_values = get_args(field.type_) # Create invocation parser
allowed_types = set() parser = argparse.ArgumentParser()
for val in allowed_values:
allowed_types.add(type(val))
allowed_types_list = list(allowed_types)
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
command_parser.add_argument( def exit(*args, **kwargs):
f"--{name}", raise InvalidArgs
dest=name,
type=field_type, parser.exit = exit
default=field.default, subparsers = parser.add_subparsers(dest="type")
choices=allowed_values,
help=field.field_info.description, # Create subparsers for each invocation
) invocations = BaseInvocation.get_all_subclasses()
else: add_parsers(subparsers, invocations, add_arguments=add_invocation_args)
command_parser.add_argument(
f"--{name}", # Create subparsers for each command
dest=name, commands = BaseCommand.get_all_subclasses()
type=field.type_, add_parsers(subparsers, commands, exclude_fields=["type"])
default=field.default,
help=field.field_info.description,
)
return parser return parser
def get_invocation_command(invocation) -> str:
fields = invocation.__fields__.items()
type_hints = get_type_hints(type(invocation))
command = [invocation.type]
for name, field in fields:
if name in ["id", "type"]:
continue
# TODO: add links
# Skip image fields when serializing command
type_hint = type_hints.get(name) or None
if type_hint is ImageField or ImageField in get_args(type_hint):
continue
field_value = getattr(invocation, name)
field_default = field.default
if field_value != field_default:
if type_hint is str or str in get_args(type_hint):
command.append(f'--{name} "{field_value}"')
else:
command.append(f"--{name} {field_value}")
return " ".join(command)
def get_graph_execution_history(
graph_execution_state: GraphExecutionState,
) -> Iterable[str]:
"""Gets the history of fully-executed invocations for a graph execution"""
return (
n
for n in reversed(graph_execution_state.executed_history)
if n in graph_execution_state.graph.nodes
)
def generate_matching_edges( def generate_matching_edges(
a: BaseInvocation, b: BaseInvocation a: BaseInvocation, b: BaseInvocation
) -> list[tuple[EdgeConnection, EdgeConnection]]: ) -> list[tuple[EdgeConnection, EdgeConnection]]:
@ -233,13 +135,12 @@ def invoke_cli():
invoker = Invoker(services) invoker = Invoker(services)
session: GraphExecutionState = invoker.create_execution_state() session: GraphExecutionState = invoker.create_execution_state()
parser = get_invocation_parser() parser = get_command_parser()
# Uncomment to print out previous sessions at startup # Uncomment to print out previous sessions at startup
# print(services.session_manager.list()) # print(services.session_manager.list())
# Defaults storage context = CliContext(invoker, session, parser)
defaults: Dict[str, Any] = dict()
while True: while True:
try: try:
@ -248,13 +149,6 @@ def invoke_cli():
# Ctrl-c exits # Ctrl-c exits
break break
if cmd_input in ["exit", "q"]:
break
if cmd_input in ["--help", "help", "h", "?"]:
parser.print_help()
continue
try: try:
# Refresh the state of the session # Refresh the state of the session
session = invoker.services.graph_execution_manager.get(session.id) session = invoker.services.graph_execution_manager.get(session.id)
@ -272,35 +166,23 @@ def invoke_cli():
# Parse args to create invocation # Parse args to create invocation
args = vars(parser.parse_args(shlex.split(cmd.strip()))) args = vars(parser.parse_args(shlex.split(cmd.strip())))
# Check for special commands
# TODO: These might be better as Pydantic models, similar to the invocations
if args["type"] == "history":
history_count = args["count"] or 5
for i in range(min(history_count, len(history))):
entry_id = history[-1 - i]
entry = session.graph.get_node(entry_id)
print(f"{entry_id}: {get_invocation_command(entry.invocation)}")
continue
if args["type"] == "reset_default":
if args["input"] in defaults:
del defaults[args["input"]]
continue
if args["type"] == "default":
field = args["input"]
field_value = args["value"]
defaults[field] = field_value
continue
# Override defaults # Override defaults
for field_name, field_default in defaults.items(): for field_name, field_default in context.defaults.items():
if field_name in args: if field_name in args:
args[field_name] = field_default args[field_name] = field_default
# Parse invocation # Parse invocation
args["id"] = current_id args["id"] = current_id
command = InvocationCommand(invocation=args) command = CliCommand(command=args)
# Run any CLI commands immediately
# TODO: this won't behave as expected if piping and using e.g. history,
# since invocations are gathered and then run together at the end.
# This is more efficient if the CLI is running against a distributed
# backend, so it's preferable not to change that behavior.
if isinstance(command.command, BaseCommand):
command.command.run(context)
continue
# Pipe previous command output (if there was a previous command) # Pipe previous command output (if there was a previous command)
edges = [] edges = []
@ -314,7 +196,7 @@ def invoke_cli():
else session.graph.get_node(from_id) else session.graph.get_node(from_id)
) )
matching_edges = generate_matching_edges( matching_edges = generate_matching_edges(
from_node, command.invocation from_node, command.command
) )
edges.extend(matching_edges) edges.extend(matching_edges)
@ -323,22 +205,25 @@ def invoke_cli():
for link in args["link_node"]: for link in args["link_node"]:
link_node = session.graph.get_node(link) link_node = session.graph.get_node(link)
matching_edges = generate_matching_edges( matching_edges = generate_matching_edges(
link_node, command.invocation link_node, command.command
) )
matching_destinations = [e[1] for e in matching_edges]
edges = [e for e in edges if e[1] not in matching_destinations]
edges.extend(matching_edges) edges.extend(matching_edges)
if "link" in args and args["link"]: if "link" in args and args["link"]:
for link in args["link"]: for link in args["link"]:
edges = [e for e in edges if e[1].node_id != command.command.id and e[1].field != link[2]]
edges.append( edges.append(
( (
EdgeConnection(node_id=link[1], field=link[0]), EdgeConnection(node_id=link[1], field=link[0]),
EdgeConnection( EdgeConnection(
node_id=command.invocation.id, field=link[2] node_id=command.command.id, field=link[2]
), ),
) )
) )
new_invocations.append((command.invocation, edges)) new_invocations.append((command.command, edges))
current_id = current_id + 1 current_id = current_id + 1
@ -347,13 +232,14 @@ def invoke_cli():
for invocation in new_invocations: for invocation in new_invocations:
session.add_node(invocation[0]) session.add_node(invocation[0])
for edge in invocation[1]: for edge in invocation[1]:
print(edge)
session.add_edge(edge) session.add_edge(edge)
# Execute all available invocations # Execute all available invocations
invoker.invoke(session, invoke_all=True) invoker.invoke(session, invoke_all=True)
while not session.is_complete(): while not session.is_complete():
# Wait some time # Wait some time
session = invoker.services.graph_execution_manager.get(session.id) session = context.get_session()
time.sleep(0.1) time.sleep(0.1)
# Print any errors # Print any errors
@ -366,11 +252,15 @@ def invoke_cli():
# Start a new session # Start a new session
print("Creating a new session") print("Creating a new session")
session = invoker.create_execution_state() session = invoker.create_execution_state()
context.session = session
except InvalidArgs: except InvalidArgs:
print('Invalid command, use "help" to list commands') print('Invalid command, use "help" to list commands')
continue continue
except ExitCli:
break
except SystemExit: except SystemExit:
continue continue