diff --git a/ldm/models/diffusion/ksampler.py b/ldm/models/diffusion/ksampler.py index cc4677f47e..c48e533410 100644 --- a/ldm/models/diffusion/ksampler.py +++ b/ldm/models/diffusion/ksampler.py @@ -67,7 +67,7 @@ class KSampler(object): x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0] # for GPU draw model_wrap_cfg = CFGDenoiser(self.model) extra_args = {'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale} - return (K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not self.accelerator.is_main_process), + return (K.sampling.__dict__[f'sample_{self.schedule}'](model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not self.accelerator.is_main_process), None) def gather(samples_ddim): diff --git a/ldm/simplet2i.py b/ldm/simplet2i.py index bd53ebb362..c01981afe4 100644 --- a/ldm/simplet2i.py +++ b/ldm/simplet2i.py @@ -18,7 +18,7 @@ t2i = T2I(outdir = // outputs/txt2img-samples batch_size = // how many images to generate per sampling (1) steps = // 50 seed = // current system time - sampler_name= ['ddim','plms','klms'] // klms + sampler_name= ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms grid = // false width = // image width, multiple of 64 (512) height = // image height, multiple of 64 (512) @@ -448,19 +448,29 @@ The vast majority of these arguments default to reasonable values. except AttributeError: raise SystemExit + msg = f'setting sampler to {self.sampler_name}' if self.sampler_name=='plms': - print("setting sampler to plms") self.sampler = PLMSSampler(self.model) elif self.sampler_name == 'ddim': - print("setting sampler to ddim") self.sampler = DDIMSampler(self.model) - elif self.sampler_name == 'klms': - print("setting sampler to klms") + elif self.sampler_name == 'k_dpm_2_a': + self.sampler = KSampler(self.model,'dpm_2_ancestral') + elif self.sampler_name == 'k_dpm_2': + self.sampler = KSampler(self.model,'dpm_2') + elif self.sampler_name == 'k_euler_a': + self.sampler = KSampler(self.model,'euler_ancestral') + elif self.sampler_name == 'k_euler': + self.sampler = KSampler(self.model,'euler') + elif self.sampler_name == 'k_heun': + self.sampler = KSampler(self.model,'heun') + elif self.sampler_name == 'k_lms': self.sampler = KSampler(self.model,'lms') else: - print(f"unsupported sampler {self.sampler_name}, defaulting to plms") + msg = f'unsupported sampler {self.sampler_name}, defaulting to plms' self.sampler = PLMSSampler(self.model) + print(msg) + return self.model def _load_model_from_config(self, config, ckpt): diff --git a/scripts/dream.py b/scripts/dream.py index c57597f32e..b6a29627aa 100755 --- a/scripts/dream.py +++ b/scripts/dream.py @@ -226,6 +226,7 @@ def _reconstruct_switches(t2i,opt): switches.append(f'-W{opt.width or t2i.width}') switches.append(f'-H{opt.height or t2i.height}') switches.append(f'-C{opt.cfg_scale or t2i.cfg_scale}') + switches.append(f'-m{t2i.sampler_name}') if opt.init_img: switches.append(f'-I{opt.init_img}') if opt.strength and opt.init_img is not None: @@ -266,9 +267,9 @@ def create_argv_parser(): help="number of images to produce per iteration (faster, but doesn't generate individual seeds") parser.add_argument('--sampler','-m', dest="sampler_name", - choices=['plms','ddim', 'klms'], - default='klms', - help="which sampler to use (klms) - can only be set on command line") + choices=['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'], + default='k_lms', + help="which sampler to use (k_lms) - can only be set on command line") parser.add_argument('--outdir', '-o', type=str,