""" Readline helper functions for invoke.py. You may import the global singleton `completer` to get access to the completer object itself. This is useful when you want to autocomplete seeds: from invokeai.frontend.CLI.readline import completer completer.add_seed(18247566) completer.add_seed(9281839) """ import atexit import os import re from ...backend.args import Args from ...backend.globals import Globals from ...backend.stable_diffusion import HuggingFaceConceptsLibrary # ---------------readline utilities--------------------- try: import readline readline_available = True except (ImportError, ModuleNotFoundError) as e: print(f"** An error occurred when loading the readline module: {str(e)}") readline_available = False IMG_EXTENSIONS = (".png", ".jpg", ".jpeg", ".PNG", ".JPG", ".JPEG", ".gif", ".GIF") WEIGHT_EXTENSIONS = (".ckpt", ".vae", ".safetensors") TEXT_EXTENSIONS = (".txt", ".TXT") CONFIG_EXTENSIONS = (".yaml", ".yml") COMMANDS = ( "--steps", "-s", "--seed", "-S", "--iterations", "-n", "--width", "-W", "--height", "-H", "--cfg_scale", "-C", "--threshold", "--perlin", "--grid", "-g", "--individual", "-i", "--save_intermediates", "--init_img", "-I", "--init_mask", "-M", "--init_color", "--strength", "-f", "--variants", "-v", "--outdir", "-o", "--sampler", "-A", "-m", "--embedding_path", "--device", "--grid", "-g", "--facetool", "-ft", "--facetool_strength", "-G", "--codeformer_fidelity", "-cf", "--upscale", "-U", "-save_orig", "--save_original", "--log_tokenization", "-t", "--hires_fix", "--inpaint_replace", "-r", "--png_compression", "-z", "--text_mask", "-tm", "--h_symmetry_time_pct", "--v_symmetry_time_pct", "!fix", "!fetch", "!replay", "!history", "!search", "!clear", "!models", "!switch", "!import_model", "!optimize_model", "!convert_model", "!edit_model", "!del_model", "!mask", "!triggers", ) MODEL_COMMANDS = ( "!switch", "!edit_model", "!del_model", ) CKPT_MODEL_COMMANDS = ("!optimize_model",) WEIGHT_COMMANDS = ( "!import_model", "!convert_model", ) IMG_PATH_COMMANDS = ("--outdir[=\s]",) TEXT_PATH_COMMANDS = ("!replay",) IMG_FILE_COMMANDS = ( "!fix", "!fetch", "!mask", "--init_img[=\s]", "-I", "--init_mask[=\s]", "-M", "--init_color[=\s]", "--embedding_path[=\s]", ) path_regexp = "(" + "|".join(IMG_PATH_COMMANDS + IMG_FILE_COMMANDS) + ")\s*\S*$" weight_regexp = "(" + "|".join(WEIGHT_COMMANDS) + ")\s*\S*$" text_regexp = "(" + "|".join(TEXT_PATH_COMMANDS) + ")\s*\S*$" class Completer(object): def __init__(self, options, models={}): self.options = sorted(options) self.models = models self.seeds = set() self.matches = list() self.default_dir = None self.linebuffer = None self.auto_history_active = True self.extensions = None self.concepts = None self.embedding_terms = set() return def complete(self, text, state): """ Completes invoke command line. BUG: it doesn't correctly complete files that have spaces in the name. """ buffer = readline.get_line_buffer() if state == 0: # extensions defined, so go directly into path completion mode if self.extensions is not None: self.matches = self._path_completions(text, state, self.extensions) # looking for an image file elif re.search(path_regexp, buffer): do_shortcut = re.search("^" + "|".join(IMG_FILE_COMMANDS), buffer) self.matches = self._path_completions( text, state, IMG_EXTENSIONS, shortcut_ok=do_shortcut ) # looking for a seed elif re.search("(-S\s*|--seed[=\s])\d*$", buffer): self.matches = self._seed_completions(text, state) # looking for an embedding concept elif re.search("<[\w-]*$", buffer): self.matches = self._concept_completions(text, state) # looking for a model elif re.match("^" + "|".join(MODEL_COMMANDS), buffer): self.matches = self._model_completions(text, state) # looking for a ckpt model elif re.match("^" + "|".join(CKPT_MODEL_COMMANDS), buffer): self.matches = self._model_completions(text, state, ckpt_only=True) elif re.search(weight_regexp, buffer): self.matches = self._path_completions( text, state, WEIGHT_EXTENSIONS, default_dir=Globals.root, ) elif re.search(text_regexp, buffer): self.matches = self._path_completions(text, state, TEXT_EXTENSIONS) # This is the first time for this text, so build a match list. elif text: self.matches = [s for s in self.options if s and s.startswith(text)] else: self.matches = self.options[:] # Return the state'th item from the match list, # if we have that many. try: response = self.matches[state] except IndexError: response = None return response def complete_extensions(self, extensions: list): """ If called with a list of extensions, will force completer to do file path completions. """ self.extensions = extensions def add_history(self, line): """ Pass thru to readline """ if not self.auto_history_active: readline.add_history(line) def clear_history(self): """ Pass clear_history() thru to readline """ readline.clear_history() def search_history(self, match: str): """ Like show_history() but only shows items that contain the match string. """ self.show_history(match) def remove_history_item(self, pos): readline.remove_history_item(pos) def add_seed(self, seed): """ Add a seed to the autocomplete list for display when -S is autocompleted. """ if seed is not None: self.seeds.add(str(seed)) def set_default_dir(self, path): self.default_dir = path def set_options(self, options): self.options = options def get_line(self, index): try: line = self.get_history_item(index) except IndexError: return None return line def get_current_history_length(self): return readline.get_current_history_length() def get_history_item(self, index): return readline.get_history_item(index) def show_history(self, match=None): """ Print the session history using the pydoc pager """ import pydoc lines = list() h_len = self.get_current_history_length() if h_len < 1: print("") return for i in range(0, h_len): line = self.get_history_item(i + 1) if match and match not in line: continue lines.append(f"[{i+1}] {line}") pydoc.pager("\n".join(lines)) def set_line(self, line) -> None: """ Set the default string displayed in the next line of input. """ self.linebuffer = line readline.redisplay() def update_models(self, models: dict) -> None: """ update our list of models """ self.models = models def _seed_completions(self, text, state): m = re.search("(-S\s?|--seed[=\s]?)(\d*)", text) if m: switch = m.groups()[0] partial = m.groups()[1] else: switch = "" partial = text matches = list() for s in self.seeds: if s.startswith(partial): matches.append(switch + s) matches.sort() return matches def add_embedding_terms(self, terms: list[str]): self.embedding_terms = set(terms) if self.concepts: self.embedding_terms.update(set(self.concepts.list_concepts())) def _concept_completions(self, text, state): if self.concepts is None: # cache Concepts() instance so we can check for updates in concepts_list during runtime. self.concepts = HuggingFaceConceptsLibrary() self.embedding_terms.update(set(self.concepts.list_concepts())) else: self.embedding_terms.update(set(self.concepts.list_concepts())) partial = text[1:] # this removes the leading '<' if len(partial) == 0: return list(self.embedding_terms) # whole dump - think if user wants this! matches = list() for concept in self.embedding_terms: if concept.startswith(partial): matches.append(f"<{concept}>") matches.sort() return matches def _model_completions(self, text, state, ckpt_only=False): m = re.search("(!switch\s+)(\w*)", text) if m: switch = m.groups()[0] partial = m.groups()[1] else: switch = "" partial = text matches = list() for s in self.models: name = self.models[s]["model_name"] format = self.models[s]["format"] if format == "vae": continue if ckpt_only and format != "ckpt": continue if name.startswith(partial): matches.append(switch + name) matches.sort() return matches def _pre_input_hook(self): if self.linebuffer: readline.insert_text(self.linebuffer) readline.redisplay() self.linebuffer = None def _path_completions( self, text, state, extensions, shortcut_ok=True, default_dir: str = "" ): # separate the switch from the partial path match = re.search("^(-\w|--\w+=?)(.*)", text) if match is None: switch = None partial_path = text else: switch, partial_path = match.groups() partial_path = partial_path.lstrip() matches = list() path = os.path.expanduser(partial_path) if os.path.isdir(path): dir = path elif os.path.dirname(path) != "": dir = os.path.dirname(path) else: dir = default_dir if os.path.exists(default_dir) else "" path = os.path.join(dir, path) dir_list = os.listdir(dir or ".") if shortcut_ok and os.path.exists(self.default_dir) and dir == "": dir_list += os.listdir(self.default_dir) for node in dir_list: if node.startswith(".") and len(node) > 1: continue full_path = os.path.join(dir, node) if not (node.endswith(extensions) or os.path.isdir(full_path)): continue if path and not full_path.startswith(path): continue if switch is None: match_path = os.path.join(dir, node) matches.append( match_path + "/" if os.path.isdir(full_path) else match_path ) elif os.path.isdir(full_path): matches.append( switch + os.path.join(os.path.dirname(full_path), node) + "/" ) elif node.endswith(extensions): matches.append(switch + os.path.join(os.path.dirname(full_path), node)) return matches class DummyCompleter(Completer): def __init__(self, options): super().__init__(options) self.history = list() def add_history(self, line): self.history.append(line) def clear_history(self): self.history = list() def get_current_history_length(self): return len(self.history) def get_history_item(self, index): return self.history[index - 1] def remove_history_item(self, index): return self.history.pop(index - 1) def set_line(self, line): print(f"# {line}") def generic_completer(commands: list) -> Completer: if readline_available: completer = Completer(commands, []) readline.set_completer(completer.complete) readline.set_pre_input_hook(completer._pre_input_hook) readline.set_completer_delims(" ") readline.parse_and_bind("tab: complete") readline.parse_and_bind("set print-completions-horizontally off") readline.parse_and_bind("set page-completions on") readline.parse_and_bind("set skip-completed-text on") readline.parse_and_bind("set show-all-if-ambiguous on") else: completer = DummyCompleter(commands) return completer def get_completer(opt: Args, models=[]) -> Completer: if readline_available: completer = Completer(COMMANDS, models) readline.set_completer(completer.complete) # pyreadline3 does not have a set_auto_history() method try: readline.set_auto_history(False) completer.auto_history_active = False except: completer.auto_history_active = True readline.set_pre_input_hook(completer._pre_input_hook) readline.set_completer_delims(" ") readline.parse_and_bind("tab: complete") readline.parse_and_bind("set print-completions-horizontally off") readline.parse_and_bind("set page-completions on") readline.parse_and_bind("set skip-completed-text on") readline.parse_and_bind("set show-all-if-ambiguous on") outdir = os.path.expanduser(opt.outdir) if os.path.isabs(outdir): histfile = os.path.join(outdir, ".invoke_history") else: histfile = os.path.join(Globals.root, outdir, ".invoke_history") try: readline.read_history_file(histfile) readline.set_history_length(1000) except FileNotFoundError: pass except OSError: # file likely corrupted newname = f"{histfile}.old" print( f"## Your history file {histfile} couldn't be loaded and may be corrupted. Renaming it to {newname}" ) os.replace(histfile, newname) atexit.register(readline.write_history_file, histfile) else: completer = DummyCompleter(COMMANDS) return completer