mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(samplers): add ability use all k_* samplers
Signed-off-by: Ben Alkov <ben.alkov@gmail.com>
This commit is contained in:
parent
a21156e3e3
commit
050dffd269
@ -67,7 +67,7 @@ class KSampler(object):
|
||||
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0] # for GPU draw
|
||||
model_wrap_cfg = CFGDenoiser(self.model)
|
||||
extra_args = {'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale}
|
||||
return (K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not self.accelerator.is_main_process),
|
||||
return (K.sampling.__dict__[f'sample_{self.schedule}'](model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not self.accelerator.is_main_process),
|
||||
None)
|
||||
|
||||
def gather(samples_ddim):
|
||||
|
@ -11,7 +11,7 @@ t2i = T2I(outdir = <path> // outputs/txt2img-samples
|
||||
batch_size = <integer> // how many images to generate per sampling (1)
|
||||
steps = <integer> // 50
|
||||
seed = <integer> // current system time
|
||||
sampler_name= ['ddim','plms','klms'] // klms
|
||||
sampler_name= ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms
|
||||
grid = <boolean> // false
|
||||
width = <integer> // image width, multiple of 64 (512)
|
||||
height = <integer> // image height, multiple of 64 (512)
|
||||
@ -435,19 +435,29 @@ The vast majority of these arguments default to reasonable values.
|
||||
except AttributeError:
|
||||
raise SystemExit
|
||||
|
||||
msg = f'setting sampler to {self.sampler_name}'
|
||||
if self.sampler_name=='plms':
|
||||
print("setting sampler to plms")
|
||||
self.sampler = PLMSSampler(self.model)
|
||||
elif self.sampler_name == 'ddim':
|
||||
print("setting sampler to ddim")
|
||||
self.sampler = DDIMSampler(self.model)
|
||||
elif self.sampler_name == 'klms':
|
||||
print("setting sampler to klms")
|
||||
elif self.sampler_name == 'k_dpm_2_a':
|
||||
self.sampler = KSampler(self.model,'dpm_2_ancestral')
|
||||
elif self.sampler_name == 'k_dpm_2':
|
||||
self.sampler = KSampler(self.model,'dpm_2')
|
||||
elif self.sampler_name == 'k_euler_a':
|
||||
self.sampler = KSampler(self.model,'euler_ancestral')
|
||||
elif self.sampler_name == 'k_euler':
|
||||
self.sampler = KSampler(self.model,'euler')
|
||||
elif self.sampler_name == 'k_heun':
|
||||
self.sampler = KSampler(self.model,'heun')
|
||||
elif self.sampler_name == 'k_lms':
|
||||
self.sampler = KSampler(self.model,'lms')
|
||||
else:
|
||||
print(f"unsupported sampler {self.sampler_name}, defaulting to plms")
|
||||
msg = f'unsupported sampler {self.sampler_name}, defaulting to plms'
|
||||
self.sampler = PLMSSampler(self.model)
|
||||
|
||||
print(msg)
|
||||
|
||||
return self.model
|
||||
|
||||
def _load_model_from_config(self, config, ckpt):
|
||||
|
@ -260,9 +260,9 @@ def create_argv_parser():
|
||||
help="number of images to produce per iteration (faster, but doesn't generate individual seeds")
|
||||
parser.add_argument('--sampler','-m',
|
||||
dest="sampler_name",
|
||||
choices=['plms','ddim', 'klms'],
|
||||
default='klms',
|
||||
help="which sampler to use (klms) - can only be set on command line")
|
||||
choices=['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'],
|
||||
default='k_lms',
|
||||
help="which sampler to use (k_lms) - can only be set on command line")
|
||||
parser.add_argument('--outdir',
|
||||
'-o',
|
||||
type=str,
|
||||
|
Loading…
Reference in New Issue
Block a user