refactoring complete; please test carefully!

This commit is contained in:
Lincoln Stein 2022-08-25 17:30:08 -04:00
commit 19fa222810
5 changed files with 495 additions and 463 deletions

View File

@ -97,6 +97,27 @@ contributing this code.
![Dream Web Server](static/dream_web_server.png)
## Reading Prompts from a File
You can automate dream.py by providing a text file with the prompts
you want to run, one line per prompt. The text file must be composed
with a text editor (e.g. Notepad) and not a word processor. Each line
should look like what you would type at the dream> prompt:
~~~~
a beautiful sunny day in the park, children playing -n4 -C10
stormy weather on a mountain top, goats grazing -s100
innovative packaging for a squid's dinner -S137038382
~~~~
Then pass this file's name to dream.py when you invoke it:
~~~~
(ldm) ~/stable-diffusion$ python3 scripts/dream.py --from_file="path/to/prompts.txt"
~~~~
>>>>>>> big-refactoring
## Weighted Prompts
You may weight different sections of the prompt to tell the sampler to attach different levels of

View File

@ -2,6 +2,7 @@ Feature requests:
1. "gobig" mode - split image into strips, scale up, add detail using
img2img and reassemble with feathering. Issue #66.
See https://github.com/jquesnelle/txt2imghd
2. Port basujindal low VRAM optimizations. Issue #62

195
ldm/dream_util.py Normal file
View File

@ -0,0 +1,195 @@
'''Utilities for dealing with PNG images and their path names'''
import os
import atexit
import re
from math import sqrt,floor,ceil
from PIL import Image,PngImagePlugin
# -------------------image generation utils-----
class PngWriter:
def __init__(self,outdir,prompt=None,batch_size=1):
self.outdir = outdir
self.batch_size = batch_size
self.prompt = prompt
self.filepath = None
self.files_written = []
os.makedirs(outdir, exist_ok=True)
def write_image(self,image,seed):
self.filepath = self.unique_filename(seed,self.filepath) # will increment name in some sensible way
try:
prompt = f'{self.prompt} -S{seed}'
self.save_image_and_prompt_to_png(image,prompt,self.filepath)
except IOError as e:
print(e)
self.files_written.append([self.filepath,seed])
def unique_filename(self,seed,previouspath=None):
revision = 1
if previouspath is None:
# sort reverse alphabetically until we find max+1
dirlist = sorted(os.listdir(self.outdir),reverse=True)
# find the first filename that matches our pattern or return 000000.0.png
filename = next((f for f in dirlist if re.match('^(\d+)\..*\.png',f)),'0000000.0.png')
basecount = int(filename.split('.',1)[0])
basecount += 1
if self.batch_size > 1:
filename = f'{basecount:06}.{seed}.01.png'
else:
filename = f'{basecount:06}.{seed}.png'
return os.path.join(self.outdir,filename)
else:
basename = os.path.basename(previouspath)
x = re.match('^(\d+)\..*\.png',basename)
if not x:
return self.unique_filename(seed,previouspath)
basecount = int(x.groups()[0])
series = 0
finished = False
while not finished:
series += 1
filename = f'{basecount:06}.{seed}.png'
if self.batch_size>1 or os.path.exists(os.path.join(self.outdir,filename)):
filename = f'{basecount:06}.{seed}.{series:02}.png'
finished = not os.path.exists(os.path.join(self.outdir,filename))
return os.path.join(self.outdir,filename)
def save_image_and_prompt_to_png(self,image,prompt,path):
info = PngImagePlugin.PngInfo()
info.add_text("Dream",prompt)
image.save(path,"PNG",pnginfo=info)
def make_grid(self,image_list,rows=None,cols=None):
image_cnt = len(image_list)
if None in (rows,cols):
rows = floor(sqrt(image_cnt)) # try to make it square
cols = ceil(image_cnt/rows)
width = image_list[0].width
height = image_list[0].height
grid_img = Image.new('RGB',(width*cols,height*rows))
for r in range(0,rows):
for c in range (0,cols):
i = r*rows + c
grid_img.paste(image_list[i],(c*width,r*height))
return grid_img
class PromptFormatter():
def __init__(self,t2i,opt):
self.t2i = t2i
self.opt = opt
def normalize_prompt(self):
'''Normalize the prompt and switches'''
t2i = self.t2i
opt = self.opt
switches = list()
switches.append(f'"{opt.prompt}"')
switches.append(f'-s{opt.steps or t2i.steps}')
switches.append(f'-b{opt.batch_size or t2i.batch_size}')
switches.append(f'-W{opt.width or t2i.width}')
switches.append(f'-H{opt.height or t2i.height}')
switches.append(f'-C{opt.cfg_scale or t2i.cfg_scale}')
switches.append(f'-m{t2i.sampler_name}')
if opt.init_img:
switches.append(f'-I{opt.init_img}')
if opt.strength and opt.init_img is not None:
switches.append(f'-f{opt.strength or t2i.strength}')
if t2i.full_precision:
switches.append('-F')
return ' '.join(switches)
# ---------------readline utilities---------------------
try:
import readline
readline_available = True
except:
readline_available = False
class Completer():
def __init__(self,options):
self.options = sorted(options)
return
def complete(self,text,state):
buffer = readline.get_line_buffer()
if text.startswith(('-I','--init_img')):
return self._path_completions(text,state,('.png'))
if buffer.strip().endswith('cd') or text.startswith(('.','/')):
return self._path_completions(text,state,())
response = None
if state == 0:
# This is the first time for this text, so build a match list.
if text:
self.matches = [s
for s in self.options
if s and s.startswith(text)]
else:
self.matches = self.options[:]
# Return the state'th item from the match list,
# if we have that many.
try:
response = self.matches[state]
except IndexError:
response = None
return response
def _path_completions(self,text,state,extensions):
# get the path so far
if text.startswith('-I'):
path = text.replace('-I','',1).lstrip()
elif text.startswith('--init_img='):
path = text.replace('--init_img=','',1).lstrip()
else:
path = text
matches = list()
path = os.path.expanduser(path)
if len(path)==0:
matches.append(text+'./')
else:
dir = os.path.dirname(path)
dir_list = os.listdir(dir)
for n in dir_list:
if n.startswith('.') and len(n)>1:
continue
full_path = os.path.join(dir,n)
if full_path.startswith(path):
if os.path.isdir(full_path):
matches.append(os.path.join(os.path.dirname(text),n)+'/')
elif n.endswith(extensions):
matches.append(os.path.join(os.path.dirname(text),n))
try:
response = matches[state]
except IndexError:
response = None
return response
if readline_available:
readline.set_completer(Completer(['cd','pwd',
'--steps','-s','--seed','-S','--iterations','-n','--batch_size','-b',
'--width','-W','--height','-H','--cfg_scale','-C','--grid','-g',
'--individual','-i','--init_img','-I','--strength','-f','-v','--variants']).complete)
readline.set_completer_delims(" ")
readline.parse_and_bind('tab: complete')
histfile = os.path.join(os.path.expanduser('~'),".dream_history")
try:
readline.read_history_file(histfile)
readline.set_history_length(1000)
except FileNotFoundError:
pass
atexit.register(readline.write_history_file,histfile)

View File

@ -4,53 +4,6 @@
# Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich
# Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors
"""Simplified text to image API for stable diffusion/latent diffusion
Example Usage:
from ldm.simplet2i import T2I
# Create an object with default values
t2i = T2I(outdir = <path> // outputs/txt2img-samples
model = <path> // models/ldm/stable-diffusion-v1/model.ckpt
config = <path> // default="configs/stable-diffusion/v1-inference.yaml
iterations = <integer> // how many times to run the sampling (1)
batch_size = <integer> // how many images to generate per sampling (1)
steps = <integer> // 50
seed = <integer> // current system time
sampler_name= ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms
grid = <boolean> // false
width = <integer> // image width, multiple of 64 (512)
height = <integer> // image height, multiple of 64 (512)
cfg_scale = <float> // unconditional guidance scale (7.5)
fixed_code = <boolean> // False
)
# do the slow model initialization
t2i.load_model()
# Do the fast inference & image generation. Any options passed here
# override the default values assigned during class initialization
# Will call load_model() if the model was not previously loaded.
# The method returns a list of images. Each row of the list is a sub-list of [filename,seed]
results = t2i.txt2img(prompt = "an astronaut riding a horse"
outdir = "./outputs/txt2img-samples)
)
for row in results:
print(f'filename={row[0]}')
print(f'seed ={row[1]}')
# Same thing, but using an initial image.
results = t2i.img2img(prompt = "an astronaut riding a horse"
outdir = "./outputs/img2img-samples"
init_img = "./sketches/horse+rider.png")
for row in results:
print(f'filename={row[0]}')
print(f'seed ={row[1]}')
"""
import torch
import numpy as np
import random
@ -65,8 +18,8 @@ from torchvision.utils import make_grid
from pytorch_lightning import seed_everything
from torch import autocast
from contextlib import contextmanager, nullcontext
import transformers
import time
import math
import re
import traceback
@ -74,12 +27,74 @@ from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.models.diffusion.ksampler import KSampler
from ldm.dream_util import PngWriter
"""Simplified text to image API for stable diffusion/latent diffusion
Example Usage:
from ldm.simplet2i import T2I
# Create an object with default values
t2i = T2I(model = <path> // models/ldm/stable-diffusion-v1/model.ckpt
config = <path> // configs/stable-diffusion/v1-inference.yaml
iterations = <integer> // how many times to run the sampling (1)
batch_size = <integer> // how many images to generate per sampling (1)
steps = <integer> // 50
seed = <integer> // current system time
sampler_name= ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms
grid = <boolean> // false
width = <integer> // image width, multiple of 64 (512)
height = <integer> // image height, multiple of 64 (512)
cfg_scale = <float> // unconditional guidance scale (7.5)
)
# do the slow model initialization
t2i.load_model()
# Do the fast inference & image generation. Any options passed here
# override the default values assigned during class initialization
# Will call load_model() if the model was not previously loaded and so
# may be slow at first.
# The method returns a list of images. Each row of the list is a sub-list of [filename,seed]
results = t2i.prompt2png(prompt = "an astronaut riding a horse",
outdir = "./outputs/samples",
iterations = 3)
for row in results:
print(f'filename={row[0]}')
print(f'seed ={row[1]}')
# Same thing, but using an initial image.
results = t2i.prompt2png(prompt = "an astronaut riding a horse",
outdir = "./outputs/,
iterations = 3,
init_img = "./sketches/horse+rider.png")
for row in results:
print(f'filename={row[0]}')
print(f'seed ={row[1]}')
# Same thing, but we return a series of Image objects, which lets you manipulate them,
# combine them, and save them under arbitrary names
results = t2i.prompt2image(prompt = "an astronaut riding a horse"
outdir = "./outputs/")
for row in results:
im = row[0]
seed = row[1]
im.save(f'./outputs/samples/an_astronaut_riding_a_horse-{seed}.png')
im.thumbnail(100,100).save('./outputs/samples/astronaut_thumb.jpg')
Note that the old txt2img() and img2img() calls are deprecated but will
still work.
"""
class T2I:
"""T2I class
Attributes
----------
outdir
model
config
iterations
@ -87,12 +102,9 @@ class T2I:
steps
seed
sampler_name
grid
individual
width
height
cfg_scale
fixed_code
latent_channels
downsampling_factor
precision
@ -102,23 +114,19 @@ class T2I:
The vast majority of these arguments default to reasonable values.
"""
def __init__(self,
outdir="outputs/txt2img-samples",
batch_size=1,
iterations = 1,
width=512,
height=512,
grid=False,
individual=None, # redundant
steps=50,
seed=None,
cfg_scale=7.5,
weights="models/ldm/stable-diffusion-v1/model.ckpt",
config = "configs/stable-diffusion/v1-inference.yaml",
width=512,
height=512,
sampler_name="klms",
latent_channels=4,
downsampling_factor=8,
ddim_eta=0.0, # deterministic
fixed_code=False,
precision='autocast',
full_precision=False,
strength=0.75, # default in scripts/img2img.py
@ -126,18 +134,15 @@ The vast majority of these arguments default to reasonable values.
latent_diffusion_weights=False, # just to keep track of this parameter when regenerating prompt
device='cuda'
):
self.outdir = outdir
self.batch_size = batch_size
self.iterations = iterations
self.width = width
self.height = height
self.grid = grid
self.steps = steps
self.cfg_scale = cfg_scale
self.weights = weights
self.config = config
self.sampler_name = sampler_name
self.fixed_code = fixed_code
self.latent_channels = latent_channels
self.downsampling_factor = downsampling_factor
self.ddim_eta = ddim_eta
@ -153,17 +158,77 @@ The vast majority of these arguments default to reasonable values.
self.seed = self._new_seed()
else:
self.seed = seed
transformers.logging.set_verbosity_error()
@torch.no_grad()
def txt2img(self,prompt,outdir=None,batch_size=None,iterations=None,
steps=None,seed=None,grid=None,individual=None,width=None,height=None,
cfg_scale=None,ddim_eta=None,strength=None,embedding_path=None,init_img=None,
skip_normalize=False,variants=None): # note the "variants" option is an unused hack caused by how options are passed
"""
Generate an image from the prompt, writing iteration images into the outdir
The output is a list of lists in the format: [[filename1,seed1], [filename2,seed2],...]
"""
outdir = outdir or self.outdir
def prompt2png(self,prompt,outdir,**kwargs):
'''
Takes a prompt and an output directory, writes out the requested number
of PNG files, and returns an array of [[filename,seed],[filename,seed]...]
Optional named arguments are the same as those passed to T2I and prompt2image()
'''
results = self.prompt2image(prompt,**kwargs)
pngwriter = PngWriter(outdir,prompt,kwargs.get('batch_size',self.batch_size))
for r in results:
metadata_str = f'prompt2png("{prompt}" {kwargs} seed={r[1]}' # gets written into the PNG
pngwriter.write_image(r[0],r[1])
return pngwriter.files_written
def txt2img(self,prompt,**kwargs):
outdir = kwargs.get('outdir','outputs/img-samples')
return self.prompt2png(prompt,outdir,**kwargs)
def img2img(self,prompt,**kwargs):
outdir = kwargs.get('outdir','outputs/img-samples')
assert 'init_img' in kwargs,'call to img2img() must include the init_img argument'
return self.prompt2png(prompt,outdir,**kwargs)
def prompt2image(self,
# these are common
prompt,
batch_size=None,
iterations=None,
steps=None,
seed=None,
cfg_scale=None,
ddim_eta=None,
skip_normalize=False,
image_callback=None,
# these are specific to txt2img
width=None,
height=None,
# these are specific to img2img
init_img=None,
strength=None,
variants=None,
**args): # eat up additional cruft
'''
ldm.prompt2image() is the common entry point for txt2img() and img2img()
It takes the following arguments:
prompt // prompt string (no default)
iterations // iterations (1); image count=iterations x batch_size
batch_size // images per iteration (1)
steps // refinement steps per iteration
seed // seed for random number generator
width // width of image, in multiples of 64 (512)
height // height of image, in multiples of 64 (512)
cfg_scale // how strongly the prompt influences the image (7.5) (must be >1)
init_img // path to an initial image - its dimensions override width and height
strength // strength for noising/unnoising init_img. 0.0 preserves image exactly, 1.0 replaces it completely
ddim_eta // image randomness (eta=0.0 means the same seed always produces the same image)
variants // if >0, the 1st generated image will be passed back to img2img to generate the requested number of variants
callback // a function or method that will be called each time an image is generated
To use the callback, define a function of method that receives two arguments, an Image object
and the seed. You can then do whatever you like with the image, including converting it to
different formats and manipulating it. For example:
def process_image(image,seed):
image.save(f{'images/seed.png'})
The callback used by the prompt2png() can be found in ldm/dream_util.py. It contains code
to create the requested output directory, select a unique informative name for each image, and
write the prompt into the PNG metadata.
'''
steps = steps or self.steps
seed = seed or self.seed
width = width or self.width
@ -172,52 +237,70 @@ The vast majority of these arguments default to reasonable values.
ddim_eta = ddim_eta or self.ddim_eta
batch_size = batch_size or self.batch_size
iterations = iterations or self.iterations
strength = strength or self.strength # not actually used here, but preserved for code refactoring
embedding_path = embedding_path or self.embedding_path
strength = strength or self.strength
model = self.load_model() # will instantiate the model or return it from cache
assert strength<1.0 and strength>=0.0, "strength (-f) must be >=0.0 and <1.0"
assert cfg_scale>1.0, "CFG_Scale (-C) must be >1.0"
assert 0. <= strength <= 1., 'can only work with strength in [0.0, 1.0]'
w = int(width/64) * 64
h = int(height/64) * 64
if h != height or w != width:
print(f'Height and width must be multiples of 64. Resizing to {h}x{w}')
height = h
width = w
# grid and individual are mutually exclusive, with individual taking priority.
# not necessary, but needed for compatability with dream bot
if (grid is None):
grid = self.grid
if individual:
grid = False
data = [batch_size * [prompt]]
scope = autocast if self.precision=="autocast" else nullcontext
# make directories and establish names for the output files
os.makedirs(outdir, exist_ok=True)
tic = time.time()
if init_img:
assert os.path.exists(init_img),f'{init_img}: File not found'
results = self._img2img(prompt,
data=data,precision_scope=scope,
batch_size=batch_size,iterations=iterations,
steps=steps,seed=seed,cfg_scale=cfg_scale,ddim_eta=ddim_eta,
skip_normalize=skip_normalize,
init_img=init_img,strength=strength,variants=variants,
callback=image_callback)
else:
results = self._txt2img(prompt,
data=data,precision_scope=scope,
batch_size=batch_size,iterations=iterations,
steps=steps,seed=seed,cfg_scale=cfg_scale,ddim_eta=ddim_eta,
skip_normalize=skip_normalize,
width=width,height=height,
callback=image_callback)
toc = time.time()
print(f'{len(results)} images generated in',"%4.2fs"% (toc-tic))
return results
@torch.no_grad()
def _txt2img(self,prompt,
data,precision_scope,
batch_size,iterations,
steps,seed,cfg_scale,ddim_eta,
skip_normalize,
width,height,
callback): # the callback is called each time a new Image is generated
"""
Generate an image from the prompt, writing iteration images into the outdir
The output is a list of lists in the format: [[image1,seed1], [image2,seed2],...]
"""
start_code = None
if self.fixed_code:
start_code = torch.randn([batch_size,
self.latent_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
device=self.device)
precision_scope = autocast if self.precision=="autocast" else nullcontext
sampler = self.sampler
images = list()
seeds = list()
filename = None
image_count = 0
tic = time.time()
# Gawd. Too many levels of indent here. Need to refactor into smaller routines!
try:
with precision_scope(self.device.type), model.ema_scope():
with precision_scope(self.device.type), self.model.ema_scope():
all_samples = list()
for n in trange(iterations, desc="Sampling"):
seed_everything(seed)
for prompts in tqdm(data, desc="data", dynamic_ncols=True):
uc = None
if cfg_scale != 1.0:
uc = model.get_learned_conditioning(batch_size * [""])
uc = self.model.get_learned_conditioning(batch_size * [""])
if isinstance(prompts, tuple):
prompts = list(prompts)
@ -233,138 +316,78 @@ The vast majority of these arguments default to reasonable values.
weight = weights[i]
if not skip_normalize:
weight = weight / totalWeight
c = torch.add(c,model.get_learned_conditioning(subprompts[i]), alpha=weight)
c = torch.add(c,self.model.get_learned_conditioning(subprompts[i]), alpha=weight)
else: # just standard 1 prompt
c = model.get_learned_conditioning(prompts)
c = self.model.get_learned_conditioning(prompts)
shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor]
samples_ddim, _ = sampler.sample(S=steps,
conditioning=c,
batch_size=batch_size,
shape=shape,
verbose=False,
unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc,
eta=ddim_eta,
x_T=start_code)
conditioning=c,
batch_size=batch_size,
shape=shape,
verbose=False,
unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc,
eta=ddim_eta)
x_samples_ddim = model.decode_first_stage(samples_ddim)
x_samples_ddim = self.model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
if not grid:
for x_sample in x_samples_ddim:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
filename = self._unique_filename(outdir,previousname=filename,
seed=seed,isbatch=(batch_size>1))
assert not os.path.exists(filename)
Image.fromarray(x_sample.astype(np.uint8)).save(filename)
images.append([filename,seed])
else:
all_samples.append(x_samples_ddim)
seeds.append(seed)
image_count += 1
for x_sample in x_samples_ddim:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
image = Image.fromarray(x_sample.astype(np.uint8))
images.append([image,seed])
if callback is not None:
callback(image,seed)
seed = self._new_seed()
if grid:
images = self._make_grid(samples=all_samples,
seeds=seeds,
batch_size=batch_size,
iterations=iterations,
outdir=outdir)
except KeyboardInterrupt:
print('*interrupted*')
print('Partial results will be returned; if --grid was requested, nothing will be returned.')
except RuntimeError as e:
print(str(e))
toc = time.time()
print(f'{image_count} images generated in',"%4.2fs"% (toc-tic))
return images
# There is lots of shared code between this and txt2img and should be refactored.
@torch.no_grad()
def img2img(self,prompt,outdir=None,init_img=None,batch_size=None,iterations=None,
steps=None,seed=None,grid=None,individual=None,width=None,height=None,
cfg_scale=None,ddim_eta=None,strength=None,embedding_path=None,
skip_normalize=False,variants=None): # note the "variants" option is an unused hack caused by how options are passed
def _img2img(self,prompt,
data,precision_scope,
batch_size,iterations,
steps,seed,cfg_scale,ddim_eta,
skip_normalize,
init_img,strength,variants,
callback):
"""
Generate an image from the prompt and the initial image, writing iteration images into the outdir
The output is a list of lists in the format: [[filename1,seed1], [filename2,seed2],...]
The output is a list of lists in the format: [[image,seed1], [image,seed2],...]
"""
outdir = outdir or self.outdir
steps = steps or self.steps
seed = seed or self.seed
cfg_scale = cfg_scale or self.cfg_scale
ddim_eta = ddim_eta or self.ddim_eta
batch_size = batch_size or self.batch_size
iterations = iterations or self.iterations
strength = strength or self.strength
embedding_path = embedding_path or self.embedding_path
assert strength<1.0 and strength>=0.0, "strength (-f) must be >=0.0 and <1.0"
assert cfg_scale>1.0, "CFG_Scale (-C) must be >1.0"
if init_img is None:
print("no init_img provided!")
return []
model = self.load_model() # will instantiate the model or return it from cache
precision_scope = autocast if self.precision=="autocast" else nullcontext
# grid and individual are mutually exclusive, with individual taking priority.
# not necessary, but needed for compatability with dream bot
if (grid is None):
grid = self.grid
if individual:
grid = False
data = [batch_size * [prompt]]
# PLMS sampler not supported yet, so ignore previous sampler
if self.sampler_name!='ddim':
print(f"sampler '{self.sampler_name}' is not yet supported. Using DDM sampler")
sampler = DDIMSampler(model, device=self.device)
sampler = DDIMSampler(self.model, device=self.device)
else:
sampler = self.sampler
# make directories and establish names for the output files
os.makedirs(outdir, exist_ok=True)
assert os.path.isfile(init_img)
init_image = self._load_img(init_img).to(self.device)
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
with precision_scope(self.device.type):
init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space
init_latent = self.model.get_first_stage_encoding(self.model.encode_first_stage(init_image)) # move to latent space
sampler.make_schedule(ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False)
try:
assert 0. <= strength <= 1., 'can only work with strength in [0.0, 1.0]'
except AssertionError:
print(f"strength must be between 0.0 and 1.0, but received value {strength}")
return []
t_enc = int(strength * steps)
print(f"target t_enc is {t_enc} steps")
# print(f"target t_enc is {t_enc} steps")
images = list()
seeds = list()
filename = None
image_count = 0 # actual number of iterations performed
tic = time.time()
# Gawd. Too many levels of indent here. Need to refactor into smaller routines!
try:
with precision_scope(self.device.type), model.ema_scope():
with precision_scope(self.device.type), self.model.ema_scope():
all_samples = list()
for n in trange(iterations, desc="Sampling"):
seed_everything(seed)
for prompts in tqdm(data, desc="data", dynamic_ncols=True):
uc = None
if cfg_scale != 1.0:
uc = model.get_learned_conditioning(batch_size * [""])
uc = self.model.get_learned_conditioning(batch_size * [""])
if isinstance(prompts, tuple):
prompts = list(prompts)
@ -380,9 +403,9 @@ The vast majority of these arguments default to reasonable values.
weight = weights[i]
if not skip_normalize:
weight = weight / totalWeight
c = torch.add(c,model.get_learned_conditioning(subprompts[i]), alpha=weight)
c = torch.add(c,self.model.get_learned_conditioning(subprompts[i]), alpha=weight)
else: # just standard 1 prompt
c = model.get_learned_conditioning(prompts)
c = self.model.get_learned_conditioning(prompts)
# encode (scaled latent)
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(self.device))
@ -390,28 +413,16 @@ The vast majority of these arguments default to reasonable values.
samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc,)
x_samples = model.decode_first_stage(samples)
x_samples = self.model.decode_first_stage(samples)
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
if not grid:
for x_sample in x_samples:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
filename = self._unique_filename(outdir,previousname=filename,
seed=seed,isbatch=(batch_size>1))
assert not os.path.exists(filename)
Image.fromarray(x_sample.astype(np.uint8)).save(filename)
images.append([filename,seed])
else:
all_samples.append(x_samples)
seeds.append(seed)
image_count +=1
for x_sample in x_samples:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
image = Image.fromarray(x_sample.astype(np.uint8))
images.append([image,seed])
if callback is not None:
callback(image,seed)
seed = self._new_seed()
if grid:
images = self._make_grid(samples=all_samples,
seeds=seeds,
batch_size=batch_size,
iterations=iterations,
outdir=outdir)
except KeyboardInterrupt:
print('*interrupted*')
@ -419,26 +430,6 @@ The vast majority of these arguments default to reasonable values.
except RuntimeError as e:
print("Oops! A runtime error has occurred. If this is unexpected, please copy-and-paste this stack trace and post it as an Issue to http://github.com/lstein/stable-diffusion")
traceback.print_exc()
toc = time.time()
print(f'{image_count} images generated in',"%4.2fs"% (toc-tic))
return images
def _make_grid(self,samples,seeds,batch_size,iterations,outdir):
images = list()
n_rows = batch_size if batch_size>1 else int(math.sqrt(batch_size * iterations))
# save as grid
grid = torch.stack(samples, 0)
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
grid = make_grid(grid, nrow=n_rows)
# to image
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
filename = self._unique_filename(outdir,seed=seeds[0],grid_count=batch_size*iterations)
Image.fromarray(grid.astype(np.uint8)).save(filename)
for s in seeds:
images.append([filename,s])
return images
def _new_seed(self):
@ -489,8 +480,8 @@ The vast majority of these arguments default to reasonable values.
def _load_model_from_config(self, config, ckpt):
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
# if "global_step" in pl_sd:
# print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
@ -514,43 +505,6 @@ The vast majority of these arguments default to reasonable values.
image = torch.from_numpy(image)
return 2.*image - 1.
def _unique_filename(self,outdir,previousname=None,seed=0,isbatch=False,grid_count=None):
revision = 1
if previousname is None:
# sort reverse alphabetically until we find max+1
dirlist = sorted(os.listdir(outdir),reverse=True)
# find the first filename that matches our pattern or return 000000.0.png
filename = next((f for f in dirlist if re.match('^(\d+)\..*\.png',f)),'0000000.0.png')
basecount = int(filename.split('.',1)[0])
basecount += 1
if grid_count is not None:
grid_label = f'grid#1-{grid_count}'
filename = f'{basecount:06}.{seed}.{grid_label}.png'
elif isbatch:
filename = f'{basecount:06}.{seed}.01.png'
else:
filename = f'{basecount:06}.{seed}.png'
return os.path.join(outdir,filename)
else:
previousname = os.path.basename(previousname)
x = re.match('^(\d+)\..*\.png',previousname)
if not x:
return self._unique_filename(outdir,previousname,seed)
basecount = int(x.groups()[0])
series = 0
finished = False
while not finished:
series += 1
filename = f'{basecount:06}.{seed}.png'
if isbatch or os.path.exists(os.path.join(outdir,filename)):
filename = f'{basecount:06}.{seed}.{series:02}.png'
finished = not os.path.exists(os.path.join(outdir,filename))
return os.path.join(outdir,filename)
def _split_weighted_subprompts(text):
"""
grabs all text up to the first occurrence of ':'

View File

@ -3,18 +3,10 @@
import argparse
import shlex
import atexit
import os
import sys
import copy
from PIL import Image,PngImagePlugin
# readline unavailable on windows systems
try:
import readline
readline_available = True
except:
readline_available = False
from ldm.dream_util import Completer,PngWriter,PromptFormatter
debugging = False
@ -35,10 +27,6 @@ def main():
config = "configs/stable-diffusion/v1-inference.yaml"
weights = "models/ldm/stable-diffusion-v1/model.ckpt"
# command line history will be stored in a file called "~/.dream_history"
if readline_available:
setup_readline()
print("* Initializing, be patient...\n")
sys.path.append('.')
from pytorch_lightning import logging
@ -54,8 +42,6 @@ def main():
# the user input loop
t2i = T2I(width=width,
height=height,
batch_size=opt.batch_size,
outdir=opt.outdir,
sampler_name=opt.sampler_name,
weights=weights,
full_precision=opt.full_precision,
@ -87,13 +73,13 @@ def main():
log_path = os.path.join(opt.outdir,'dream_log.txt')
with open(log_path,'a') as log:
cmd_parser = create_cmd_parser()
main_loop(t2i,cmd_parser,log,infile)
main_loop(t2i,opt.outdir,cmd_parser,log,infile)
log.close()
if infile:
infile.close()
def main_loop(t2i,parser,log,infile):
def main_loop(t2i,outdir,parser,log,infile):
''' prompt/read/execute loop '''
done = False
@ -131,13 +117,13 @@ def main_loop(t2i,parser,log,infile):
if elements[0]=='cd' and len(elements)>1:
if os.path.exists(elements[1]):
print(f"setting image output directory to {elements[1]}")
t2i.outdir=elements[1]
outdir=elements[1]
else:
print(f"directory {elements[1]} does not exist")
continue
if elements[0]=='pwd':
print(f"current output directory is {t2i.outdir}")
print(f"current output directory is {outdir}")
continue
if elements[0].startswith('!dream'): # in case a stored prompt still contains the !dream command
@ -166,117 +152,77 @@ def main_loop(t2i,parser,log,infile):
print("Try again with a prompt!")
continue
normalized_prompt = PromptFormatter(t2i,opt).normalize_prompt()
individual_images = not opt.grid
try:
if opt.init_img is None:
results = t2i.txt2img(**vars(opt))
else:
assert os.path.exists(opt.init_img),f"No file found at {opt.init_img}. On Linux systems, pressing <tab> after -I will autocomplete a list of possible image files."
if None not in (opt.width,opt.height):
print('Warning: width and height options are ignored when modifying an init image')
results = t2i.img2img(**vars(opt))
file_writer = PngWriter(outdir,normalized_prompt,opt.batch_size)
callback = file_writer.write_image if individual_images else None
image_list = t2i.prompt2image(image_callback=callback,**vars(opt))
results = file_writer.files_written if individual_images else image_list
if opt.grid and len(results) > 0:
grid_img = file_writer.make_grid([r[0] for r in results])
filename = file_writer.unique_filename(results[0][1])
seeds = [a[1] for a in results]
results = [[filename,seeds]]
metadata_prompt = f'{normalized_prompt} -S{results[0][1]}'
file_writer.save_image_and_prompt_to_png(grid_img,metadata_prompt,filename)
except AssertionError as e:
print(e)
continue
allVariantResults = []
if opt.variants is not None:
print(f"Generating {opt.variants} variant(s)...")
newopt = copy.deepcopy(opt)
newopt.iterations = 1
newopt.variants = None
for r in results:
newopt.init_img = r[0]
print(f"\t generating variant for {newopt.init_img}")
for j in range(0, opt.variants):
try:
variantResults = t2i.img2img(**vars(newopt))
allVariantResults.append([newopt,variantResults])
except AssertionError as e:
print(e)
continue
print(f"{opt.variants} Variants generated!")
except OSError as e:
print(e)
continue
print("Outputs:")
write_log_message(t2i,opt,results,log)
if allVariantResults:
print("Variant outputs:")
for vr in allVariantResults:
write_log_message(t2i,vr[0],vr[1],log)
write_log_message(t2i,normalized_prompt,results,log)
print("goodbye!")
# variant generation is going to be superseded by a generalized
# "prompt-morph" functionality
# def generate_variants(t2i,outdir,opt,previous_gens):
# variants = []
# print(f"Generating {opt.variants} variant(s)...")
# newopt = copy.deepcopy(opt)
# newopt.iterations = 1
# newopt.variants = None
# for r in previous_gens:
# newopt.init_img = r[0]
# prompt = PromptFormatter(t2i,newopt).normalize_prompt()
# print(f"] generating variant for {newopt.init_img}")
# for j in range(0,opt.variants):
# try:
# file_writer = PngWriter(outdir,prompt,newopt.batch_size)
# callback = file_writer.write_image
# t2i.prompt2image(image_callback=callback,**vars(newopt))
# results = file_writer.files_written
# variants.append([prompt,results])
# except AssertionError as e:
# print(e)
# continue
# print(f'{opt.variants} variants generated')
# return variants
def write_log_message(t2i,opt,results,logfile):
''' logs the name of the output image, its prompt and seed to the terminal, log file, and a Dream text chunk in the PNG metadata '''
switches = _reconstruct_switches(t2i,opt)
prompt_str = ' '.join(switches)
# when multiple images are produced in batch, then we keep track of where each starts
def write_log_message(t2i,prompt,results,logfile):
''' logs the name of the output image, its prompt and seed to the terminal, log file, and a Dream text chunk in the PNG metadata'''
last_seed = None
img_num = 1
batch_size = opt.batch_size or t2i.batch_size
seenit = {}
seeds = [a[1] for a in results]
if batch_size > 1:
seeds = f"(seeds for each batch row: {seeds})"
else:
seeds = f"(seeds for individual images: {seeds})"
for r in results:
seed = r[1]
log_message = (f'{r[0]}: {prompt_str} -S{seed}')
log_message = (f'{r[0]}: {prompt} -S{seed}')
if batch_size > 1:
if seed != last_seed:
img_num = 1
log_message += f' # (batch image {img_num} of {batch_size})'
else:
img_num += 1
log_message += f' # (batch image {img_num} of {batch_size})'
last_seed = seed
print(log_message)
logfile.write(log_message+"\n")
logfile.flush()
if r[0] not in seenit:
seenit[r[0]] = True
try:
if opt.grid:
_write_prompt_to_png(r[0],f'{prompt_str} -g -S{seed} {seeds}')
else:
_write_prompt_to_png(r[0],f'{prompt_str} -S{seed}')
except FileNotFoundError:
print(f"Could not open file '{r[0]}' for reading")
def _reconstruct_switches(t2i,opt):
'''Normalize the prompt and switches'''
switches = list()
switches.append(f'"{opt.prompt}"')
switches.append(f'-s{opt.steps or t2i.steps}')
switches.append(f'-b{opt.batch_size or t2i.batch_size}')
switches.append(f'-W{opt.width or t2i.width}')
switches.append(f'-H{opt.height or t2i.height}')
switches.append(f'-C{opt.cfg_scale or t2i.cfg_scale}')
switches.append(f'-m{t2i.sampler_name}')
if opt.variants:
switches.append(f'-v{opt.variants}')
if opt.init_img:
switches.append(f'-I{opt.init_img}')
if opt.strength and opt.init_img is not None:
switches.append(f'-f{opt.strength or t2i.strength}')
if t2i.full_precision:
switches.append('-F')
return switches
def _write_prompt_to_png(path,prompt):
info = PngImagePlugin.PngInfo()
info.add_text("Dream",prompt)
im = Image.open(path)
im.save(path,"PNG",pnginfo=info)
def create_argv_parser():
parser = argparse.ArgumentParser(description="Parse script's command line args")
parser.add_argument("--laion400m",
@ -297,10 +243,6 @@ def create_argv_parser():
dest='full_precision',
action='store_true',
help="use slower full precision math for calculations")
parser.add_argument('-b','--batch_size',
type=int,
default=1,
help="number of images to produce per iteration (faster, but doesn't generate individual seeds")
parser.add_argument('--sampler','-m',
dest="sampler_name",
choices=['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'],
@ -336,93 +278,12 @@ def create_cmd_parser():
parser.add_argument('-i','--individual',action='store_true',help="generate individual files (default)")
parser.add_argument('-I','--init_img',type=str,help="path to input image for img2img mode (supersedes width and height)")
parser.add_argument('-f','--strength',default=0.75,type=float,help="strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely")
parser.add_argument('-v','--variants',type=int,help="in img2img mode, the first generated image will get passed back to img2img to generate the requested number of variants")
# variants is going to be superseded by a generalized "prompt-morph" function
# parser.add_argument('-v','--variants',type=int,help="in img2img mode, the first generated image will get passed back to img2img to generate the requested number of variants")
parser.add_argument('-x','--skip_normalize',action='store_true',help="skip subprompt weight normalization")
return parser
if readline_available:
def setup_readline():
readline.set_completer(Completer(['cd','pwd',
'--steps','-s','--seed','-S','--iterations','-n','--batch_size','-b',
'--width','-W','--height','-H','--cfg_scale','-C','--grid','-g',
'--individual','-i','--init_img','-I','--strength','-f','-v','--variants']).complete)
readline.set_completer_delims(" ")
readline.parse_and_bind('tab: complete')
load_history()
def load_history():
histfile = os.path.join(os.path.expanduser('~'),".dream_history")
try:
readline.read_history_file(histfile)
readline.set_history_length(1000)
except FileNotFoundError:
pass
atexit.register(readline.write_history_file,histfile)
class Completer():
def __init__(self,options):
self.options = sorted(options)
return
def complete(self,text,state):
buffer = readline.get_line_buffer()
if text.startswith(('-I','--init_img')):
return self._path_completions(text,state,('.png'))
if buffer.strip().endswith('cd') or text.startswith(('.','/')):
return self._path_completions(text,state,())
response = None
if state == 0:
# This is the first time for this text, so build a match list.
if text:
self.matches = [s
for s in self.options
if s and s.startswith(text)]
else:
self.matches = self.options[:]
# Return the state'th item from the match list,
# if we have that many.
try:
response = self.matches[state]
except IndexError:
response = None
return response
def _path_completions(self,text,state,extensions):
# get the path so far
if text.startswith('-I'):
path = text.replace('-I','',1).lstrip()
elif text.startswith('--init_img='):
path = text.replace('--init_img=','',1).lstrip()
else:
path = text
matches = list()
path = os.path.expanduser(path)
if len(path)==0:
matches.append(text+'./')
else:
dir = os.path.dirname(path)
dir_list = os.listdir(dir)
for n in dir_list:
if n.startswith('.') and len(n)>1:
continue
full_path = os.path.join(dir,n)
if full_path.startswith(path):
if os.path.isdir(full_path):
matches.append(os.path.join(os.path.dirname(text),n)+'/')
elif n.endswith(extensions):
matches.append(os.path.join(os.path.dirname(text),n))
try:
response = matches[state]
except IndexError:
response = None
return response
if __name__ == "__main__":
main()