diff --git a/ldm/generate.py b/ldm/generate.py index a028cdf1ff..99d0a0b2c7 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -117,6 +117,7 @@ class Generate: seamless = False, embedding_path = None, device_type = 'cuda', + ignore_ctrl_c = False, ): self.iterations = iterations self.width = width @@ -134,6 +135,7 @@ class Generate: self.seamless = seamless self.embedding_path = embedding_path self.device_type = device_type + self.ignore_ctrl_c = ignore_ctrl_c # note, this logic probably doesn't belong here... self.model = None # empty for now self.sampler = None self.device = None @@ -212,7 +214,7 @@ class Generate: **args, ): # eat up additional cruft """ - ldm.prompt2image() is the common entry point for txt2img() and img2img() + ldm.generate.prompt2image() is the common entry point for txt2img() and img2img() It takes the following arguments: prompt // prompt string (no default) iterations // iterations (1); image count=iterations @@ -345,6 +347,8 @@ class Generate: except KeyboardInterrupt: print('*interrupted*') + if not self.ignore_ctrl_c: + raise KeyboardInterrupt print( '>> Partial results will be returned; if --grid was requested, nothing will be returned.' ) diff --git a/scripts/dream.py b/scripts/dream.py index e284c7cc4c..618f723c2f 100755 --- a/scripts/dream.py +++ b/scripts/dream.py @@ -15,6 +15,11 @@ from ldm.dream.server import DreamServer, ThreadingDreamServer from ldm.dream.image_util import make_grid from omegaconf import OmegaConf +# Placeholder to be replaced with proper class that tracks the +# outputs and associates with the prompt that generated them. +# Just want to get the formatting look right for now. +output_cntr = 0 + def main(): """Initialize command-line parsers and the diffusion model""" arg_parser = create_argv_parser() @@ -63,7 +68,8 @@ def main(): # this is solely for recreating the prompt seamless = opt.seamless, embedding_path = opt.embedding_path, - device_type = opt.device + device_type = opt.device, + ignore_ctrl_c = opt.infile is None, ) # make sure the output directory exists @@ -292,16 +298,18 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile): print(e) continue - print('Outputs:') + print('\033[1mOutputs:\033[0m') log_path = os.path.join(current_outdir, 'dream_log.txt') write_log_message(results, log_path) - print('goodbye!') + print('goodbye!\033[0m') def get_next_command(infile=None) -> str: #command string if infile is None: - command = input('dream> ') + print('\033[1m') # add some boldface + command = input('dream> ') + print('\033[0m',end='') else: command = infile.readline() if not command: @@ -339,8 +347,11 @@ def dream_server_loop(t2i, host, port, outdir): def write_log_message(results, log_path): """logs the name of the output image, prompt, and prompt args to the terminal and log file""" + global output_cntr log_lines = [f'{path}: {prompt}\n' for path, prompt in results] - print(*log_lines, sep='') + for l in log_lines: + output_cntr += 1 + print(f'\033[1m[{output_cntr}]\033[0m {l}',end='') with open(log_path, 'a', encoding='utf-8') as file: file.writelines(log_lines)