diff --git a/scripts/dream.py b/scripts/dream.py index 8175221cb3..ef7f6f3c56 100755 --- a/scripts/dream.py +++ b/scripts/dream.py @@ -64,13 +64,16 @@ def main(): # gets rid of annoying messages about random seed logging.getLogger('pytorch_lightning').setLevel(logging.ERROR) + # load the infile as a list of lines infile = None - try: - if opt.infile: - infile = open(opt.infile, 'r') - except FileNotFoundError as e: - print(e) - exit(-1) + if opt.infile: + if os.path.isfile(opt.infile): + with open(opt.infile, "r") as file: + infile = file.read() + infile = infile.split("\n") + else: + print(f"WARNING: '{opt.infile}' not found. Aborting.") + sys.exit(-1) # exit does not work on every os, sys.exit does afaik # preload the model t2i.load_model() @@ -119,8 +122,6 @@ def main(): cmd_parser = create_cmd_parser() main_loop(t2i, opt.outdir, cmd_parser, log, infile) log.close() - if infile: - infile.close() def main_loop(t2i, outdir, parser, log, infile): @@ -129,15 +130,19 @@ def main_loop(t2i, outdir, parser, log, infile): last_seeds = [] while not done: - try: - command = infile.readline() if infile else input('dream> ') - except EOFError: - done = True - break + if not infile: + command = input("dream> ") + else: + try: + # get the next line of the infile + command = infile.pop(0) + except IndexError: + done = True + break - if infile and len(command) == 0: - done = True - break + # skip empty lines + if not command.strip(): + continue if command.startswith(('#', '//')): continue @@ -152,9 +157,6 @@ def main_loop(t2i, outdir, parser, log, infile): print(str(e)) continue - if len(elements) == 0: - continue - if elements[0] == 'q': done = True break