Merge branch 'hwharrison-main' into main

This enables k_lms sampling (now the default)`:wq
This commit is contained in:
Lincoln Stein 2022-08-21 20:17:22 -04:00
commit 3c74dd41c4
7 changed files with 161 additions and 21 deletions

View File

@ -73,6 +73,14 @@ The --init_img (-I) option gives the path to the seed picture. --strength (-f) c
the original will be modified, ranging from 0.0 (keep the original intact), to 1.0 (ignore the original
completely). The default is 0.75, and ranges from 0.25-0.75 give interesting results.
## Changes
- v1.01 (21 August 2022)
* added k_lms sampling **Please run "conda update -f environment.yaml" to load the k_lms dependencies**
* use half precision arithmetic by default, resulting in faster execution and lower memory requirements
Pass argument --full_precision to dream.py to get slower but more accurate image generation
## Installation
### Linux/Mac

View File

@ -24,6 +24,8 @@ dependencies:
- transformers==4.19.2
- torchmetrics==0.6.0
- kornia==0.6
- -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
- accelerate==0.12.0
- -e git+https://github.com/openai/CLIP.git@main#egg=clip
- -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
- -e git+https://github.com/lstein/k-diffusion.git@master#egg=k-diffusion
- -e .

View File

@ -0,0 +1,74 @@
'''wrapper around part of Karen Crownson's k-duffsion library, making it call compatible with other Samplers'''
import k_diffusion as K
import torch
import torch.nn as nn
import accelerate
class CFGDenoiser(nn.Module):
def __init__(self, model):
super().__init__()
self.inner_model = model
def forward(self, x, sigma, uncond, cond, cond_scale):
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
return uncond + (cond - uncond) * cond_scale
class KSampler(object):
def __init__(self,model,schedule="lms", **kwargs):
super().__init__()
self.model = K.external.CompVisDenoiser(model)
self.accelerator = accelerate.Accelerator()
self.device = self.accelerator.device
self.schedule = schedule
def forward(self, x, sigma, uncond, cond, cond_scale):
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
return uncond + (cond - uncond) * cond_scale
# most of these arguments are ignored and are only present for compatibility with
# other samples
@torch.no_grad()
def sample(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs
):
sigmas = self.model.get_sigmas(S)
if x_T:
x = x_T
else:
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0] # for GPU draw
model_wrap_cfg = CFGDenoiser(self.model)
extra_args = {'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale}
return (K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not self.accelerator.is_main_process),
None)
def gather(samples_ddim):
return self.accelerator.gather(samples_ddim)

View File

@ -11,7 +11,7 @@ t2i = T2I(outdir = <path> // outputs/txt2img-samples
batch_size = <integer> // how many images to generate per sampling (1)
steps = <integer> // 50
seed = <integer> // current system time
sampler = ['ddim','plms'] // ddim
sampler = ['ddim','plms','klms'] // klms
grid = <boolean> // false
width = <integer> // image width, multiple of 64 (512)
height = <integer> // image height, multiple of 64 (512)
@ -62,8 +62,9 @@ import time
import math
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.models.diffusion.ksampler import KSampler
class T2I:
"""T2I class
@ -101,12 +102,13 @@ class T2I:
cfg_scale=7.5,
weights="models/ldm/stable-diffusion-v1/model.ckpt",
config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml",
sampler="plms",
sampler="klms",
latent_channels=4,
downsampling_factor=8,
ddim_eta=0.0, # deterministic
fixed_code=False,
precision='autocast',
full_precision=False,
strength=0.75 # default in scripts/img2img.py
):
self.outdir = outdir
@ -125,6 +127,7 @@ class T2I:
self.downsampling_factor = downsampling_factor
self.ddim_eta = ddim_eta
self.precision = precision
self.full_precision = full_precision
self.strength = strength
self.model = None # empty for now
self.sampler = None
@ -387,6 +390,9 @@ class T2I:
elif self.sampler_name == 'ddim':
print("setting sampler to ddim")
self.sampler = DDIMSampler(self.model)
elif self.sampler_name == 'klms':
print("setting sampler to klms")
self.sampler = KSampler(self.model,'lms')
else:
print(f"unsupported sampler {self.sampler_name}, defaulting to plms")
self.sampler = PLMSSampler(self.model)
@ -403,7 +409,11 @@ class T2I:
m, u = model.load_state_dict(sd, strict=False)
model.cuda()
model.eval()
model.half()
if self.full_precision:
print('Using slower but more accurate full-precision math (--full_precision)')
else:
print('Using half precision math. Call with --full_precision to use full precision')
model.half()
return model
def _load_img(self,path):

View File

@ -1,3 +1,4 @@
#!/usr/bin/env python3
import argparse
import shlex
import atexit
@ -49,6 +50,7 @@ def main():
outdir=opt.outdir,
sampler=opt.sampler,
weights=weights,
full_precision=opt.full_precision,
config=config)
# make sure the output directory exists
@ -165,14 +167,18 @@ def create_argv_parser():
type=int,
default=1,
help="number of images to generate")
parser.add_argument('-F','--full_precision',
dest='full_precision',
action='store_true',
help="use slower full precision math for calculations")
parser.add_argument('-b','--batch_size',
type=int,
default=1,
help="number of images to produce per iteration (currently not working properly - producing too many images)")
parser.add_argument('--sampler',
choices=['plms','ddim'],
default='plms',
help="which sampler to use")
choices=['plms','ddim', 'klms'],
default='klms',
help="which sampler to use (klms)")
parser.add_argument('-o',
'--outdir',
type=str,

View File

@ -1,3 +1,4 @@
#!/usr/bin/env python3
# Before running stable-diffusion on an internet-isolated machine,
# run this script from one with internet connectivity. The
# two machines must share a common .cache directory.

View File

@ -12,6 +12,10 @@ from pytorch_lightning import seed_everything
from torch import autocast
from contextlib import contextmanager, nullcontext
import accelerate
import k_diffusion as K
import torch.nn as nn
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
@ -80,6 +84,11 @@ def main():
action='store_true',
help="use plms sampling",
)
parser.add_argument(
"--klms",
action='store_true',
help="use klms sampling",
)
parser.add_argument(
"--laion400m",
action='store_true',
@ -190,6 +199,22 @@ def main():
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)
#for klms
model_wrap = K.external.CompVisDenoiser(model)
accelerator = accelerate.Accelerator()
device = accelerator.device
class CFGDenoiser(nn.Module):
def __init__(self, model):
super().__init__()
self.inner_model = model
def forward(self, x, sigma, uncond, cond, cond_scale):
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
return uncond + (cond - uncond) * cond_scale
if opt.plms:
sampler = PLMSSampler(model)
else:
@ -226,8 +251,8 @@ def main():
with model.ema_scope():
tic = time.time()
all_samples = list()
for n in trange(opt.n_iter, desc="Sampling"):
for prompts in tqdm(data, desc="data"):
for n in trange(opt.n_iter, desc="Sampling", disable =not accelerator.is_main_process):
for prompts in tqdm(data, desc="data", disable =not accelerator.is_main_process):
uc = None
if opt.scale != 1.0:
uc = model.get_learned_conditioning(batch_size * [""])
@ -235,18 +260,32 @@ def main():
prompts = list(prompts)
c = model.get_learned_conditioning(prompts)
shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
conditioning=c,
batch_size=opt.n_samples,
shape=shape,
verbose=False,
unconditional_guidance_scale=opt.scale,
unconditional_conditioning=uc,
eta=opt.ddim_eta,
x_T=start_code)
if not opt.klms:
samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
conditioning=c,
batch_size=opt.n_samples,
shape=shape,
verbose=False,
unconditional_guidance_scale=opt.scale,
unconditional_conditioning=uc,
eta=opt.ddim_eta,
x_T=start_code)
else:
sigmas = model_wrap.get_sigmas(opt.ddim_steps)
if start_code:
x = start_code
else:
x = torch.randn([opt.n_samples, *shape], device=device) * sigmas[0] # for GPU draw
model_wrap_cfg = CFGDenoiser(model_wrap)
extra_args = {'cond': c, 'uncond': uc, 'cond_scale': opt.scale}
samples_ddim = K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not accelerator.is_main_process)
x_samples_ddim = model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
if opt.klms:
x_sample = accelerator.gather(x_samples_ddim)
if not opt.skip_save:
for x_sample in x_samples_ddim: