code is reorganized and mostly functional. Grid needs to be brought back online, as well as naming of img2img variants (currently the variants get written but not logged)

This commit is contained in:
Lincoln Stein 2022-08-24 19:47:59 -04:00
parent b12955c963
commit b978536385
3 changed files with 101 additions and 128 deletions

View File

@ -1,6 +1,7 @@
'''Utilities for dealing with PNG images and their path names'''
import os
import atexit
import re
from PIL import Image,PngImagePlugin
# ---------------readline utilities---------------------
@ -94,40 +95,43 @@ if readline_available:
# -------------------image generation utils-----
class PngWriter:
def __init__(self,opt):
self.opt = opt
self.filepath = None
self.files_written = []
def __init__(self,outdir,opt,prompt):
self.outdir = outdir
self.opt = opt
self.prompt = prompt
self.filepath = None
self.files_written = []
def write_image(self,image,seed):
self.filepath = self.unique_filename(self,opt,seed,self.filepath) # will increment name in some sensible way
self.filepath = self.unique_filename(seed,self.filepath) # will increment name in some sensible way
try:
image.save(self.filename)
prompt = f'{self.prompt} -S{seed}'
self.save_image_and_prompt_to_png(image,prompt,self.filepath)
except IOError as e:
print(e)
self.files_written.append([self.filepath,seed])
def unique_filename(self,opt,seed,previouspath):
def unique_filename(self,seed,previouspath):
revision = 1
if previouspath is None:
# sort reverse alphabetically until we find max+1
dirlist = sorted(os.listdir(outdir),reverse=True)
dirlist = sorted(os.listdir(self.outdir),reverse=True)
# find the first filename that matches our pattern or return 000000.0.png
filename = next((f for f in dirlist if re.match('^(\d+)\..*\.png',f)),'0000000.0.png')
basecount = int(filename.split('.',1)[0])
basecount += 1
if opt.batch_size > 1:
if self.opt.batch_size > 1:
filename = f'{basecount:06}.{seed}.01.png'
else:
filename = f'{basecount:06}.{seed}.png'
return os.path.join(outdir,filename)
return os.path.join(self.outdir,filename)
else:
basename = os.path.basename(previouspath)
x = re.match('^(\d+)\..*\.png',basename)
if not x:
return self.unique_filename(opt,seed,previouspath)
return self.unique_filename(seed,previouspath)
basecount = int(x.groups()[0])
series = 0
@ -135,9 +139,41 @@ class PngWriter:
while not finished:
series += 1
filename = f'{basecount:06}.{seed}.png'
if isbatch or os.path.exists(os.path.join(outdir,filename)):
if self.opt.batch_size>1 or os.path.exists(os.path.join(self.outdir,filename)):
filename = f'{basecount:06}.{seed}.{series:02}.png'
finished = not os.path.exists(os.path.join(outdir,filename))
return os.path.join(outdir,filename)
finished = not os.path.exists(os.path.join(self.outdir,filename))
return os.path.join(self.outdir,filename)
def save_image_and_prompt_to_png(self,image,prompt,path):
info = PngImagePlugin.PngInfo()
info.add_text("Dream",prompt)
image.save(path,"PNG",pnginfo=info)
class PromptFormatter():
def __init__(self,t2i,opt):
self.t2i = t2i
self.opt = opt
def normalize_prompt(self):
'''Normalize the prompt and switches'''
t2i = self.t2i
opt = self.opt
switches = list()
switches.append(f'"{opt.prompt}"')
switches.append(f'-s{opt.steps or t2i.steps}')
switches.append(f'-b{opt.batch_size or t2i.batch_size}')
switches.append(f'-W{opt.width or t2i.width}')
switches.append(f'-H{opt.height or t2i.height}')
switches.append(f'-C{opt.cfg_scale or t2i.cfg_scale}')
switches.append(f'-m{t2i.sampler_name}')
if opt.variants:
switches.append(f'-v{opt.variants}')
if opt.init_img:
switches.append(f'-I{opt.init_img}')
if opt.strength and opt.init_img is not None:
switches.append(f'-f{opt.strength or t2i.strength}')
if t2i.full_precision:
switches.append('-F')
return ' '.join(switches)

View File

@ -99,13 +99,13 @@ The vast majority of these arguments default to reasonable values.
def __init__(self,
batch_size=1,
iterations = 1,
grid=False,
individual=None, # redundant
steps=50,
seed=None,
cfg_scale=7.5,
weights="models/ldm/stable-diffusion-v1/model.ckpt",
config = "configs/stable-diffusion/v1-inference.yaml",
width=512,
height=512,
sampler_name="klms",
latent_channels=4,
downsampling_factor=8,
@ -121,7 +121,6 @@ The vast majority of these arguments default to reasonable values.
self.iterations = iterations
self.width = width
self.height = height
self.grid = grid
self.steps = steps
self.cfg_scale = cfg_scale
self.weights = weights
@ -143,25 +142,26 @@ The vast majority of these arguments default to reasonable values.
else:
self.seed = seed
def generate(self,
# these are common
prompt,
batch_size=None,
iterations=None,
steps=None,
seed=None,
cfg_scale=None,
ddim_eta=None,
skip_normalize=False,
image_callback=None,
# these are specific to txt2img
width=None,
height=None,
# these are specific to img2img
init_img=None,
strength=None,
variants=None):
'''ldm.generate() is the common entry point for txt2img() and img2img()'''
def prompt2image(self,
# these are common
prompt,
batch_size=None,
iterations=None,
steps=None,
seed=None,
cfg_scale=None,
ddim_eta=None,
skip_normalize=False,
image_callback=None,
# these are specific to txt2img
width=None,
height=None,
# these are specific to img2img
init_img=None,
strength=None,
variants=None,
**args): # eat up additional cruft
'''ldm.prompt2image() is the common entry point for txt2img() and img2img()'''
steps = steps or self.steps
seed = seed or self.seed
width = width or self.width
@ -178,10 +178,6 @@ The vast majority of these arguments default to reasonable values.
data = [batch_size * [prompt]]
scope = autocast if self.precision=="autocast" else nullcontext
if grid:
callback = self.image2png
else:
callback = None
tic = time.time()
if init_img:
@ -212,7 +208,7 @@ The vast majority of these arguments default to reasonable values.
steps,seed,cfg_scale,ddim_eta,
skip_normalize,
width,height,
callback=callback): # the callback is called each time a new Image is generated
callback): # the callback is called each time a new Image is generated
"""
Generate an image from the prompt, writing iteration images into the outdir
The output is a list of lists in the format: [[image1,seed1], [image2,seed2],...]
@ -224,14 +220,14 @@ The vast majority of these arguments default to reasonable values.
# Gawd. Too many levels of indent here. Need to refactor into smaller routines!
try:
with precision_scope(self.device.type), model.ema_scope():
with precision_scope(self.device.type), self.model.ema_scope():
all_samples = list()
for n in trange(iterations, desc="Sampling"):
seed_everything(seed)
for prompts in tqdm(data, desc="data", dynamic_ncols=True):
uc = None
if cfg_scale != 1.0:
uc = model.get_learned_conditioning(batch_size * [""])
uc = self.model.get_learned_conditioning(batch_size * [""])
if isinstance(prompts, tuple):
prompts = list(prompts)
@ -247,9 +243,9 @@ The vast majority of these arguments default to reasonable values.
weight = weights[i]
if not skip_normalize:
weight = weight / totalWeight
c = torch.add(c,model.get_learned_conditioning(subprompts[i]), alpha=weight)
c = torch.add(c,self.model.get_learned_conditioning(subprompts[i]), alpha=weight)
else: # just standard 1 prompt
c = model.get_learned_conditioning(prompts)
c = self.model.get_learned_conditioning(prompts)
shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor]
samples_ddim, _ = sampler.sample(S=steps,
@ -261,7 +257,7 @@ The vast majority of these arguments default to reasonable values.
unconditional_conditioning=uc,
eta=ddim_eta)
x_samples_ddim = model.decode_first_stage(samples_ddim)
x_samples_ddim = self.model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
for x_sample in x_samples_ddim:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
@ -277,8 +273,6 @@ The vast majority of these arguments default to reasonable values.
except RuntimeError as e:
print(str(e))
toc = time.time()
print(f'{image_count} images generated in',"%4.2fs"% (toc-tic))
return images
@torch.no_grad()
@ -297,14 +291,14 @@ The vast majority of these arguments default to reasonable values.
# PLMS sampler not supported yet, so ignore previous sampler
if self.sampler_name!='ddim':
print(f"sampler '{self.sampler_name}' is not yet supported. Using DDM sampler")
sampler = DDIMSampler(model, device=self.device)
sampler = DDIMSampler(self.model, device=self.device)
else:
sampler = self.sampler
init_image = self._load_img(init_img).to(self.device)
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
with precision_scope(self.device.type):
init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space
init_latent = self.model.get_first_stage_encoding(self.model.encode_first_stage(init_image)) # move to latent space
sampler.make_schedule(ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False)
@ -314,14 +308,14 @@ The vast majority of these arguments default to reasonable values.
images = list()
try:
with precision_scope(self.device.type), model.ema_scope():
with precision_scope(self.device.type), self.model.ema_scope():
all_samples = list()
for n in trange(iterations, desc="Sampling"):
seed_everything(seed)
for prompts in tqdm(data, desc="data", dynamic_ncols=True):
uc = None
if cfg_scale != 1.0:
uc = model.get_learned_conditioning(batch_size * [""])
uc = self.model.get_learned_conditioning(batch_size * [""])
if isinstance(prompts, tuple):
prompts = list(prompts)
@ -337,9 +331,9 @@ The vast majority of these arguments default to reasonable values.
weight = weights[i]
if not skip_normalize:
weight = weight / totalWeight
c = torch.add(c,model.get_learned_conditioning(subprompts[i]), alpha=weight)
c = torch.add(c,self.model.get_learned_conditioning(subprompts[i]), alpha=weight)
else: # just standard 1 prompt
c = model.get_learned_conditioning(prompts)
c = self.model.get_learned_conditioning(prompts)
# encode (scaled latent)
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(self.device))
@ -347,7 +341,7 @@ The vast majority of these arguments default to reasonable values.
samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc,)
x_samples = model.decode_first_stage(samples)
x_samples = self.model.decode_first_stage(samples)
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
for x_sample in x_samples:

View File

@ -6,7 +6,7 @@ import shlex
import os
import sys
import copy
from ldm.dream_util import Completer,PngWriter
from ldm.dream_util import Completer,PngWriter,PromptFormatter
debugging = False
@ -27,10 +27,6 @@ def main():
config = "configs/stable-diffusion/v1-inference.yaml"
weights = "models/ldm/stable-diffusion-v1/model.ckpt"
# command line history will be stored in a file called "~/.dream_history"
if readline_available:
setup_readline()
print("* Initializing, be patient...\n")
sys.path.append('.')
from pytorch_lightning import logging
@ -46,8 +42,6 @@ def main():
# the user input loop
t2i = T2I(width=width,
height=height,
batch_size=opt.batch_size,
outdir=opt.outdir,
sampler_name=opt.sampler_name,
weights=weights,
full_precision=opt.full_precision,
@ -79,13 +73,13 @@ def main():
log_path = os.path.join(opt.outdir,'dream_log.txt')
with open(log_path,'a') as log:
cmd_parser = create_cmd_parser()
main_loop(t2i,cmd_parser,log,infile)
main_loop(t2i,opt.outdir,cmd_parser,log,infile)
log.close()
if infile:
infile.close()
def main_loop(t2i,parser,log,infile):
def main_loop(t2i,outdir,parser,log,infile):
''' prompt/read/execute loop '''
done = False
@ -123,13 +117,13 @@ def main_loop(t2i,parser,log,infile):
if elements[0]=='cd' and len(elements)>1:
if os.path.exists(elements[1]):
print(f"setting image output directory to {elements[1]}")
opt.outdir=elements[1]
outdir=elements[1]
else:
print(f"directory {elements[1]} does not exist")
continue
if elements[0]=='pwd':
print(f"current output directory is {opt.outdir}")
print(f"current output directory is {outdir}")
continue
if elements[0].startswith('!dream'): # in case a stored prompt still contains the !dream command
@ -158,88 +152,41 @@ def main_loop(t2i,parser,log,infile):
print("Try again with a prompt!")
continue
normalized_prompt = PromptFormatter(t2i,opt).normalize_prompt()
try:
file_writer = PngWriter(opt)
opt.callback = file_writer(write_image)
run_generator(**vars(opt))
file_writer = PngWriter(outdir,opt,normalized_prompt)
callback = file_writer.write_image
t2i.prompt2image(image_callback=callback,
**vars(opt))
results = file_writer.files_written
except AssertionError as e:
print(e)
continue
print("Outputs:")
write_log_message(t2i,opt,results,log)
write_log_message(t2i,normalized_prompt,results,log)
print("goodbye!")
def write_log_message(t2i,opt,results,logfile):
def write_log_message(t2i,prompt,results,logfile):
''' logs the name of the output image, its prompt and seed to the terminal, log file, and a Dream text chunk in the PNG metadata '''
switches = _reconstruct_switches(t2i,opt)
prompt_str = ' '.join(switches)
# when multiple images are produced in batch, then we keep track of where each starts
last_seed = None
img_num = 1
batch_size = opt.batch_size or t2i.batch_size
seenit = {}
seeds = [a[1] for a in results]
if batch_size > 1:
seeds = f"(seeds for each batch row: {seeds})"
else:
seeds = f"(seeds for individual images: {seeds})"
seeds = f"(seeds for individual images: {seeds})"
for r in results:
seed = r[1]
log_message = (f'{r[0]}: {prompt_str} -S{seed}')
log_message = (f'{r[0]}: {prompt} -S{seed}')
if batch_size > 1:
if seed != last_seed:
img_num = 1
log_message += f' # (batch image {img_num} of {batch_size})'
else:
img_num += 1
log_message += f' # (batch image {img_num} of {batch_size})'
last_seed = seed
print(log_message)
logfile.write(log_message+"\n")
logfile.flush()
if r[0] not in seenit:
seenit[r[0]] = True
try:
if opt.grid:
_write_prompt_to_png(r[0],f'{prompt_str} -g -S{seed} {seeds}')
else:
_write_prompt_to_png(r[0],f'{prompt_str} -S{seed}')
except FileNotFoundError:
print(f"Could not open file '{r[0]}' for reading")
def _reconstruct_switches(t2i,opt):
'''Normalize the prompt and switches'''
switches = list()
switches.append(f'"{opt.prompt}"')
switches.append(f'-s{opt.steps or t2i.steps}')
switches.append(f'-b{opt.batch_size or t2i.batch_size}')
switches.append(f'-W{opt.width or t2i.width}')
switches.append(f'-H{opt.height or t2i.height}')
switches.append(f'-C{opt.cfg_scale or t2i.cfg_scale}')
switches.append(f'-m{t2i.sampler_name}')
if opt.variants:
switches.append(f'-v{opt.variants}')
if opt.init_img:
switches.append(f'-I{opt.init_img}')
if opt.strength and opt.init_img is not None:
switches.append(f'-f{opt.strength or t2i.strength}')
if t2i.full_precision:
switches.append('-F')
return switches
def _write_prompt_to_png(path,prompt):
info = PngImagePlugin.PngInfo()
info.add_text("Dream",prompt)
im = Image.open(path)
im.save(path,"PNG",pnginfo=info)
def create_argv_parser():
parser = argparse.ArgumentParser(description="Parse script's command line args")
parser.add_argument("--laion400m",
@ -260,10 +207,6 @@ def create_argv_parser():
dest='full_precision',
action='store_true',
help="use slower full precision math for calculations")
parser.add_argument('-b','--batch_size',
type=int,
default=1,
help="number of images to produce per iteration (faster, but doesn't generate individual seeds")
parser.add_argument('--sampler','-m',
dest="sampler_name",
choices=['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'],