mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
added support for changing sampler on prompt line
This commit is contained in:
commit
1f8bc9482a
@ -344,6 +344,7 @@ repository and associated paper for details and limitations.
|
|||||||
enabled if the GFPGAN directory is located as a sibling to Stable Diffusion.
|
enabled if the GFPGAN directory is located as a sibling to Stable Diffusion.
|
||||||
VRAM requirements are modestly reduced. Thanks to both [Blessedcoolant](https://github.com/blessedcoolant) and
|
VRAM requirements are modestly reduced. Thanks to both [Blessedcoolant](https://github.com/blessedcoolant) and
|
||||||
[Oceanswave](https://github.com/oceanswave) for their work on this.
|
[Oceanswave](https://github.com/oceanswave) for their work on this.
|
||||||
|
- You can now swap samplers on the dream> command line. [Blessedcoolant](https://github.com/blessedcoolant)
|
||||||
|
|
||||||
- v1.11 (26 August 2022)
|
- v1.11 (26 August 2022)
|
||||||
- NEW FEATURE: Support upscaling and face enhancement using the GFPGAN module. (kudos to [Oceanswave](https://github.com/Oceanswave)
|
- NEW FEATURE: Support upscaling and face enhancement using the GFPGAN module. (kudos to [Oceanswave](https://github.com/Oceanswave)
|
||||||
|
@ -106,6 +106,8 @@ class PromptFormatter:
|
|||||||
self.t2i = t2i
|
self.t2i = t2i
|
||||||
self.opt = opt
|
self.opt = opt
|
||||||
|
|
||||||
|
# note: the t2i object should provide all these values.
|
||||||
|
# there should be no need to or against opt values
|
||||||
def normalize_prompt(self):
|
def normalize_prompt(self):
|
||||||
"""Normalize the prompt and switches"""
|
"""Normalize the prompt and switches"""
|
||||||
t2i = self.t2i
|
t2i = self.t2i
|
||||||
@ -118,13 +120,15 @@ class PromptFormatter:
|
|||||||
switches.append(f'-W{opt.width or t2i.width}')
|
switches.append(f'-W{opt.width or t2i.width}')
|
||||||
switches.append(f'-H{opt.height or t2i.height}')
|
switches.append(f'-H{opt.height or t2i.height}')
|
||||||
switches.append(f'-C{opt.cfg_scale or t2i.cfg_scale}')
|
switches.append(f'-C{opt.cfg_scale or t2i.cfg_scale}')
|
||||||
switches.append(f'-m{t2i.sampler_name}')
|
switches.append(f'-A{opt.sampler_name or t2i.sampler_name}')
|
||||||
if opt.init_img:
|
if opt.init_img:
|
||||||
switches.append(f'-I{opt.init_img}')
|
switches.append(f'-I{opt.init_img}')
|
||||||
if opt.strength and opt.init_img is not None:
|
if opt.strength and opt.init_img is not None:
|
||||||
switches.append(f'-f{opt.strength or t2i.strength}')
|
switches.append(f'-f{opt.strength or t2i.strength}')
|
||||||
if opt.gfpgan_strength:
|
if opt.gfpgan_strength:
|
||||||
switches.append(f'-G{opt.gfpgan_strength}')
|
switches.append(f'-G{opt.gfpgan_strength}')
|
||||||
|
if opt.upscale:
|
||||||
|
switches.append(f'-U {" ".join([str(u) for u in opt.upscale])}')
|
||||||
if t2i.full_precision:
|
if t2i.full_precision:
|
||||||
switches.append('-F')
|
switches.append('-F')
|
||||||
return ' '.join(switches)
|
return ' '.join(switches)
|
||||||
|
@ -212,6 +212,7 @@ class T2I:
|
|||||||
save_original=False,
|
save_original=False,
|
||||||
upscale=None,
|
upscale=None,
|
||||||
variants=None,
|
variants=None,
|
||||||
|
user_sampler=None,
|
||||||
**args,
|
**args,
|
||||||
): # eat up additional cruft
|
): # eat up additional cruft
|
||||||
"""
|
"""
|
||||||
@ -271,6 +272,10 @@ class T2I:
|
|||||||
|
|
||||||
scope = autocast if self.precision == 'autocast' else nullcontext
|
scope = autocast if self.precision == 'autocast' else nullcontext
|
||||||
|
|
||||||
|
if user_sampler and (user_sampler != self.sampler_name):
|
||||||
|
self.sampler_name = user_sampler
|
||||||
|
self._set_sampler()
|
||||||
|
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
torch.cuda.torch.cuda.reset_peak_memory_stats()
|
torch.cuda.torch.cuda.reset_peak_memory_stats()
|
||||||
results = list()
|
results = list()
|
||||||
@ -537,39 +542,38 @@ class T2I:
|
|||||||
except AttributeError:
|
except AttributeError:
|
||||||
raise SystemExit
|
raise SystemExit
|
||||||
|
|
||||||
msg = f'setting sampler to {self.sampler_name}'
|
self._set_sampler()
|
||||||
if self.sampler_name == 'plms':
|
|
||||||
self.sampler = PLMSSampler(self.model, device=self.device)
|
|
||||||
elif self.sampler_name == 'ddim':
|
|
||||||
self.sampler = DDIMSampler(self.model, device=self.device)
|
|
||||||
elif self.sampler_name == 'k_dpm_2_a':
|
|
||||||
self.sampler = KSampler(
|
|
||||||
self.model, 'dpm_2_ancestral', device=self.device
|
|
||||||
)
|
|
||||||
elif self.sampler_name == 'k_dpm_2':
|
|
||||||
self.sampler = KSampler(
|
|
||||||
self.model, 'dpm_2', device=self.device
|
|
||||||
)
|
|
||||||
elif self.sampler_name == 'k_euler_a':
|
|
||||||
self.sampler = KSampler(
|
|
||||||
self.model, 'euler_ancestral', device=self.device
|
|
||||||
)
|
|
||||||
elif self.sampler_name == 'k_euler':
|
|
||||||
self.sampler = KSampler(
|
|
||||||
self.model, 'euler', device=self.device
|
|
||||||
)
|
|
||||||
elif self.sampler_name == 'k_heun':
|
|
||||||
self.sampler = KSampler(self.model, 'heun', device=self.device)
|
|
||||||
elif self.sampler_name == 'k_lms':
|
|
||||||
self.sampler = KSampler(self.model, 'lms', device=self.device)
|
|
||||||
else:
|
|
||||||
msg = f'unsupported sampler {self.sampler_name}, defaulting to plms'
|
|
||||||
self.sampler = PLMSSampler(self.model, device=self.device)
|
|
||||||
|
|
||||||
print(msg)
|
|
||||||
|
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
|
def _set_sampler(self):
|
||||||
|
msg = f'>> Setting Sampler to {self.sampler_name}'
|
||||||
|
if self.sampler_name == 'plms':
|
||||||
|
self.sampler = PLMSSampler(self.model, device=self.device)
|
||||||
|
elif self.sampler_name == 'ddim':
|
||||||
|
self.sampler = DDIMSampler(self.model, device=self.device)
|
||||||
|
elif self.sampler_name == 'k_dpm_2_a':
|
||||||
|
self.sampler = KSampler(
|
||||||
|
self.model, 'dpm_2_ancestral', device=self.device
|
||||||
|
)
|
||||||
|
elif self.sampler_name == 'k_dpm_2':
|
||||||
|
self.sampler = KSampler(self.model, 'dpm_2', device=self.device)
|
||||||
|
elif self.sampler_name == 'k_euler_a':
|
||||||
|
self.sampler = KSampler(
|
||||||
|
self.model, 'euler_ancestral', device=self.device
|
||||||
|
)
|
||||||
|
elif self.sampler_name == 'k_euler':
|
||||||
|
self.sampler = KSampler(self.model, 'euler', device=self.device)
|
||||||
|
elif self.sampler_name == 'k_heun':
|
||||||
|
self.sampler = KSampler(self.model, 'heun', device=self.device)
|
||||||
|
elif self.sampler_name == 'k_lms':
|
||||||
|
self.sampler = KSampler(self.model, 'lms', device=self.device)
|
||||||
|
else:
|
||||||
|
msg = f'>> Unsupported Sampler: {self.sampler_name}, Defaulting to plms'
|
||||||
|
self.sampler = PLMSSampler(self.model, device=self.device)
|
||||||
|
|
||||||
|
print(msg)
|
||||||
|
|
||||||
def _load_model_from_config(self, config, ckpt):
|
def _load_model_from_config(self, config, ckpt):
|
||||||
print(f'Loading model from {ckpt}')
|
print(f'Loading model from {ckpt}')
|
||||||
pl_sd = torch.load(ckpt, map_location='cpu')
|
pl_sd = torch.load(ckpt, map_location='cpu')
|
||||||
|
@ -51,7 +51,8 @@ def main():
|
|||||||
weights=weights,
|
weights=weights,
|
||||||
full_precision=opt.full_precision,
|
full_precision=opt.full_precision,
|
||||||
config=config,
|
config=config,
|
||||||
latent_diffusion_weights=opt.laion400m, # this is solely for recreating the prompt
|
# this is solely for recreating the prompt
|
||||||
|
latent_diffusion_weights=opt.laion400m,
|
||||||
embedding_path=opt.embedding_path,
|
embedding_path=opt.embedding_path,
|
||||||
device=opt.device,
|
device=opt.device,
|
||||||
)
|
)
|
||||||
@ -281,8 +282,9 @@ def create_argv_parser():
|
|||||||
help='use slower full precision math for calculations',
|
help='use slower full precision math for calculations',
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--sampler',
|
'-A',
|
||||||
'-m',
|
'-m',
|
||||||
|
'--sampler',
|
||||||
dest='sampler_name',
|
dest='sampler_name',
|
||||||
choices=[
|
choices=[
|
||||||
'ddim',
|
'ddim',
|
||||||
@ -295,7 +297,7 @@ def create_argv_parser():
|
|||||||
'plms',
|
'plms',
|
||||||
],
|
],
|
||||||
default='k_lms',
|
default='k_lms',
|
||||||
help='which sampler to use (k_lms) - can only be set on command line',
|
help='which sampler to use (k_lms)',
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--outdir',
|
'--outdir',
|
||||||
@ -447,6 +449,25 @@ def create_cmd_parser():
|
|||||||
action='store_true',
|
action='store_true',
|
||||||
help='skip subprompt weight normalization',
|
help='skip subprompt weight normalization',
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'-A',
|
||||||
|
'-m',
|
||||||
|
'--sampler',
|
||||||
|
dest='sampler_name',
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
choices=[
|
||||||
|
'ddim',
|
||||||
|
'k_dpm_2_a',
|
||||||
|
'k_dpm_2',
|
||||||
|
'k_euler_a',
|
||||||
|
'k_euler',
|
||||||
|
'k_heun',
|
||||||
|
'k_lms',
|
||||||
|
'plms',
|
||||||
|
],
|
||||||
|
help='Change to another supported sampler using this command',
|
||||||
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user