Allow ctrl c when using --from_file (#472)

* added ansi escapes to highlight key parts of CLI session

* adjust exception handling so that ^C will abort when reading prompts from a file
This commit is contained in:
Lincoln Stein 2022-09-09 18:49:51 -04:00 committed by GitHub
parent 75f633cda8
commit 723d074442
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 6 deletions

View File

@ -117,6 +117,7 @@ class Generate:
seamless = False, seamless = False,
embedding_path = None, embedding_path = None,
device_type = 'cuda', device_type = 'cuda',
ignore_ctrl_c = False,
): ):
self.iterations = iterations self.iterations = iterations
self.width = width self.width = width
@ -134,6 +135,7 @@ class Generate:
self.seamless = seamless self.seamless = seamless
self.embedding_path = embedding_path self.embedding_path = embedding_path
self.device_type = device_type self.device_type = device_type
self.ignore_ctrl_c = ignore_ctrl_c # note, this logic probably doesn't belong here...
self.model = None # empty for now self.model = None # empty for now
self.sampler = None self.sampler = None
self.device = None self.device = None
@ -210,7 +212,7 @@ class Generate:
**args, **args,
): # eat up additional cruft ): # eat up additional cruft
""" """
ldm.prompt2image() is the common entry point for txt2img() and img2img() ldm.generate.prompt2image() is the common entry point for txt2img() and img2img()
It takes the following arguments: It takes the following arguments:
prompt // prompt string (no default) prompt // prompt string (no default)
iterations // iterations (1); image count=iterations iterations // iterations (1); image count=iterations
@ -341,6 +343,8 @@ class Generate:
except KeyboardInterrupt: except KeyboardInterrupt:
print('*interrupted*') print('*interrupted*')
if not self.ignore_ctrl_c:
raise KeyboardInterrupt
print( print(
'>> Partial results will be returned; if --grid was requested, nothing will be returned.' '>> Partial results will be returned; if --grid was requested, nothing will be returned.'
) )

View File

@ -15,6 +15,11 @@ from ldm.dream.server import DreamServer, ThreadingDreamServer
from ldm.dream.image_util import make_grid from ldm.dream.image_util import make_grid
from omegaconf import OmegaConf from omegaconf import OmegaConf
# Placeholder to be replaced with proper class that tracks the
# outputs and associates with the prompt that generated them.
# Just want to get the formatting look right for now.
output_cntr = 0
def main(): def main():
"""Initialize command-line parsers and the diffusion model""" """Initialize command-line parsers and the diffusion model"""
arg_parser = create_argv_parser() arg_parser = create_argv_parser()
@ -63,7 +68,8 @@ def main():
# this is solely for recreating the prompt # this is solely for recreating the prompt
seamless = opt.seamless, seamless = opt.seamless,
embedding_path = opt.embedding_path, embedding_path = opt.embedding_path,
device_type = opt.device device_type = opt.device,
ignore_ctrl_c = opt.infile is None,
) )
# make sure the output directory exists # make sure the output directory exists
@ -292,16 +298,18 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile):
print(e) print(e)
continue continue
print('Outputs:') print('\033[1mOutputs:\033[0m')
log_path = os.path.join(current_outdir, 'dream_log.txt') log_path = os.path.join(current_outdir, 'dream_log.txt')
write_log_message(results, log_path) write_log_message(results, log_path)
print('goodbye!') print('goodbye!\033[0m')
def get_next_command(infile=None) -> str: #command string def get_next_command(infile=None) -> str: #command string
if infile is None: if infile is None:
print('\033[1m') # add some boldface
command = input('dream> ') command = input('dream> ')
print('\033[0m',end='')
else: else:
command = infile.readline() command = infile.readline()
if not command: if not command:
@ -339,8 +347,11 @@ def dream_server_loop(t2i, host, port, outdir):
def write_log_message(results, log_path): def write_log_message(results, log_path):
"""logs the name of the output image, prompt, and prompt args to the terminal and log file""" """logs the name of the output image, prompt, and prompt args to the terminal and log file"""
global output_cntr
log_lines = [f'{path}: {prompt}\n' for path, prompt in results] log_lines = [f'{path}: {prompt}\n' for path, prompt in results]
print(*log_lines, sep='') for l in log_lines:
output_cntr += 1
print(f'\033[1m[{output_cntr}]\033[0m {l}',end='')
with open(log_path, 'a', encoding='utf-8') as file: with open(log_path, 'a', encoding='utf-8') as file:
file.writelines(log_lines) file.writelines(log_lines)