InvokeAI/invokeai/frontend/CLI/readline.py
2023-03-03 01:02:00 -05:00

498 lines
14 KiB
Python

"""
Readline helper functions for invoke.py.
You may import the global singleton `completer` to get access to the
completer object itself. This is useful when you want to autocomplete
seeds:
from invokeai.frontend.CLI.readline import completer
completer.add_seed(18247566)
completer.add_seed(9281839)
"""
import atexit
import os
import re
from ...backend.args import Args
from ...backend.globals import Globals
from ...backend.stable_diffusion import HuggingFaceConceptsLibrary
# ---------------readline utilities---------------------
try:
import readline
readline_available = True
except (ImportError, ModuleNotFoundError) as e:
print(f"** An error occurred when loading the readline module: {str(e)}")
readline_available = False
IMG_EXTENSIONS = (".png", ".jpg", ".jpeg", ".PNG", ".JPG", ".JPEG", ".gif", ".GIF")
WEIGHT_EXTENSIONS = (".ckpt", ".vae", ".safetensors")
TEXT_EXTENSIONS = (".txt", ".TXT")
CONFIG_EXTENSIONS = (".yaml", ".yml")
COMMANDS = (
"--steps",
"-s",
"--seed",
"-S",
"--iterations",
"-n",
"--width",
"-W",
"--height",
"-H",
"--cfg_scale",
"-C",
"--threshold",
"--perlin",
"--grid",
"-g",
"--individual",
"-i",
"--save_intermediates",
"--init_img",
"-I",
"--init_mask",
"-M",
"--init_color",
"--strength",
"-f",
"--variants",
"-v",
"--outdir",
"-o",
"--sampler",
"-A",
"-m",
"--embedding_path",
"--device",
"--grid",
"-g",
"--facetool",
"-ft",
"--facetool_strength",
"-G",
"--codeformer_fidelity",
"-cf",
"--upscale",
"-U",
"-save_orig",
"--save_original",
"--log_tokenization",
"-t",
"--hires_fix",
"--inpaint_replace",
"-r",
"--png_compression",
"-z",
"--text_mask",
"-tm",
"--h_symmetry_time_pct",
"--v_symmetry_time_pct",
"!fix",
"!fetch",
"!replay",
"!history",
"!search",
"!clear",
"!models",
"!switch",
"!import_model",
"!optimize_model",
"!convert_model",
"!edit_model",
"!del_model",
"!mask",
"!triggers",
)
MODEL_COMMANDS = (
"!switch",
"!edit_model",
"!del_model",
)
CKPT_MODEL_COMMANDS = ("!optimize_model",)
WEIGHT_COMMANDS = (
"!import_model",
"!convert_model",
)
IMG_PATH_COMMANDS = ("--outdir[=\s]",)
TEXT_PATH_COMMANDS = ("!replay",)
IMG_FILE_COMMANDS = (
"!fix",
"!fetch",
"!mask",
"--init_img[=\s]",
"-I",
"--init_mask[=\s]",
"-M",
"--init_color[=\s]",
"--embedding_path[=\s]",
)
path_regexp = "(" + "|".join(IMG_PATH_COMMANDS + IMG_FILE_COMMANDS) + ")\s*\S*$"
weight_regexp = "(" + "|".join(WEIGHT_COMMANDS) + ")\s*\S*$"
text_regexp = "(" + "|".join(TEXT_PATH_COMMANDS) + ")\s*\S*$"
class Completer(object):
def __init__(self, options, models={}):
self.options = sorted(options)
self.models = models
self.seeds = set()
self.matches = list()
self.default_dir = None
self.linebuffer = None
self.auto_history_active = True
self.extensions = None
self.concepts = None
self.embedding_terms = set()
return
def complete(self, text, state):
"""
Completes invoke command line.
BUG: it doesn't correctly complete files that have spaces in the name.
"""
buffer = readline.get_line_buffer()
if state == 0:
# extensions defined, so go directly into path completion mode
if self.extensions is not None:
self.matches = self._path_completions(text, state, self.extensions)
# looking for an image file
elif re.search(path_regexp, buffer):
do_shortcut = re.search("^" + "|".join(IMG_FILE_COMMANDS), buffer)
self.matches = self._path_completions(
text, state, IMG_EXTENSIONS, shortcut_ok=do_shortcut
)
# looking for a seed
elif re.search("(-S\s*|--seed[=\s])\d*$", buffer):
self.matches = self._seed_completions(text, state)
# looking for an embedding concept
elif re.search("<[\w-]*$", buffer):
self.matches = self._concept_completions(text, state)
# looking for a model
elif re.match("^" + "|".join(MODEL_COMMANDS), buffer):
self.matches = self._model_completions(text, state)
# looking for a ckpt model
elif re.match("^" + "|".join(CKPT_MODEL_COMMANDS), buffer):
self.matches = self._model_completions(text, state, ckpt_only=True)
elif re.search(weight_regexp, buffer):
self.matches = self._path_completions(
text,
state,
WEIGHT_EXTENSIONS,
default_dir=Globals.root,
)
elif re.search(text_regexp, buffer):
self.matches = self._path_completions(text, state, TEXT_EXTENSIONS)
# This is the first time for this text, so build a match list.
elif text:
self.matches = [s for s in self.options if s and s.startswith(text)]
else:
self.matches = self.options[:]
# Return the state'th item from the match list,
# if we have that many.
try:
response = self.matches[state]
except IndexError:
response = None
return response
def complete_extensions(self, extensions: list):
"""
If called with a list of extensions, will force completer
to do file path completions.
"""
self.extensions = extensions
def add_history(self, line):
"""
Pass thru to readline
"""
if not self.auto_history_active:
readline.add_history(line)
def clear_history(self):
"""
Pass clear_history() thru to readline
"""
readline.clear_history()
def search_history(self, match: str):
"""
Like show_history() but only shows items that
contain the match string.
"""
self.show_history(match)
def remove_history_item(self, pos):
readline.remove_history_item(pos)
def add_seed(self, seed):
"""
Add a seed to the autocomplete list for display when -S is autocompleted.
"""
if seed is not None:
self.seeds.add(str(seed))
def set_default_dir(self, path):
self.default_dir = path
def set_options(self, options):
self.options = options
def get_line(self, index):
try:
line = self.get_history_item(index)
except IndexError:
return None
return line
def get_current_history_length(self):
return readline.get_current_history_length()
def get_history_item(self, index):
return readline.get_history_item(index)
def show_history(self, match=None):
"""
Print the session history using the pydoc pager
"""
import pydoc
lines = list()
h_len = self.get_current_history_length()
if h_len < 1:
print("<empty history>")
return
for i in range(0, h_len):
line = self.get_history_item(i + 1)
if match and match not in line:
continue
lines.append(f"[{i+1}] {line}")
pydoc.pager("\n".join(lines))
def set_line(self, line) -> None:
"""
Set the default string displayed in the next line of input.
"""
self.linebuffer = line
readline.redisplay()
def update_models(self, models: dict) -> None:
"""
update our list of models
"""
self.models = models
def _seed_completions(self, text, state):
m = re.search("(-S\s?|--seed[=\s]?)(\d*)", text)
if m:
switch = m.groups()[0]
partial = m.groups()[1]
else:
switch = ""
partial = text
matches = list()
for s in self.seeds:
if s.startswith(partial):
matches.append(switch + s)
matches.sort()
return matches
def add_embedding_terms(self, terms: list[str]):
self.embedding_terms = set(terms)
if self.concepts:
self.embedding_terms.update(set(self.concepts.list_concepts()))
def _concept_completions(self, text, state):
if self.concepts is None:
# cache Concepts() instance so we can check for updates in concepts_list during runtime.
self.concepts = HuggingFaceConceptsLibrary()
self.embedding_terms.update(set(self.concepts.list_concepts()))
else:
self.embedding_terms.update(set(self.concepts.list_concepts()))
partial = text[1:] # this removes the leading '<'
if len(partial) == 0:
return list(self.embedding_terms) # whole dump - think if user wants this!
matches = list()
for concept in self.embedding_terms:
if concept.startswith(partial):
matches.append(f"<{concept}>")
matches.sort()
return matches
def _model_completions(self, text, state, ckpt_only=False):
m = re.search("(!switch\s+)(\w*)", text)
if m:
switch = m.groups()[0]
partial = m.groups()[1]
else:
switch = ""
partial = text
matches = list()
for s in self.models:
format = self.models[s]["format"]
if format == "vae":
continue
if ckpt_only and format != "ckpt":
continue
if s.startswith(partial):
matches.append(switch + s)
matches.sort()
return matches
def _pre_input_hook(self):
if self.linebuffer:
readline.insert_text(self.linebuffer)
readline.redisplay()
self.linebuffer = None
def _path_completions(
self, text, state, extensions, shortcut_ok=True, default_dir: str = ""
):
# separate the switch from the partial path
match = re.search("^(-\w|--\w+=?)(.*)", text)
if match is None:
switch = None
partial_path = text
else:
switch, partial_path = match.groups()
partial_path = partial_path.lstrip()
matches = list()
path = os.path.expanduser(partial_path)
if os.path.isdir(path):
dir = path
elif os.path.dirname(path) != "":
dir = os.path.dirname(path)
else:
dir = default_dir if os.path.exists(default_dir) else ""
path = os.path.join(dir, path)
dir_list = os.listdir(dir or ".")
if shortcut_ok and os.path.exists(self.default_dir) and dir == "":
dir_list += os.listdir(self.default_dir)
for node in dir_list:
if node.startswith(".") and len(node) > 1:
continue
full_path = os.path.join(dir, node)
if not (node.endswith(extensions) or os.path.isdir(full_path)):
continue
if path and not full_path.startswith(path):
continue
if switch is None:
match_path = os.path.join(dir, node)
matches.append(
match_path + "/" if os.path.isdir(full_path) else match_path
)
elif os.path.isdir(full_path):
matches.append(
switch + os.path.join(os.path.dirname(full_path), node) + "/"
)
elif node.endswith(extensions):
matches.append(switch + os.path.join(os.path.dirname(full_path), node))
return matches
class DummyCompleter(Completer):
def __init__(self, options):
super().__init__(options)
self.history = list()
def add_history(self, line):
self.history.append(line)
def clear_history(self):
self.history = list()
def get_current_history_length(self):
return len(self.history)
def get_history_item(self, index):
return self.history[index - 1]
def remove_history_item(self, index):
return self.history.pop(index - 1)
def set_line(self, line):
print(f"# {line}")
def generic_completer(commands: list) -> Completer:
if readline_available:
completer = Completer(commands, [])
readline.set_completer(completer.complete)
readline.set_pre_input_hook(completer._pre_input_hook)
readline.set_completer_delims(" ")
readline.parse_and_bind("tab: complete")
readline.parse_and_bind("set print-completions-horizontally off")
readline.parse_and_bind("set page-completions on")
readline.parse_and_bind("set skip-completed-text on")
readline.parse_and_bind("set show-all-if-ambiguous on")
else:
completer = DummyCompleter(commands)
return completer
def get_completer(opt: Args, models=[]) -> Completer:
if readline_available:
completer = Completer(COMMANDS, models)
readline.set_completer(completer.complete)
# pyreadline3 does not have a set_auto_history() method
try:
readline.set_auto_history(False)
completer.auto_history_active = False
except:
completer.auto_history_active = True
readline.set_pre_input_hook(completer._pre_input_hook)
readline.set_completer_delims(" ")
readline.parse_and_bind("tab: complete")
readline.parse_and_bind("set print-completions-horizontally off")
readline.parse_and_bind("set page-completions on")
readline.parse_and_bind("set skip-completed-text on")
readline.parse_and_bind("set show-all-if-ambiguous on")
outdir = os.path.expanduser(opt.outdir)
if os.path.isabs(outdir):
histfile = os.path.join(outdir, ".invoke_history")
else:
histfile = os.path.join(Globals.root, outdir, ".invoke_history")
try:
readline.read_history_file(histfile)
readline.set_history_length(1000)
except FileNotFoundError:
pass
except OSError: # file likely corrupted
newname = f"{histfile}.old"
print(
f"## Your history file {histfile} couldn't be loaded and may be corrupted. Renaming it to {newname}"
)
os.replace(histfile, newname)
atexit.register(readline.write_history_file, histfile)
else:
completer = DummyCompleter(COMMANDS)
return completer