diff --git a/environment.yaml b/environment.yaml index 7f25da800a..2ac2596575 100644 --- a/environment.yaml +++ b/environment.yaml @@ -24,6 +24,8 @@ dependencies: - transformers==4.19.2 - torchmetrics==0.6.0 - kornia==0.6 + - accelerate==0.12.0 + - git+https://github.com/crowsonkb/k-diffusion.git@master - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers - -e git+https://github.com/openai/CLIP.git@main#egg=clip - -e . diff --git a/scripts/dream.py b/scripts/dream.py index cc7614980f..322c4e3a82 100755 --- a/scripts/dream.py +++ b/scripts/dream.py @@ -169,7 +169,7 @@ def create_argv_parser(): default=1, help="number of images to produce per iteration (currently not working properly - producing too many images)") parser.add_argument('--sampler', - choices=['plms','ddim'], + choices=['plms','ddim', 'klms'], default='plms', help="which sampler to use") parser.add_argument('-o', diff --git a/scripts/txt2img.py b/scripts/txt2img.py index da77e1a03e..42d5e83496 100644 --- a/scripts/txt2img.py +++ b/scripts/txt2img.py @@ -12,6 +12,10 @@ from pytorch_lightning import seed_everything from torch import autocast from contextlib import contextmanager, nullcontext +import accelerate +import k_diffusion as K +import torch.nn as nn + from ldm.util import instantiate_from_config from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.plms import PLMSSampler @@ -80,6 +84,11 @@ def main(): action='store_true', help="use plms sampling", ) + parser.add_argument( + "--klms", + action='store_true', + help="use klms sampling", + ) parser.add_argument( "--laion400m", action='store_true', @@ -190,6 +199,22 @@ def main(): device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = model.to(device) + #for klms + model_wrap = K.external.CompVisDenoiser(model) + accelerator = accelerate.Accelerator() + device = accelerator.device + class CFGDenoiser(nn.Module): + def __init__(self, model): + super().__init__() + self.inner_model = model + + def forward(self, x, sigma, uncond, cond, cond_scale): + x_in = torch.cat([x] * 2) + sigma_in = torch.cat([sigma] * 2) + cond_in = torch.cat([uncond, cond]) + uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) + return uncond + (cond - uncond) * cond_scale + if opt.plms: sampler = PLMSSampler(model) else: @@ -226,8 +251,8 @@ def main(): with model.ema_scope(): tic = time.time() all_samples = list() - for n in trange(opt.n_iter, desc="Sampling"): - for prompts in tqdm(data, desc="data"): + for n in trange(opt.n_iter, desc="Sampling", disable =not accelerator.is_main_process): + for prompts in tqdm(data, desc="data", disable =not accelerator.is_main_process): uc = None if opt.scale != 1.0: uc = model.get_learned_conditioning(batch_size * [""]) @@ -235,18 +260,32 @@ def main(): prompts = list(prompts) c = model.get_learned_conditioning(prompts) shape = [opt.C, opt.H // opt.f, opt.W // opt.f] - samples_ddim, _ = sampler.sample(S=opt.ddim_steps, - conditioning=c, - batch_size=opt.n_samples, - shape=shape, - verbose=False, - unconditional_guidance_scale=opt.scale, - unconditional_conditioning=uc, - eta=opt.ddim_eta, - x_T=start_code) - + + if not opt.klms: + samples_ddim, _ = sampler.sample(S=opt.ddim_steps, + conditioning=c, + batch_size=opt.n_samples, + shape=shape, + verbose=False, + unconditional_guidance_scale=opt.scale, + unconditional_conditioning=uc, + eta=opt.ddim_eta, + x_T=start_code) + else: + sigmas = model_wrap.get_sigmas(opt.ddim_steps) + if start_code: + x = start_code + else: + x = torch.randn([opt.n_samples, *shape], device=device) * sigmas[0] # for GPU draw + model_wrap_cfg = CFGDenoiser(model_wrap) + extra_args = {'cond': c, 'uncond': uc, 'cond_scale': opt.scale} + samples_ddim = K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not accelerator.is_main_process) + x_samples_ddim = model.decode_first_stage(samples_ddim) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + + if opt.klms: + x_sample = accelerator.gather(x_samples_ddim) if not opt.skip_save: for x_sample in x_samples_ddim: