mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'tildebyte-feat-samplers-add-remaining-k' into main
This adds the remaining k_* samplers to the dream.py script.
This commit is contained in:
commit
7f4a5e946d
@ -67,7 +67,7 @@ class KSampler(object):
|
|||||||
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0] # for GPU draw
|
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0] # for GPU draw
|
||||||
model_wrap_cfg = CFGDenoiser(self.model)
|
model_wrap_cfg = CFGDenoiser(self.model)
|
||||||
extra_args = {'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale}
|
extra_args = {'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale}
|
||||||
return (K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not self.accelerator.is_main_process),
|
return (K.sampling.__dict__[f'sample_{self.schedule}'](model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not self.accelerator.is_main_process),
|
||||||
None)
|
None)
|
||||||
|
|
||||||
def gather(samples_ddim):
|
def gather(samples_ddim):
|
||||||
|
@ -18,7 +18,7 @@ t2i = T2I(outdir = <path> // outputs/txt2img-samples
|
|||||||
batch_size = <integer> // how many images to generate per sampling (1)
|
batch_size = <integer> // how many images to generate per sampling (1)
|
||||||
steps = <integer> // 50
|
steps = <integer> // 50
|
||||||
seed = <integer> // current system time
|
seed = <integer> // current system time
|
||||||
sampler_name= ['ddim','plms','klms'] // klms
|
sampler_name= ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms
|
||||||
grid = <boolean> // false
|
grid = <boolean> // false
|
||||||
width = <integer> // image width, multiple of 64 (512)
|
width = <integer> // image width, multiple of 64 (512)
|
||||||
height = <integer> // image height, multiple of 64 (512)
|
height = <integer> // image height, multiple of 64 (512)
|
||||||
@ -448,19 +448,29 @@ The vast majority of these arguments default to reasonable values.
|
|||||||
except AttributeError:
|
except AttributeError:
|
||||||
raise SystemExit
|
raise SystemExit
|
||||||
|
|
||||||
|
msg = f'setting sampler to {self.sampler_name}'
|
||||||
if self.sampler_name=='plms':
|
if self.sampler_name=='plms':
|
||||||
print("setting sampler to plms")
|
|
||||||
self.sampler = PLMSSampler(self.model)
|
self.sampler = PLMSSampler(self.model)
|
||||||
elif self.sampler_name == 'ddim':
|
elif self.sampler_name == 'ddim':
|
||||||
print("setting sampler to ddim")
|
|
||||||
self.sampler = DDIMSampler(self.model)
|
self.sampler = DDIMSampler(self.model)
|
||||||
elif self.sampler_name == 'klms':
|
elif self.sampler_name == 'k_dpm_2_a':
|
||||||
print("setting sampler to klms")
|
self.sampler = KSampler(self.model,'dpm_2_ancestral')
|
||||||
|
elif self.sampler_name == 'k_dpm_2':
|
||||||
|
self.sampler = KSampler(self.model,'dpm_2')
|
||||||
|
elif self.sampler_name == 'k_euler_a':
|
||||||
|
self.sampler = KSampler(self.model,'euler_ancestral')
|
||||||
|
elif self.sampler_name == 'k_euler':
|
||||||
|
self.sampler = KSampler(self.model,'euler')
|
||||||
|
elif self.sampler_name == 'k_heun':
|
||||||
|
self.sampler = KSampler(self.model,'heun')
|
||||||
|
elif self.sampler_name == 'k_lms':
|
||||||
self.sampler = KSampler(self.model,'lms')
|
self.sampler = KSampler(self.model,'lms')
|
||||||
else:
|
else:
|
||||||
print(f"unsupported sampler {self.sampler_name}, defaulting to plms")
|
msg = f'unsupported sampler {self.sampler_name}, defaulting to plms'
|
||||||
self.sampler = PLMSSampler(self.model)
|
self.sampler = PLMSSampler(self.model)
|
||||||
|
|
||||||
|
print(msg)
|
||||||
|
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def _load_model_from_config(self, config, ckpt):
|
def _load_model_from_config(self, config, ckpt):
|
||||||
|
@ -226,6 +226,7 @@ def _reconstruct_switches(t2i,opt):
|
|||||||
switches.append(f'-W{opt.width or t2i.width}')
|
switches.append(f'-W{opt.width or t2i.width}')
|
||||||
switches.append(f'-H{opt.height or t2i.height}')
|
switches.append(f'-H{opt.height or t2i.height}')
|
||||||
switches.append(f'-C{opt.cfg_scale or t2i.cfg_scale}')
|
switches.append(f'-C{opt.cfg_scale or t2i.cfg_scale}')
|
||||||
|
switches.append(f'-m{t2i.sampler_name}')
|
||||||
if opt.init_img:
|
if opt.init_img:
|
||||||
switches.append(f'-I{opt.init_img}')
|
switches.append(f'-I{opt.init_img}')
|
||||||
if opt.strength and opt.init_img is not None:
|
if opt.strength and opt.init_img is not None:
|
||||||
@ -266,9 +267,9 @@ def create_argv_parser():
|
|||||||
help="number of images to produce per iteration (faster, but doesn't generate individual seeds")
|
help="number of images to produce per iteration (faster, but doesn't generate individual seeds")
|
||||||
parser.add_argument('--sampler','-m',
|
parser.add_argument('--sampler','-m',
|
||||||
dest="sampler_name",
|
dest="sampler_name",
|
||||||
choices=['plms','ddim', 'klms'],
|
choices=['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'],
|
||||||
default='klms',
|
default='k_lms',
|
||||||
help="which sampler to use (klms) - can only be set on command line")
|
help="which sampler to use (k_lms) - can only be set on command line")
|
||||||
parser.add_argument('--outdir',
|
parser.add_argument('--outdir',
|
||||||
'-o',
|
'-o',
|
||||||
type=str,
|
type=str,
|
||||||
|
Loading…
Reference in New Issue
Block a user