Merge branch 'hwharrison-main' into main

This enables k_lms sampling (now the default)`:wq
This commit is contained in:
Lincoln Stein 2022-08-21 20:17:22 -04:00
commit 3c74dd41c4
7 changed files with 161 additions and 21 deletions

View File

@ -73,6 +73,14 @@ The --init_img (-I) option gives the path to the seed picture. --strength (-f) c
the original will be modified, ranging from 0.0 (keep the original intact), to 1.0 (ignore the original the original will be modified, ranging from 0.0 (keep the original intact), to 1.0 (ignore the original
completely). The default is 0.75, and ranges from 0.25-0.75 give interesting results. completely). The default is 0.75, and ranges from 0.25-0.75 give interesting results.
## Changes
- v1.01 (21 August 2022)
* added k_lms sampling **Please run "conda update -f environment.yaml" to load the k_lms dependencies**
* use half precision arithmetic by default, resulting in faster execution and lower memory requirements
Pass argument --full_precision to dream.py to get slower but more accurate image generation
## Installation ## Installation
### Linux/Mac ### Linux/Mac

View File

@ -24,6 +24,8 @@ dependencies:
- transformers==4.19.2 - transformers==4.19.2
- torchmetrics==0.6.0 - torchmetrics==0.6.0
- kornia==0.6 - kornia==0.6
- -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers - accelerate==0.12.0
- -e git+https://github.com/openai/CLIP.git@main#egg=clip - -e git+https://github.com/openai/CLIP.git@main#egg=clip
- -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
- -e git+https://github.com/lstein/k-diffusion.git@master#egg=k-diffusion
- -e . - -e .

View File

@ -0,0 +1,74 @@
'''wrapper around part of Karen Crownson's k-duffsion library, making it call compatible with other Samplers'''
import k_diffusion as K
import torch
import torch.nn as nn
import accelerate
class CFGDenoiser(nn.Module):
def __init__(self, model):
super().__init__()
self.inner_model = model
def forward(self, x, sigma, uncond, cond, cond_scale):
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
return uncond + (cond - uncond) * cond_scale
class KSampler(object):
def __init__(self,model,schedule="lms", **kwargs):
super().__init__()
self.model = K.external.CompVisDenoiser(model)
self.accelerator = accelerate.Accelerator()
self.device = self.accelerator.device
self.schedule = schedule
def forward(self, x, sigma, uncond, cond, cond_scale):
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
return uncond + (cond - uncond) * cond_scale
# most of these arguments are ignored and are only present for compatibility with
# other samples
@torch.no_grad()
def sample(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs
):
sigmas = self.model.get_sigmas(S)
if x_T:
x = x_T
else:
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0] # for GPU draw
model_wrap_cfg = CFGDenoiser(self.model)
extra_args = {'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale}
return (K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not self.accelerator.is_main_process),
None)
def gather(samples_ddim):
return self.accelerator.gather(samples_ddim)

View File

@ -11,7 +11,7 @@ t2i = T2I(outdir = <path> // outputs/txt2img-samples
batch_size = <integer> // how many images to generate per sampling (1) batch_size = <integer> // how many images to generate per sampling (1)
steps = <integer> // 50 steps = <integer> // 50
seed = <integer> // current system time seed = <integer> // current system time
sampler = ['ddim','plms'] // ddim sampler = ['ddim','plms','klms'] // klms
grid = <boolean> // false grid = <boolean> // false
width = <integer> // image width, multiple of 64 (512) width = <integer> // image width, multiple of 64 (512)
height = <integer> // image height, multiple of 64 (512) height = <integer> // image height, multiple of 64 (512)
@ -64,6 +64,7 @@ import math
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler from ldm.models.diffusion.plms import PLMSSampler
from ldm.models.diffusion.ksampler import KSampler
class T2I: class T2I:
"""T2I class """T2I class
@ -101,12 +102,13 @@ class T2I:
cfg_scale=7.5, cfg_scale=7.5,
weights="models/ldm/stable-diffusion-v1/model.ckpt", weights="models/ldm/stable-diffusion-v1/model.ckpt",
config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml", config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml",
sampler="plms", sampler="klms",
latent_channels=4, latent_channels=4,
downsampling_factor=8, downsampling_factor=8,
ddim_eta=0.0, # deterministic ddim_eta=0.0, # deterministic
fixed_code=False, fixed_code=False,
precision='autocast', precision='autocast',
full_precision=False,
strength=0.75 # default in scripts/img2img.py strength=0.75 # default in scripts/img2img.py
): ):
self.outdir = outdir self.outdir = outdir
@ -125,6 +127,7 @@ class T2I:
self.downsampling_factor = downsampling_factor self.downsampling_factor = downsampling_factor
self.ddim_eta = ddim_eta self.ddim_eta = ddim_eta
self.precision = precision self.precision = precision
self.full_precision = full_precision
self.strength = strength self.strength = strength
self.model = None # empty for now self.model = None # empty for now
self.sampler = None self.sampler = None
@ -387,6 +390,9 @@ class T2I:
elif self.sampler_name == 'ddim': elif self.sampler_name == 'ddim':
print("setting sampler to ddim") print("setting sampler to ddim")
self.sampler = DDIMSampler(self.model) self.sampler = DDIMSampler(self.model)
elif self.sampler_name == 'klms':
print("setting sampler to klms")
self.sampler = KSampler(self.model,'lms')
else: else:
print(f"unsupported sampler {self.sampler_name}, defaulting to plms") print(f"unsupported sampler {self.sampler_name}, defaulting to plms")
self.sampler = PLMSSampler(self.model) self.sampler = PLMSSampler(self.model)
@ -403,6 +409,10 @@ class T2I:
m, u = model.load_state_dict(sd, strict=False) m, u = model.load_state_dict(sd, strict=False)
model.cuda() model.cuda()
model.eval() model.eval()
if self.full_precision:
print('Using slower but more accurate full-precision math (--full_precision)')
else:
print('Using half precision math. Call with --full_precision to use full precision')
model.half() model.half()
return model return model

View File

@ -1,3 +1,4 @@
#!/usr/bin/env python3
import argparse import argparse
import shlex import shlex
import atexit import atexit
@ -49,6 +50,7 @@ def main():
outdir=opt.outdir, outdir=opt.outdir,
sampler=opt.sampler, sampler=opt.sampler,
weights=weights, weights=weights,
full_precision=opt.full_precision,
config=config) config=config)
# make sure the output directory exists # make sure the output directory exists
@ -165,14 +167,18 @@ def create_argv_parser():
type=int, type=int,
default=1, default=1,
help="number of images to generate") help="number of images to generate")
parser.add_argument('-F','--full_precision',
dest='full_precision',
action='store_true',
help="use slower full precision math for calculations")
parser.add_argument('-b','--batch_size', parser.add_argument('-b','--batch_size',
type=int, type=int,
default=1, default=1,
help="number of images to produce per iteration (currently not working properly - producing too many images)") help="number of images to produce per iteration (currently not working properly - producing too many images)")
parser.add_argument('--sampler', parser.add_argument('--sampler',
choices=['plms','ddim'], choices=['plms','ddim', 'klms'],
default='plms', default='klms',
help="which sampler to use") help="which sampler to use (klms)")
parser.add_argument('-o', parser.add_argument('-o',
'--outdir', '--outdir',
type=str, type=str,

View File

@ -1,3 +1,4 @@
#!/usr/bin/env python3
# Before running stable-diffusion on an internet-isolated machine, # Before running stable-diffusion on an internet-isolated machine,
# run this script from one with internet connectivity. The # run this script from one with internet connectivity. The
# two machines must share a common .cache directory. # two machines must share a common .cache directory.

View File

@ -12,6 +12,10 @@ from pytorch_lightning import seed_everything
from torch import autocast from torch import autocast
from contextlib import contextmanager, nullcontext from contextlib import contextmanager, nullcontext
import accelerate
import k_diffusion as K
import torch.nn as nn
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler from ldm.models.diffusion.plms import PLMSSampler
@ -80,6 +84,11 @@ def main():
action='store_true', action='store_true',
help="use plms sampling", help="use plms sampling",
) )
parser.add_argument(
"--klms",
action='store_true',
help="use klms sampling",
)
parser.add_argument( parser.add_argument(
"--laion400m", "--laion400m",
action='store_true', action='store_true',
@ -190,6 +199,22 @@ def main():
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device) model = model.to(device)
#for klms
model_wrap = K.external.CompVisDenoiser(model)
accelerator = accelerate.Accelerator()
device = accelerator.device
class CFGDenoiser(nn.Module):
def __init__(self, model):
super().__init__()
self.inner_model = model
def forward(self, x, sigma, uncond, cond, cond_scale):
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
return uncond + (cond - uncond) * cond_scale
if opt.plms: if opt.plms:
sampler = PLMSSampler(model) sampler = PLMSSampler(model)
else: else:
@ -226,8 +251,8 @@ def main():
with model.ema_scope(): with model.ema_scope():
tic = time.time() tic = time.time()
all_samples = list() all_samples = list()
for n in trange(opt.n_iter, desc="Sampling"): for n in trange(opt.n_iter, desc="Sampling", disable =not accelerator.is_main_process):
for prompts in tqdm(data, desc="data"): for prompts in tqdm(data, desc="data", disable =not accelerator.is_main_process):
uc = None uc = None
if opt.scale != 1.0: if opt.scale != 1.0:
uc = model.get_learned_conditioning(batch_size * [""]) uc = model.get_learned_conditioning(batch_size * [""])
@ -235,6 +260,8 @@ def main():
prompts = list(prompts) prompts = list(prompts)
c = model.get_learned_conditioning(prompts) c = model.get_learned_conditioning(prompts)
shape = [opt.C, opt.H // opt.f, opt.W // opt.f] shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
if not opt.klms:
samples_ddim, _ = sampler.sample(S=opt.ddim_steps, samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
conditioning=c, conditioning=c,
batch_size=opt.n_samples, batch_size=opt.n_samples,
@ -244,10 +271,22 @@ def main():
unconditional_conditioning=uc, unconditional_conditioning=uc,
eta=opt.ddim_eta, eta=opt.ddim_eta,
x_T=start_code) x_T=start_code)
else:
sigmas = model_wrap.get_sigmas(opt.ddim_steps)
if start_code:
x = start_code
else:
x = torch.randn([opt.n_samples, *shape], device=device) * sigmas[0] # for GPU draw
model_wrap_cfg = CFGDenoiser(model_wrap)
extra_args = {'cond': c, 'uncond': uc, 'cond_scale': opt.scale}
samples_ddim = K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not accelerator.is_main_process)
x_samples_ddim = model.decode_first_stage(samples_ddim) x_samples_ddim = model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
if opt.klms:
x_sample = accelerator.gather(x_samples_ddim)
if not opt.skip_save: if not opt.skip_save:
for x_sample in x_samples_ddim: for x_sample in x_samples_ddim:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')