added customized patches and updated the README

This commit is contained in:
Lincoln Stein 2022-08-16 21:34:37 -04:00
parent d39f5b51a8
commit d6124c44a3
8 changed files with 505 additions and 10 deletions

View File

@ -1,4 +1,95 @@
# Stable Diffusion
This is a fork of CompVis/stable-diffusion, the wonderful open source
text-to-image generator.
The original has been modified in several minor ways:
## Simplified API for text to image generation
There is now a simplified API for text to image generation, which
lets you create images from a prompt in just three lines of code:
~~~~
from ldm.simplet2i import T2I
model = T2I()
model.text2image("a unicorn in manhattan")
~~~~
Please see ldm/simplet2i.py for more information.
## Interactive command-line interface similar to the Discord bot
There is now a command-line script, located in scripts/dream.py, which
provides an interactive interface to image generation similar to
the "dream mothership" bot that Stable AI provided on its Discord
server. The advantage of this is that the lengthy model
initialization only happens once. After that image generation is
fast.
Note that this has only been tested in the Linux environment!
(ldm) ~/stable-diffusion$ ./scripts/dream.py
* Initializing, be patient...
Loading model from models/ldm/text2img-large/model.ckpt
LatentDiffusion: Running in eps-prediction mode
DiffusionWrapper has 872.30 M params.
making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 4, 32, 32) = 4096 dimensions.
making attention of type 'vanilla' with 512 in_channels
Loading Bert tokenizer from "models/bert"
setting sampler to plms
* Initialization done! Awaiting your command...
dream> ashley judd riding a camel -n2
Outputs:
outputs/txt2img-samples/00009.png: "ashley judd riding a camel" -n2 -S 416354203
outputs/txt2img-samples/00010.png: "ashley judd riding a camel" -n2 -S 1362479620
Command-line arguments ("./scripts/dream.py -h") allow you to change
various defaults, and select between the mature stable-diffusion
weights (512x512) and the older (256x256) latent diffusion weights
(laion400m).
## No need for internet connectivity when loading the model
My development machine is a GPU node in a high-performance compute
cluster which has no connection to the internet. During model
initialization, stable-diffusion tries to download the Bert tokenizer
model from huggingface.co. This obviously didn't work for me.
Rather than set up a hugging face local hub, I found the most
expedient thing to do was to download the Bert tokenizer in advance,
and patch stable-diffusion to read it from the local disk. The steps
to do this are:
(ldm) ~/stable-diffusion$ mkdir ./models/bert
> python3
>>> from transformers import BertTokenizerFast
>>> model = BertTokenizerFast.from_pretrained("bert-base-uncased")
>>> model.save_pretrained("./models/bert")
(Make sure you are in the stable-diffusion directory when you do
this!)
If you don't like this change, just copy over the file
ldm/modules/encoders/modules.py from the CompVis/stable-diffusion
repository.
## Minor fixes
I added the requirement for torchmetrics to environment.yaml.
## Installation and support
Follow the directions from the original README, which starts below, to
configure the environment and install requirements. For support,
please use this repository's GitHub Issues tracking service.
Author: Lincoln D. Stein <lincoln.stein@gmail.com>
# Original README from CompViz/stable-diffusion
*Stable Diffusion was made possible thanks to a collaboration with [Stability AI](https://stability.ai/) and [Runway](https://runwayml.com/) and builds upon our previous work:*
[**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)<br/>

View File

@ -135,7 +135,7 @@ class DDIMSampler(object):
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps, dynamic_ncols=True)
for i, step in enumerate(iterator):
index = total_steps - i - 1
@ -238,4 +238,4 @@ class DDIMSampler(object):
x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning)
return x_dec
return x_dec

View File

@ -255,7 +255,7 @@ class DDPM(pl.LightningModule):
b = shape[0]
img = torch.randn(shape, device=device)
intermediates = [img]
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps, dynamic_ncols=True):
img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
clip_denoised=self.clip_denoised)
if i % self.log_every_t == 0 or i == self.num_timesteps - 1:

View File

@ -92,7 +92,7 @@ class PLMSSampler(object):
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
print(f'Data shape for PLMS sampling is {size}')
# print(f'Data shape for PLMS sampling is {size}')
samples, intermediates = self.plms_sampling(conditioning, size,
callback=callback,
@ -134,9 +134,9 @@ class PLMSSampler(object):
intermediates = {'x_inter': [img], 'pred_x0': [img]}
time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
print(f"Running PLMS Sampling with {total_steps} timesteps")
# print(f"Running PLMS Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps, dynamic_ncols=True)
old_eps = []
for i, step in enumerate(iterator):

View File

@ -55,7 +55,10 @@ class BERTTokenizer(AbstractEncoder):
def __init__(self, device="cuda", vq_interface=True, max_length=77):
super().__init__()
from transformers import BertTokenizerFast # TODO: add to reuquirements
self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
fn = 'models/bert'
print(f'Loading Bert tokenizer from "{fn}"')
# self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
self.tokenizer = BertTokenizerFast.from_pretrained(fn,local_files_only=True)
self.device = device
self.vq_interface = vq_interface
self.max_length = max_length
@ -231,4 +234,5 @@ class FrozenClipImageEmbedder(nn.Module):
if __name__ == "__main__":
from ldm.util import count_params
model = FrozenCLIPEmbedder()
count_params(model, verbose=True)
count_params(model, verbose=True)

258
ldm/simplet2i.py Normal file
View File

@ -0,0 +1,258 @@
"""Simplified text to image API for stable diffusion/latent diffusion
Example Usage:
from ldm.simplet2i import T2I
# Create an object with default values
t2i = T2I(outdir = <path> // outputs/txt2img-samples
model = <path> // models/ldm/stable-diffusion-v1/model.ckpt
config = <path> // default="configs/stable-diffusion/v1-inference.yaml
batch = <integer> // 1
steps = <integer> // 50
seed = <integer> // current system time
sampler = ['ddim','plms'] // ddim
grid = <boolean> // false
width = <integer> // image width, multiple of 64 (512)
height = <integer> // image height, multiple of 64 (512)
cfg_scale = <float> // unconditional guidance scale (7.5)
fixed_code = <boolean> // False
)
# do the slow model initialization
t2i.load_model()
# Do the fast inference & image generation. Any options passed here
# override the default values assigned during class initialization
# Will call load_model() if the model was not previously loaded.
t2i.txt2img(prompt = <string> // required
// the remaining option arguments override constructur value when present
outdir = <path>
iterations = <integer>
batch = <integer>
steps = <integer>
seed = <integer>
sampler = ['ddim','plms']
grid = <boolean>
width = <integer>
height = <integer>
cfg_scale = <float>
) -> boolean
"""
import torch
import numpy as np
import random
import sys
import os
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm, trange
from itertools import islice
from einops import rearrange
from torchvision.utils import make_grid
from pytorch_lightning import seed_everything
from torch import autocast
from contextlib import contextmanager, nullcontext
from time import time
from math import sqrt
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
class T2I:
"""T2I class
Attributes
----------
outdir
model
config
iterations
batch
steps
seed
sampler
grid
width
height
cfg_scale
fixed_code
latent_channels
downsampling_factor
precision
"""
def __init__(self,
outdir="outputs/txt2img-samples",
batch=1,
iterations = 1,
width=256, # change to 512 for stable diffusion
height=256, # change to 512 for stable diffusion
grid=False,
steps=50,
seed=None,
cfg_scale=7.5,
weights="models/ldm/stable-diffusion-v1/model.ckpt",
config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml",
sampler="plms",
latent_channels=4,
downsampling_factor=8,
ddim_eta=0.0, # deterministic
fixed_code=False,
precision='autocast'
):
self.outdir = outdir
self.batch = batch
self.iterations = iterations
self.width = width
self.height = height
self.grid = grid
self.steps = steps
self.cfg_scale = cfg_scale
self.weights = weights
self.config = config
self.sampler_name = sampler
self.fixed_code = fixed_code
self.latent_channels = latent_channels
self.downsampling_factor = downsampling_factor
self.ddim_eta = ddim_eta
self.precision = precision
self.model = None # empty for now
self.sampler = None
if seed is None:
self.seed = self._new_seed()
else:
self.seed = seed
def txt2img(self,prompt,outdir=None,batch=None,iterations=None,
steps=None,seed=None,grid=None,width=None,height=None,
cfg_scale=None,ddim_eta=None):
""" generate an image from the prompt, writing iteration images into the outdir """
outdir = outdir or self.outdir
steps = steps or self.steps
seed = seed or self.seed
width = width or self.width
height = height or self.height
cfg_scale = cfg_scale or self.cfg_scale
ddim_eta = ddim_eta or self.ddim_eta
batch = batch or self.batch
iterations = iterations or self.iterations
if batch > 1:
iterations = 1
model = self.load_model() # will instantiate the model or return it from cache
if (grid is None):
grid = self.grid
data = [batch * [prompt]]
# make directories and establish names for the output files
os.makedirs(outdir, exist_ok=True)
base_count = len(os.listdir(outdir))-1
start_code = None
if self.fixed_code:
start_code = torch.randn([batch,
self.latent_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
device=self.device)
precision_scope = autocast if self.precision=="autocast" else nullcontext
sampler = self.sampler
images = list()
seeds = list()
with torch.no_grad():
with precision_scope("cuda"):
with model.ema_scope():
all_samples = list()
for n in trange(iterations, desc="Sampling"):
seed_everything(seed)
for prompts in tqdm(data, desc="data", dynamic_ncols=True):
uc = None
if cfg_scale != 1.0:
uc = model.get_learned_conditioning(batch * [""])
if isinstance(prompts, tuple):
prompts = list(prompts)
c = model.get_learned_conditioning(prompts)
shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor]
samples_ddim, _ = sampler.sample(S=steps,
conditioning=c,
batch_size=batch,
shape=shape,
verbose=False,
unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc,
eta=ddim_eta,
x_T=start_code)
x_samples_ddim = model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
for x_sample in x_samples_ddim:
if grid:
all_samples.append(x_samples_ddim)
seeds.append(seed)
else:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
filename = os.path.join(outdir, f"{base_count:05}.png")
Image.fromarray(x_sample.astype(np.uint8)).save(filename)
images.append([filename,seed])
base_count += 1
seed = self._new_seed()
if grid:
n_rows = int(sqrt(batch * iterations))
# save as grid
grid = torch.stack(all_samples, 0)
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
grid = make_grid(grid, nrow=n_rows)
# to image
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
filename = os.path.join(outdir, f"{base_count:05}.png")
Image.fromarray(grid.astype(np.uint8)).save(filename)
for s in seeds:
images.append([filename,s])
return images
def _new_seed(self):
self.seed = random.randrange(0,np.iinfo(np.uint32).max)
return self.seed
def load_model(self):
""" Load and initialize the model from configuration variables passed at object creation time """
if self.model is None:
seed_everything(self.seed)
try:
config = OmegaConf.load(self.config)
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = self._load_model_from_config(config,self.weights)
self.model = model.to(self.device)
except AttributeError:
raise SystemExit
if self.sampler_name=='plms':
print("setting sampler to plms")
self.sampler = PLMSSampler(self.model)
elif self.sampler_name == 'ddim':
print("setting sampler to ddim")
self.sampler = DDIMSampler(self.model)
else:
print(f"unsupported sampler {self.sampler_name}, defaulting to plms")
self.sampler = PLMSSampler(self.model)
return self.model
def _load_model_from_config(self, config, ckpt):
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
model.cuda()
model.eval()
return model

View File

@ -12,8 +12,6 @@ from queue import Queue
from inspect import isfunction
from PIL import Image, ImageDraw, ImageFont
def log_txt_as_img(wh, xc, size=10):
# wh a tuple of (width, height)
# xc a list of captions to plot

144
scripts/dream.py Executable file
View File

@ -0,0 +1,144 @@
#!/usr/bin/env python
import readline
import argparse
import shlex
import atexit
from os import path
def main():
arg_parser = create_argv_parser()
opt = arg_parser.parse_args()
if opt.laion400m:
# defaults suitable to the older latent diffusion weights
width = 256
height = 256
config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml"
weights = "models/ldm/text2img-large/model.ckpt"
else:
# some defaults suitable for stable diffusion weights
width = 512
height = 512
config = "configs/stable-diffusion/v1-inference.yaml"
weights = "models/ldm/stable-diffusion-v1/model.ckpt"
# command line history will be stored in a file called "~/.dream_history"
load_history()
print("* Initializing, be patient...\n")
from pytorch_lightning import logging
from ldm.simplet2i import T2I
# creating a simple text2image object with a handful of
# defaults passed on the command line.
# additional parameters will be added (or overriden) during
# the user input loop
t2i = T2I(width=width,
height=height,
batch=opt.batch,
outdir=opt.outdir,
sampler=opt.sampler,
weights=weights,
config=config)
# gets rid of annoying messages about random seed
logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)
# preload the model
t2i.load_model()
print("\n* Initialization done! Awaiting your command...")
log_path = path.join(opt.outdir,"dream_log.txt")
with open(log_path,'a') as log:
cmd_parser = create_cmd_parser()
main_loop(t2i,cmd_parser,log)
log.close()
def main_loop(t2i,parser,log):
while True:
try:
command = input("dream> ")
except EOFError:
print("goodbye!")
break
elements = shlex.split(command)
switches = ['']
switches_started = False
for el in elements:
if el[0]=='-' and not switches_started:
switches_started = True
if switches_started:
switches.append(el)
else:
switches[0] += el
switches[0] += ' '
switches[0] = switches[0][:len(switches[0])-1]
try:
opt = parser.parse_args(switches)
except SystemExit:
parser.print_help()
pass
results = t2i.txt2img(**vars(opt))
print("Outputs:")
for r in results:
log_message = " ".join([' ',str(r[0])+':',
f'"{switches[0]}"',
*switches[1:],f'-S {r[1]}'])
print(log_message)
log.write(log_message+"\n")
log.flush()
def create_argv_parser():
parser = argparse.ArgumentParser(description="Parse script's command line args")
parser.add_argument("--laion400m",
"--latent_diffusion",
"-l",
dest='laion400m',
action='store_true',
help="fallback to the latent diffusion (LAION4400M) weights and config")
parser.add_argument('-n','--iterations',
type=int,
default=1,
help="number of images to produce per sampling (overrides -n<iterations>, faster but doesn't produce individual seeds)")
parser.add_argument('-b','--batch',
type=int,
default=1,
help="number of images to produce per sampling (currently broken")
parser.add_argument('--sampler',
choices=['plms','ddim'],
default='plms',
help="which sampler to use")
parser.add_argument('-o',
'--outdir',
type=str,
default="outputs/txt2img-samples",
help="directory in which to place generated images and a log of prompts and seeds")
return parser
def create_cmd_parser():
parser = argparse.ArgumentParser(description="Parse terminal input in a discord 'dreambot' fashion")
parser.add_argument('prompt')
parser.add_argument('-s','--steps',type=int,help="number of steps")
parser.add_argument('-S','--seed',type=int,help="image seed")
parser.add_argument('-n','--iterations',type=int,default=1,help="number of samplings to perform")
parser.add_argument('-b','--batch',type=int,default=1,help="number of images to produce per sampling (currently broken)")
parser.add_argument('-W','--width',type=int,help="image width, multiple of 64")
parser.add_argument('-H','--height',type=int,help="image height, multiple of 64")
parser.add_argument('-C','--cfg_scale',type=float,help="prompt configuration scale (7.5)")
parser.add_argument('-g','--grid',action='store_true',help="generate a grid")
return parser
def load_history():
histfile = path.join(path.expanduser('~'),".dream_history")
try:
readline.read_history_file(histfile)
readline.set_history_length(1000)
except FileNotFoundError:
pass
atexit.register(readline.write_history_file,histfile)
if __name__ == "__main__":
main()