resolve numerous small merge bugs

- This merges PR #882

Coauthor: ArDiouscuros
This commit is contained in:
Lincoln Stein 2022-10-21 12:54:13 -04:00
parent 55db9dba0a
commit c9f9eed04e
2 changed files with 82 additions and 19 deletions

View File

@ -22,6 +22,7 @@ except (ImportError,ModuleNotFoundError):
IMG_EXTENSIONS = ('.png','.jpg','.jpeg','.PNG','.JPG','.JPEG','.gif','.GIF') IMG_EXTENSIONS = ('.png','.jpg','.jpeg','.PNG','.JPG','.JPEG','.gif','.GIF')
WEIGHT_EXTENSIONS = ('.ckpt','.bae') WEIGHT_EXTENSIONS = ('.ckpt','.bae')
TEXT_EXTENSIONS = ('.txt','.TXT')
CONFIG_EXTENSIONS = ('.yaml','.yml') CONFIG_EXTENSIONS = ('.yaml','.yml')
COMMANDS = ( COMMANDS = (
'--steps','-s', '--steps','-s',
@ -69,6 +70,9 @@ WEIGHT_COMMANDS = (
IMG_PATH_COMMANDS = ( IMG_PATH_COMMANDS = (
'--outdir[=\s]', '--outdir[=\s]',
) )
TEXT_PATH_COMMANDS=(
'!replay',
)
IMG_FILE_COMMANDS=( IMG_FILE_COMMANDS=(
'!fix', '!fix',
'!fetch', '!fetch',
@ -78,8 +82,9 @@ IMG_FILE_COMMANDS=(
'--init_color[=\s]', '--init_color[=\s]',
'--embedding_path[=\s]', '--embedding_path[=\s]',
) )
path_regexp = '('+'|'.join(IMG_PATH_COMMANDS+IMG_FILE_COMMANDS) + ')\s*\S*$' path_regexp = '(' + '|'.join(IMG_PATH_COMMANDS+IMG_FILE_COMMANDS) + ')\s*\S*$'
weight_regexp = '('+'|'.join(WEIGHT_COMMANDS) + ')\s*\S*$' weight_regexp = '(' + '|'.join(WEIGHT_COMMANDS) + ')\s*\S*$'
text_regexp = '(' + '|'.join(TEXT_PATH_COMMANDS) + ')\s*\S*$'
class Completer(object): class Completer(object):
def __init__(self, options, models=[]): def __init__(self, options, models=[]):
@ -122,6 +127,9 @@ class Completer(object):
elif re.search(weight_regexp,buffer): elif re.search(weight_regexp,buffer):
self.matches = self._path_completions(text, state, WEIGHT_EXTENSIONS) self.matches = self._path_completions(text, state, WEIGHT_EXTENSIONS)
elif re.search(text_regexp,buffer):
self.matches = self._path_completions(text, state, TEXT_EXTENSIONS)
# This is the first time for this text, so build a match list. # This is the first time for this text, so build a match list.
elif text: elif text:
self.matches = [ self.matches = [

View File

@ -17,9 +17,15 @@ from ldm.invoke.pngwriter import PngWriter, retrieve_metadata, write_metadata
from ldm.invoke.image_util import make_grid from ldm.invoke.image_util import make_grid
from ldm.invoke.log import write_log from ldm.invoke.log import write_log
from omegaconf import OmegaConf from omegaconf import OmegaConf
from pathlib import Path
# global used in multiple functions (fix)
infile = None
def main(): def main():
"""Initialize command-line parsers and the diffusion model""" """Initialize command-line parsers and the diffusion model"""
global infile
opt = Args() opt = Args()
args = opt.parse_args() args = opt.parse_args()
if not args: if not args:
@ -48,7 +54,6 @@ def main():
os.makedirs(opt.outdir) os.makedirs(opt.outdir)
# load the infile as a list of lines # load the infile as a list of lines
infile = None
if opt.infile: if opt.infile:
try: try:
if os.path.isfile(opt.infile): if os.path.isfile(opt.infile):
@ -96,14 +101,16 @@ def main():
) )
try: try:
main_loop(gen, opt, infile) main_loop(gen, opt)
except KeyboardInterrupt: except KeyboardInterrupt:
print("\ngoodbye!") print("\ngoodbye!")
# TODO: main_loop() has gotten busy. Needs to be refactored. # TODO: main_loop() has gotten busy. Needs to be refactored.
def main_loop(gen, opt, infile): def main_loop(gen, opt):
"""prompt/read/execute loop""" """prompt/read/execute loop"""
global infile
done = False done = False
doneAfterInFile = infile is not None
path_filter = re.compile(r'[<>:"/\\|?*]') path_filter = re.compile(r'[<>:"/\\|?*]')
last_results = list() last_results = list()
model_config = OmegaConf.load(opt.conf) model_config = OmegaConf.load(opt.conf)
@ -130,7 +137,8 @@ def main_loop(gen, opt, infile):
try: try:
command = get_next_command(infile) command = get_next_command(infile)
except EOFError: except EOFError:
done = True done = infile is None or doneAfterInFile
infile = None
continue continue
# skip empty lines # skip empty lines
@ -368,7 +376,10 @@ def main_loop(gen, opt, infile):
print('goodbye!') print('goodbye!')
# TO DO: remove repetitive code and the awkward command.replace() trope
# Just do a simple parse of the command!
def do_command(command:str, gen, opt:Args, completer) -> tuple: def do_command(command:str, gen, opt:Args, completer) -> tuple:
global infile
operation = 'generate' # default operation, alternative is 'postprocess' operation = 'generate' # default operation, alternative is 'postprocess'
if command.startswith('!dream'): # in case a stored prompt still contains the !dream command if command.startswith('!dream'): # in case a stored prompt still contains the !dream command
@ -414,8 +425,16 @@ def do_command(command:str, gen, opt:Args, completer) -> tuple:
operation = None operation = None
elif command.startswith('!fetch'): elif command.startswith('!fetch'):
file_path = command.replace('!fetch ','',1) file_path = command.replace('!fetch','',1).strip()
retrieve_dream_command(opt,file_path,completer) retrieve_dream_command(opt,file_path,completer)
completer.add_history(command)
operation = None
elif command.startswith('!replay'):
file_path = command.replace('!replay','',1).strip()
if infile is None and os.path.isfile(file_path):
infile = open(file_path, 'r', encoding='utf-8')
completer.add_history(command)
operation = None operation = None
elif command.startswith('!history'): elif command.startswith('!history'):
@ -423,7 +442,7 @@ def do_command(command:str, gen, opt:Args, completer) -> tuple:
operation = None operation = None
elif command.startswith('!search'): elif command.startswith('!search'):
search_str = command.replace('!search ','',1) search_str = command.replace('!search','',1).strip()
completer.show_history(search_str) completer.show_history(search_str)
operation = None operation = None
@ -723,27 +742,63 @@ def make_step_callback(gen, opt, prefix):
image.save(filename,'PNG') image.save(filename,'PNG')
return callback return callback
def retrieve_dream_command(opt,file_path,completer): def retrieve_dream_command(opt,command,completer):
''' '''
Given a full or partial path to a previously-generated image file, Given a full or partial path to a previously-generated image file,
will retrieve and format the dream command used to generate the image, will retrieve and format the dream command used to generate the image,
and pop it into the readline buffer (linux, Mac), or print out a comment and pop it into the readline buffer (linux, Mac), or print out a comment
for cut-and-paste (windows) for cut-and-paste (windows)
Given a wildcard path to a folder with image png files,
will retrieve and format the dream command used to generate the images,
and save them to a file commands.txt for further processing
''' '''
if len(command) == 0:
return
tokens = command.split()
if len(tokens) > 1:
outfilepath = tokens[1]
else:
outfilepath = "commands.txt"
file_path = tokens[0]
dir,basename = os.path.split(file_path) dir,basename = os.path.split(file_path)
if len(dir) == 0: if len(dir) == 0:
path = os.path.join(opt.outdir,basename) dir = opt.outdir
else:
path = file_path outdir,outname = os.path.split(outfilepath)
if len(outdir) == 0:
outfilepath = os.path.join(dir,outname)
try: try:
cmd = dream_cmd_from_png(path) paths = list(Path(dir).glob(basename))
except OSError: except ValueError:
print(f'** {path}: file could not be read') print(f'## "{basename}": unacceptable pattern')
return return
except (KeyError, AttributeError):
print(f'** {path}: file has no metadata') commands = []
return for path in paths:
completer.set_line(cmd) try:
cmd = dream_cmd_from_png(path)
except OSError:
print(f'## {path}: file could not be read')
continue
except (KeyError, AttributeError, IndexError):
print(f'## {path}: file has no metadata')
continue
except:
print(f'## {path}: file could not be processed')
continue
commands.append(f'# {path}')
commands.append(cmd)
with open(outfilepath, 'w', encoding='utf-8') as f:
f.write('\n'.join(commands))
print(f'>> File {outfilepath} with commands created')
if len(commands) == 2:
completer.set_line(commands[1])
######################################
if __name__ == '__main__': if __name__ == '__main__':
main() main()