mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Enable users to set sampler using prompts
This commit is contained in:
parent
38ed6393fa
commit
d1551b1bd4
@ -133,7 +133,8 @@ class T2I:
|
|||||||
full_precision=False,
|
full_precision=False,
|
||||||
strength=0.75, # default in scripts/img2img.py
|
strength=0.75, # default in scripts/img2img.py
|
||||||
embedding_path=None,
|
embedding_path=None,
|
||||||
latent_diffusion_weights=False, # just to keep track of this parameter when regenerating prompt
|
# just to keep track of this parameter when regenerating prompt
|
||||||
|
latent_diffusion_weights=False,
|
||||||
device='cuda',
|
device='cuda',
|
||||||
gfpgan=None,
|
gfpgan=None,
|
||||||
):
|
):
|
||||||
@ -175,7 +176,8 @@ class T2I:
|
|||||||
outdir, prompt, kwargs.get('batch_size', self.batch_size)
|
outdir, prompt, kwargs.get('batch_size', self.batch_size)
|
||||||
)
|
)
|
||||||
for r in results:
|
for r in results:
|
||||||
metadata_str = f'prompt2png("{prompt}" {kwargs} seed={r[1]}' # gets written into the PNG
|
# gets written into the PNG
|
||||||
|
metadata_str = f'prompt2png("{prompt}" {kwargs} seed={r[1]}'
|
||||||
pngwriter.write_image(r[0], r[1])
|
pngwriter.write_image(r[0], r[1])
|
||||||
return pngwriter.files_written
|
return pngwriter.files_written
|
||||||
|
|
||||||
@ -210,6 +212,7 @@ class T2I:
|
|||||||
strength=None,
|
strength=None,
|
||||||
gfpgan_strength=None,
|
gfpgan_strength=None,
|
||||||
variants=None,
|
variants=None,
|
||||||
|
user_sampler=None,
|
||||||
**args,
|
**args,
|
||||||
): # eat up additional cruft
|
): # eat up additional cruft
|
||||||
"""
|
"""
|
||||||
@ -269,6 +272,10 @@ class T2I:
|
|||||||
|
|
||||||
scope = autocast if self.precision == 'autocast' else nullcontext
|
scope = autocast if self.precision == 'autocast' else nullcontext
|
||||||
|
|
||||||
|
if user_sampler and (user_sampler != self.sampler_name):
|
||||||
|
self.sampler_name = user_sampler
|
||||||
|
self._set_sampler()
|
||||||
|
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
results = list()
|
results = list()
|
||||||
|
|
||||||
@ -305,12 +312,15 @@ class T2I:
|
|||||||
iter_images = next(images_iterator)
|
iter_images = next(images_iterator)
|
||||||
for image in iter_images:
|
for image in iter_images:
|
||||||
try:
|
try:
|
||||||
# if gfpgan strength is none or less than or equal to 0.0 then
|
# if gfpgan strength is none or less than or equal to 0.0 then
|
||||||
# don't even attempt to use GFPGAN.
|
# don't even attempt to use GFPGAN.
|
||||||
# if the user specified a value of -G that satisifies the condition and
|
# if the user specified a value of -G that satisifies the condition and
|
||||||
# --gfpgan wasn't specified, at startup then
|
# --gfpgan wasn't specified, at startup then
|
||||||
# the net result is a message gets printed - nothing else happens.
|
# the net result is a message gets printed - nothing else happens.
|
||||||
if gfpgan_strength is not None and gfpgan_strength > 0.0:
|
if (
|
||||||
|
gfpgan_strength is not None
|
||||||
|
and gfpgan_strength > 0.0
|
||||||
|
):
|
||||||
image = self._run_gfpgan(
|
image = self._run_gfpgan(
|
||||||
image, gfpgan_strength
|
image, gfpgan_strength
|
||||||
)
|
)
|
||||||
@ -499,39 +509,38 @@ class T2I:
|
|||||||
except AttributeError:
|
except AttributeError:
|
||||||
raise SystemExit
|
raise SystemExit
|
||||||
|
|
||||||
msg = f'setting sampler to {self.sampler_name}'
|
self._set_sampler()
|
||||||
if self.sampler_name == 'plms':
|
|
||||||
self.sampler = PLMSSampler(self.model, device=self.device)
|
|
||||||
elif self.sampler_name == 'ddim':
|
|
||||||
self.sampler = DDIMSampler(self.model, device=self.device)
|
|
||||||
elif self.sampler_name == 'k_dpm_2_a':
|
|
||||||
self.sampler = KSampler(
|
|
||||||
self.model, 'dpm_2_ancestral', device=self.device
|
|
||||||
)
|
|
||||||
elif self.sampler_name == 'k_dpm_2':
|
|
||||||
self.sampler = KSampler(
|
|
||||||
self.model, 'dpm_2', device=self.device
|
|
||||||
)
|
|
||||||
elif self.sampler_name == 'k_euler_a':
|
|
||||||
self.sampler = KSampler(
|
|
||||||
self.model, 'euler_ancestral', device=self.device
|
|
||||||
)
|
|
||||||
elif self.sampler_name == 'k_euler':
|
|
||||||
self.sampler = KSampler(
|
|
||||||
self.model, 'euler', device=self.device
|
|
||||||
)
|
|
||||||
elif self.sampler_name == 'k_heun':
|
|
||||||
self.sampler = KSampler(self.model, 'heun', device=self.device)
|
|
||||||
elif self.sampler_name == 'k_lms':
|
|
||||||
self.sampler = KSampler(self.model, 'lms', device=self.device)
|
|
||||||
else:
|
|
||||||
msg = f'unsupported sampler {self.sampler_name}, defaulting to plms'
|
|
||||||
self.sampler = PLMSSampler(self.model, device=self.device)
|
|
||||||
|
|
||||||
print(msg)
|
|
||||||
|
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
|
def _set_sampler(self):
|
||||||
|
msg = f'>> Setting Sampler to {self.sampler_name}'
|
||||||
|
if self.sampler_name == 'plms':
|
||||||
|
self.sampler = PLMSSampler(self.model, device=self.device)
|
||||||
|
elif self.sampler_name == 'ddim':
|
||||||
|
self.sampler = DDIMSampler(self.model, device=self.device)
|
||||||
|
elif self.sampler_name == 'k_dpm_2_a':
|
||||||
|
self.sampler = KSampler(
|
||||||
|
self.model, 'dpm_2_ancestral', device=self.device
|
||||||
|
)
|
||||||
|
elif self.sampler_name == 'k_dpm_2':
|
||||||
|
self.sampler = KSampler(self.model, 'dpm_2', device=self.device)
|
||||||
|
elif self.sampler_name == 'k_euler_a':
|
||||||
|
self.sampler = KSampler(
|
||||||
|
self.model, 'euler_ancestral', device=self.device
|
||||||
|
)
|
||||||
|
elif self.sampler_name == 'k_euler':
|
||||||
|
self.sampler = KSampler(self.model, 'euler', device=self.device)
|
||||||
|
elif self.sampler_name == 'k_heun':
|
||||||
|
self.sampler = KSampler(self.model, 'heun', device=self.device)
|
||||||
|
elif self.sampler_name == 'k_lms':
|
||||||
|
self.sampler = KSampler(self.model, 'lms', device=self.device)
|
||||||
|
else:
|
||||||
|
msg = f'>> Unsupported Sampler: {self.sampler_name}, Defaulting to plms'
|
||||||
|
self.sampler = PLMSSampler(self.model, device=self.device)
|
||||||
|
|
||||||
|
print(msg)
|
||||||
|
|
||||||
def _load_model_from_config(self, config, ckpt):
|
def _load_model_from_config(self, config, ckpt):
|
||||||
print(f'Loading model from {ckpt}')
|
print(f'Loading model from {ckpt}')
|
||||||
pl_sd = torch.load(ckpt, map_location='cpu')
|
pl_sd = torch.load(ckpt, map_location='cpu')
|
||||||
|
@ -52,7 +52,8 @@ def main():
|
|||||||
weights=weights,
|
weights=weights,
|
||||||
full_precision=opt.full_precision,
|
full_precision=opt.full_precision,
|
||||||
config=config,
|
config=config,
|
||||||
latent_diffusion_weights=opt.laion400m, # this is solely for recreating the prompt
|
# this is solely for recreating the prompt
|
||||||
|
latent_diffusion_weights=opt.laion400m,
|
||||||
embedding_path=opt.embedding_path,
|
embedding_path=opt.embedding_path,
|
||||||
device=opt.device,
|
device=opt.device,
|
||||||
)
|
)
|
||||||
@ -508,6 +509,23 @@ def create_cmd_parser():
|
|||||||
action='store_true',
|
action='store_true',
|
||||||
help='skip subprompt weight normalization',
|
help='skip subprompt weight normalization',
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'-m',
|
||||||
|
'--user_sampler',
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
choices=[
|
||||||
|
'ddim',
|
||||||
|
'k_dpm_2_a',
|
||||||
|
'k_dpm_2',
|
||||||
|
'k_euler_a',
|
||||||
|
'k_euler',
|
||||||
|
'k_heun',
|
||||||
|
'k_lms',
|
||||||
|
'plms',
|
||||||
|
],
|
||||||
|
help='Change to another supported sampler using this command',
|
||||||
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user