feat(samplers): add ability use all k_* samplers

Signed-off-by: Ben Alkov <ben.alkov@gmail.com>
This commit is contained in:
Ben Alkov
2022-08-23 17:25:39 -04:00
parent a21156e3e3
commit 050dffd269
3 changed files with 20 additions and 10 deletions

View File

@ -67,7 +67,7 @@ class KSampler(object):
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0] # for GPU draw
model_wrap_cfg = CFGDenoiser(self.model)
extra_args = {'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale}
return (K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not self.accelerator.is_main_process),
return (K.sampling.__dict__[f'sample_{self.schedule}'](model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not self.accelerator.is_main_process),
None)
def gather(samples_ddim):