user's prompt is now normalized for reproducibility and written into the destination PNG file as a tEXt metadata chunk named "Dream". You can retrieve the prompt with an image editing program that supports browsing the full metadata, or with the images2prompt.py script located in 'scripts'

This commit is contained in:
Lincoln Stein
2022-08-22 00:12:16 -04:00
parent 5f352aec87
commit aa2729d868
3 changed files with 66 additions and 55 deletions

View File

@ -4,6 +4,7 @@ import shlex
import atexit
import os
import sys
from PIL import Image,PngImagePlugin
# readline unavailable on windows systems
try:
@ -48,7 +49,7 @@ def main():
height=height,
batch_size=opt.batch_size,
outdir=opt.outdir,
sampler=opt.sampler,
sampler_name=opt.sampler_name,
weights=weights,
full_precision=opt.full_precision,
config=config,
@ -130,16 +131,15 @@ def main_loop(t2i,parser,log):
def write_log_message(t2i,opt,results,logfile):
''' logs the name of the output image, its prompt and seed to both the terminal and the log file '''
_output_for_individual(t2i,opt,results,logfile)
def _output_for_individual(opt,results,logfile):
switches = _reconstruct_switches(opt)
''' logs the name of the output image, its prompt and seed to the terminal, log file, and a Dream text chunk in the PNG metadata '''
switches = _reconstruct_switches(t2i,opt)
prompt_str = ' '.join(switches)
# when multiple images are produced in batch, then we keep track of where each starts
last_seed = None
img_num = 1
batch_size = opt.batch_size or t2i.batch_size
seenit = {}
for r in results:
seed = r[1]
@ -148,59 +148,44 @@ def _output_for_individual(opt,results,logfile):
if batch_size > 1:
if seed != last_seed:
img_num = 1
log_message += ' # (batch image {img_num} of {batch_size})'
log_message += f' # (batch image {img_num} of {batch_size})'
else:
img_num += 1
log_message += f' # (batch image {img_num} of {batch_size})'
last_seed = seed
print(log_message)
logfile.write(log_message+"\n")
logfile.flush()
if r[0] not in seenit:
seenit[r[0]] = True
try:
_write_prompt_to_png(r[0],f'{prompt_str} -S{seed}')
except FileNotFoundError:
print(f"Could not open file '{r[0]}' for reading")
def _reconstruct_switches(t2i,opt):
'''Normalize the prompt and switches'''
switches = (f'"{opt.prompt}"')
switches = list()
switches.append(f'"{opt.prompt}"')
switches.append(f'-s{opt.steps or t2i.steps}')
switches.append(f'-b{opt.batch_size or t2i.batch_size}')
switches.append(f'-m{opt.sampler or t2i.sampler_name}')
switches.append(f'-W{opt.width or t2i.width}')
switches.append(f'-H{opt.height or t2i.height}')
switches.append(f'-C{opt.cfg_scale or t2i.cfg_scale}')
switches.append(f'-W{opt.width or t2i.width}')
switches.append(f'-H{opt.height or t2i.height}')
switches.append(f'-C{opt.cfg_scale or t2i.cfg_scale}')
if opt.init_img:
switches.append(f'-I{opt.init_img}')
if opt.strength:
switches.append(f'-f{opt.strength of t2i.strength}')
if opt.full_precision or t2i.full_precision:
if opt.strength and opt.init_img is not None:
switches.append(f'-f{opt.strength or t2i.strength}')
if t2i.full_precision:
switches.append('-F')
if t2i.model.liaon400m:
switches.append('-l')
return switches
def _output_for_grid(switches,results,logfile):
first_seed = results[0][1]
log_message = " ".join([' ',str(results[0][0])+':',
f'"{switches[0]}"',
*switches[1:],f'-S {results[0][1]}'])
print(log_message)
logfile.write(log_message+"\n")
all_seeds = [row[1] for row in results]
log_message = f' seeds for individual rows: {all_seeds}'
print(log_message)
logfile.write(log_message+"\n")
def _output_for_grid(switches,results,logfile):
first_seed = results[0][1]
log_message = " ".join([' ',str(results[0][0])+':',
f'"{switches[0]}"',
*switches[1:],f'-S {results[0][1]}'])
print(log_message)
logfile.write(log_message+"\n")
all_seeds = [row[1] for row in results]
log_message = f' seeds for individual rows: {all_seeds}'
print(log_message)
logfile.write(log_message+"\n")
def _write_prompt_to_png(path,prompt):
info = PngImagePlugin.PngInfo()
info.add_text("Dream",prompt)
im = Image.open(path)
im.save(path,"PNG",pnginfo=info)
def create_argv_parser():
parser = argparse.ArgumentParser(description="Parse script's command line args")
parser.add_argument("--laion400m",
@ -220,11 +205,12 @@ def create_argv_parser():
parser.add_argument('-b','--batch_size',
type=int,
default=1,
help="number of images to produce per iteration (currently not working properly - producing too many images)")
help="number of images to produce per iteration (faster, but doesn't generate individual seeds")
parser.add_argument('--sampler','-m',
dest="sampler_name",
choices=['plms','ddim', 'klms'],
default='klms',
help="which sampler to use (klms)")
help="which sampler to use (klms) - can only be set on command line")
parser.add_argument('-o',
'--outdir',
type=str,
@ -240,10 +226,6 @@ def create_cmd_parser():
parser.add_argument('-S','--seed',type=int,help="image seed")
parser.add_argument('-n','--iterations',type=int,default=1,help="number of samplings to perform")
parser.add_argument('-b','--batch_size',type=int,default=1,help="number of images to produce per sampling")
parser.add_argument('--sampler',
choices=['plms','ddim', 'klms'],
default='klms',
help="which sampler to use (klms)")
parser.add_argument('-W','--width',type=int,help="image width, multiple of 64")
parser.add_argument('-H','--height',type=int,help="image height, multiple of 64")
parser.add_argument('-C','--cfg_scale',default=7.5,type=float,help="prompt configuration scale")

29
scripts/images2prompt.py Normal file
View File

@ -0,0 +1,29 @@
#!/usr/bin/env python3
'''This script reads the "Dream" Stable Diffusion prompt embedded in files generated by dream.py'''
import sys
from PIL import Image,PngImagePlugin
if len(sys.argv) < 2:
print("Usage: file2prompt.py <file1.png> <file2.png> <file3.png>...")
exit(-1)
filenames = sys.argv[1:]
for f in filenames:
try:
im = Image.open(f)
try:
prompt = im.text['Dream']
except KeyError:
prompt = ''
print(f'{f}: {prompt}')
except FileNotFoundError:
sys.stderr.write(f'{f} not found\n')
continue
except PermissionError:
sys.stderr.write(f'{f} could not be opened due to inadequate permissions\n')
continue