mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
test of normalization of prompt
This commit is contained in:
parent
c4c4974b39
commit
5f352aec87
@ -88,6 +88,8 @@ class T2I:
|
|||||||
downsampling_factor
|
downsampling_factor
|
||||||
precision
|
precision
|
||||||
strength
|
strength
|
||||||
|
|
||||||
|
The vast majority of these arguments default to reasonable values.
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
outdir="outputs/txt2img-samples",
|
outdir="outputs/txt2img-samples",
|
||||||
@ -109,7 +111,8 @@ class T2I:
|
|||||||
fixed_code=False,
|
fixed_code=False,
|
||||||
precision='autocast',
|
precision='autocast',
|
||||||
full_precision=False,
|
full_precision=False,
|
||||||
strength=0.75 # default in scripts/img2img.py
|
strength=0.75 # default in scripts/img2img.py,
|
||||||
|
latent_diffusion_weights=False
|
||||||
):
|
):
|
||||||
self.outdir = outdir
|
self.outdir = outdir
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
@ -119,7 +122,7 @@ class T2I:
|
|||||||
self.grid = grid
|
self.grid = grid
|
||||||
self.steps = steps
|
self.steps = steps
|
||||||
self.cfg_scale = cfg_scale
|
self.cfg_scale = cfg_scale
|
||||||
self.weights = weights
|
self.weights = weights
|
||||||
self.config = config
|
self.config = config
|
||||||
self.sampler_name = sampler
|
self.sampler_name = sampler
|
||||||
self.fixed_code = fixed_code
|
self.fixed_code = fixed_code
|
||||||
@ -131,6 +134,7 @@ class T2I:
|
|||||||
self.strength = strength
|
self.strength = strength
|
||||||
self.model = None # empty for now
|
self.model = None # empty for now
|
||||||
self.sampler = None
|
self.sampler = None
|
||||||
|
self.latent_diffusion_weights=latent_diffusion_weights
|
||||||
if seed is None:
|
if seed is None:
|
||||||
self.seed = self._new_seed()
|
self.seed = self._new_seed()
|
||||||
else:
|
else:
|
||||||
|
@ -51,7 +51,9 @@ def main():
|
|||||||
sampler=opt.sampler,
|
sampler=opt.sampler,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
full_precision=opt.full_precision,
|
full_precision=opt.full_precision,
|
||||||
config=config)
|
config=config,
|
||||||
|
latent_diffusion_weights=opt.laion400m # this is solely for recreating the prompt
|
||||||
|
)
|
||||||
|
|
||||||
# make sure the output directory exists
|
# make sure the output directory exists
|
||||||
if not os.path.exists(opt.outdir):
|
if not os.path.exists(opt.outdir):
|
||||||
@ -119,7 +121,7 @@ def main_loop(t2i,parser,log):
|
|||||||
else:
|
else:
|
||||||
results = t2i.img2img(**vars(opt))
|
results = t2i.img2img(**vars(opt))
|
||||||
print("Outputs:")
|
print("Outputs:")
|
||||||
write_log_message(opt,switches,results,log)
|
write_log_message(t2i,opt,results,log)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
print('*interrupted*')
|
print('*interrupted*')
|
||||||
continue
|
continue
|
||||||
@ -127,22 +129,66 @@ def main_loop(t2i,parser,log):
|
|||||||
print("goodbye!")
|
print("goodbye!")
|
||||||
|
|
||||||
|
|
||||||
def write_log_message(opt,switches,results,logfile):
|
def write_log_message(t2i,opt,results,logfile):
|
||||||
''' logs the name of the output image, its prompt and seed to both the terminal and the log file '''
|
''' logs the name of the output image, its prompt and seed to both the terminal and the log file '''
|
||||||
if opt.grid:
|
_output_for_individual(t2i,opt,results,logfile)
|
||||||
_output_for_grid(switches,results,logfile)
|
|
||||||
else:
|
|
||||||
_output_for_individual(switches,results,logfile)
|
|
||||||
|
|
||||||
def _output_for_individual(switches,results,logfile):
|
def _output_for_individual(opt,results,logfile):
|
||||||
|
switches = _reconstruct_switches(opt)
|
||||||
|
prompt_str = ' '.join(switches)
|
||||||
|
|
||||||
|
last_seed = None
|
||||||
|
img_num = 1
|
||||||
|
batch_size = opt.batch_size or t2i.batch_size
|
||||||
|
|
||||||
for r in results:
|
for r in results:
|
||||||
log_message = " ".join([' ',str(r[0])+':',
|
seed = r[1]
|
||||||
f'"{switches[0]}"',
|
log_message = (f'{r[0]}: {prompt_str} -S{seed}')
|
||||||
*switches[1:],f'-S {r[1]}'])
|
|
||||||
|
if batch_size > 1:
|
||||||
|
if seed != last_seed:
|
||||||
|
img_num = 1
|
||||||
|
log_message += ' # (batch image {img_num} of {batch_size})'
|
||||||
|
else:
|
||||||
|
img_num += 1
|
||||||
|
log_message += f' # (batch image {img_num} of {batch_size})'
|
||||||
|
last_seed = seed
|
||||||
|
|
||||||
print(log_message)
|
print(log_message)
|
||||||
logfile.write(log_message+"\n")
|
logfile.write(log_message+"\n")
|
||||||
logfile.flush()
|
logfile.flush()
|
||||||
|
|
||||||
|
def _reconstruct_switches(t2i,opt):
|
||||||
|
'''Normalize the prompt and switches'''
|
||||||
|
switches = (f'"{opt.prompt}"')
|
||||||
|
switches.append(f'-s{opt.steps or t2i.steps}')
|
||||||
|
switches.append(f'-b{opt.batch_size or t2i.batch_size}')
|
||||||
|
switches.append(f'-m{opt.sampler or t2i.sampler_name}')
|
||||||
|
switches.append(f'-W{opt.width or t2i.width}')
|
||||||
|
switches.append(f'-H{opt.height or t2i.height}')
|
||||||
|
switches.append(f'-C{opt.cfg_scale or t2i.cfg_scale}')
|
||||||
|
if opt.init_img:
|
||||||
|
switches.append(f'-I{opt.init_img}')
|
||||||
|
if opt.strength:
|
||||||
|
switches.append(f'-f{opt.strength of t2i.strength}')
|
||||||
|
if opt.full_precision or t2i.full_precision:
|
||||||
|
switches.append('-F')
|
||||||
|
if t2i.model.liaon400m:
|
||||||
|
switches.append('-l')
|
||||||
|
return switches
|
||||||
|
|
||||||
|
def _output_for_grid(switches,results,logfile):
|
||||||
|
first_seed = results[0][1]
|
||||||
|
log_message = " ".join([' ',str(results[0][0])+':',
|
||||||
|
f'"{switches[0]}"',
|
||||||
|
*switches[1:],f'-S {results[0][1]}'])
|
||||||
|
print(log_message)
|
||||||
|
logfile.write(log_message+"\n")
|
||||||
|
all_seeds = [row[1] for row in results]
|
||||||
|
log_message = f' seeds for individual rows: {all_seeds}'
|
||||||
|
print(log_message)
|
||||||
|
logfile.write(log_message+"\n")
|
||||||
|
|
||||||
def _output_for_grid(switches,results,logfile):
|
def _output_for_grid(switches,results,logfile):
|
||||||
first_seed = results[0][1]
|
first_seed = results[0][1]
|
||||||
log_message = " ".join([' ',str(results[0][0])+':',
|
log_message = " ".join([' ',str(results[0][0])+':',
|
||||||
@ -162,7 +208,7 @@ def create_argv_parser():
|
|||||||
"-l",
|
"-l",
|
||||||
dest='laion400m',
|
dest='laion400m',
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help="fallback to the latent diffusion (LAION4400M) weights and config")
|
help="fallback to the latent diffusion (laion400m) weights and config")
|
||||||
parser.add_argument('-n','--iterations',
|
parser.add_argument('-n','--iterations',
|
||||||
type=int,
|
type=int,
|
||||||
default=1,
|
default=1,
|
||||||
@ -175,7 +221,7 @@ def create_argv_parser():
|
|||||||
type=int,
|
type=int,
|
||||||
default=1,
|
default=1,
|
||||||
help="number of images to produce per iteration (currently not working properly - producing too many images)")
|
help="number of images to produce per iteration (currently not working properly - producing too many images)")
|
||||||
parser.add_argument('--sampler',
|
parser.add_argument('--sampler','-m',
|
||||||
choices=['plms','ddim', 'klms'],
|
choices=['plms','ddim', 'klms'],
|
||||||
default='klms',
|
default='klms',
|
||||||
help="which sampler to use (klms)")
|
help="which sampler to use (klms)")
|
||||||
@ -193,7 +239,11 @@ def create_cmd_parser():
|
|||||||
parser.add_argument('-s','--steps',type=int,help="number of steps")
|
parser.add_argument('-s','--steps',type=int,help="number of steps")
|
||||||
parser.add_argument('-S','--seed',type=int,help="image seed")
|
parser.add_argument('-S','--seed',type=int,help="image seed")
|
||||||
parser.add_argument('-n','--iterations',type=int,default=1,help="number of samplings to perform")
|
parser.add_argument('-n','--iterations',type=int,default=1,help="number of samplings to perform")
|
||||||
parser.add_argument('-b','--batch_size',type=int,default=1,help="number of images to produce per sampling (currently broken)")
|
parser.add_argument('-b','--batch_size',type=int,default=1,help="number of images to produce per sampling")
|
||||||
|
parser.add_argument('--sampler',
|
||||||
|
choices=['plms','ddim', 'klms'],
|
||||||
|
default='klms',
|
||||||
|
help="which sampler to use (klms)")
|
||||||
parser.add_argument('-W','--width',type=int,help="image width, multiple of 64")
|
parser.add_argument('-W','--width',type=int,help="image width, multiple of 64")
|
||||||
parser.add_argument('-H','--height',type=int,help="image height, multiple of 64")
|
parser.add_argument('-H','--height',type=int,help="image height, multiple of 64")
|
||||||
parser.add_argument('-C','--cfg_scale',default=7.5,type=float,help="prompt configuration scale")
|
parser.add_argument('-C','--cfg_scale',default=7.5,type=float,help="prompt configuration scale")
|
||||||
|
Loading…
Reference in New Issue
Block a user