Remove print statement styling (#504)

Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
This commit is contained in:
blessedcoolant
2022-09-12 07:47:12 +12:00
committed by GitHub
parent 4951e66103
commit b86a1deb00

View File

@ -20,6 +20,7 @@ from omegaconf import OmegaConf
# Just want to get the formatting look right for now. # Just want to get the formatting look right for now.
output_cntr = 0 output_cntr = 0
def main(): def main():
"""Initialize command-line parsers and the diffusion model""" """Initialize command-line parsers and the diffusion model"""
arg_parser = create_argv_parser() arg_parser = create_argv_parser()
@ -58,18 +59,18 @@ def main():
# additional parameters will be added (or overriden) during # additional parameters will be added (or overriden) during
# the user input loop # the user input loop
t2i = Generate( t2i = Generate(
width = width, width=width,
height = height, height=height,
sampler_name = opt.sampler_name, sampler_name=opt.sampler_name,
weights = weights, weights=weights,
full_precision = opt.full_precision, full_precision=opt.full_precision,
config = config, config=config,
grid = opt.grid, grid=opt.grid,
# this is solely for recreating the prompt # this is solely for recreating the prompt
seamless = opt.seamless, seamless=opt.seamless,
embedding_path = opt.embedding_path, embedding_path=opt.embedding_path,
device_type = opt.device, device_type=opt.device,
ignore_ctrl_c = opt.infile is None, ignore_ctrl_c=opt.infile is None,
) )
# make sure the output directory exists # make sure the output directory exists
@ -183,12 +184,14 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile):
if len(opt.prompt) == 0: if len(opt.prompt) == 0:
print('Try again with a prompt!') print('Try again with a prompt!')
continue continue
if opt.init_img is not None and re.match('^-\\d+$',opt.init_img): # retrieve previous value! # retrieve previous value!
if opt.init_img is not None and re.match('^-\\d+$', opt.init_img):
try: try:
opt.init_img = last_results[int(opt.init_img)][0] opt.init_img = last_results[int(opt.init_img)][0]
print(f'>> Reusing previous image {opt.init_img}') print(f'>> Reusing previous image {opt.init_img}')
except IndexError: except IndexError:
print(f'>> No previous initial image at position {opt.init_img} found') print(
f'>> No previous initial image at position {opt.init_img} found')
opt.init_img = None opt.init_img = None
continue continue
@ -241,7 +244,7 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile):
subdir = subdir[:(path_max - 27 - len(os.path.abspath(outdir)))] subdir = subdir[:(path_max - 27 - len(os.path.abspath(outdir)))]
current_outdir = os.path.join(outdir, subdir) current_outdir = os.path.join(outdir, subdir)
print ('Writing files to directory: "' + current_outdir + '"') print('Writing files to directory: "' + current_outdir + '"')
# make sure the output directory exists # make sure the output directory exists
if not os.path.exists(current_outdir): if not os.path.exists(current_outdir):
@ -256,6 +259,7 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile):
prefix = file_writer.unique_prefix() prefix = file_writer.unique_prefix()
results = [] # list of filename, prompt pairs results = [] # list of filename, prompt pairs
grid_images = dict() # seed -> Image, only used if `do_grid` grid_images = dict() # seed -> Image, only used if `do_grid`
def image_writer(image, seed, upscaled=False): def image_writer(image, seed, upscaled=False):
if do_grid: if do_grid:
grid_images[seed] = image grid_images[seed] = image
@ -272,19 +276,24 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile):
else: else:
iter_opt.with_variations = opt.with_variations + this_variation iter_opt.with_variations = opt.with_variations + this_variation
iter_opt.variation_amount = 0 iter_opt.variation_amount = 0
normalized_prompt = PromptFormatter(t2i, iter_opt).normalize_prompt() normalized_prompt = PromptFormatter(
t2i, iter_opt).normalize_prompt()
metadata_prompt = f'{normalized_prompt} -S{iter_opt.seed}' metadata_prompt = f'{normalized_prompt} -S{iter_opt.seed}'
elif opt.with_variations is not None: elif opt.with_variations is not None:
normalized_prompt = PromptFormatter(t2i, opt).normalize_prompt() normalized_prompt = PromptFormatter(
metadata_prompt = f'{normalized_prompt} -S{opt.seed}' # use the original seed - the per-iteration value is the last variation-seed t2i, opt).normalize_prompt()
# use the original seed - the per-iteration value is the last variation-seed
metadata_prompt = f'{normalized_prompt} -S{opt.seed}'
else: else:
normalized_prompt = PromptFormatter(t2i, opt).normalize_prompt() normalized_prompt = PromptFormatter(
t2i, opt).normalize_prompt()
metadata_prompt = f'{normalized_prompt} -S{seed}' metadata_prompt = f'{normalized_prompt} -S{seed}'
path = file_writer.save_image_and_prompt_to_png(image, metadata_prompt, filename) path = file_writer.save_image_and_prompt_to_png(
image, metadata_prompt, filename)
if (not upscaled) or opt.save_original: if (not upscaled) or opt.save_original:
# only append to results if we didn't overwrite an earlier output # only append to results if we didn't overwrite an earlier output
results.append([path, metadata_prompt]) results.append([path, metadata_prompt])
last_results.append([path,seed]) last_results.append([path, seed])
t2i.prompt2image(image_callback=image_writer, **vars(opt)) t2i.prompt2image(image_callback=image_writer, **vars(opt))
@ -293,7 +302,8 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile):
first_seed = last_results[0][1] first_seed = last_results[0][1]
filename = f'{prefix}.{first_seed}.png' filename = f'{prefix}.{first_seed}.png'
# TODO better metadata for grid images # TODO better metadata for grid images
normalized_prompt = PromptFormatter(t2i, opt).normalize_prompt() normalized_prompt = PromptFormatter(
t2i, opt).normalize_prompt()
metadata_prompt = f'{normalized_prompt} -S{first_seed} --grid -N{len(grid_images)}' metadata_prompt = f'{normalized_prompt} -S{first_seed} --grid -N{len(grid_images)}'
path = file_writer.save_image_and_prompt_to_png( path = file_writer.save_image_and_prompt_to_png(
grid_img, metadata_prompt, filename grid_img, metadata_prompt, filename
@ -308,18 +318,16 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile):
print(e) print(e)
continue continue
print('\033[1mOutputs:\033[0m') print('Outputs:')
log_path = os.path.join(current_outdir, 'dream_log.txt') log_path = os.path.join(current_outdir, 'dream_log.txt')
write_log_message(results, log_path) write_log_message(results, log_path)
print('goodbye!\033[0m') print('goodbye!')
def get_next_command(infile=None) -> str: #command string def get_next_command(infile=None) -> str: # command string
if infile is None: if infile is None:
print('\033[1m') # add some boldface
command = input('dream> ') command = input('dream> ')
print('\033[0m',end='')
else: else:
command = infile.readline() command = infile.readline()
if not command: if not command:
@ -329,6 +337,7 @@ def get_next_command(infile=None) -> str: #command string
print(f'#{command}') print(f'#{command}')
return command return command
def dream_server_loop(t2i, host, port, outdir): def dream_server_loop(t2i, host, port, outdir):
print('\n* --web was specified, starting web server...') print('\n* --web was specified, starting web server...')
# Change working directory to the stable-diffusion directory # Change working directory to the stable-diffusion directory
@ -342,7 +351,8 @@ def dream_server_loop(t2i, host, port, outdir):
dream_server = ThreadingDreamServer((host, port)) dream_server = ThreadingDreamServer((host, port))
print(">> Started Stable Diffusion dream server!") print(">> Started Stable Diffusion dream server!")
if host == '0.0.0.0': if host == '0.0.0.0':
print(f"Point your browser at http://localhost:{port} or use the host's DNS name or IP address.") print(
f"Point your browser at http://localhost:{port} or use the host's DNS name or IP address.")
else: else:
print(">> Default host address now 127.0.0.1 (localhost). Use --host 0.0.0.0 to bind any address.") print(">> Default host address now 127.0.0.1 (localhost). Use --host 0.0.0.0 to bind any address.")
print(f">> Point your browser at http://{host}:{port}.") print(f">> Point your browser at http://{host}:{port}.")
@ -361,13 +371,13 @@ def write_log_message(results, log_path):
log_lines = [f'{path}: {prompt}\n' for path, prompt in results] log_lines = [f'{path}: {prompt}\n' for path, prompt in results]
for l in log_lines: for l in log_lines:
output_cntr += 1 output_cntr += 1
print(f'\033[1m[{output_cntr}]\033[0m {l}',end='') print(output_cntr)
with open(log_path, 'a', encoding='utf-8') as file: with open(log_path, 'a', encoding='utf-8') as file:
file.writelines(log_lines) file.writelines(log_lines)
SAMPLER_CHOICES=[ SAMPLER_CHOICES = [
'ddim', 'ddim',
'k_dpm_2_a', 'k_dpm_2_a',
'k_dpm_2', 'k_dpm_2',
@ -378,6 +388,7 @@ SAMPLER_CHOICES=[
'plms', 'plms',
] ]
def create_argv_parser(): def create_argv_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="""Generate images using Stable Diffusion. description="""Generate images using Stable Diffusion.
@ -518,8 +529,8 @@ def create_argv_parser():
) )
parser.add_argument( parser.add_argument(
'--config', '--config',
default ='configs/models.yaml', default='configs/models.yaml',
help ='Path to configuration file for alternate models.', help='Path to configuration file for alternate models.',
) )
return parser return parser