added support for changing sampler on prompt line

This commit is contained in:
Lincoln Stein 2022-08-28 19:03:38 -04:00
commit 1f8bc9482a
4 changed files with 64 additions and 34 deletions

View File

@ -344,6 +344,7 @@ repository and associated paper for details and limitations.
enabled if the GFPGAN directory is located as a sibling to Stable Diffusion. enabled if the GFPGAN directory is located as a sibling to Stable Diffusion.
VRAM requirements are modestly reduced. Thanks to both [Blessedcoolant](https://github.com/blessedcoolant) and VRAM requirements are modestly reduced. Thanks to both [Blessedcoolant](https://github.com/blessedcoolant) and
[Oceanswave](https://github.com/oceanswave) for their work on this. [Oceanswave](https://github.com/oceanswave) for their work on this.
- You can now swap samplers on the dream> command line. [Blessedcoolant](https://github.com/blessedcoolant)
- v1.11 (26 August 2022) - v1.11 (26 August 2022)
- NEW FEATURE: Support upscaling and face enhancement using the GFPGAN module. (kudos to [Oceanswave](https://github.com/Oceanswave) - NEW FEATURE: Support upscaling and face enhancement using the GFPGAN module. (kudos to [Oceanswave](https://github.com/Oceanswave)

View File

@ -106,6 +106,8 @@ class PromptFormatter:
self.t2i = t2i self.t2i = t2i
self.opt = opt self.opt = opt
# note: the t2i object should provide all these values.
# there should be no need to or against opt values
def normalize_prompt(self): def normalize_prompt(self):
"""Normalize the prompt and switches""" """Normalize the prompt and switches"""
t2i = self.t2i t2i = self.t2i
@ -118,13 +120,15 @@ class PromptFormatter:
switches.append(f'-W{opt.width or t2i.width}') switches.append(f'-W{opt.width or t2i.width}')
switches.append(f'-H{opt.height or t2i.height}') switches.append(f'-H{opt.height or t2i.height}')
switches.append(f'-C{opt.cfg_scale or t2i.cfg_scale}') switches.append(f'-C{opt.cfg_scale or t2i.cfg_scale}')
switches.append(f'-m{t2i.sampler_name}') switches.append(f'-A{opt.sampler_name or t2i.sampler_name}')
if opt.init_img: if opt.init_img:
switches.append(f'-I{opt.init_img}') switches.append(f'-I{opt.init_img}')
if opt.strength and opt.init_img is not None: if opt.strength and opt.init_img is not None:
switches.append(f'-f{opt.strength or t2i.strength}') switches.append(f'-f{opt.strength or t2i.strength}')
if opt.gfpgan_strength: if opt.gfpgan_strength:
switches.append(f'-G{opt.gfpgan_strength}') switches.append(f'-G{opt.gfpgan_strength}')
if opt.upscale:
switches.append(f'-U {" ".join([str(u) for u in opt.upscale])}')
if t2i.full_precision: if t2i.full_precision:
switches.append('-F') switches.append('-F')
return ' '.join(switches) return ' '.join(switches)

View File

@ -212,6 +212,7 @@ class T2I:
save_original=False, save_original=False,
upscale=None, upscale=None,
variants=None, variants=None,
user_sampler=None,
**args, **args,
): # eat up additional cruft ): # eat up additional cruft
""" """
@ -271,6 +272,10 @@ class T2I:
scope = autocast if self.precision == 'autocast' else nullcontext scope = autocast if self.precision == 'autocast' else nullcontext
if user_sampler and (user_sampler != self.sampler_name):
self.sampler_name = user_sampler
self._set_sampler()
tic = time.time() tic = time.time()
torch.cuda.torch.cuda.reset_peak_memory_stats() torch.cuda.torch.cuda.reset_peak_memory_stats()
results = list() results = list()
@ -537,39 +542,38 @@ class T2I:
except AttributeError: except AttributeError:
raise SystemExit raise SystemExit
msg = f'setting sampler to {self.sampler_name}' self._set_sampler()
if self.sampler_name == 'plms':
self.sampler = PLMSSampler(self.model, device=self.device)
elif self.sampler_name == 'ddim':
self.sampler = DDIMSampler(self.model, device=self.device)
elif self.sampler_name == 'k_dpm_2_a':
self.sampler = KSampler(
self.model, 'dpm_2_ancestral', device=self.device
)
elif self.sampler_name == 'k_dpm_2':
self.sampler = KSampler(
self.model, 'dpm_2', device=self.device
)
elif self.sampler_name == 'k_euler_a':
self.sampler = KSampler(
self.model, 'euler_ancestral', device=self.device
)
elif self.sampler_name == 'k_euler':
self.sampler = KSampler(
self.model, 'euler', device=self.device
)
elif self.sampler_name == 'k_heun':
self.sampler = KSampler(self.model, 'heun', device=self.device)
elif self.sampler_name == 'k_lms':
self.sampler = KSampler(self.model, 'lms', device=self.device)
else:
msg = f'unsupported sampler {self.sampler_name}, defaulting to plms'
self.sampler = PLMSSampler(self.model, device=self.device)
print(msg)
return self.model return self.model
def _set_sampler(self):
msg = f'>> Setting Sampler to {self.sampler_name}'
if self.sampler_name == 'plms':
self.sampler = PLMSSampler(self.model, device=self.device)
elif self.sampler_name == 'ddim':
self.sampler = DDIMSampler(self.model, device=self.device)
elif self.sampler_name == 'k_dpm_2_a':
self.sampler = KSampler(
self.model, 'dpm_2_ancestral', device=self.device
)
elif self.sampler_name == 'k_dpm_2':
self.sampler = KSampler(self.model, 'dpm_2', device=self.device)
elif self.sampler_name == 'k_euler_a':
self.sampler = KSampler(
self.model, 'euler_ancestral', device=self.device
)
elif self.sampler_name == 'k_euler':
self.sampler = KSampler(self.model, 'euler', device=self.device)
elif self.sampler_name == 'k_heun':
self.sampler = KSampler(self.model, 'heun', device=self.device)
elif self.sampler_name == 'k_lms':
self.sampler = KSampler(self.model, 'lms', device=self.device)
else:
msg = f'>> Unsupported Sampler: {self.sampler_name}, Defaulting to plms'
self.sampler = PLMSSampler(self.model, device=self.device)
print(msg)
def _load_model_from_config(self, config, ckpt): def _load_model_from_config(self, config, ckpt):
print(f'Loading model from {ckpt}') print(f'Loading model from {ckpt}')
pl_sd = torch.load(ckpt, map_location='cpu') pl_sd = torch.load(ckpt, map_location='cpu')

View File

@ -51,7 +51,8 @@ def main():
weights=weights, weights=weights,
full_precision=opt.full_precision, full_precision=opt.full_precision,
config=config, config=config,
latent_diffusion_weights=opt.laion400m, # this is solely for recreating the prompt # this is solely for recreating the prompt
latent_diffusion_weights=opt.laion400m,
embedding_path=opt.embedding_path, embedding_path=opt.embedding_path,
device=opt.device, device=opt.device,
) )
@ -281,8 +282,9 @@ def create_argv_parser():
help='use slower full precision math for calculations', help='use slower full precision math for calculations',
) )
parser.add_argument( parser.add_argument(
'--sampler', '-A',
'-m', '-m',
'--sampler',
dest='sampler_name', dest='sampler_name',
choices=[ choices=[
'ddim', 'ddim',
@ -295,7 +297,7 @@ def create_argv_parser():
'plms', 'plms',
], ],
default='k_lms', default='k_lms',
help='which sampler to use (k_lms) - can only be set on command line', help='which sampler to use (k_lms)',
) )
parser.add_argument( parser.add_argument(
'--outdir', '--outdir',
@ -447,6 +449,25 @@ def create_cmd_parser():
action='store_true', action='store_true',
help='skip subprompt weight normalization', help='skip subprompt weight normalization',
) )
parser.add_argument(
'-A',
'-m',
'--sampler',
dest='sampler_name',
default=None,
type=str,
choices=[
'ddim',
'k_dpm_2_a',
'k_dpm_2',
'k_euler_a',
'k_euler',
'k_heun',
'k_lms',
'plms',
],
help='Change to another supported sampler using this command',
)
return parser return parser