Remove print statement styling (#504)

Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
This commit is contained in:
blessedcoolant 2022-09-12 07:47:12 +12:00 committed by GitHub
parent 4951e66103
commit b86a1deb00
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -20,23 +20,24 @@ from omegaconf import OmegaConf
# Just want to get the formatting look right for now. # Just want to get the formatting look right for now.
output_cntr = 0 output_cntr = 0
def main(): def main():
"""Initialize command-line parsers and the diffusion model""" """Initialize command-line parsers and the diffusion model"""
arg_parser = create_argv_parser() arg_parser = create_argv_parser()
opt = arg_parser.parse_args() opt = arg_parser.parse_args()
if opt.laion400m: if opt.laion400m:
print('--laion400m flag has been deprecated. Please use --model laion400m instead.') print('--laion400m flag has been deprecated. Please use --model laion400m instead.')
sys.exit(-1) sys.exit(-1)
if opt.weights != 'model': if opt.weights != 'model':
print('--weights argument has been deprecated. Please configure ./configs/models.yaml, and call it using --model instead.') print('--weights argument has been deprecated. Please configure ./configs/models.yaml, and call it using --model instead.')
sys.exit(-1) sys.exit(-1)
try: try:
models = OmegaConf.load(opt.config) models = OmegaConf.load(opt.config)
width = models[opt.model].width width = models[opt.model].width
height = models[opt.model].height height = models[opt.model].height
config = models[opt.model].config config = models[opt.model].config
weights = models[opt.model].weights weights = models[opt.model].weights
except (FileNotFoundError, IOError, KeyError) as e: except (FileNotFoundError, IOError, KeyError) as e:
print(f'{e}. Aborting.') print(f'{e}. Aborting.')
@ -58,18 +59,18 @@ def main():
# additional parameters will be added (or overriden) during # additional parameters will be added (or overriden) during
# the user input loop # the user input loop
t2i = Generate( t2i = Generate(
width = width, width=width,
height = height, height=height,
sampler_name = opt.sampler_name, sampler_name=opt.sampler_name,
weights = weights, weights=weights,
full_precision = opt.full_precision, full_precision=opt.full_precision,
config = config, config=config,
grid = opt.grid, grid=opt.grid,
# this is solely for recreating the prompt # this is solely for recreating the prompt
seamless = opt.seamless, seamless=opt.seamless,
embedding_path = opt.embedding_path, embedding_path=opt.embedding_path,
device_type = opt.device, device_type=opt.device,
ignore_ctrl_c = opt.infile is None, ignore_ctrl_c=opt.infile is None,
) )
# make sure the output directory exists # make sure the output directory exists
@ -113,8 +114,8 @@ def main():
def main_loop(t2i, outdir, prompt_as_dir, parser, infile): def main_loop(t2i, outdir, prompt_as_dir, parser, infile):
"""prompt/read/execute loop""" """prompt/read/execute loop"""
done = False done = False
path_filter = re.compile(r'[<>:"/\\|?*]') path_filter = re.compile(r'[<>:"/\\|?*]')
last_results = list() last_results = list()
# os.pathconf is not available on Windows # os.pathconf is not available on Windows
@ -134,7 +135,7 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile):
except KeyboardInterrupt: except KeyboardInterrupt:
done = True done = True
continue continue
# skip empty lines # skip empty lines
if not command.strip(): if not command.strip():
continue continue
@ -183,15 +184,17 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile):
if len(opt.prompt) == 0: if len(opt.prompt) == 0:
print('Try again with a prompt!') print('Try again with a prompt!')
continue continue
if opt.init_img is not None and re.match('^-\\d+$',opt.init_img): # retrieve previous value! # retrieve previous value!
if opt.init_img is not None and re.match('^-\\d+$', opt.init_img):
try: try:
opt.init_img = last_results[int(opt.init_img)][0] opt.init_img = last_results[int(opt.init_img)][0]
print(f'>> Reusing previous image {opt.init_img}') print(f'>> Reusing previous image {opt.init_img}')
except IndexError: except IndexError:
print(f'>> No previous initial image at position {opt.init_img} found') print(
f'>> No previous initial image at position {opt.init_img} found')
opt.init_img = None opt.init_img = None
continue continue
if opt.seed is not None and opt.seed < 0: # retrieve previous value! if opt.seed is not None and opt.seed < 0: # retrieve previous value!
try: try:
opt.seed = last_results[opt.seed][1] opt.seed = last_results[opt.seed][1]
@ -201,12 +204,12 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile):
opt.seed = None opt.seed = None
continue continue
do_grid = opt.grid or t2i.grid do_grid = opt.grid or t2i.grid
if opt.with_variations is not None: if opt.with_variations is not None:
# shotgun parsing, woo # shotgun parsing, woo
parts = [] parts = []
broken = False # python doesn't have labeled loops... broken = False # python doesn't have labeled loops...
for part in opt.with_variations.split(','): for part in opt.with_variations.split(','):
seed_and_weight = part.split(':') seed_and_weight = part.split(':')
if len(seed_and_weight) != 2: if len(seed_and_weight) != 2:
@ -241,7 +244,7 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile):
subdir = subdir[:(path_max - 27 - len(os.path.abspath(outdir)))] subdir = subdir[:(path_max - 27 - len(os.path.abspath(outdir)))]
current_outdir = os.path.join(outdir, subdir) current_outdir = os.path.join(outdir, subdir)
print ('Writing files to directory: "' + current_outdir + '"') print('Writing files to directory: "' + current_outdir + '"')
# make sure the output directory exists # make sure the output directory exists
if not os.path.exists(current_outdir): if not os.path.exists(current_outdir):
@ -253,9 +256,10 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile):
last_results = [] last_results = []
try: try:
file_writer = PngWriter(current_outdir) file_writer = PngWriter(current_outdir)
prefix = file_writer.unique_prefix() prefix = file_writer.unique_prefix()
results = [] # list of filename, prompt pairs results = [] # list of filename, prompt pairs
grid_images = dict() # seed -> Image, only used if `do_grid` grid_images = dict() # seed -> Image, only used if `do_grid`
def image_writer(image, seed, upscaled=False): def image_writer(image, seed, upscaled=False):
if do_grid: if do_grid:
grid_images[seed] = image grid_images[seed] = image
@ -265,35 +269,41 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile):
else: else:
filename = f'{prefix}.{seed}.png' filename = f'{prefix}.{seed}.png'
if opt.variation_amount > 0: if opt.variation_amount > 0:
iter_opt = argparse.Namespace(**vars(opt)) # copy iter_opt = argparse.Namespace(**vars(opt)) # copy
this_variation = [[seed, opt.variation_amount]] this_variation = [[seed, opt.variation_amount]]
if opt.with_variations is None: if opt.with_variations is None:
iter_opt.with_variations = this_variation iter_opt.with_variations = this_variation
else: else:
iter_opt.with_variations = opt.with_variations + this_variation iter_opt.with_variations = opt.with_variations + this_variation
iter_opt.variation_amount = 0 iter_opt.variation_amount = 0
normalized_prompt = PromptFormatter(t2i, iter_opt).normalize_prompt() normalized_prompt = PromptFormatter(
t2i, iter_opt).normalize_prompt()
metadata_prompt = f'{normalized_prompt} -S{iter_opt.seed}' metadata_prompt = f'{normalized_prompt} -S{iter_opt.seed}'
elif opt.with_variations is not None: elif opt.with_variations is not None:
normalized_prompt = PromptFormatter(t2i, opt).normalize_prompt() normalized_prompt = PromptFormatter(
metadata_prompt = f'{normalized_prompt} -S{opt.seed}' # use the original seed - the per-iteration value is the last variation-seed t2i, opt).normalize_prompt()
# use the original seed - the per-iteration value is the last variation-seed
metadata_prompt = f'{normalized_prompt} -S{opt.seed}'
else: else:
normalized_prompt = PromptFormatter(t2i, opt).normalize_prompt() normalized_prompt = PromptFormatter(
t2i, opt).normalize_prompt()
metadata_prompt = f'{normalized_prompt} -S{seed}' metadata_prompt = f'{normalized_prompt} -S{seed}'
path = file_writer.save_image_and_prompt_to_png(image, metadata_prompt, filename) path = file_writer.save_image_and_prompt_to_png(
image, metadata_prompt, filename)
if (not upscaled) or opt.save_original: if (not upscaled) or opt.save_original:
# only append to results if we didn't overwrite an earlier output # only append to results if we didn't overwrite an earlier output
results.append([path, metadata_prompt]) results.append([path, metadata_prompt])
last_results.append([path,seed]) last_results.append([path, seed])
t2i.prompt2image(image_callback=image_writer, **vars(opt)) t2i.prompt2image(image_callback=image_writer, **vars(opt))
if do_grid and len(grid_images) > 0: if do_grid and len(grid_images) > 0:
grid_img = make_grid(list(grid_images.values())) grid_img = make_grid(list(grid_images.values()))
first_seed = last_results[0][1] first_seed = last_results[0][1]
filename = f'{prefix}.{first_seed}.png' filename = f'{prefix}.{first_seed}.png'
# TODO better metadata for grid images # TODO better metadata for grid images
normalized_prompt = PromptFormatter(t2i, opt).normalize_prompt() normalized_prompt = PromptFormatter(
t2i, opt).normalize_prompt()
metadata_prompt = f'{normalized_prompt} -S{first_seed} --grid -N{len(grid_images)}' metadata_prompt = f'{normalized_prompt} -S{first_seed} --grid -N{len(grid_images)}'
path = file_writer.save_image_and_prompt_to_png( path = file_writer.save_image_and_prompt_to_png(
grid_img, metadata_prompt, filename grid_img, metadata_prompt, filename
@ -308,18 +318,16 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile):
print(e) print(e)
continue continue
print('\033[1mOutputs:\033[0m') print('Outputs:')
log_path = os.path.join(current_outdir, 'dream_log.txt') log_path = os.path.join(current_outdir, 'dream_log.txt')
write_log_message(results, log_path) write_log_message(results, log_path)
print('goodbye!\033[0m') print('goodbye!')
def get_next_command(infile=None) -> str: #command string def get_next_command(infile=None) -> str: # command string
if infile is None: if infile is None:
print('\033[1m') # add some boldface command = input('dream> ')
command = input('dream> ')
print('\033[0m',end='')
else: else:
command = infile.readline() command = infile.readline()
if not command: if not command:
@ -329,6 +337,7 @@ def get_next_command(infile=None) -> str: #command string
print(f'#{command}') print(f'#{command}')
return command return command
def dream_server_loop(t2i, host, port, outdir): def dream_server_loop(t2i, host, port, outdir):
print('\n* --web was specified, starting web server...') print('\n* --web was specified, starting web server...')
# Change working directory to the stable-diffusion directory # Change working directory to the stable-diffusion directory
@ -342,7 +351,8 @@ def dream_server_loop(t2i, host, port, outdir):
dream_server = ThreadingDreamServer((host, port)) dream_server = ThreadingDreamServer((host, port))
print(">> Started Stable Diffusion dream server!") print(">> Started Stable Diffusion dream server!")
if host == '0.0.0.0': if host == '0.0.0.0':
print(f"Point your browser at http://localhost:{port} or use the host's DNS name or IP address.") print(
f"Point your browser at http://localhost:{port} or use the host's DNS name or IP address.")
else: else:
print(">> Default host address now 127.0.0.1 (localhost). Use --host 0.0.0.0 to bind any address.") print(">> Default host address now 127.0.0.1 (localhost). Use --host 0.0.0.0 to bind any address.")
print(f">> Point your browser at http://{host}:{port}.") print(f">> Point your browser at http://{host}:{port}.")
@ -361,13 +371,13 @@ def write_log_message(results, log_path):
log_lines = [f'{path}: {prompt}\n' for path, prompt in results] log_lines = [f'{path}: {prompt}\n' for path, prompt in results]
for l in log_lines: for l in log_lines:
output_cntr += 1 output_cntr += 1
print(f'\033[1m[{output_cntr}]\033[0m {l}',end='') print(output_cntr)
with open(log_path, 'a', encoding='utf-8') as file: with open(log_path, 'a', encoding='utf-8') as file:
file.writelines(log_lines) file.writelines(log_lines)
SAMPLER_CHOICES=[ SAMPLER_CHOICES = [
'ddim', 'ddim',
'k_dpm_2_a', 'k_dpm_2_a',
'k_dpm_2', 'k_dpm_2',
@ -378,6 +388,7 @@ SAMPLER_CHOICES=[
'plms', 'plms',
] ]
def create_argv_parser(): def create_argv_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="""Generate images using Stable Diffusion. description="""Generate images using Stable Diffusion.
@ -518,8 +529,8 @@ def create_argv_parser():
) )
parser.add_argument( parser.add_argument(
'--config', '--config',
default ='configs/models.yaml', default='configs/models.yaml',
help ='Path to configuration file for alternate models.', help='Path to configuration file for alternate models.',
) )
return parser return parser