From ebc4b52f41b6e04b7676135c263a6d30f28ea00c Mon Sep 17 00:00:00 2001 From: Kyle Schouviller Date: Sat, 4 Mar 2023 14:46:02 -0800 Subject: [PATCH] [cli] Update CLI to define commands as Pydantic objects --- invokeai/app/cli/__init__.py | 0 invokeai/app/cli/commands.py | 202 +++++++++++++++++++++++++++++++++ invokeai/app/cli_app.py | 214 +++++++++-------------------------- 3 files changed, 254 insertions(+), 162 deletions(-) create mode 100644 invokeai/app/cli/__init__.py create mode 100644 invokeai/app/cli/commands.py diff --git a/invokeai/app/cli/__init__.py b/invokeai/app/cli/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/invokeai/app/cli/commands.py b/invokeai/app/cli/commands.py new file mode 100644 index 0000000000..21e65291e9 --- /dev/null +++ b/invokeai/app/cli/commands.py @@ -0,0 +1,202 @@ +# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) + +from abc import ABC, abstractmethod +import argparse +from typing import Any, Callable, Iterable, Literal, get_args, get_origin, get_type_hints +from pydantic import BaseModel, Field + +from ..invocations.image import ImageField +from ..services.graph import GraphExecutionState +from ..services.invoker import Invoker + + +def add_parsers( + subparsers, + commands: list[type], + command_field: str = "type", + exclude_fields: list[str] = ["id", "type"], + add_arguments: Callable[[argparse.ArgumentParser], None]|None = None + ): + """Adds parsers for each command to the subparsers""" + + # Create subparsers for each command + for command in commands: + hints = get_type_hints(command) + cmd_name = get_args(hints[command_field])[0] + command_parser = subparsers.add_parser(cmd_name, help=command.__doc__) + + if add_arguments is not None: + add_arguments(command_parser) + + # Convert all fields to arguments + fields = command.__fields__ # type: ignore + for name, field in fields.items(): + if name in exclude_fields: + continue + + if get_origin(field.type_) == Literal: + allowed_values = get_args(field.type_) + allowed_types = set() + for val in allowed_values: + allowed_types.add(type(val)) + allowed_types_list = list(allowed_types) + field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore + + command_parser.add_argument( + f"--{name}", + dest=name, + type=field_type, + default=field.default, + choices=allowed_values, + help=field.field_info.description, + ) + else: + command_parser.add_argument( + f"--{name}", + dest=name, + type=field.type_, + default=field.default, + help=field.field_info.description, + ) + + +class CliContext: + invoker: Invoker + session: GraphExecutionState + parser: argparse.ArgumentParser + defaults: dict[str, Any] + + def __init__(self, invoker: Invoker, session: GraphExecutionState, parser: argparse.ArgumentParser): + self.invoker = invoker + self.session = session + self.parser = parser + self.defaults = dict() + + def get_session(self): + self.session = self.invoker.services.graph_execution_manager.get(self.session.id) + return self.session + + +class ExitCli(Exception): + """Exception to exit the CLI""" + pass + + +class BaseCommand(ABC, BaseModel): + """A CLI command""" + + # All commands must include a type name like this: + # type: Literal['your_command_name'] = 'your_command_name' + + @classmethod + def get_all_subclasses(cls): + subclasses = [] + toprocess = [cls] + while len(toprocess) > 0: + next = toprocess.pop(0) + next_subclasses = next.__subclasses__() + subclasses.extend(next_subclasses) + toprocess.extend(next_subclasses) + return subclasses + + @classmethod + def get_commands(cls): + return tuple(BaseCommand.get_all_subclasses()) + + @classmethod + def get_commands_map(cls): + # Get the type strings out of the literals and into a dictionary + return dict(map(lambda t: (get_args(get_type_hints(t)['type'])[0], t),BaseCommand.get_all_subclasses())) + + @abstractmethod + def run(self, context: CliContext) -> None: + """Run the command. Raise ExitCli to exit.""" + pass + + +class ExitCommand(BaseCommand): + """Exits the CLI""" + type: Literal['exit'] = 'exit' + + def run(self, context: CliContext) -> None: + raise ExitCli() + + +class HelpCommand(BaseCommand): + """Shows help""" + type: Literal['help'] = 'help' + + def run(self, context: CliContext) -> None: + context.parser.print_help() + + +def get_graph_execution_history( + graph_execution_state: GraphExecutionState, +) -> Iterable[str]: + """Gets the history of fully-executed invocations for a graph execution""" + return ( + n + for n in reversed(graph_execution_state.executed_history) + if n in graph_execution_state.graph.nodes + ) + + +def get_invocation_command(invocation) -> str: + fields = invocation.__fields__.items() + type_hints = get_type_hints(type(invocation)) + command = [invocation.type] + for name, field in fields: + if name in ["id", "type"]: + continue + + # TODO: add links + + # Skip image fields when serializing command + type_hint = type_hints.get(name) or None + if type_hint is ImageField or ImageField in get_args(type_hint): + continue + + field_value = getattr(invocation, name) + field_default = field.default + if field_value != field_default: + if type_hint is str or str in get_args(type_hint): + command.append(f'--{name} "{field_value}"') + else: + command.append(f"--{name} {field_value}") + + return " ".join(command) + + +class HistoryCommand(BaseCommand): + """Shows the invocation history""" + type: Literal['history'] = 'history' + + # Inputs + # fmt: off + count: int = Field(default=5, gt=0, description="The number of history entries to show") + # fmt: on + + def run(self, context: CliContext) -> None: + history = list(get_graph_execution_history(context.get_session())) + for i in range(min(self.count, len(history))): + entry_id = history[-1 - i] + entry = context.get_session().graph.get_node(entry_id) + print(f"{entry_id}: {get_invocation_command(entry)}") + + +class SetDefaultCommand(BaseCommand): + """Sets a default value for a field""" + type: Literal['default'] = 'default' + + # Inputs + # fmt: off + field: str = Field(description="The field to set the default for") + value: str = Field(description="The value to set the default to, or None to clear the default") + # fmt: on + + def run(self, context: CliContext) -> None: + if self.value is None: + if self.field in context.defaults: + del context.defaults[self.field] + else: + context.defaults[self.field] = self.value diff --git a/invokeai/app/cli_app.py b/invokeai/app/cli_app.py index 2f20cfde58..721760b222 100644 --- a/invokeai/app/cli_app.py +++ b/invokeai/app/cli_app.py @@ -5,13 +5,7 @@ import os import shlex import time from typing import ( - Any, - Dict, - Iterable, - Literal, Union, - get_args, - get_origin, get_type_hints, ) @@ -19,9 +13,9 @@ from pydantic import BaseModel from pydantic.fields import Field from ..backend import Args +from .cli.commands import BaseCommand, CliContext, ExitCli, add_parsers, get_graph_execution_history from .invocations import * from .invocations.baseinvocation import BaseInvocation -from .invocations.image import ImageField from .services.events import EventServiceBase from .services.generate_initializer import get_generate from .services.graph import EdgeConnection, GraphExecutionState @@ -33,15 +27,33 @@ from .services.processor import DefaultInvocationProcessor from .services.sqlite import SqliteItemStorage -class InvocationCommand(BaseModel): - invocation: Union[BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore +class CliCommand(BaseModel): + command: Union[BaseCommand.get_commands() + BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore class InvalidArgs(Exception): pass -def get_invocation_parser() -> argparse.ArgumentParser: +def add_invocation_args(command_parser): + # Add linking capability + command_parser.add_argument( + "--link", + "-l", + action="append", + nargs=3, + help="A link in the format 'dest_field source_node source_field'. source_node can be relative to history (e.g. -1)", + ) + + command_parser.add_argument( + "--link_node", + "-ln", + action="append", + help="A link from all fields in the specified node. Node can be relative to history (e.g. -1)", + ) + + +def get_command_parser() -> argparse.ArgumentParser: # Create invocation parser parser = argparse.ArgumentParser() @@ -49,129 +61,19 @@ def get_invocation_parser() -> argparse.ArgumentParser: raise InvalidArgs parser.exit = exit - subparsers = parser.add_subparsers(dest="type") - invocation_parsers = dict() - - # Add history parser - history_parser = subparsers.add_parser( - "history", help="Shows the invocation history" - ) - history_parser.add_argument( - "count", - nargs="?", - default=5, - type=int, - help="The number of history entries to show", - ) - - # Add default parser - default_parser = subparsers.add_parser( - "default", help="Define a default value for all inputs with a specified name" - ) - default_parser.add_argument("input", type=str, help="The input field") - default_parser.add_argument("value", help="The default value") - - default_parser = subparsers.add_parser( - "reset_default", help="Resets a default value" - ) - default_parser.add_argument("input", type=str, help="The input field") # Create subparsers for each invocation invocations = BaseInvocation.get_all_subclasses() - for invocation in invocations: - hints = get_type_hints(invocation) - cmd_name = get_args(hints["type"])[0] - command_parser = subparsers.add_parser(cmd_name, help=invocation.__doc__) - invocation_parsers[cmd_name] = command_parser + add_parsers(subparsers, invocations, add_arguments=add_invocation_args) - # Add linking capability - command_parser.add_argument( - "--link", - "-l", - action="append", - nargs=3, - help="A link in the format 'dest_field source_node source_field'. source_node can be relative to history (e.g. -1)", - ) - - command_parser.add_argument( - "--link_node", - "-ln", - action="append", - help="A link from all fields in the specified node. Node can be relative to history (e.g. -1)", - ) - - # Convert all fields to arguments - fields = invocation.__fields__ - for name, field in fields.items(): - if name in ["id", "type"]: - continue - - if get_origin(field.type_) == Literal: - allowed_values = get_args(field.type_) - allowed_types = set() - for val in allowed_values: - allowed_types.add(type(val)) - allowed_types_list = list(allowed_types) - field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore - - command_parser.add_argument( - f"--{name}", - dest=name, - type=field_type, - default=field.default, - choices=allowed_values, - help=field.field_info.description, - ) - else: - command_parser.add_argument( - f"--{name}", - dest=name, - type=field.type_, - default=field.default, - help=field.field_info.description, - ) + # Create subparsers for each command + commands = BaseCommand.get_all_subclasses() + add_parsers(subparsers, commands, exclude_fields=["type"]) return parser -def get_invocation_command(invocation) -> str: - fields = invocation.__fields__.items() - type_hints = get_type_hints(type(invocation)) - command = [invocation.type] - for name, field in fields: - if name in ["id", "type"]: - continue - - # TODO: add links - - # Skip image fields when serializing command - type_hint = type_hints.get(name) or None - if type_hint is ImageField or ImageField in get_args(type_hint): - continue - - field_value = getattr(invocation, name) - field_default = field.default - if field_value != field_default: - if type_hint is str or str in get_args(type_hint): - command.append(f'--{name} "{field_value}"') - else: - command.append(f"--{name} {field_value}") - - return " ".join(command) - - -def get_graph_execution_history( - graph_execution_state: GraphExecutionState, -) -> Iterable[str]: - """Gets the history of fully-executed invocations for a graph execution""" - return ( - n - for n in reversed(graph_execution_state.executed_history) - if n in graph_execution_state.graph.nodes - ) - - def generate_matching_edges( a: BaseInvocation, b: BaseInvocation ) -> list[tuple[EdgeConnection, EdgeConnection]]: @@ -233,13 +135,12 @@ def invoke_cli(): invoker = Invoker(services) session: GraphExecutionState = invoker.create_execution_state() - parser = get_invocation_parser() + parser = get_command_parser() # Uncomment to print out previous sessions at startup # print(services.session_manager.list()) - # Defaults storage - defaults: Dict[str, Any] = dict() + context = CliContext(invoker, session, parser) while True: try: @@ -248,13 +149,6 @@ def invoke_cli(): # Ctrl-c exits break - if cmd_input in ["exit", "q"]: - break - - if cmd_input in ["--help", "help", "h", "?"]: - parser.print_help() - continue - try: # Refresh the state of the session session = invoker.services.graph_execution_manager.get(session.id) @@ -272,35 +166,23 @@ def invoke_cli(): # Parse args to create invocation args = vars(parser.parse_args(shlex.split(cmd.strip()))) - # Check for special commands - # TODO: These might be better as Pydantic models, similar to the invocations - if args["type"] == "history": - history_count = args["count"] or 5 - for i in range(min(history_count, len(history))): - entry_id = history[-1 - i] - entry = session.graph.get_node(entry_id) - print(f"{entry_id}: {get_invocation_command(entry.invocation)}") - continue - - if args["type"] == "reset_default": - if args["input"] in defaults: - del defaults[args["input"]] - continue - - if args["type"] == "default": - field = args["input"] - field_value = args["value"] - defaults[field] = field_value - continue - # Override defaults - for field_name, field_default in defaults.items(): + for field_name, field_default in context.defaults.items(): if field_name in args: args[field_name] = field_default # Parse invocation args["id"] = current_id - command = InvocationCommand(invocation=args) + command = CliCommand(command=args) + + # Run any CLI commands immediately + # TODO: this won't behave as expected if piping and using e.g. history, + # since invocations are gathered and then run together at the end. + # This is more efficient if the CLI is running against a distributed + # backend, so it's preferable not to change that behavior. + if isinstance(command.command, BaseCommand): + command.command.run(context) + continue # Pipe previous command output (if there was a previous command) edges = [] @@ -314,7 +196,7 @@ def invoke_cli(): else session.graph.get_node(from_id) ) matching_edges = generate_matching_edges( - from_node, command.invocation + from_node, command.command ) edges.extend(matching_edges) @@ -323,22 +205,25 @@ def invoke_cli(): for link in args["link_node"]: link_node = session.graph.get_node(link) matching_edges = generate_matching_edges( - link_node, command.invocation + link_node, command.command ) + matching_destinations = [e[1] for e in matching_edges] + edges = [e for e in edges if e[1] not in matching_destinations] edges.extend(matching_edges) if "link" in args and args["link"]: for link in args["link"]: + edges = [e for e in edges if e[1].node_id != command.command.id and e[1].field != link[2]] edges.append( ( EdgeConnection(node_id=link[1], field=link[0]), EdgeConnection( - node_id=command.invocation.id, field=link[2] + node_id=command.command.id, field=link[2] ), ) ) - new_invocations.append((command.invocation, edges)) + new_invocations.append((command.command, edges)) current_id = current_id + 1 @@ -347,13 +232,14 @@ def invoke_cli(): for invocation in new_invocations: session.add_node(invocation[0]) for edge in invocation[1]: + print(edge) session.add_edge(edge) # Execute all available invocations invoker.invoke(session, invoke_all=True) while not session.is_complete(): # Wait some time - session = invoker.services.graph_execution_manager.get(session.id) + session = context.get_session() time.sleep(0.1) # Print any errors @@ -366,11 +252,15 @@ def invoke_cli(): # Start a new session print("Creating a new session") session = invoker.create_execution_state() + context.session = session except InvalidArgs: print('Invalid command, use "help" to list commands') continue + except ExitCli: + break + except SystemExit: continue