From 5f352aec87ef167f5de8160f279b1b3de3e067e4 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sun, 21 Aug 2022 22:48:40 -0400 Subject: [PATCH] test of normalization of prompt --- ldm/simplet2i.py | 8 +++-- scripts/dream.py | 78 +++++++++++++++++++++++++++++++++++++++--------- 2 files changed, 70 insertions(+), 16 deletions(-) diff --git a/ldm/simplet2i.py b/ldm/simplet2i.py index e99660a8ab..63ed8c4583 100644 --- a/ldm/simplet2i.py +++ b/ldm/simplet2i.py @@ -88,6 +88,8 @@ class T2I: downsampling_factor precision strength + +The vast majority of these arguments default to reasonable values. """ def __init__(self, outdir="outputs/txt2img-samples", @@ -109,7 +111,8 @@ class T2I: fixed_code=False, precision='autocast', full_precision=False, - strength=0.75 # default in scripts/img2img.py + strength=0.75 # default in scripts/img2img.py, + latent_diffusion_weights=False ): self.outdir = outdir self.batch_size = batch_size @@ -119,7 +122,7 @@ class T2I: self.grid = grid self.steps = steps self.cfg_scale = cfg_scale - self.weights = weights + self.weights = weights self.config = config self.sampler_name = sampler self.fixed_code = fixed_code @@ -131,6 +134,7 @@ class T2I: self.strength = strength self.model = None # empty for now self.sampler = None + self.latent_diffusion_weights=latent_diffusion_weights if seed is None: self.seed = self._new_seed() else: diff --git a/scripts/dream.py b/scripts/dream.py index b8abb780fd..00e5263108 100755 --- a/scripts/dream.py +++ b/scripts/dream.py @@ -51,7 +51,9 @@ def main(): sampler=opt.sampler, weights=weights, full_precision=opt.full_precision, - config=config) + config=config, + latent_diffusion_weights=opt.laion400m # this is solely for recreating the prompt + ) # make sure the output directory exists if not os.path.exists(opt.outdir): @@ -119,7 +121,7 @@ def main_loop(t2i,parser,log): else: results = t2i.img2img(**vars(opt)) print("Outputs:") - write_log_message(opt,switches,results,log) + write_log_message(t2i,opt,results,log) except KeyboardInterrupt: print('*interrupted*') continue @@ -127,22 +129,66 @@ def main_loop(t2i,parser,log): print("goodbye!") -def write_log_message(opt,switches,results,logfile): +def write_log_message(t2i,opt,results,logfile): ''' logs the name of the output image, its prompt and seed to both the terminal and the log file ''' - if opt.grid: - _output_for_grid(switches,results,logfile) - else: - _output_for_individual(switches,results,logfile) + _output_for_individual(t2i,opt,results,logfile) -def _output_for_individual(switches,results,logfile): +def _output_for_individual(opt,results,logfile): + switches = _reconstruct_switches(opt) + prompt_str = ' '.join(switches) + + last_seed = None + img_num = 1 + batch_size = opt.batch_size or t2i.batch_size + for r in results: - log_message = " ".join([' ',str(r[0])+':', - f'"{switches[0]}"', - *switches[1:],f'-S {r[1]}']) + seed = r[1] + log_message = (f'{r[0]}: {prompt_str} -S{seed}') + + if batch_size > 1: + if seed != last_seed: + img_num = 1 + log_message += ' # (batch image {img_num} of {batch_size})' + else: + img_num += 1 + log_message += f' # (batch image {img_num} of {batch_size})' + last_seed = seed + print(log_message) logfile.write(log_message+"\n") logfile.flush() +def _reconstruct_switches(t2i,opt): + '''Normalize the prompt and switches''' + switches = (f'"{opt.prompt}"') + switches.append(f'-s{opt.steps or t2i.steps}') + switches.append(f'-b{opt.batch_size or t2i.batch_size}') + switches.append(f'-m{opt.sampler or t2i.sampler_name}') + switches.append(f'-W{opt.width or t2i.width}') + switches.append(f'-H{opt.height or t2i.height}') + switches.append(f'-C{opt.cfg_scale or t2i.cfg_scale}') + if opt.init_img: + switches.append(f'-I{opt.init_img}') + if opt.strength: + switches.append(f'-f{opt.strength of t2i.strength}') + if opt.full_precision or t2i.full_precision: + switches.append('-F') + if t2i.model.liaon400m: + switches.append('-l') + return switches + +def _output_for_grid(switches,results,logfile): + first_seed = results[0][1] + log_message = " ".join([' ',str(results[0][0])+':', + f'"{switches[0]}"', + *switches[1:],f'-S {results[0][1]}']) + print(log_message) + logfile.write(log_message+"\n") + all_seeds = [row[1] for row in results] + log_message = f' seeds for individual rows: {all_seeds}' + print(log_message) + logfile.write(log_message+"\n") + def _output_for_grid(switches,results,logfile): first_seed = results[0][1] log_message = " ".join([' ',str(results[0][0])+':', @@ -162,7 +208,7 @@ def create_argv_parser(): "-l", dest='laion400m', action='store_true', - help="fallback to the latent diffusion (LAION4400M) weights and config") + help="fallback to the latent diffusion (laion400m) weights and config") parser.add_argument('-n','--iterations', type=int, default=1, @@ -175,7 +221,7 @@ def create_argv_parser(): type=int, default=1, help="number of images to produce per iteration (currently not working properly - producing too many images)") - parser.add_argument('--sampler', + parser.add_argument('--sampler','-m', choices=['plms','ddim', 'klms'], default='klms', help="which sampler to use (klms)") @@ -193,7 +239,11 @@ def create_cmd_parser(): parser.add_argument('-s','--steps',type=int,help="number of steps") parser.add_argument('-S','--seed',type=int,help="image seed") parser.add_argument('-n','--iterations',type=int,default=1,help="number of samplings to perform") - parser.add_argument('-b','--batch_size',type=int,default=1,help="number of images to produce per sampling (currently broken)") + parser.add_argument('-b','--batch_size',type=int,default=1,help="number of images to produce per sampling") + parser.add_argument('--sampler', + choices=['plms','ddim', 'klms'], + default='klms', + help="which sampler to use (klms)") parser.add_argument('-W','--width',type=int,help="image width, multiple of 64") parser.add_argument('-H','--height',type=int,help="image height, multiple of 64") parser.add_argument('-C','--cfg_scale',default=7.5,type=float,help="prompt configuration scale")