add klms sampling

This commit is contained in:
henry 2022-08-20 22:28:29 -05:00
parent 1ca3dc553c
commit 41f0afbcb6
3 changed files with 54 additions and 13 deletions

View File

@ -24,6 +24,8 @@ dependencies:
- transformers==4.19.2 - transformers==4.19.2
- torchmetrics==0.6.0 - torchmetrics==0.6.0
- kornia==0.6 - kornia==0.6
- accelerate==0.12.0
- git+https://github.com/crowsonkb/k-diffusion.git@master
- -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
- -e git+https://github.com/openai/CLIP.git@main#egg=clip - -e git+https://github.com/openai/CLIP.git@main#egg=clip
- -e . - -e .

View File

@ -169,7 +169,7 @@ def create_argv_parser():
default=1, default=1,
help="number of images to produce per iteration (currently not working properly - producing too many images)") help="number of images to produce per iteration (currently not working properly - producing too many images)")
parser.add_argument('--sampler', parser.add_argument('--sampler',
choices=['plms','ddim'], choices=['plms','ddim', 'klms'],
default='plms', default='plms',
help="which sampler to use") help="which sampler to use")
parser.add_argument('-o', parser.add_argument('-o',

View File

@ -12,6 +12,10 @@ from pytorch_lightning import seed_everything
from torch import autocast from torch import autocast
from contextlib import contextmanager, nullcontext from contextlib import contextmanager, nullcontext
import accelerate
import k_diffusion as K
import torch.nn as nn
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler from ldm.models.diffusion.plms import PLMSSampler
@ -80,6 +84,11 @@ def main():
action='store_true', action='store_true',
help="use plms sampling", help="use plms sampling",
) )
parser.add_argument(
"--klms",
action='store_true',
help="use klms sampling",
)
parser.add_argument( parser.add_argument(
"--laion400m", "--laion400m",
action='store_true', action='store_true',
@ -190,6 +199,22 @@ def main():
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device) model = model.to(device)
#for klms
model_wrap = K.external.CompVisDenoiser(model)
accelerator = accelerate.Accelerator()
device = accelerator.device
class CFGDenoiser(nn.Module):
def __init__(self, model):
super().__init__()
self.inner_model = model
def forward(self, x, sigma, uncond, cond, cond_scale):
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
return uncond + (cond - uncond) * cond_scale
if opt.plms: if opt.plms:
sampler = PLMSSampler(model) sampler = PLMSSampler(model)
else: else:
@ -226,8 +251,8 @@ def main():
with model.ema_scope(): with model.ema_scope():
tic = time.time() tic = time.time()
all_samples = list() all_samples = list()
for n in trange(opt.n_iter, desc="Sampling"): for n in trange(opt.n_iter, desc="Sampling", disable =not accelerator.is_main_process):
for prompts in tqdm(data, desc="data"): for prompts in tqdm(data, desc="data", disable =not accelerator.is_main_process):
uc = None uc = None
if opt.scale != 1.0: if opt.scale != 1.0:
uc = model.get_learned_conditioning(batch_size * [""]) uc = model.get_learned_conditioning(batch_size * [""])
@ -235,6 +260,8 @@ def main():
prompts = list(prompts) prompts = list(prompts)
c = model.get_learned_conditioning(prompts) c = model.get_learned_conditioning(prompts)
shape = [opt.C, opt.H // opt.f, opt.W // opt.f] shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
if not opt.klms:
samples_ddim, _ = sampler.sample(S=opt.ddim_steps, samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
conditioning=c, conditioning=c,
batch_size=opt.n_samples, batch_size=opt.n_samples,
@ -244,10 +271,22 @@ def main():
unconditional_conditioning=uc, unconditional_conditioning=uc,
eta=opt.ddim_eta, eta=opt.ddim_eta,
x_T=start_code) x_T=start_code)
else:
sigmas = model_wrap.get_sigmas(opt.ddim_steps)
if start_code:
x = start_code
else:
x = torch.randn([opt.n_samples, *shape], device=device) * sigmas[0] # for GPU draw
model_wrap_cfg = CFGDenoiser(model_wrap)
extra_args = {'cond': c, 'uncond': uc, 'cond_scale': opt.scale}
samples_ddim = K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not accelerator.is_main_process)
x_samples_ddim = model.decode_first_stage(samples_ddim) x_samples_ddim = model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
if opt.klms:
x_sample = accelerator.gather(x_samples_ddim)
if not opt.skip_save: if not opt.skip_save:
for x_sample in x_samples_ddim: for x_sample in x_samples_ddim:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')