test of normalization of prompt

This commit is contained in:
Lincoln Stein 2022-08-21 22:48:40 -04:00
parent c4c4974b39
commit 5f352aec87
2 changed files with 70 additions and 16 deletions

View File

@ -88,6 +88,8 @@ class T2I:
downsampling_factor downsampling_factor
precision precision
strength strength
The vast majority of these arguments default to reasonable values.
""" """
def __init__(self, def __init__(self,
outdir="outputs/txt2img-samples", outdir="outputs/txt2img-samples",
@ -109,7 +111,8 @@ class T2I:
fixed_code=False, fixed_code=False,
precision='autocast', precision='autocast',
full_precision=False, full_precision=False,
strength=0.75 # default in scripts/img2img.py strength=0.75 # default in scripts/img2img.py,
latent_diffusion_weights=False
): ):
self.outdir = outdir self.outdir = outdir
self.batch_size = batch_size self.batch_size = batch_size
@ -119,7 +122,7 @@ class T2I:
self.grid = grid self.grid = grid
self.steps = steps self.steps = steps
self.cfg_scale = cfg_scale self.cfg_scale = cfg_scale
self.weights = weights self.weights = weights
self.config = config self.config = config
self.sampler_name = sampler self.sampler_name = sampler
self.fixed_code = fixed_code self.fixed_code = fixed_code
@ -131,6 +134,7 @@ class T2I:
self.strength = strength self.strength = strength
self.model = None # empty for now self.model = None # empty for now
self.sampler = None self.sampler = None
self.latent_diffusion_weights=latent_diffusion_weights
if seed is None: if seed is None:
self.seed = self._new_seed() self.seed = self._new_seed()
else: else:

View File

@ -51,7 +51,9 @@ def main():
sampler=opt.sampler, sampler=opt.sampler,
weights=weights, weights=weights,
full_precision=opt.full_precision, full_precision=opt.full_precision,
config=config) config=config,
latent_diffusion_weights=opt.laion400m # this is solely for recreating the prompt
)
# make sure the output directory exists # make sure the output directory exists
if not os.path.exists(opt.outdir): if not os.path.exists(opt.outdir):
@ -119,7 +121,7 @@ def main_loop(t2i,parser,log):
else: else:
results = t2i.img2img(**vars(opt)) results = t2i.img2img(**vars(opt))
print("Outputs:") print("Outputs:")
write_log_message(opt,switches,results,log) write_log_message(t2i,opt,results,log)
except KeyboardInterrupt: except KeyboardInterrupt:
print('*interrupted*') print('*interrupted*')
continue continue
@ -127,22 +129,66 @@ def main_loop(t2i,parser,log):
print("goodbye!") print("goodbye!")
def write_log_message(opt,switches,results,logfile): def write_log_message(t2i,opt,results,logfile):
''' logs the name of the output image, its prompt and seed to both the terminal and the log file ''' ''' logs the name of the output image, its prompt and seed to both the terminal and the log file '''
if opt.grid: _output_for_individual(t2i,opt,results,logfile)
_output_for_grid(switches,results,logfile)
else: def _output_for_individual(opt,results,logfile):
_output_for_individual(switches,results,logfile) switches = _reconstruct_switches(opt)
prompt_str = ' '.join(switches)
last_seed = None
img_num = 1
batch_size = opt.batch_size or t2i.batch_size
def _output_for_individual(switches,results,logfile):
for r in results: for r in results:
log_message = " ".join([' ',str(r[0])+':', seed = r[1]
f'"{switches[0]}"', log_message = (f'{r[0]}: {prompt_str} -S{seed}')
*switches[1:],f'-S {r[1]}'])
if batch_size > 1:
if seed != last_seed:
img_num = 1
log_message += ' # (batch image {img_num} of {batch_size})'
else:
img_num += 1
log_message += f' # (batch image {img_num} of {batch_size})'
last_seed = seed
print(log_message) print(log_message)
logfile.write(log_message+"\n") logfile.write(log_message+"\n")
logfile.flush() logfile.flush()
def _reconstruct_switches(t2i,opt):
'''Normalize the prompt and switches'''
switches = (f'"{opt.prompt}"')
switches.append(f'-s{opt.steps or t2i.steps}')
switches.append(f'-b{opt.batch_size or t2i.batch_size}')
switches.append(f'-m{opt.sampler or t2i.sampler_name}')
switches.append(f'-W{opt.width or t2i.width}')
switches.append(f'-H{opt.height or t2i.height}')
switches.append(f'-C{opt.cfg_scale or t2i.cfg_scale}')
if opt.init_img:
switches.append(f'-I{opt.init_img}')
if opt.strength:
switches.append(f'-f{opt.strength of t2i.strength}')
if opt.full_precision or t2i.full_precision:
switches.append('-F')
if t2i.model.liaon400m:
switches.append('-l')
return switches
def _output_for_grid(switches,results,logfile):
first_seed = results[0][1]
log_message = " ".join([' ',str(results[0][0])+':',
f'"{switches[0]}"',
*switches[1:],f'-S {results[0][1]}'])
print(log_message)
logfile.write(log_message+"\n")
all_seeds = [row[1] for row in results]
log_message = f' seeds for individual rows: {all_seeds}'
print(log_message)
logfile.write(log_message+"\n")
def _output_for_grid(switches,results,logfile): def _output_for_grid(switches,results,logfile):
first_seed = results[0][1] first_seed = results[0][1]
log_message = " ".join([' ',str(results[0][0])+':', log_message = " ".join([' ',str(results[0][0])+':',
@ -162,7 +208,7 @@ def create_argv_parser():
"-l", "-l",
dest='laion400m', dest='laion400m',
action='store_true', action='store_true',
help="fallback to the latent diffusion (LAION4400M) weights and config") help="fallback to the latent diffusion (laion400m) weights and config")
parser.add_argument('-n','--iterations', parser.add_argument('-n','--iterations',
type=int, type=int,
default=1, default=1,
@ -175,7 +221,7 @@ def create_argv_parser():
type=int, type=int,
default=1, default=1,
help="number of images to produce per iteration (currently not working properly - producing too many images)") help="number of images to produce per iteration (currently not working properly - producing too many images)")
parser.add_argument('--sampler', parser.add_argument('--sampler','-m',
choices=['plms','ddim', 'klms'], choices=['plms','ddim', 'klms'],
default='klms', default='klms',
help="which sampler to use (klms)") help="which sampler to use (klms)")
@ -193,7 +239,11 @@ def create_cmd_parser():
parser.add_argument('-s','--steps',type=int,help="number of steps") parser.add_argument('-s','--steps',type=int,help="number of steps")
parser.add_argument('-S','--seed',type=int,help="image seed") parser.add_argument('-S','--seed',type=int,help="image seed")
parser.add_argument('-n','--iterations',type=int,default=1,help="number of samplings to perform") parser.add_argument('-n','--iterations',type=int,default=1,help="number of samplings to perform")
parser.add_argument('-b','--batch_size',type=int,default=1,help="number of images to produce per sampling (currently broken)") parser.add_argument('-b','--batch_size',type=int,default=1,help="number of images to produce per sampling")
parser.add_argument('--sampler',
choices=['plms','ddim', 'klms'],
default='klms',
help="which sampler to use (klms)")
parser.add_argument('-W','--width',type=int,help="image width, multiple of 64") parser.add_argument('-W','--width',type=int,help="image width, multiple of 64")
parser.add_argument('-H','--height',type=int,help="image height, multiple of 64") parser.add_argument('-H','--height',type=int,help="image height, multiple of 64")
parser.add_argument('-C','--cfg_scale',default=7.5,type=float,help="prompt configuration scale") parser.add_argument('-C','--cfg_scale',default=7.5,type=float,help="prompt configuration scale")