mostly back to full functionality; just missing grid generation code

This commit is contained in:
Lincoln Stein 2022-08-25 00:42:37 -04:00
parent b978536385
commit 0b4459b707
3 changed files with 245 additions and 138 deletions

View File

@ -4,6 +4,92 @@ import atexit
import re
from PIL import Image,PngImagePlugin
# -------------------image generation utils-----
class PngWriter:
def __init__(self,outdir,prompt=None,batch_size=1):
self.outdir = outdir
self.batch_size = batch_size
self.prompt = prompt
self.filepath = None
self.files_written = []
os.makedirs(outdir, exist_ok=True)
def write_image(self,image,seed):
self.filepath = self.unique_filename(seed,self.filepath) # will increment name in some sensible way
try:
prompt = f'{self.prompt} -S{seed}'
self.save_image_and_prompt_to_png(image,prompt,self.filepath)
except IOError as e:
print(e)
self.files_written.append([self.filepath,seed])
def unique_filename(self,seed,previouspath):
revision = 1
if previouspath is None:
# sort reverse alphabetically until we find max+1
dirlist = sorted(os.listdir(self.outdir),reverse=True)
# find the first filename that matches our pattern or return 000000.0.png
filename = next((f for f in dirlist if re.match('^(\d+)\..*\.png',f)),'0000000.0.png')
basecount = int(filename.split('.',1)[0])
basecount += 1
if self.batch_size > 1:
filename = f'{basecount:06}.{seed}.01.png'
else:
filename = f'{basecount:06}.{seed}.png'
return os.path.join(self.outdir,filename)
else:
basename = os.path.basename(previouspath)
x = re.match('^(\d+)\..*\.png',basename)
if not x:
return self.unique_filename(seed,previouspath)
basecount = int(x.groups()[0])
series = 0
finished = False
while not finished:
series += 1
filename = f'{basecount:06}.{seed}.png'
if self.batch_size>1 or os.path.exists(os.path.join(self.outdir,filename)):
filename = f'{basecount:06}.{seed}.{series:02}.png'
finished = not os.path.exists(os.path.join(self.outdir,filename))
return os.path.join(self.outdir,filename)
def save_image_and_prompt_to_png(self,image,prompt,path):
info = PngImagePlugin.PngInfo()
info.add_text("Dream",prompt)
image.save(path,"PNG",pnginfo=info)
class PromptFormatter():
def __init__(self,t2i,opt):
self.t2i = t2i
self.opt = opt
def normalize_prompt(self):
'''Normalize the prompt and switches'''
t2i = self.t2i
opt = self.opt
switches = list()
switches.append(f'"{opt.prompt}"')
switches.append(f'-s{opt.steps or t2i.steps}')
switches.append(f'-b{opt.batch_size or t2i.batch_size}')
switches.append(f'-W{opt.width or t2i.width}')
switches.append(f'-H{opt.height or t2i.height}')
switches.append(f'-C{opt.cfg_scale or t2i.cfg_scale}')
switches.append(f'-m{t2i.sampler_name}')
if opt.variants:
switches.append(f'-v{opt.variants}')
if opt.init_img:
switches.append(f'-I{opt.init_img}')
if opt.strength and opt.init_img is not None:
switches.append(f'-f{opt.strength or t2i.strength}')
if t2i.full_precision:
switches.append('-F')
return ' '.join(switches)
# ---------------readline utilities---------------------
try:
import readline
@ -92,88 +178,3 @@ if readline_available:
pass
atexit.register(readline.write_history_file,histfile)
# -------------------image generation utils-----
class PngWriter:
def __init__(self,outdir,opt,prompt):
self.outdir = outdir
self.opt = opt
self.prompt = prompt
self.filepath = None
self.files_written = []
def write_image(self,image,seed):
self.filepath = self.unique_filename(seed,self.filepath) # will increment name in some sensible way
try:
prompt = f'{self.prompt} -S{seed}'
self.save_image_and_prompt_to_png(image,prompt,self.filepath)
except IOError as e:
print(e)
self.files_written.append([self.filepath,seed])
def unique_filename(self,seed,previouspath):
revision = 1
if previouspath is None:
# sort reverse alphabetically until we find max+1
dirlist = sorted(os.listdir(self.outdir),reverse=True)
# find the first filename that matches our pattern or return 000000.0.png
filename = next((f for f in dirlist if re.match('^(\d+)\..*\.png',f)),'0000000.0.png')
basecount = int(filename.split('.',1)[0])
basecount += 1
if self.opt.batch_size > 1:
filename = f'{basecount:06}.{seed}.01.png'
else:
filename = f'{basecount:06}.{seed}.png'
return os.path.join(self.outdir,filename)
else:
basename = os.path.basename(previouspath)
x = re.match('^(\d+)\..*\.png',basename)
if not x:
return self.unique_filename(seed,previouspath)
basecount = int(x.groups()[0])
series = 0
finished = False
while not finished:
series += 1
filename = f'{basecount:06}.{seed}.png'
if self.opt.batch_size>1 or os.path.exists(os.path.join(self.outdir,filename)):
filename = f'{basecount:06}.{seed}.{series:02}.png'
finished = not os.path.exists(os.path.join(self.outdir,filename))
return os.path.join(self.outdir,filename)
def save_image_and_prompt_to_png(self,image,prompt,path):
info = PngImagePlugin.PngInfo()
info.add_text("Dream",prompt)
image.save(path,"PNG",pnginfo=info)
class PromptFormatter():
def __init__(self,t2i,opt):
self.t2i = t2i
self.opt = opt
def normalize_prompt(self):
'''Normalize the prompt and switches'''
t2i = self.t2i
opt = self.opt
switches = list()
switches.append(f'"{opt.prompt}"')
switches.append(f'-s{opt.steps or t2i.steps}')
switches.append(f'-b{opt.batch_size or t2i.batch_size}')
switches.append(f'-W{opt.width or t2i.width}')
switches.append(f'-H{opt.height or t2i.height}')
switches.append(f'-C{opt.cfg_scale or t2i.cfg_scale}')
switches.append(f'-m{t2i.sampler_name}')
if opt.variants:
switches.append(f'-v{opt.variants}')
if opt.init_img:
switches.append(f'-I{opt.init_img}')
if opt.strength and opt.init_img is not None:
switches.append(f'-f{opt.strength or t2i.strength}')
if t2i.full_precision:
switches.append('-F')
return ' '.join(switches)

View File

@ -4,52 +4,6 @@
# Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich
# Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors
"""Simplified text to image API for stable diffusion/latent diffusion
Example Usage:
from ldm.simplet2i import T2I
# Create an object with default values
t2i = T2I(outdir = <path> // outputs/txt2img-samples
model = <path> // models/ldm/stable-diffusion-v1/model.ckpt
config = <path> // default="configs/stable-diffusion/v1-inference.yaml
iterations = <integer> // how many times to run the sampling (1)
batch_size = <integer> // how many images to generate per sampling (1)
steps = <integer> // 50
seed = <integer> // current system time
sampler_name= ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms
grid = <boolean> // false
width = <integer> // image width, multiple of 64 (512)
height = <integer> // image height, multiple of 64 (512)
cfg_scale = <float> // unconditional guidance scale (7.5)
)
# do the slow model initialization
t2i.load_model()
# Do the fast inference & image generation. Any options passed here
# override the default values assigned during class initialization
# Will call load_model() if the model was not previously loaded.
# The method returns a list of images. Each row of the list is a sub-list of [filename,seed]
results = t2i.txt2img(prompt = "an astronaut riding a horse"
outdir = "./outputs/txt2img-samples)
)
for row in results:
print(f'filename={row[0]}')
print(f'seed ={row[1]}')
# Same thing, but using an initial image.
results = t2i.img2img(prompt = "an astronaut riding a horse"
outdir = "./outputs/img2img-samples"
init_img = "./sketches/horse+rider.png")
for row in results:
print(f'filename={row[0]}')
print(f'seed ={row[1]}')
"""
import torch
import numpy as np
import random
@ -64,6 +18,7 @@ from torchvision.utils import make_grid
from pytorch_lightning import seed_everything
from torch import autocast
from contextlib import contextmanager, nullcontext
import transformers
import time
import math
import re
@ -73,6 +28,69 @@ from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.models.diffusion.ksampler import KSampler
from ldm.dream_util import PngWriter
"""Simplified text to image API for stable diffusion/latent diffusion
Example Usage:
from ldm.simplet2i import T2I
# Create an object with default values
t2i = T2I(model = <path> // models/ldm/stable-diffusion-v1/model.ckpt
config = <path> // configs/stable-diffusion/v1-inference.yaml
iterations = <integer> // how many times to run the sampling (1)
batch_size = <integer> // how many images to generate per sampling (1)
steps = <integer> // 50
seed = <integer> // current system time
sampler_name= ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms
grid = <boolean> // false
width = <integer> // image width, multiple of 64 (512)
height = <integer> // image height, multiple of 64 (512)
cfg_scale = <float> // unconditional guidance scale (7.5)
)
# do the slow model initialization
t2i.load_model()
# Do the fast inference & image generation. Any options passed here
# override the default values assigned during class initialization
# Will call load_model() if the model was not previously loaded and so
# may be slow at first.
# The method returns a list of images. Each row of the list is a sub-list of [filename,seed]
results = t2i.prompt2png(prompt = "an astronaut riding a horse",
outdir = "./outputs/samples",
iterations = 3)
for row in results:
print(f'filename={row[0]}')
print(f'seed ={row[1]}')
# Same thing, but using an initial image.
results = t2i.prompt2png(prompt = "an astronaut riding a horse",
outdir = "./outputs/,
iterations = 3,
init_img = "./sketches/horse+rider.png")
for row in results:
print(f'filename={row[0]}')
print(f'seed ={row[1]}')
# Same thing, but we return a series of Image objects, which lets you manipulate them,
# combine them, and save them under arbitrary names
results = t2i.prompt2image(prompt = "an astronaut riding a horse"
outdir = "./outputs/")
for row in results:
im = row[0]
seed = row[1]
im.save(f'./outputs/samples/an_astronaut_riding_a_horse-{seed}.png')
im.thumbnail(100,100).save('./outputs/samples/astronaut_thumb.jpg')
Note that the old txt2img() and img2img() calls are deprecated but will
still work.
"""
class T2I:
"""T2I class
@ -141,7 +159,30 @@ The vast majority of these arguments default to reasonable values.
self.seed = self._new_seed()
else:
self.seed = seed
transformers.logging.set_verbosity_error()
def prompt2png(self,prompt,outdir,**kwargs):
'''
Takes a prompt and an output directory, writes out the requested number
of PNG files, and returns an array of [[filename,seed],[filename,seed]...]
Optional named arguments are the same as those passed to T2I and prompt2image()
'''
results = self.prompt2image(prompt,**kwargs)
pngwriter = PngWriter(outdir,prompt,kwargs.get('batch_size',self.batch_size))
for r in results:
metadata_str = f'prompt2png("{prompt}" {kwargs} seed={r[1]}' # gets written into the PNG
pngwriter.write_image(r[0],r[1])
return pngwriter.files_written
def txt2img(self,prompt,**kwargs):
outdir = kwargs.get('outdir','outputs/img-samples')
return self.prompt2png(prompt,outdir,**kwargs)
def img2img(self,prompt,**kwargs):
outdir = kwargs.get('outdir','outputs/img-samples')
assert 'init_img' in kwargs,'call to img2img() must include the init_img argument'
return self.prompt2png(prompt,outdir,**kwargs)
def prompt2image(self,
# these are common
prompt,
@ -161,7 +202,34 @@ The vast majority of these arguments default to reasonable values.
strength=None,
variants=None,
**args): # eat up additional cruft
'''ldm.prompt2image() is the common entry point for txt2img() and img2img()'''
'''
ldm.prompt2image() is the common entry point for txt2img() and img2img()
It takes the following arguments:
prompt // prompt string (no default)
iterations // iterations (1); image count=iterations x batch_size
batch_size // images per iteration (1)
steps // refinement steps per iteration
seed // seed for random number generator
width // width of image, in multiples of 64 (512)
height // height of image, in multiples of 64 (512)
cfg_scale // how strongly the prompt influences the image (7.5) (must be >1)
init_img // path to an initial image - its dimensions override width and height
strength // strength for noising/unnoising init_img. 0.0 preserves image exactly, 1.0 replaces it completely
ddim_eta // image randomness (eta=0.0 means the same seed always produces the same image)
variants // if >0, the 1st generated image will be passed back to img2img to generate the requested number of variants
callback // a function or method that will be called each time an image is generated
To use the callback, define a function of method that receives two arguments, an Image object
and the seed. You can then do whatever you like with the image, including converting it to
different formats and manipulating it. For example:
def process_image(image,seed):
image.save(f{'images/seed.png'})
The callback used by the prompt2png() can be found in ldm/dream_util.py. It contains code
to create the requested output directory, select a unique informative name for each image, and
write the prompt into the PNG metadata.
'''
steps = steps or self.steps
seed = seed or self.seed
width = width or self.width
@ -175,6 +243,12 @@ The vast majority of these arguments default to reasonable values.
model = self.load_model() # will instantiate the model or return it from cache
assert cfg_scale>1.0, "CFG_Scale (-C) must be >1.0"
assert 0. <= strength <= 1., 'can only work with strength in [0.0, 1.0]'
w = int(width/64) * 64
h = int(height/64) * 64
if h != height or w != width:
print(f'Height and width must be multiples of 64. Resizing to {h}x{w}')
height = h
width = w
data = [batch_size * [prompt]]
scope = autocast if self.precision=="autocast" else nullcontext
@ -303,8 +377,7 @@ The vast majority of these arguments default to reasonable values.
sampler.make_schedule(ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False)
t_enc = int(strength * steps)
print(f"target t_enc is {t_enc} steps")
# print(f"target t_enc is {t_enc} steps")
images = list()
try:
@ -408,8 +481,8 @@ The vast majority of these arguments default to reasonable values.
def _load_model_from_config(self, config, ckpt):
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
# if "global_step" in pl_sd:
# print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)

View File

@ -153,25 +153,58 @@ def main_loop(t2i,outdir,parser,log,infile):
continue
normalized_prompt = PromptFormatter(t2i,opt).normalize_prompt()
variants = None
try:
file_writer = PngWriter(outdir,opt,normalized_prompt)
file_writer = PngWriter(outdir,normalized_prompt,opt.batch_size)
callback = file_writer.write_image
t2i.prompt2image(image_callback=callback,
**vars(opt))
results = file_writer.files_written
if None not in (opt.variants,opt.init_img):
variants = generate_variants(t2i,outdir,opt,results)
except AssertionError as e:
print(e)
continue
print("Outputs:")
write_log_message(t2i,normalized_prompt,results,log)
if variants is not None:
print('Variants:')
for vr in variants:
write_log_message(t2i,vr[0],vr[1],log)
print("goodbye!")
def generate_variants(t2i,outdir,opt,previous_gens):
variants = []
print(f"Generating {opt.variants} variant(s)...")
newopt = copy.deepcopy(opt)
newopt.iterations = 1
newopt.variants = None
for r in previous_gens:
newopt.init_img = r[0]
prompt = PromptFormatter(t2i,newopt).normalize_prompt()
print(f"] generating variant for {newopt.init_img}")
for j in range(0,opt.variants):
try:
file_writer = PngWriter(outdir,prompt,newopt.batch_size)
callback = file_writer.write_image
t2i.prompt2image(image_callback=callback,**vars(newopt))
results = file_writer.files_written
variants.append([prompt,results])
except AssertionError as e:
print(e)
continue
print(f'{opt.variants} variants generated')
return variants
def write_log_message(t2i,prompt,results,logfile):
''' logs the name of the output image, its prompt and seed to the terminal, log file, and a Dream text chunk in the PNG metadata '''
''' logs the name of the output image, its prompt and seed to the terminal, log file, and a Dream text chunk in the PNG metadata'''
last_seed = None
img_num = 1
seenit = {}