diff --git a/README.md b/README.md index 5215530a9e..131289bab1 100644 --- a/README.md +++ b/README.md @@ -73,6 +73,14 @@ The --init_img (-I) option gives the path to the seed picture. --strength (-f) c the original will be modified, ranging from 0.0 (keep the original intact), to 1.0 (ignore the original completely). The default is 0.75, and ranges from 0.25-0.75 give interesting results. +## Changes + +- v1.01 (21 August 2022) +* added k_lms sampling **Please run "conda update -f environment.yaml" to load the k_lms dependencies** +* use half precision arithmetic by default, resulting in faster execution and lower memory requirements +Pass argument --full_precision to dream.py to get slower but more accurate image generation + + ## Installation ### Linux/Mac diff --git a/environment.yaml b/environment.yaml index 7f25da800a..0de05e815a 100644 --- a/environment.yaml +++ b/environment.yaml @@ -24,6 +24,8 @@ dependencies: - transformers==4.19.2 - torchmetrics==0.6.0 - kornia==0.6 - - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers + - accelerate==0.12.0 - -e git+https://github.com/openai/CLIP.git@main#egg=clip + - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers + - -e git+https://github.com/lstein/k-diffusion.git@master#egg=k-diffusion - -e . diff --git a/ldm/models/diffusion/ksampler.py b/ldm/models/diffusion/ksampler.py new file mode 100644 index 0000000000..cc4677f47e --- /dev/null +++ b/ldm/models/diffusion/ksampler.py @@ -0,0 +1,74 @@ +'''wrapper around part of Karen Crownson's k-duffsion library, making it call compatible with other Samplers''' +import k_diffusion as K +import torch +import torch.nn as nn +import accelerate + +class CFGDenoiser(nn.Module): + def __init__(self, model): + super().__init__() + self.inner_model = model + + def forward(self, x, sigma, uncond, cond, cond_scale): + x_in = torch.cat([x] * 2) + sigma_in = torch.cat([sigma] * 2) + cond_in = torch.cat([uncond, cond]) + uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) + return uncond + (cond - uncond) * cond_scale + +class KSampler(object): + def __init__(self,model,schedule="lms", **kwargs): + super().__init__() + self.model = K.external.CompVisDenoiser(model) + self.accelerator = accelerate.Accelerator() + self.device = self.accelerator.device + self.schedule = schedule + + def forward(self, x, sigma, uncond, cond, cond_scale): + x_in = torch.cat([x] * 2) + sigma_in = torch.cat([sigma] * 2) + cond_in = torch.cat([uncond, cond]) + uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) + return uncond + (cond - uncond) * cond_scale + + + # most of these arguments are ignored and are only present for compatibility with + # other samples + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + + sigmas = self.model.get_sigmas(S) + if x_T: + x = x_T + else: + x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0] # for GPU draw + model_wrap_cfg = CFGDenoiser(self.model) + extra_args = {'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale} + return (K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not self.accelerator.is_main_process), + None) + + def gather(samples_ddim): + return self.accelerator.gather(samples_ddim) diff --git a/ldm/simplet2i.py b/ldm/simplet2i.py index 796a99396b..e99660a8ab 100644 --- a/ldm/simplet2i.py +++ b/ldm/simplet2i.py @@ -11,7 +11,7 @@ t2i = T2I(outdir = <path> // outputs/txt2img-samples batch_size = <integer> // how many images to generate per sampling (1) steps = <integer> // 50 seed = <integer> // current system time - sampler = ['ddim','plms'] // ddim + sampler = ['ddim','plms','klms'] // klms grid = <boolean> // false width = <integer> // image width, multiple of 64 (512) height = <integer> // image height, multiple of 64 (512) @@ -62,8 +62,9 @@ import time import math from ldm.util import instantiate_from_config -from ldm.models.diffusion.ddim import DDIMSampler -from ldm.models.diffusion.plms import PLMSSampler +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.models.diffusion.plms import PLMSSampler +from ldm.models.diffusion.ksampler import KSampler class T2I: """T2I class @@ -101,12 +102,13 @@ class T2I: cfg_scale=7.5, weights="models/ldm/stable-diffusion-v1/model.ckpt", config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml", - sampler="plms", + sampler="klms", latent_channels=4, downsampling_factor=8, ddim_eta=0.0, # deterministic fixed_code=False, precision='autocast', + full_precision=False, strength=0.75 # default in scripts/img2img.py ): self.outdir = outdir @@ -125,6 +127,7 @@ class T2I: self.downsampling_factor = downsampling_factor self.ddim_eta = ddim_eta self.precision = precision + self.full_precision = full_precision self.strength = strength self.model = None # empty for now self.sampler = None @@ -387,6 +390,9 @@ class T2I: elif self.sampler_name == 'ddim': print("setting sampler to ddim") self.sampler = DDIMSampler(self.model) + elif self.sampler_name == 'klms': + print("setting sampler to klms") + self.sampler = KSampler(self.model,'lms') else: print(f"unsupported sampler {self.sampler_name}, defaulting to plms") self.sampler = PLMSSampler(self.model) @@ -403,7 +409,11 @@ class T2I: m, u = model.load_state_dict(sd, strict=False) model.cuda() model.eval() - model.half() + if self.full_precision: + print('Using slower but more accurate full-precision math (--full_precision)') + else: + print('Using half precision math. Call with --full_precision to use full precision') + model.half() return model def _load_img(self,path): diff --git a/scripts/dream.py b/scripts/dream.py index 44b7d9978a..b8abb780fd 100755 --- a/scripts/dream.py +++ b/scripts/dream.py @@ -1,3 +1,4 @@ +#!/usr/bin/env python3 import argparse import shlex import atexit @@ -49,6 +50,7 @@ def main(): outdir=opt.outdir, sampler=opt.sampler, weights=weights, + full_precision=opt.full_precision, config=config) # make sure the output directory exists @@ -165,14 +167,18 @@ def create_argv_parser(): type=int, default=1, help="number of images to generate") + parser.add_argument('-F','--full_precision', + dest='full_precision', + action='store_true', + help="use slower full precision math for calculations") parser.add_argument('-b','--batch_size', type=int, default=1, help="number of images to produce per iteration (currently not working properly - producing too many images)") parser.add_argument('--sampler', - choices=['plms','ddim'], - default='plms', - help="which sampler to use") + choices=['plms','ddim', 'klms'], + default='klms', + help="which sampler to use (klms)") parser.add_argument('-o', '--outdir', type=str, diff --git a/scripts/preload_models.py b/scripts/preload_models.py index 7db461bec2..ad1a1eecc5 100644 --- a/scripts/preload_models.py +++ b/scripts/preload_models.py @@ -1,3 +1,4 @@ +#!/usr/bin/env python3 # Before running stable-diffusion on an internet-isolated machine, # run this script from one with internet connectivity. The # two machines must share a common .cache directory. diff --git a/scripts/txt2img.py b/scripts/txt2img.py index da77e1a03e..42d5e83496 100644 --- a/scripts/txt2img.py +++ b/scripts/txt2img.py @@ -12,6 +12,10 @@ from pytorch_lightning import seed_everything from torch import autocast from contextlib import contextmanager, nullcontext +import accelerate +import k_diffusion as K +import torch.nn as nn + from ldm.util import instantiate_from_config from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.plms import PLMSSampler @@ -80,6 +84,11 @@ def main(): action='store_true', help="use plms sampling", ) + parser.add_argument( + "--klms", + action='store_true', + help="use klms sampling", + ) parser.add_argument( "--laion400m", action='store_true', @@ -190,6 +199,22 @@ def main(): device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = model.to(device) + #for klms + model_wrap = K.external.CompVisDenoiser(model) + accelerator = accelerate.Accelerator() + device = accelerator.device + class CFGDenoiser(nn.Module): + def __init__(self, model): + super().__init__() + self.inner_model = model + + def forward(self, x, sigma, uncond, cond, cond_scale): + x_in = torch.cat([x] * 2) + sigma_in = torch.cat([sigma] * 2) + cond_in = torch.cat([uncond, cond]) + uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) + return uncond + (cond - uncond) * cond_scale + if opt.plms: sampler = PLMSSampler(model) else: @@ -226,8 +251,8 @@ def main(): with model.ema_scope(): tic = time.time() all_samples = list() - for n in trange(opt.n_iter, desc="Sampling"): - for prompts in tqdm(data, desc="data"): + for n in trange(opt.n_iter, desc="Sampling", disable =not accelerator.is_main_process): + for prompts in tqdm(data, desc="data", disable =not accelerator.is_main_process): uc = None if opt.scale != 1.0: uc = model.get_learned_conditioning(batch_size * [""]) @@ -235,18 +260,32 @@ def main(): prompts = list(prompts) c = model.get_learned_conditioning(prompts) shape = [opt.C, opt.H // opt.f, opt.W // opt.f] - samples_ddim, _ = sampler.sample(S=opt.ddim_steps, - conditioning=c, - batch_size=opt.n_samples, - shape=shape, - verbose=False, - unconditional_guidance_scale=opt.scale, - unconditional_conditioning=uc, - eta=opt.ddim_eta, - x_T=start_code) - + + if not opt.klms: + samples_ddim, _ = sampler.sample(S=opt.ddim_steps, + conditioning=c, + batch_size=opt.n_samples, + shape=shape, + verbose=False, + unconditional_guidance_scale=opt.scale, + unconditional_conditioning=uc, + eta=opt.ddim_eta, + x_T=start_code) + else: + sigmas = model_wrap.get_sigmas(opt.ddim_steps) + if start_code: + x = start_code + else: + x = torch.randn([opt.n_samples, *shape], device=device) * sigmas[0] # for GPU draw + model_wrap_cfg = CFGDenoiser(model_wrap) + extra_args = {'cond': c, 'uncond': uc, 'cond_scale': opt.scale} + samples_ddim = K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not accelerator.is_main_process) + x_samples_ddim = model.decode_first_stage(samples_ddim) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + + if opt.klms: + x_sample = accelerator.gather(x_samples_ddim) if not opt.skip_save: for x_sample in x_samples_ddim: