From 93481616004620c51ac86a4b5032d43f251de9d5 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sun, 26 Mar 2023 00:24:27 -0400 Subject: [PATCH] add basic autocomplete functionality to node cli - Commands, invocations and their parameters will now autocomplete using introspection. - Two types of parameter *arguments* will also autocomplete: - --sampler_name will autocomplete the scheduler name - --model will autocomplete the model name - There don't seem to be commands for reading/writing image files yet, so path autocompletion is not implemented --- invokeai/app/cli/completer.py | 168 ++++++++++++++++++++++++++++++++++ invokeai/app/cli_app.py | 6 +- 2 files changed, 172 insertions(+), 2 deletions(-) create mode 100644 invokeai/app/cli/completer.py diff --git a/invokeai/app/cli/completer.py b/invokeai/app/cli/completer.py new file mode 100644 index 0000000000..cffff200bf --- /dev/null +++ b/invokeai/app/cli/completer.py @@ -0,0 +1,168 @@ +""" +Readline helper functions for cli_app.py +You may import the global singleton `completer` to get access to the +completer object. +""" +import atexit +import readline +import shlex +import sys + +from pathlib import Path +from typing import List, Dict, Literal, get_args, get_type_hints, get_origin + +from ...backend import ModelManager, Globals +from ..invocations.baseinvocation import BaseInvocation +from .commands import BaseCommand + +# singleton object, class variable +completer = None + +class Completer(object): + + def __init__(self, model_manager: ModelManager): + self.commands = self.get_commands() + self.matches = None + self.linebuffer = None + self.manager = model_manager + return + + def complete(self, text, state): + """ + Complete commands and switches fromm the node CLI command line. + Switches are determined in a context-specific manner. + """ + + buffer = readline.get_line_buffer() + if state == 0: + options = None + try: + current_command, current_switch = self.get_current_command(buffer) + options = self.get_command_options(current_command, current_switch) + except IndexError: + pass + options = options or list(self.parse_commands().keys()) + + if not text: # first time + self.matches = options + else: + self.matches = [s for s in options if s and s.startswith(text)] + + try: + match = self.matches[state] + except IndexError: + match = None + return match + + @classmethod + def get_commands(self)->List[object]: + """ + Return a list of all the client commands and invocations. + """ + return BaseCommand.get_commands() + BaseInvocation.get_invocations() + + def get_current_command(self, buffer: str)->tuple[str, str]: + """ + Parse the readline buffer to find the most recent command and its switch. + """ + if len(buffer)==0: + return None, None + tokens = shlex.split(buffer) + command = None + switch = None + for t in tokens: + if t[0].isalpha(): + if switch is None: + command = t + else: + switch = t + # don't try to autocomplete switches that are already complete + if switch and buffer.endswith(' '): + switch=None + return command or '', switch or '' + + def parse_commands(self)->Dict[str, List[str]]: + """ + Return a dict in which the keys are the command name + and the values are the parameters the command takes. + """ + result = dict() + for command in self.commands: + hints = get_type_hints(command) + name = get_args(hints['type'])[0] + result.update({name:hints}) + return result + + def get_command_options(self, command: str, switch: str)->List[str]: + """ + Return all the parameters that can be passed to the command as + command-line switches. Returns None if the command is unrecognized. + """ + parsed_commands = self.parse_commands() + if command not in parsed_commands: + return None + + # handle switches in the format "-foo=bar" + argument = None + if switch and '=' in switch: + switch, argument = switch.split('=') + + parameter = switch.strip('-') + if parameter in parsed_commands[command]: + if argument is None: + return self.get_parameter_options(parameter, parsed_commands[command][parameter]) + else: + return [f"--{parameter}={x}" for x in self.get_parameter_options(parameter, parsed_commands[command][parameter])] + else: + return [f"--{x}" for x in parsed_commands[command].keys()] + + def get_parameter_options(self, parameter: str, typehint)->List[str]: + """ + Given a parameter type (such as Literal), offers autocompletions. + """ + if get_origin(typehint) == Literal: + return get_args(typehint) + if parameter == 'model': + return self.manager.model_names() + + def _pre_input_hook(self): + if self.linebuffer: + readline.insert_text(self.linebuffer) + readline.redisplay() + self.linebuffer = None + +def get_completer(model_manager: ModelManager) -> Completer: + global completer + + if completer: + return completer + + completer = Completer(model_manager) + + readline.set_completer(completer.complete) + # pyreadline3 does not have a set_auto_history() method + try: + readline.set_auto_history(True) + except: + pass + readline.set_pre_input_hook(completer._pre_input_hook) + readline.set_completer_delims(" ") + readline.parse_and_bind("tab: complete") + readline.parse_and_bind("set print-completions-horizontally off") + readline.parse_and_bind("set page-completions on") + readline.parse_and_bind("set skip-completed-text on") + readline.parse_and_bind("set show-all-if-ambiguous on") + + histfile = Path(Globals.root, ".invoke_history") + try: + readline.read_history_file(histfile) + readline.set_history_length(1000) + except FileNotFoundError: + pass + except OSError: # file likely corrupted + newname = f"{histfile}.old" + print( + f"## Your history file {histfile} couldn't be loaded and may be corrupted. Renaming it to {newname}" + ) + histfile.replace(Path(newname)) + atexit.register(readline.write_history_file, histfile) diff --git a/invokeai/app/cli_app.py b/invokeai/app/cli_app.py index 6390253250..927f1954d9 100644 --- a/invokeai/app/cli_app.py +++ b/invokeai/app/cli_app.py @@ -14,6 +14,7 @@ from pydantic.fields import Field from ..backend import Args from .cli.commands import BaseCommand, CliContext, ExitCli, add_parsers, get_graph_execution_history +from .cli.completer import get_completer from .invocations import * from .invocations.baseinvocation import BaseInvocation from .services.events import EventServiceBase @@ -129,6 +130,7 @@ def invoke_cli(): config = Args() config.parse_args() model_manager = get_model_manager(config) + completer = get_completer(model_manager) events = EventServiceBase() @@ -162,8 +164,8 @@ def invoke_cli(): while True: try: - cmd_input = input("> ") - except KeyboardInterrupt: + cmd_input = input(f"{model_manager.current_model or '(no model)'}> ") + except (KeyboardInterrupt, EOFError): # Ctrl-c exits break