Enable users to set sampler using prompts

This commit is contained in:
blessedcoolant 2022-08-28 16:05:00 +12:00
parent 38ed6393fa
commit d1551b1bd4
2 changed files with 63 additions and 36 deletions

View File

@ -133,7 +133,8 @@ class T2I:
full_precision=False,
strength=0.75, # default in scripts/img2img.py
embedding_path=None,
latent_diffusion_weights=False, # just to keep track of this parameter when regenerating prompt
# just to keep track of this parameter when regenerating prompt
latent_diffusion_weights=False,
device='cuda',
gfpgan=None,
):
@ -175,7 +176,8 @@ class T2I:
outdir, prompt, kwargs.get('batch_size', self.batch_size)
)
for r in results:
metadata_str = f'prompt2png("{prompt}" {kwargs} seed={r[1]}' # gets written into the PNG
# gets written into the PNG
metadata_str = f'prompt2png("{prompt}" {kwargs} seed={r[1]}'
pngwriter.write_image(r[0], r[1])
return pngwriter.files_written
@ -210,6 +212,7 @@ class T2I:
strength=None,
gfpgan_strength=None,
variants=None,
user_sampler=None,
**args,
): # eat up additional cruft
"""
@ -269,6 +272,10 @@ class T2I:
scope = autocast if self.precision == 'autocast' else nullcontext
if user_sampler and (user_sampler != self.sampler_name):
self.sampler_name = user_sampler
self._set_sampler()
tic = time.time()
results = list()
@ -305,12 +312,15 @@ class T2I:
iter_images = next(images_iterator)
for image in iter_images:
try:
# if gfpgan strength is none or less than or equal to 0.0 then
# if gfpgan strength is none or less than or equal to 0.0 then
# don't even attempt to use GFPGAN.
# if the user specified a value of -G that satisifies the condition and
# if the user specified a value of -G that satisifies the condition and
# --gfpgan wasn't specified, at startup then
# the net result is a message gets printed - nothing else happens.
if gfpgan_strength is not None and gfpgan_strength > 0.0:
if (
gfpgan_strength is not None
and gfpgan_strength > 0.0
):
image = self._run_gfpgan(
image, gfpgan_strength
)
@ -499,39 +509,38 @@ class T2I:
except AttributeError:
raise SystemExit
msg = f'setting sampler to {self.sampler_name}'
if self.sampler_name == 'plms':
self.sampler = PLMSSampler(self.model, device=self.device)
elif self.sampler_name == 'ddim':
self.sampler = DDIMSampler(self.model, device=self.device)
elif self.sampler_name == 'k_dpm_2_a':
self.sampler = KSampler(
self.model, 'dpm_2_ancestral', device=self.device
)
elif self.sampler_name == 'k_dpm_2':
self.sampler = KSampler(
self.model, 'dpm_2', device=self.device
)
elif self.sampler_name == 'k_euler_a':
self.sampler = KSampler(
self.model, 'euler_ancestral', device=self.device
)
elif self.sampler_name == 'k_euler':
self.sampler = KSampler(
self.model, 'euler', device=self.device
)
elif self.sampler_name == 'k_heun':
self.sampler = KSampler(self.model, 'heun', device=self.device)
elif self.sampler_name == 'k_lms':
self.sampler = KSampler(self.model, 'lms', device=self.device)
else:
msg = f'unsupported sampler {self.sampler_name}, defaulting to plms'
self.sampler = PLMSSampler(self.model, device=self.device)
print(msg)
self._set_sampler()
return self.model
def _set_sampler(self):
msg = f'>> Setting Sampler to {self.sampler_name}'
if self.sampler_name == 'plms':
self.sampler = PLMSSampler(self.model, device=self.device)
elif self.sampler_name == 'ddim':
self.sampler = DDIMSampler(self.model, device=self.device)
elif self.sampler_name == 'k_dpm_2_a':
self.sampler = KSampler(
self.model, 'dpm_2_ancestral', device=self.device
)
elif self.sampler_name == 'k_dpm_2':
self.sampler = KSampler(self.model, 'dpm_2', device=self.device)
elif self.sampler_name == 'k_euler_a':
self.sampler = KSampler(
self.model, 'euler_ancestral', device=self.device
)
elif self.sampler_name == 'k_euler':
self.sampler = KSampler(self.model, 'euler', device=self.device)
elif self.sampler_name == 'k_heun':
self.sampler = KSampler(self.model, 'heun', device=self.device)
elif self.sampler_name == 'k_lms':
self.sampler = KSampler(self.model, 'lms', device=self.device)
else:
msg = f'>> Unsupported Sampler: {self.sampler_name}, Defaulting to plms'
self.sampler = PLMSSampler(self.model, device=self.device)
print(msg)
def _load_model_from_config(self, config, ckpt):
print(f'Loading model from {ckpt}')
pl_sd = torch.load(ckpt, map_location='cpu')

View File

@ -52,7 +52,8 @@ def main():
weights=weights,
full_precision=opt.full_precision,
config=config,
latent_diffusion_weights=opt.laion400m, # this is solely for recreating the prompt
# this is solely for recreating the prompt
latent_diffusion_weights=opt.laion400m,
embedding_path=opt.embedding_path,
device=opt.device,
)
@ -508,6 +509,23 @@ def create_cmd_parser():
action='store_true',
help='skip subprompt weight normalization',
)
parser.add_argument(
'-m',
'--user_sampler',
default=None,
type=str,
choices=[
'ddim',
'k_dpm_2_a',
'k_dpm_2',
'k_euler_a',
'k_euler',
'k_heun',
'k_lms',
'plms',
],
help='Change to another supported sampler using this command',
)
return parser