implement history viewing & replaying in CLI

- Enhance tab completion functionality
- Each of the switches that read a filepath (e.g. --init_img) will trigger file path completion. The
  -S switch will display a list of recently-used seeds.
- Added new !fetch command to retrieve the metadata from a previously-generated image and populate the
  readline linebuffer with the appropriate editable command to regenerate.
- Added new !history command to display previous commands and reload them for modification.
- The !fetch and !fix commands both autocomplete *and* search automatically through the current
  outdir for files.
- The completer maintains a list of recently used seeds and will try to autocomplete them.
This commit is contained in:
Lincoln Stein 2022-09-27 14:27:55 -04:00
parent 36c9a7d39c
commit fe00a8c05c
3 changed files with 353 additions and 137 deletions

View File

@ -89,6 +89,7 @@ import os
import re import re
import copy import copy
import base64 import base64
import functools
import ldm.dream.pngwriter import ldm.dream.pngwriter
from ldm.dream.conditioning import split_weighted_subprompts from ldm.dream.conditioning import split_weighted_subprompts
@ -221,9 +222,15 @@ class Args(object):
# outpainting parameters # outpainting parameters
if a['out_direction']: if a['out_direction']:
switches.append(f'-D {" ".join([str(u) for u in a["out_direction"]])}') switches.append(f'-D {" ".join([str(u) for u in a["out_direction"]])}')
# LS: slight semantic drift which needs addressing in the future:
# 1. Variations come out of the stored metadata as a packed string with the keyword "variations"
# 2. However, they come out of the CLI (and probably web) with the keyword "with_variations" and
# in broken-out form. Variation (1) should be changed to comply with (2)
if a['with_variations']: if a['with_variations']:
formatted_variations = ','.join(f'{seed}:{weight}' for seed, weight in (a["with_variations"])) formatted_variations = ','.join(f'{seed}:{weight}' for seed, weight in (a["variations"]))
switches.append(f'-V {formatted_variations}') switches.append(f'-V {a["formatted_variations"]}')
if 'variations' in a:
switches.append(f'-V {a["variations"]}')
return ' '.join(switches) return ' '.join(switches)
def __getattribute__(self,name): def __getattribute__(self,name):
@ -750,6 +757,7 @@ def metadata_dumps(opt,
return metadata return metadata
@functools.lru_cache(maxsize=50)
def metadata_from_png(png_file_path) -> Args: def metadata_from_png(png_file_path) -> Args:
''' '''
Given the path to a PNG file created by dream.py, retrieves Given the path to a PNG file created by dream.py, retrieves
@ -762,7 +770,11 @@ def metadata_from_png(png_file_path) -> Args:
else: else:
return legacy_metadata_load(meta,png_file_path) return legacy_metadata_load(meta,png_file_path)
def metadata_loads(metadata) ->list: def dream_cmd_from_png(png_file_path):
opt = metadata_from_png(png_file_path)
return opt.dream_prompt_str()
def metadata_loads(metadata) -> list:
''' '''
Takes the dictionary corresponding to RFC266 (https://github.com/lstein/stable-diffusion/issues/266) Takes the dictionary corresponding to RFC266 (https://github.com/lstein/stable-diffusion/issues/266)
and returns a series of opt objects for each of the images described in the dictionary. Note that this and returns a series of opt objects for each of the images described in the dictionary. Note that this

View File

@ -1,38 +1,92 @@
""" """
Readline helper functions for dream.py (linux and mac only). Readline helper functions for dream.py (linux and mac only).
You may import the global singleton `completer` to get access to the
completer object itself. This is useful when you want to autocomplete
seeds:
from ldm.dream.readline import completer
completer.add_seed(18247566)
completer.add_seed(9281839)
""" """
import os import os
import re import re
import atexit import atexit
completer = None
# ---------------readline utilities--------------------- # ---------------readline utilities---------------------
try: try:
import readline import readline
readline_available = True readline_available = True
except: except:
readline_available = False readline_available = False
IMG_EXTENSIONS = ('.png','.jpg','.jpeg')
COMMANDS = (
'--steps','-s',
'--seed','-S',
'--iterations','-n',
'--width','-W','--height','-H',
'--cfg_scale','-C',
'--grid','-g',
'--individual','-i',
'--init_img','-I',
'--init_mask','-M',
'--init_color',
'--strength','-f',
'--variants','-v',
'--outdir','-o',
'--sampler','-A','-m',
'--embedding_path',
'--device',
'--grid','-g',
'--gfpgan_strength','-G',
'--upscale','-U',
'-save_orig','--save_original',
'--skip_normalize','-x',
'--log_tokenization','-t',
'!fix','!fetch',
)
IMG_PATH_COMMANDS = (
'--init_img[=\s]','-I',
'--init_mask[=\s]','-M',
'--init_color[=\s]',
'--embedding_path[=\s]',
'--outdir[=\s]'
)
IMG_FILE_COMMANDS=(
'!fix',
'!fetch',
)
path_regexp = '('+'|'.join(IMG_PATH_COMMANDS+IMG_FILE_COMMANDS) + ')\s*\S*$'
class Completer: class Completer:
def __init__(self, options): def __init__(self, options):
self.options = sorted(options) self.options = sorted(options)
self.seeds = set()
self.matches = list()
self.default_dir = None
self.linebuffer = None
return return
def complete(self, text, state): def complete(self, text, state):
'''
Completes dream command line.
BUG: it doesn't correctly complete files that have spaces in the name.
'''
buffer = readline.get_line_buffer() buffer = readline.get_line_buffer()
if text.startswith(('-I', '--init_img','-M','--init_mask',
'--init_color')):
return self._path_completions(text, state, ('.png','.jpg','.jpeg'))
if buffer.strip().endswith('pp') or text.startswith(('.', '/')):
return self._path_completions(text, state, ('.png','.jpg','.jpeg'))
response = None
if state == 0: if state == 0:
if re.search(path_regexp,buffer):
do_shortcut = re.search('^'+'|'.join(IMG_FILE_COMMANDS),buffer)
self.matches = self._path_completions(text, state, IMG_EXTENSIONS,shortcut_ok=do_shortcut)
# looking for a seed
elif re.search('(-S\s*|--seed[=\s])\d*$',buffer):
self.matches= self._seed_completions(text,state)
# This is the first time for this text, so build a match list. # This is the first time for this text, so build a match list.
if text: elif text:
self.matches = [ self.matches = [
s for s in self.options if s and s.startswith(text) s for s in self.options if s and s.startswith(text)
] ]
@ -47,81 +101,158 @@ class Completer:
response = None response = None
return response return response
def _path_completions(self, text, state, extensions): def add_to_history(self,line):
# get the path so far '''
# TODO: replace this mess with a regular expression match This is a no-op; readline handles this automatically. But we provide it
if text.startswith('-I'): for DummyReadline compatibility.
path = text.replace('-I', '', 1).lstrip() '''
elif text.startswith('--init_img='): pass
path = text.replace('--init_img=', '', 1).lstrip()
elif text.startswith('--init_mask='): def add_seed(self, seed):
path = text.replace('--init_mask=', '', 1).lstrip() '''
elif text.startswith('-M'): Add a seed to the autocomplete list for display when -S is autocompleted.
path = text.replace('-M', '', 1).lstrip() '''
elif text.startswith('--init_color='): if seed is not None:
path = text.replace('--init_color=', '', 1).lstrip() self.seeds.add(str(seed))
def set_default_dir(self, path):
self.default_dir=path
def get_line(self,index):
try:
line = self.get_history_item(index)
except IndexError:
return None
return line
def get_current_history_length(self):
return readline.get_current_history_length()
def get_history_item(self,index):
return readline.get_history_item(index)
def show_history(self):
'''
Print the session history using the pydoc pager
'''
import pydoc
lines = list()
h_len = self.get_current_history_length()
if h_len < 1:
print('<empty history>')
return
for i in range(0,h_len):
lines.append(f'[{i+1}] {self.get_history_item(i+1)}')
pydoc.pager('\n'.join(lines))
def set_line(self,line)->None:
self.linebuffer = line
readline.redisplay()
def _seed_completions(self, text, state):
m = re.search('(-S\s?|--seed[=\s]?)(\d*)',text)
if m:
switch = m.groups()[0]
partial = m.groups()[1]
else: else:
path = text switch = ''
partial = text
matches = list() matches = list()
for s in self.seeds:
if s.startswith(partial):
matches.append(switch+s)
matches.sort()
return matches
path = os.path.expanduser(path) def _pre_input_hook(self):
if len(path) == 0: if self.linebuffer:
matches.append(text + './') readline.insert_text(self.linebuffer)
readline.redisplay()
self.linebuffer = None
def _path_completions(self, text, state, extensions, shortcut_ok=False):
# separate the switch from the partial path
match = re.search('^(-\w|--\w+=?)(.*)',text)
if match is None:
switch = None
partial_path = text
else: else:
switch,partial_path = match.groups()
partial_path = partial_path.lstrip()
matches = list()
path = os.path.expanduser(partial_path)
if os.path.isdir(path):
dir = path
elif os.path.dirname(path) != '':
dir = os.path.dirname(path) dir = os.path.dirname(path)
dir_list = os.listdir(dir) else:
for n in dir_list: dir = ''
if n.startswith('.') and len(n) > 1: path= os.path.join(dir,path)
continue
full_path = os.path.join(dir, n)
if full_path.startswith(path):
if os.path.isdir(full_path):
matches.append(
os.path.join(os.path.dirname(text), n) + '/'
)
elif n.endswith(extensions):
matches.append(os.path.join(os.path.dirname(text), n))
try: dir_list = os.listdir(dir or '.')
response = matches[state] if shortcut_ok and os.path.exists(self.default_dir) and dir=='':
except IndexError: dir_list += os.listdir(self.default_dir)
response = None
return response
for node in dir_list:
if node.startswith('.') and len(node) > 1:
continue
full_path = os.path.join(dir, node)
if not (node.endswith(extensions) or os.path.isdir(full_path)):
continue
if not full_path.startswith(path):
continue
if switch is None:
match_path = os.path.join(dir,node)
matches.append(match_path+'/' if os.path.isdir(full_path) else match_path)
elif os.path.isdir(full_path):
matches.append(
switch+os.path.join(os.path.dirname(full_path), node) + '/'
)
elif node.endswith(extensions):
matches.append(
switch+os.path.join(os.path.dirname(full_path), node)
)
return matches
class DummyCompleter(Completer):
def __init__(self,options):
super().__init__(options)
self.history = list()
def add_to_history(self,line):
self.history.append(line)
def get_current_history_length(self):
return len(self.history)
def get_history_item(self,index):
return self.history[index-1]
def set_line(self,line):
print(f'# {line}')
if readline_available: if readline_available:
completer = Completer(COMMANDS)
readline.set_completer( readline.set_completer(
Completer( completer.complete
[
'--steps','-s',
'--seed','-S',
'--iterations','-n',
'--width','-W','--height','-H',
'--cfg_scale','-C',
'--grid','-g',
'--individual','-i',
'--init_img','-I',
'--init_mask','-M',
'--init_color',
'--strength','-f',
'--variants','-v',
'--outdir','-o',
'--sampler','-A','-m',
'--embedding_path',
'--device',
'--grid','-g',
'--gfpgan_strength','-G',
'--upscale','-U',
'-save_orig','--save_original',
'--skip_normalize','-x',
'--log_tokenization','t',
]
).complete
) )
readline.set_pre_input_hook(completer._pre_input_hook)
readline.set_completer_delims(' ') readline.set_completer_delims(' ')
readline.parse_and_bind('tab: complete') readline.parse_and_bind('tab: complete')
readline.parse_and_bind('set print-completions-horizontally off')
readline.parse_and_bind('set page-completions on')
readline.parse_and_bind('set skip-completed-text on')
readline.parse_and_bind('set bell-style visible')
readline.parse_and_bind('set show-all-if-ambiguous on')
histfile = os.path.join(os.path.expanduser('~'), '.dream_history') histfile = os.path.join(os.path.expanduser('~'), '.dream_history')
try: try:
readline.read_history_file(histfile) readline.read_history_file(histfile)
@ -129,3 +260,6 @@ if readline_available:
except FileNotFoundError: except FileNotFoundError:
pass pass
atexit.register(readline.write_history_file, histfile) atexit.register(readline.write_history_file, histfile)
else:
completer = DummyCompleter(COMMANDS)

View File

@ -9,8 +9,8 @@ import copy
import warnings import warnings
import time import time
sys.path.append('.') # corrects a weird problem on Macs sys.path.append('.') # corrects a weird problem on Macs
import ldm.dream.readline from ldm.dream.readline import completer
from ldm.dream.args import Args, metadata_dumps, metadata_from_png from ldm.dream.args import Args, metadata_dumps, metadata_from_png, dream_cmd_from_png
from ldm.dream.pngwriter import PngWriter from ldm.dream.pngwriter import PngWriter
from ldm.dream.image_util import make_grid from ldm.dream.image_util import make_grid
from ldm.dream.log import write_log from ldm.dream.log import write_log
@ -142,7 +142,10 @@ def main_loop(gen, opt, infile):
while not done: while not done:
operation = 'generate' # default operation, alternative is 'postprocess' operation = 'generate' # default operation, alternative is 'postprocess'
if completer:
completer.set_default_dir(opt.outdir)
try: try:
command = get_next_command(infile) command = get_next_command(infile)
except EOFError: except EOFError:
@ -160,16 +163,28 @@ def main_loop(gen, opt, infile):
done = True done = True
break break
if command.startswith( if command.startswith('!dream'): # in case a stored prompt still contains the !dream command
'!dream'
): # in case a stored prompt still contains the !dream command
command = command.replace('!dream ','',1) command = command.replace('!dream ','',1)
if command.startswith( if command.startswith('!fix'):
'!fix'
):
command = command.replace('!fix ','',1) command = command.replace('!fix ','',1)
operation = 'postprocess' operation = 'postprocess'
if command.startswith('!fetch'):
file_path = command.replace('!fetch ','',1)
retrieve_dream_command(opt,file_path)
continue
if command == '!history':
completer.show_history()
continue
match = re.match('^!(\d+)',command)
if match:
command_no = match.groups()[0]
command = completer.get_line(int(command_no))
completer.set_line(command)
continue
if opt.parse_cmd(command) is None: if opt.parse_cmd(command) is None:
continue continue
@ -220,37 +235,15 @@ def main_loop(gen, opt, infile):
opt.strength = 0.75 if opt.out_direction is None else 0.83 opt.strength = 0.75 if opt.out_direction is None else 0.83
if opt.with_variations is not None: if opt.with_variations is not None:
# shotgun parsing, woo opt.with_variations = split_variations(opt.with_variations)
parts = []
broken = False # python doesn't have labeled loops...
for part in opt.with_variations.split(','):
seed_and_weight = part.split(':')
if len(seed_and_weight) != 2:
print(f'could not parse with_variation part "{part}"')
broken = True
break
try:
seed = int(seed_and_weight[0])
weight = float(seed_and_weight[1])
except ValueError:
print(f'could not parse with_variation part "{part}"')
broken = True
break
parts.append([seed, weight])
if broken:
continue
if len(parts) > 0:
opt.with_variations = parts
else:
opt.with_variations = None
if opt.prompt_as_dir: if opt.prompt_as_dir:
# sanitize the prompt to a valid folder name # sanitize the prompt to a valid folder name
subdir = path_filter.sub('_', opt.prompt)[:name_max].rstrip(' .') subdir = path_filter.sub('_', opt.prompt)[:name_max].rstrip(' .')
# truncate path to maximum allowed length # truncate path to maximum allowed length
# 27 is the length of '######.##########.##.png', plus two separators and a NUL # 39 is the length of '######.##########.##########-##.png', plus two separators and a NUL
subdir = subdir[:(path_max - 27 - len(os.path.abspath(opt.outdir)))] subdir = subdir[:(path_max - 39 - len(os.path.abspath(opt.outdir)))]
current_outdir = os.path.join(opt.outdir, subdir) current_outdir = os.path.join(opt.outdir, subdir)
print('Writing files to directory: "' + current_outdir + '"') print('Writing files to directory: "' + current_outdir + '"')
@ -281,23 +274,17 @@ def main_loop(gen, opt, infile):
if opt.grid: if opt.grid:
grid_images[seed] = image grid_images[seed] = image
else: else:
if operation == 'postprocess': postprocessed = upscaled if upscaled else operation=='postprocess'
filename = choose_postprocess_name(opt.prompt) filename, formatted_dream_prompt = prepare_image_metadata(
elif upscaled and opt.save_original: opt,
filename = f'{prefix}.{seed}.postprocessed.png' prefix,
else: seed,
filename = f'{prefix}.{seed}.png' operation,
if opt.variation_amount > 0: prior_variations,
first_seed = first_seed or seed postprocessed,
this_variation = [[seed, opt.variation_amount]] first_seed
opt.with_variations = prior_variations + this_variation )
formatted_dream_prompt = opt.dream_prompt_str(seed=first_seed)
elif len(prior_variations) > 0:
formatted_dream_prompt = opt.dream_prompt_str(seed=first_seed)
elif operation == 'postprocess':
formatted_dream_prompt = '!fix '+opt.dream_prompt_str(seed=seed)
else:
formatted_dream_prompt = opt.dream_prompt_str(seed=seed)
path = file_writer.save_image_and_prompt_to_png( path = file_writer.save_image_and_prompt_to_png(
image = image, image = image,
dream_prompt = formatted_dream_prompt, dream_prompt = formatted_dream_prompt,
@ -311,10 +298,15 @@ def main_loop(gen, opt, infile):
if (not upscaled) or opt.save_original: if (not upscaled) or opt.save_original:
# only append to results if we didn't overwrite an earlier output # only append to results if we didn't overwrite an earlier output
results.append([path, formatted_dream_prompt]) results.append([path, formatted_dream_prompt])
# so that the seed autocompletes (on linux|mac when -S or --seed specified
if completer:
completer.add_seed(seed)
completer.add_seed(first_seed)
last_results.append([path, seed]) last_results.append([path, seed])
if operation == 'generate': if operation == 'generate':
catch_ctrl_c = infile is None # if running interactively, we catch keyboard interrupts catch_ctrl_c = infile is None # if running interactively, we catch keyboard interrupts
opt.last_operation='generate'
gen.prompt2image( gen.prompt2image(
image_callback=image_writer, image_callback=image_writer,
catch_interrupts=catch_ctrl_c, catch_interrupts=catch_ctrl_c,
@ -322,7 +314,7 @@ def main_loop(gen, opt, infile):
) )
elif operation == 'postprocess': elif operation == 'postprocess':
print(f'>> fixing {opt.prompt}') print(f'>> fixing {opt.prompt}')
do_postprocess(gen,opt,image_writer) opt.last_operation = do_postprocess(gen,opt,image_writer)
if opt.grid and len(grid_images) > 0: if opt.grid and len(grid_images) > 0:
grid_img = make_grid(list(grid_images.values())) grid_img = make_grid(list(grid_images.values()))
@ -357,6 +349,7 @@ def main_loop(gen, opt, infile):
global output_cntr global output_cntr
output_cntr = write_log(results, log_path ,('txt', 'md'), output_cntr) output_cntr = write_log(results, log_path ,('txt', 'md'), output_cntr)
print() print()
completer.add_to_history(command)
print('goodbye!') print('goodbye!')
@ -378,7 +371,8 @@ def do_postprocess (gen, opt, callback):
elif opt.out_direction: elif opt.out_direction:
tool = 'outpaint' tool = 'outpaint'
opt.save_original = True # do not overwrite old image! opt.save_original = True # do not overwrite old image!
return gen.apply_postprocessor( opt.last_operation = f'postprocess:{tool}'
gen.apply_postprocessor(
image_path = opt.prompt, image_path = opt.prompt,
tool = tool, tool = tool,
gfpgan_strength = opt.gfpgan_strength, gfpgan_strength = opt.gfpgan_strength,
@ -389,18 +383,54 @@ def do_postprocess (gen, opt, callback):
callback = callback, callback = callback,
opt = opt, opt = opt,
) )
return opt.last_operation
def choose_postprocess_name(original_filename): def prepare_image_metadata(
basename,_ = os.path.splitext(os.path.basename(original_filename)) opt,
if re.search('\d+\.\d+$',basename): prefix,
return f'{basename}.fixed.png' seed,
match = re.search('(\d+\.\d+)\.fixed(-(\d+))?$',basename) operation='generate',
if match: prior_variations=[],
counter = match.group(3) or 0 postprocessed=False,
return '{prefix}-{counter:02d}.png'.format(prefix=match.group(1), counter=int(counter)+1) first_seed=None
):
if postprocessed and opt.save_original:
filename = choose_postprocess_name(opt,prefix,seed)
else: else:
return f'{basename}.fixed.png' filename = f'{prefix}.{seed}.png'
if opt.variation_amount > 0:
first_seed = first_seed or seed
this_variation = [[seed, opt.variation_amount]]
opt.with_variations = prior_variations + this_variation
formatted_dream_prompt = opt.dream_prompt_str(seed=first_seed)
elif len(prior_variations) > 0:
formatted_dream_prompt = opt.dream_prompt_str(seed=first_seed)
elif operation == 'postprocess':
formatted_dream_prompt = '!fix '+opt.dream_prompt_str(seed=seed)
else:
formatted_dream_prompt = opt.dream_prompt_str(seed=seed)
return filename,formatted_dream_prompt
def choose_postprocess_name(opt,prefix,seed) -> str:
match = re.search('postprocess:(\w+)',opt.last_operation)
if match:
modifier = match.group(1) # will look like "gfpgan", "upscale", "outpaint" or "embiggen"
else:
modifier = 'postprocessed'
counter = 0
filename = None
available = False
while not available:
if counter > 0:
filename = f'{prefix}.{seed}.{modifier}.png'
else:
filename = f'{prefix}.{seed}.{modifier}-{counter:02d}.png'
available = not os.path.exists(os.path.join(opt.outdir,filename))
counter += 1
return filename
def get_next_command(infile=None) -> str: # command string def get_next_command(infile=None) -> str: # command string
if infile is None: if infile is None:
@ -430,6 +460,46 @@ def invoke_ai_web_server_loop(gen, gfpgan, codeformer, esrgan):
pass pass
def split_variations(variations_string) -> list:
# shotgun parsing, woo
parts = []
broken = False # python doesn't have labeled loops...
for part in variations_string.split(','):
seed_and_weight = part.split(':')
if len(seed_and_weight) != 2:
print(f'** Could not parse with_variation part "{part}"')
broken = True
break
try:
seed = int(seed_and_weight[0])
weight = float(seed_and_weight[1])
except ValueError:
print(f'** Could not parse with_variation part "{part}"')
broken = True
break
parts.append([seed, weight])
if broken:
return None
elif len(parts) == 0:
return None
else:
return parts
def retrieve_dream_command(opt,file_path):
'''
Given a full or partial path to a previously-generated image file,
will retrieve and format the dream command used to generate the image,
and pop it into the readline buffer (linux, Mac), or print out a comment
for cut-and-paste (windows)
'''
dir,basename = os.path.split(file_path)
if len(dir) == 0:
path = os.path.join(opt.outdir,basename)
else:
path = file_path
cmd = dream_cmd_from_png(path)
completer.set_line(cmd)
def write_log_message(results, log_path): def write_log_message(results, log_path):
"""logs the name of the output image, prompt, and prompt args to the terminal and log file""" """logs the name of the output image, prompt, and prompt args to the terminal and log file"""
global output_cntr global output_cntr