Allow ctrl c when using --from_file (#472)

* added ansi escapes to highlight key parts of CLI session

* adjust exception handling so that ^C will abort when reading prompts from a file
This commit is contained in:
Lincoln Stein
2022-09-09 18:49:51 -04:00
committed by GitHub
parent 75f633cda8
commit 723d074442
2 changed files with 21 additions and 6 deletions

View File

@ -15,6 +15,11 @@ from ldm.dream.server import DreamServer, ThreadingDreamServer
from ldm.dream.image_util import make_grid
from omegaconf import OmegaConf
# Placeholder to be replaced with proper class that tracks the
# outputs and associates with the prompt that generated them.
# Just want to get the formatting look right for now.
output_cntr = 0
def main():
"""Initialize command-line parsers and the diffusion model"""
arg_parser = create_argv_parser()
@ -63,7 +68,8 @@ def main():
# this is solely for recreating the prompt
seamless = opt.seamless,
embedding_path = opt.embedding_path,
device_type = opt.device
device_type = opt.device,
ignore_ctrl_c = opt.infile is None,
)
# make sure the output directory exists
@ -292,16 +298,18 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile):
print(e)
continue
print('Outputs:')
print('\033[1mOutputs:\033[0m')
log_path = os.path.join(current_outdir, 'dream_log.txt')
write_log_message(results, log_path)
print('goodbye!')
print('goodbye!\033[0m')
def get_next_command(infile=None) -> str: #command string
if infile is None:
command = input('dream> ')
print('\033[1m') # add some boldface
command = input('dream> ')
print('\033[0m',end='')
else:
command = infile.readline()
if not command:
@ -339,8 +347,11 @@ def dream_server_loop(t2i, host, port, outdir):
def write_log_message(results, log_path):
"""logs the name of the output image, prompt, and prompt args to the terminal and log file"""
global output_cntr
log_lines = [f'{path}: {prompt}\n' for path, prompt in results]
print(*log_lines, sep='')
for l in log_lines:
output_cntr += 1
print(f'\033[1m[{output_cntr}]\033[0m {l}',end='')
with open(log_path, 'a', encoding='utf-8') as file:
file.writelines(log_lines)