feat(samplers): add ability use all k_* samplers

Signed-off-by: Ben Alkov <ben.alkov@gmail.com>
This commit is contained in:
Ben Alkov 2022-08-23 17:25:39 -04:00
parent a21156e3e3
commit 050dffd269
3 changed files with 20 additions and 10 deletions

View File

@ -67,7 +67,7 @@ class KSampler(object):
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0] # for GPU draw
model_wrap_cfg = CFGDenoiser(self.model)
extra_args = {'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale}
return (K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not self.accelerator.is_main_process),
return (K.sampling.__dict__[f'sample_{self.schedule}'](model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not self.accelerator.is_main_process),
None)
def gather(samples_ddim):

View File

@ -11,7 +11,7 @@ t2i = T2I(outdir = <path> // outputs/txt2img-samples
batch_size = <integer> // how many images to generate per sampling (1)
steps = <integer> // 50
seed = <integer> // current system time
sampler_name= ['ddim','plms','klms'] // klms
sampler_name= ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms
grid = <boolean> // false
width = <integer> // image width, multiple of 64 (512)
height = <integer> // image height, multiple of 64 (512)
@ -435,19 +435,29 @@ The vast majority of these arguments default to reasonable values.
except AttributeError:
raise SystemExit
msg = f'setting sampler to {self.sampler_name}'
if self.sampler_name=='plms':
print("setting sampler to plms")
self.sampler = PLMSSampler(self.model)
elif self.sampler_name == 'ddim':
print("setting sampler to ddim")
self.sampler = DDIMSampler(self.model)
elif self.sampler_name == 'klms':
print("setting sampler to klms")
elif self.sampler_name == 'k_dpm_2_a':
self.sampler = KSampler(self.model,'dpm_2_ancestral')
elif self.sampler_name == 'k_dpm_2':
self.sampler = KSampler(self.model,'dpm_2')
elif self.sampler_name == 'k_euler_a':
self.sampler = KSampler(self.model,'euler_ancestral')
elif self.sampler_name == 'k_euler':
self.sampler = KSampler(self.model,'euler')
elif self.sampler_name == 'k_heun':
self.sampler = KSampler(self.model,'heun')
elif self.sampler_name == 'k_lms':
self.sampler = KSampler(self.model,'lms')
else:
print(f"unsupported sampler {self.sampler_name}, defaulting to plms")
msg = f'unsupported sampler {self.sampler_name}, defaulting to plms'
self.sampler = PLMSSampler(self.model)
print(msg)
return self.model
def _load_model_from_config(self, config, ckpt):

View File

@ -260,9 +260,9 @@ def create_argv_parser():
help="number of images to produce per iteration (faster, but doesn't generate individual seeds")
parser.add_argument('--sampler','-m',
dest="sampler_name",
choices=['plms','ddim', 'klms'],
default='klms',
help="which sampler to use (klms) - can only be set on command line")
choices=['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'],
default='k_lms',
help="which sampler to use (k_lms) - can only be set on command line")
parser.add_argument('--outdir',
'-o',
type=str,