user's prompt is now normalized for reproducibility and written into the destination PNG file as a tEXt metadata chunk named "Dream". You can retrieve the prompt with an image editing program that supports browsing the full metadata, or with the images2prompt.py script located in 'scripts'

This commit is contained in:
Lincoln Stein 2022-08-22 00:12:16 -04:00
parent 5f352aec87
commit aa2729d868
3 changed files with 66 additions and 55 deletions

View File

@ -11,7 +11,7 @@ t2i = T2I(outdir = <path> // outputs/txt2img-samples
batch_size = <integer> // how many images to generate per sampling (1)
steps = <integer> // 50
seed = <integer> // current system time
sampler = ['ddim','plms','klms'] // klms
sampler_name= ['ddim','plms','klms'] // klms
grid = <boolean> // false
width = <integer> // image width, multiple of 64 (512)
height = <integer> // image height, multiple of 64 (512)
@ -77,7 +77,7 @@ class T2I:
batch_size
steps
seed
sampler
sampler_name
grid
individual
width
@ -104,15 +104,15 @@ The vast majority of these arguments default to reasonable values.
cfg_scale=7.5,
weights="models/ldm/stable-diffusion-v1/model.ckpt",
config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml",
sampler="klms",
sampler_name="klms",
latent_channels=4,
downsampling_factor=8,
ddim_eta=0.0, # deterministic
fixed_code=False,
precision='autocast',
full_precision=False,
strength=0.75 # default in scripts/img2img.py,
latent_diffusion_weights=False
strength=0.75, # default in scripts/img2img.py
latent_diffusion_weights=False # just to keep track of this parameter when regenerating prompt
):
self.outdir = outdir
self.batch_size = batch_size
@ -124,7 +124,7 @@ The vast majority of these arguments default to reasonable values.
self.cfg_scale = cfg_scale
self.weights = weights
self.config = config
self.sampler_name = sampler
self.sampler_name = sampler_name
self.fixed_code = fixed_code
self.latent_channels = latent_channels
self.downsampling_factor = downsampling_factor
@ -416,7 +416,7 @@ The vast majority of these arguments default to reasonable values.
if self.full_precision:
print('Using slower but more accurate full-precision math (--full_precision)')
else:
print('Using half precision math. Call with --full_precision to use full precision')
print('Using half precision math. Call with --full_precision to use slower but more accurate full precision.')
model.half()
return model

View File

@ -4,6 +4,7 @@ import shlex
import atexit
import os
import sys
from PIL import Image,PngImagePlugin
# readline unavailable on windows systems
try:
@ -48,7 +49,7 @@ def main():
height=height,
batch_size=opt.batch_size,
outdir=opt.outdir,
sampler=opt.sampler,
sampler_name=opt.sampler_name,
weights=weights,
full_precision=opt.full_precision,
config=config,
@ -130,16 +131,15 @@ def main_loop(t2i,parser,log):
def write_log_message(t2i,opt,results,logfile):
''' logs the name of the output image, its prompt and seed to both the terminal and the log file '''
_output_for_individual(t2i,opt,results,logfile)
def _output_for_individual(opt,results,logfile):
switches = _reconstruct_switches(opt)
''' logs the name of the output image, its prompt and seed to the terminal, log file, and a Dream text chunk in the PNG metadata '''
switches = _reconstruct_switches(t2i,opt)
prompt_str = ' '.join(switches)
# when multiple images are produced in batch, then we keep track of where each starts
last_seed = None
img_num = 1
batch_size = opt.batch_size or t2i.batch_size
seenit = {}
for r in results:
seed = r[1]
@ -148,58 +148,43 @@ def _output_for_individual(opt,results,logfile):
if batch_size > 1:
if seed != last_seed:
img_num = 1
log_message += ' # (batch image {img_num} of {batch_size})'
log_message += f' # (batch image {img_num} of {batch_size})'
else:
img_num += 1
log_message += f' # (batch image {img_num} of {batch_size})'
last_seed = seed
print(log_message)
logfile.write(log_message+"\n")
logfile.flush()
if r[0] not in seenit:
seenit[r[0]] = True
try:
_write_prompt_to_png(r[0],f'{prompt_str} -S{seed}')
except FileNotFoundError:
print(f"Could not open file '{r[0]}' for reading")
def _reconstruct_switches(t2i,opt):
'''Normalize the prompt and switches'''
switches = (f'"{opt.prompt}"')
switches = list()
switches.append(f'"{opt.prompt}"')
switches.append(f'-s{opt.steps or t2i.steps}')
switches.append(f'-b{opt.batch_size or t2i.batch_size}')
switches.append(f'-m{opt.sampler or t2i.sampler_name}')
switches.append(f'-W{opt.width or t2i.width}')
switches.append(f'-H{opt.height or t2i.height}')
switches.append(f'-C{opt.cfg_scale or t2i.cfg_scale}')
if opt.init_img:
switches.append(f'-I{opt.init_img}')
if opt.strength:
switches.append(f'-f{opt.strength of t2i.strength}')
if opt.full_precision or t2i.full_precision:
if opt.strength and opt.init_img is not None:
switches.append(f'-f{opt.strength or t2i.strength}')
if t2i.full_precision:
switches.append('-F')
if t2i.model.liaon400m:
switches.append('-l')
return switches
def _output_for_grid(switches,results,logfile):
first_seed = results[0][1]
log_message = " ".join([' ',str(results[0][0])+':',
f'"{switches[0]}"',
*switches[1:],f'-S {results[0][1]}'])
print(log_message)
logfile.write(log_message+"\n")
all_seeds = [row[1] for row in results]
log_message = f' seeds for individual rows: {all_seeds}'
print(log_message)
logfile.write(log_message+"\n")
def _output_for_grid(switches,results,logfile):
first_seed = results[0][1]
log_message = " ".join([' ',str(results[0][0])+':',
f'"{switches[0]}"',
*switches[1:],f'-S {results[0][1]}'])
print(log_message)
logfile.write(log_message+"\n")
all_seeds = [row[1] for row in results]
log_message = f' seeds for individual rows: {all_seeds}'
print(log_message)
logfile.write(log_message+"\n")
def _write_prompt_to_png(path,prompt):
info = PngImagePlugin.PngInfo()
info.add_text("Dream",prompt)
im = Image.open(path)
im.save(path,"PNG",pnginfo=info)
def create_argv_parser():
parser = argparse.ArgumentParser(description="Parse script's command line args")
@ -220,11 +205,12 @@ def create_argv_parser():
parser.add_argument('-b','--batch_size',
type=int,
default=1,
help="number of images to produce per iteration (currently not working properly - producing too many images)")
help="number of images to produce per iteration (faster, but doesn't generate individual seeds")
parser.add_argument('--sampler','-m',
dest="sampler_name",
choices=['plms','ddim', 'klms'],
default='klms',
help="which sampler to use (klms)")
help="which sampler to use (klms) - can only be set on command line")
parser.add_argument('-o',
'--outdir',
type=str,
@ -240,10 +226,6 @@ def create_cmd_parser():
parser.add_argument('-S','--seed',type=int,help="image seed")
parser.add_argument('-n','--iterations',type=int,default=1,help="number of samplings to perform")
parser.add_argument('-b','--batch_size',type=int,default=1,help="number of images to produce per sampling")
parser.add_argument('--sampler',
choices=['plms','ddim', 'klms'],
default='klms',
help="which sampler to use (klms)")
parser.add_argument('-W','--width',type=int,help="image width, multiple of 64")
parser.add_argument('-H','--height',type=int,help="image height, multiple of 64")
parser.add_argument('-C','--cfg_scale',default=7.5,type=float,help="prompt configuration scale")

29
scripts/images2prompt.py Normal file
View File

@ -0,0 +1,29 @@
#!/usr/bin/env python3
'''This script reads the "Dream" Stable Diffusion prompt embedded in files generated by dream.py'''
import sys
from PIL import Image,PngImagePlugin
if len(sys.argv) < 2:
print("Usage: file2prompt.py <file1.png> <file2.png> <file3.png>...")
exit(-1)
filenames = sys.argv[1:]
for f in filenames:
try:
im = Image.open(f)
try:
prompt = im.text['Dream']
except KeyError:
prompt = ''
print(f'{f}: {prompt}')
except FileNotFoundError:
sys.stderr.write(f'{f} not found\n')
continue
except PermissionError:
sys.stderr.write(f'{f} could not be opened due to inadequate permissions\n')
continue