From 050dffd269b9dff3899bebcc779dd78b87106303 Mon Sep 17 00:00:00 2001 From: Ben Alkov Date: Tue, 23 Aug 2022 17:25:39 -0400 Subject: [PATCH 1/2] feat(samplers): add ability use all k_* samplers Signed-off-by: Ben Alkov --- ldm/models/diffusion/ksampler.py | 2 +- ldm/simplet2i.py | 22 ++++++++++++++++------ scripts/dream.py | 6 +++--- 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/ldm/models/diffusion/ksampler.py b/ldm/models/diffusion/ksampler.py index cc4677f47e..c48e533410 100644 --- a/ldm/models/diffusion/ksampler.py +++ b/ldm/models/diffusion/ksampler.py @@ -67,7 +67,7 @@ class KSampler(object): x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0] # for GPU draw model_wrap_cfg = CFGDenoiser(self.model) extra_args = {'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale} - return (K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not self.accelerator.is_main_process), + return (K.sampling.__dict__[f'sample_{self.schedule}'](model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not self.accelerator.is_main_process), None) def gather(samples_ddim): diff --git a/ldm/simplet2i.py b/ldm/simplet2i.py index c7f6263816..f142413b1f 100644 --- a/ldm/simplet2i.py +++ b/ldm/simplet2i.py @@ -11,7 +11,7 @@ t2i = T2I(outdir = // outputs/txt2img-samples batch_size = // how many images to generate per sampling (1) steps = // 50 seed = // current system time - sampler_name= ['ddim','plms','klms'] // klms + sampler_name= ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms grid = // false width = // image width, multiple of 64 (512) height = // image height, multiple of 64 (512) @@ -435,19 +435,29 @@ The vast majority of these arguments default to reasonable values. except AttributeError: raise SystemExit + msg = f'setting sampler to {self.sampler_name}' if self.sampler_name=='plms': - print("setting sampler to plms") self.sampler = PLMSSampler(self.model) elif self.sampler_name == 'ddim': - print("setting sampler to ddim") self.sampler = DDIMSampler(self.model) - elif self.sampler_name == 'klms': - print("setting sampler to klms") + elif self.sampler_name == 'k_dpm_2_a': + self.sampler = KSampler(self.model,'dpm_2_ancestral') + elif self.sampler_name == 'k_dpm_2': + self.sampler = KSampler(self.model,'dpm_2') + elif self.sampler_name == 'k_euler_a': + self.sampler = KSampler(self.model,'euler_ancestral') + elif self.sampler_name == 'k_euler': + self.sampler = KSampler(self.model,'euler') + elif self.sampler_name == 'k_heun': + self.sampler = KSampler(self.model,'heun') + elif self.sampler_name == 'k_lms': self.sampler = KSampler(self.model,'lms') else: - print(f"unsupported sampler {self.sampler_name}, defaulting to plms") + msg = f'unsupported sampler {self.sampler_name}, defaulting to plms' self.sampler = PLMSSampler(self.model) + print(msg) + return self.model def _load_model_from_config(self, config, ckpt): diff --git a/scripts/dream.py b/scripts/dream.py index fb8fec2384..0b7f9abd70 100755 --- a/scripts/dream.py +++ b/scripts/dream.py @@ -260,9 +260,9 @@ def create_argv_parser(): help="number of images to produce per iteration (faster, but doesn't generate individual seeds") parser.add_argument('--sampler','-m', dest="sampler_name", - choices=['plms','ddim', 'klms'], - default='klms', - help="which sampler to use (klms) - can only be set on command line") + choices=['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'], + default='k_lms', + help="which sampler to use (k_lms) - can only be set on command line") parser.add_argument('--outdir', '-o', type=str, From 4bc64a6affb07732c8d5939a6dab7f16504a002b Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Wed, 24 Aug 2022 11:18:51 -0400 Subject: [PATCH 2/2] sampler now written to PNG metadata --- scripts/dream.py | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/dream.py b/scripts/dream.py index 0b7f9abd70..704635ccdf 100755 --- a/scripts/dream.py +++ b/scripts/dream.py @@ -220,6 +220,7 @@ def _reconstruct_switches(t2i,opt): switches.append(f'-W{opt.width or t2i.width}') switches.append(f'-H{opt.height or t2i.height}') switches.append(f'-C{opt.cfg_scale or t2i.cfg_scale}') + switches.append(f'-m{t2i.sampler_name}') if opt.init_img: switches.append(f'-I{opt.init_img}') if opt.strength and opt.init_img is not None: