2023-03-04 22:46:02 +00:00
|
|
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
|
|
|
|
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
import argparse
|
|
|
|
from typing import Any, Callable, Iterable, Literal, get_args, get_origin, get_type_hints
|
|
|
|
from pydantic import BaseModel, Field
|
2023-04-06 04:06:05 +00:00
|
|
|
import networkx as nx
|
|
|
|
import matplotlib.pyplot as plt
|
2023-04-04 01:05:15 +00:00
|
|
|
|
2023-04-14 06:41:06 +00:00
|
|
|
from ..invocations.baseinvocation import BaseInvocation
|
|
|
|
from ..invocations.image import ImageField
|
|
|
|
from ..services.graph import GraphExecutionState, LibraryGraph, GraphInvocation, Edge
|
2023-03-04 22:46:02 +00:00
|
|
|
from ..services.invoker import Invoker
|
|
|
|
|
|
|
|
|
2023-04-14 06:41:06 +00:00
|
|
|
def add_field_argument(command_parser, name: str, field, default_override = None):
|
|
|
|
default = default_override if default_override is not None else field.default if field.default_factory is None else field.default_factory()
|
|
|
|
if get_origin(field.type_) == Literal:
|
|
|
|
allowed_values = get_args(field.type_)
|
|
|
|
allowed_types = set()
|
|
|
|
for val in allowed_values:
|
|
|
|
allowed_types.add(type(val))
|
|
|
|
allowed_types_list = list(allowed_types)
|
|
|
|
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
|
|
|
|
|
|
|
|
command_parser.add_argument(
|
|
|
|
f"--{name}",
|
|
|
|
dest=name,
|
|
|
|
type=field_type,
|
|
|
|
default=default,
|
|
|
|
choices=allowed_values,
|
|
|
|
help=field.field_info.description,
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
command_parser.add_argument(
|
|
|
|
f"--{name}",
|
|
|
|
dest=name,
|
|
|
|
type=field.type_,
|
|
|
|
default=default,
|
|
|
|
help=field.field_info.description,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2023-03-04 22:46:02 +00:00
|
|
|
def add_parsers(
|
|
|
|
subparsers,
|
|
|
|
commands: list[type],
|
|
|
|
command_field: str = "type",
|
|
|
|
exclude_fields: list[str] = ["id", "type"],
|
|
|
|
add_arguments: Callable[[argparse.ArgumentParser], None]|None = None
|
|
|
|
):
|
|
|
|
"""Adds parsers for each command to the subparsers"""
|
|
|
|
|
|
|
|
# Create subparsers for each command
|
|
|
|
for command in commands:
|
|
|
|
hints = get_type_hints(command)
|
|
|
|
cmd_name = get_args(hints[command_field])[0]
|
|
|
|
command_parser = subparsers.add_parser(cmd_name, help=command.__doc__)
|
|
|
|
|
|
|
|
if add_arguments is not None:
|
|
|
|
add_arguments(command_parser)
|
|
|
|
|
|
|
|
# Convert all fields to arguments
|
|
|
|
fields = command.__fields__ # type: ignore
|
|
|
|
for name, field in fields.items():
|
|
|
|
if name in exclude_fields:
|
|
|
|
continue
|
|
|
|
|
2023-04-14 06:41:06 +00:00
|
|
|
add_field_argument(command_parser, name, field)
|
|
|
|
|
|
|
|
|
|
|
|
def add_graph_parsers(
|
|
|
|
subparsers,
|
|
|
|
graphs: list[LibraryGraph],
|
|
|
|
add_arguments: Callable[[argparse.ArgumentParser], None]|None = None
|
|
|
|
):
|
|
|
|
for graph in graphs:
|
|
|
|
command_parser = subparsers.add_parser(graph.name, help=graph.description)
|
|
|
|
|
|
|
|
if add_arguments is not None:
|
|
|
|
add_arguments(command_parser)
|
|
|
|
|
|
|
|
# Add arguments for inputs
|
|
|
|
for exposed_input in graph.exposed_inputs:
|
|
|
|
node = graph.graph.get_node(exposed_input.node_path)
|
|
|
|
field = node.__fields__[exposed_input.field]
|
|
|
|
default_override = getattr(node, exposed_input.field)
|
|
|
|
add_field_argument(command_parser, exposed_input.alias, field, default_override)
|
2023-03-04 22:46:02 +00:00
|
|
|
|
|
|
|
|
|
|
|
class CliContext:
|
|
|
|
invoker: Invoker
|
|
|
|
session: GraphExecutionState
|
|
|
|
parser: argparse.ArgumentParser
|
|
|
|
defaults: dict[str, Any]
|
2023-04-14 06:41:06 +00:00
|
|
|
graph_nodes: dict[str, str]
|
|
|
|
nodes_added: list[str]
|
2023-03-04 22:46:02 +00:00
|
|
|
|
|
|
|
def __init__(self, invoker: Invoker, session: GraphExecutionState, parser: argparse.ArgumentParser):
|
|
|
|
self.invoker = invoker
|
|
|
|
self.session = session
|
|
|
|
self.parser = parser
|
|
|
|
self.defaults = dict()
|
2023-04-14 06:41:06 +00:00
|
|
|
self.graph_nodes = dict()
|
|
|
|
self.nodes_added = list()
|
2023-03-04 22:46:02 +00:00
|
|
|
|
|
|
|
def get_session(self):
|
|
|
|
self.session = self.invoker.services.graph_execution_manager.get(self.session.id)
|
|
|
|
return self.session
|
|
|
|
|
2023-04-14 06:41:06 +00:00
|
|
|
def reset(self):
|
|
|
|
self.session = self.invoker.create_execution_state()
|
|
|
|
self.graph_nodes = dict()
|
|
|
|
self.nodes_added = list()
|
|
|
|
# Leave defaults unchanged
|
|
|
|
|
|
|
|
def add_node(self, node: BaseInvocation):
|
|
|
|
self.get_session()
|
|
|
|
self.session.graph.add_node(node)
|
|
|
|
self.nodes_added.append(node.id)
|
|
|
|
self.invoker.services.graph_execution_manager.set(self.session)
|
|
|
|
|
|
|
|
def add_edge(self, edge: Edge):
|
|
|
|
self.get_session()
|
|
|
|
self.session.add_edge(edge)
|
|
|
|
self.invoker.services.graph_execution_manager.set(self.session)
|
|
|
|
|
2023-03-04 22:46:02 +00:00
|
|
|
|
|
|
|
class ExitCli(Exception):
|
|
|
|
"""Exception to exit the CLI"""
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
class BaseCommand(ABC, BaseModel):
|
|
|
|
"""A CLI command"""
|
|
|
|
|
|
|
|
# All commands must include a type name like this:
|
|
|
|
# type: Literal['your_command_name'] = 'your_command_name'
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def get_all_subclasses(cls):
|
|
|
|
subclasses = []
|
|
|
|
toprocess = [cls]
|
|
|
|
while len(toprocess) > 0:
|
|
|
|
next = toprocess.pop(0)
|
|
|
|
next_subclasses = next.__subclasses__()
|
|
|
|
subclasses.extend(next_subclasses)
|
|
|
|
toprocess.extend(next_subclasses)
|
|
|
|
return subclasses
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def get_commands(cls):
|
|
|
|
return tuple(BaseCommand.get_all_subclasses())
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def get_commands_map(cls):
|
|
|
|
# Get the type strings out of the literals and into a dictionary
|
|
|
|
return dict(map(lambda t: (get_args(get_type_hints(t)['type'])[0], t),BaseCommand.get_all_subclasses()))
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def run(self, context: CliContext) -> None:
|
|
|
|
"""Run the command. Raise ExitCli to exit."""
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
class ExitCommand(BaseCommand):
|
|
|
|
"""Exits the CLI"""
|
|
|
|
type: Literal['exit'] = 'exit'
|
|
|
|
|
|
|
|
def run(self, context: CliContext) -> None:
|
|
|
|
raise ExitCli()
|
|
|
|
|
|
|
|
|
|
|
|
class HelpCommand(BaseCommand):
|
|
|
|
"""Shows help"""
|
|
|
|
type: Literal['help'] = 'help'
|
|
|
|
|
|
|
|
def run(self, context: CliContext) -> None:
|
|
|
|
context.parser.print_help()
|
|
|
|
|
|
|
|
|
|
|
|
def get_graph_execution_history(
|
|
|
|
graph_execution_state: GraphExecutionState,
|
|
|
|
) -> Iterable[str]:
|
|
|
|
"""Gets the history of fully-executed invocations for a graph execution"""
|
|
|
|
return (
|
|
|
|
n
|
|
|
|
for n in reversed(graph_execution_state.executed_history)
|
|
|
|
if n in graph_execution_state.graph.nodes
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def get_invocation_command(invocation) -> str:
|
|
|
|
fields = invocation.__fields__.items()
|
|
|
|
type_hints = get_type_hints(type(invocation))
|
|
|
|
command = [invocation.type]
|
|
|
|
for name, field in fields:
|
|
|
|
if name in ["id", "type"]:
|
|
|
|
continue
|
|
|
|
|
|
|
|
# TODO: add links
|
|
|
|
|
|
|
|
# Skip image fields when serializing command
|
|
|
|
type_hint = type_hints.get(name) or None
|
|
|
|
if type_hint is ImageField or ImageField in get_args(type_hint):
|
|
|
|
continue
|
|
|
|
|
|
|
|
field_value = getattr(invocation, name)
|
|
|
|
field_default = field.default
|
|
|
|
if field_value != field_default:
|
|
|
|
if type_hint is str or str in get_args(type_hint):
|
|
|
|
command.append(f'--{name} "{field_value}"')
|
|
|
|
else:
|
|
|
|
command.append(f"--{name} {field_value}")
|
|
|
|
|
|
|
|
return " ".join(command)
|
|
|
|
|
|
|
|
|
|
|
|
class HistoryCommand(BaseCommand):
|
|
|
|
"""Shows the invocation history"""
|
|
|
|
type: Literal['history'] = 'history'
|
|
|
|
|
|
|
|
# Inputs
|
|
|
|
# fmt: off
|
|
|
|
count: int = Field(default=5, gt=0, description="The number of history entries to show")
|
|
|
|
# fmt: on
|
|
|
|
|
|
|
|
def run(self, context: CliContext) -> None:
|
|
|
|
history = list(get_graph_execution_history(context.get_session()))
|
|
|
|
for i in range(min(self.count, len(history))):
|
|
|
|
entry_id = history[-1 - i]
|
|
|
|
entry = context.get_session().graph.get_node(entry_id)
|
|
|
|
print(f"{entry_id}: {get_invocation_command(entry)}")
|
|
|
|
|
|
|
|
|
|
|
|
class SetDefaultCommand(BaseCommand):
|
|
|
|
"""Sets a default value for a field"""
|
|
|
|
type: Literal['default'] = 'default'
|
|
|
|
|
|
|
|
# Inputs
|
|
|
|
# fmt: off
|
|
|
|
field: str = Field(description="The field to set the default for")
|
|
|
|
value: str = Field(description="The value to set the default to, or None to clear the default")
|
|
|
|
# fmt: on
|
|
|
|
|
|
|
|
def run(self, context: CliContext) -> None:
|
|
|
|
if self.value is None:
|
|
|
|
if self.field in context.defaults:
|
|
|
|
del context.defaults[self.field]
|
|
|
|
else:
|
|
|
|
context.defaults[self.field] = self.value
|
2023-04-06 04:06:05 +00:00
|
|
|
|
|
|
|
|
|
|
|
class DrawGraphCommand(BaseCommand):
|
|
|
|
"""Debugs a graph"""
|
|
|
|
type: Literal['draw_graph'] = 'draw_graph'
|
|
|
|
|
|
|
|
def run(self, context: CliContext) -> None:
|
|
|
|
session: GraphExecutionState = context.invoker.services.graph_execution_manager.get(context.session.id)
|
|
|
|
nxgraph = session.graph.nx_graph_flat()
|
|
|
|
|
|
|
|
# Draw the networkx graph
|
|
|
|
plt.figure(figsize=(20, 20))
|
|
|
|
pos = nx.spectral_layout(nxgraph)
|
|
|
|
nx.draw_networkx_nodes(nxgraph, pos, node_size=1000)
|
|
|
|
nx.draw_networkx_edges(nxgraph, pos, width=2)
|
|
|
|
nx.draw_networkx_labels(nxgraph, pos, font_size=20, font_family="sans-serif")
|
|
|
|
plt.axis("off")
|
|
|
|
plt.show()
|
|
|
|
|
|
|
|
|
|
|
|
class DrawExecutionGraphCommand(BaseCommand):
|
|
|
|
"""Debugs an execution graph"""
|
|
|
|
type: Literal['draw_xgraph'] = 'draw_xgraph'
|
|
|
|
|
|
|
|
def run(self, context: CliContext) -> None:
|
|
|
|
session: GraphExecutionState = context.invoker.services.graph_execution_manager.get(context.session.id)
|
|
|
|
nxgraph = session.execution_graph.nx_graph_flat()
|
|
|
|
|
|
|
|
# Draw the networkx graph
|
|
|
|
plt.figure(figsize=(20, 20))
|
|
|
|
pos = nx.spectral_layout(nxgraph)
|
|
|
|
nx.draw_networkx_nodes(nxgraph, pos, node_size=1000)
|
|
|
|
nx.draw_networkx_edges(nxgraph, pos, width=2)
|
|
|
|
nx.draw_networkx_labels(nxgraph, pos, font_size=20, font_family="sans-serif")
|
|
|
|
plt.axis("off")
|
|
|
|
plt.show()
|