mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into enhance/support-another-embedding-format-main
This commit is contained in:
commit
44843be4c8
167
invokeai/app/cli/completer.py
Normal file
167
invokeai/app/cli/completer.py
Normal file
@ -0,0 +1,167 @@
|
|||||||
|
"""
|
||||||
|
Readline helper functions for cli_app.py
|
||||||
|
You may import the global singleton `completer` to get access to the
|
||||||
|
completer object.
|
||||||
|
"""
|
||||||
|
import atexit
|
||||||
|
import readline
|
||||||
|
import shlex
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Dict, Literal, get_args, get_type_hints, get_origin
|
||||||
|
|
||||||
|
from ...backend import ModelManager, Globals
|
||||||
|
from ..invocations.baseinvocation import BaseInvocation
|
||||||
|
from .commands import BaseCommand
|
||||||
|
|
||||||
|
# singleton object, class variable
|
||||||
|
completer = None
|
||||||
|
|
||||||
|
class Completer(object):
|
||||||
|
|
||||||
|
def __init__(self, model_manager: ModelManager):
|
||||||
|
self.commands = self.get_commands()
|
||||||
|
self.matches = None
|
||||||
|
self.linebuffer = None
|
||||||
|
self.manager = model_manager
|
||||||
|
return
|
||||||
|
|
||||||
|
def complete(self, text, state):
|
||||||
|
"""
|
||||||
|
Complete commands and switches fromm the node CLI command line.
|
||||||
|
Switches are determined in a context-specific manner.
|
||||||
|
"""
|
||||||
|
|
||||||
|
buffer = readline.get_line_buffer()
|
||||||
|
if state == 0:
|
||||||
|
options = None
|
||||||
|
try:
|
||||||
|
current_command, current_switch = self.get_current_command(buffer)
|
||||||
|
options = self.get_command_options(current_command, current_switch)
|
||||||
|
except IndexError:
|
||||||
|
pass
|
||||||
|
options = options or list(self.parse_commands().keys())
|
||||||
|
|
||||||
|
if not text: # first time
|
||||||
|
self.matches = options
|
||||||
|
else:
|
||||||
|
self.matches = [s for s in options if s and s.startswith(text)]
|
||||||
|
|
||||||
|
try:
|
||||||
|
match = self.matches[state]
|
||||||
|
except IndexError:
|
||||||
|
match = None
|
||||||
|
return match
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_commands(self)->List[object]:
|
||||||
|
"""
|
||||||
|
Return a list of all the client commands and invocations.
|
||||||
|
"""
|
||||||
|
return BaseCommand.get_commands() + BaseInvocation.get_invocations()
|
||||||
|
|
||||||
|
def get_current_command(self, buffer: str)->tuple[str, str]:
|
||||||
|
"""
|
||||||
|
Parse the readline buffer to find the most recent command and its switch.
|
||||||
|
"""
|
||||||
|
if len(buffer)==0:
|
||||||
|
return None, None
|
||||||
|
tokens = shlex.split(buffer)
|
||||||
|
command = None
|
||||||
|
switch = None
|
||||||
|
for t in tokens:
|
||||||
|
if t[0].isalpha():
|
||||||
|
if switch is None:
|
||||||
|
command = t
|
||||||
|
else:
|
||||||
|
switch = t
|
||||||
|
# don't try to autocomplete switches that are already complete
|
||||||
|
if switch and buffer.endswith(' '):
|
||||||
|
switch=None
|
||||||
|
return command or '', switch or ''
|
||||||
|
|
||||||
|
def parse_commands(self)->Dict[str, List[str]]:
|
||||||
|
"""
|
||||||
|
Return a dict in which the keys are the command name
|
||||||
|
and the values are the parameters the command takes.
|
||||||
|
"""
|
||||||
|
result = dict()
|
||||||
|
for command in self.commands:
|
||||||
|
hints = get_type_hints(command)
|
||||||
|
name = get_args(hints['type'])[0]
|
||||||
|
result.update({name:hints})
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_command_options(self, command: str, switch: str)->List[str]:
|
||||||
|
"""
|
||||||
|
Return all the parameters that can be passed to the command as
|
||||||
|
command-line switches. Returns None if the command is unrecognized.
|
||||||
|
"""
|
||||||
|
parsed_commands = self.parse_commands()
|
||||||
|
if command not in parsed_commands:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# handle switches in the format "-foo=bar"
|
||||||
|
argument = None
|
||||||
|
if switch and '=' in switch:
|
||||||
|
switch, argument = switch.split('=')
|
||||||
|
|
||||||
|
parameter = switch.strip('-')
|
||||||
|
if parameter in parsed_commands[command]:
|
||||||
|
if argument is None:
|
||||||
|
return self.get_parameter_options(parameter, parsed_commands[command][parameter])
|
||||||
|
else:
|
||||||
|
return [f"--{parameter}={x}" for x in self.get_parameter_options(parameter, parsed_commands[command][parameter])]
|
||||||
|
else:
|
||||||
|
return [f"--{x}" for x in parsed_commands[command].keys()]
|
||||||
|
|
||||||
|
def get_parameter_options(self, parameter: str, typehint)->List[str]:
|
||||||
|
"""
|
||||||
|
Given a parameter type (such as Literal), offers autocompletions.
|
||||||
|
"""
|
||||||
|
if get_origin(typehint) == Literal:
|
||||||
|
return get_args(typehint)
|
||||||
|
if parameter == 'model':
|
||||||
|
return self.manager.model_names()
|
||||||
|
|
||||||
|
def _pre_input_hook(self):
|
||||||
|
if self.linebuffer:
|
||||||
|
readline.insert_text(self.linebuffer)
|
||||||
|
readline.redisplay()
|
||||||
|
self.linebuffer = None
|
||||||
|
|
||||||
|
def set_autocompleter(model_manager: ModelManager) -> Completer:
|
||||||
|
global completer
|
||||||
|
|
||||||
|
if completer:
|
||||||
|
return completer
|
||||||
|
|
||||||
|
completer = Completer(model_manager)
|
||||||
|
|
||||||
|
readline.set_completer(completer.complete)
|
||||||
|
# pyreadline3 does not have a set_auto_history() method
|
||||||
|
try:
|
||||||
|
readline.set_auto_history(True)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
readline.set_pre_input_hook(completer._pre_input_hook)
|
||||||
|
readline.set_completer_delims(" ")
|
||||||
|
readline.parse_and_bind("tab: complete")
|
||||||
|
readline.parse_and_bind("set print-completions-horizontally off")
|
||||||
|
readline.parse_and_bind("set page-completions on")
|
||||||
|
readline.parse_and_bind("set skip-completed-text on")
|
||||||
|
readline.parse_and_bind("set show-all-if-ambiguous on")
|
||||||
|
|
||||||
|
histfile = Path(Globals.root, ".invoke_history")
|
||||||
|
try:
|
||||||
|
readline.read_history_file(histfile)
|
||||||
|
readline.set_history_length(1000)
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
except OSError: # file likely corrupted
|
||||||
|
newname = f"{histfile}.old"
|
||||||
|
print(
|
||||||
|
f"## Your history file {histfile} couldn't be loaded and may be corrupted. Renaming it to {newname}"
|
||||||
|
)
|
||||||
|
histfile.replace(Path(newname))
|
||||||
|
atexit.register(readline.write_history_file, histfile)
|
@ -14,6 +14,7 @@ from pydantic.fields import Field
|
|||||||
|
|
||||||
from ..backend import Args
|
from ..backend import Args
|
||||||
from .cli.commands import BaseCommand, CliContext, ExitCli, add_parsers, get_graph_execution_history
|
from .cli.commands import BaseCommand, CliContext, ExitCli, add_parsers, get_graph_execution_history
|
||||||
|
from .cli.completer import set_autocompleter
|
||||||
from .invocations import *
|
from .invocations import *
|
||||||
from .invocations.baseinvocation import BaseInvocation
|
from .invocations.baseinvocation import BaseInvocation
|
||||||
from .services.events import EventServiceBase
|
from .services.events import EventServiceBase
|
||||||
@ -130,6 +131,12 @@ def invoke_cli():
|
|||||||
config.parse_args()
|
config.parse_args()
|
||||||
model_manager = get_model_manager(config)
|
model_manager = get_model_manager(config)
|
||||||
|
|
||||||
|
# This initializes the autocompleter and returns it.
|
||||||
|
# Currently nothing is done with the returned Completer
|
||||||
|
# object, but the object can be used to change autocompletion
|
||||||
|
# behavior on the fly, if desired.
|
||||||
|
completer = set_autocompleter(model_manager)
|
||||||
|
|
||||||
events = EventServiceBase()
|
events = EventServiceBase()
|
||||||
|
|
||||||
output_folder = os.path.abspath(
|
output_folder = os.path.abspath(
|
||||||
@ -162,8 +169,8 @@ def invoke_cli():
|
|||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
cmd_input = input("> ")
|
cmd_input = input("invoke> ")
|
||||||
except KeyboardInterrupt:
|
except (KeyboardInterrupt, EOFError):
|
||||||
# Ctrl-c exits
|
# Ctrl-c exits
|
||||||
break
|
break
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user