Add DPMPP Single, Euler Karras and DPMPP2 Multi Karras Schedulers

This commit is contained in:
blessedcoolant 2023-05-12 02:23:33 +12:00
parent 3487388788
commit 8a836247c8
7 changed files with 64 additions and 46 deletions

View File

@ -60,8 +60,13 @@ SAMPLER_NAME_VALUES = Literal[
def get_scheduler(scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler: def get_scheduler(scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim']) scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim'])
scheduler_config = {**model.scheduler.config, **scheduler_extra_config}
scheduler_config = model.scheduler.config
if "_backup" in scheduler_config:
scheduler_config = scheduler_config["_backup"]
scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config}
scheduler = scheduler_class.from_config(scheduler_config) scheduler = scheduler_class.from_config(scheduler_config)
# hack copied over from generate.py # hack copied over from generate.py
if not hasattr(scheduler, 'uses_inpainting_model'): if not hasattr(scheduler, 'uses_inpainting_model'):
scheduler.uses_inpainting_model = lambda: False scheduler.uses_inpainting_model = lambda: False

View File

@ -108,18 +108,18 @@ APP_VERSION = invokeai.version.__version__
SAMPLER_CHOICES = [ SAMPLER_CHOICES = [
"ddim", "ddim",
"k_dpm_2_a",
"k_dpm_2",
"k_dpmpp_2_a",
"k_dpmpp_2",
"k_euler_a",
"k_euler",
"k_heun",
"k_lms", "k_lms",
"plms", "plms",
# diffusers: "k_heun",
"pndm", "k_euler",
"unipc" "euler_karras",
"k_euler_a",
"k_dpm_2",
"k_dpm_2_a",
"dpmpp_2s",
"k_dpmpp_2",
"k_dpmpp_2_karras",
"unipc",
] ]
PRECISION_CHOICES = [ PRECISION_CHOICES = [

View File

@ -170,8 +170,13 @@ class InvokeAIGenerator(metaclass=ABCMeta):
def get_scheduler(self, scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler: def get_scheduler(self, scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim']) scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim'])
scheduler_config = {**model.scheduler.config, **scheduler_extra_config}
scheduler_config = model.scheduler.config
if "_backup" in scheduler_config:
scheduler_config = scheduler_config["_backup"]
scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config}
scheduler = scheduler_class.from_config(scheduler_config) scheduler = scheduler_class.from_config(scheduler_config)
# hack copied over from generate.py # hack copied over from generate.py
if not hasattr(scheduler, 'uses_inpainting_model'): if not hasattr(scheduler, 'uses_inpainting_model'):
scheduler.uses_inpainting_model = lambda: False scheduler.uses_inpainting_model = lambda: False

View File

@ -1,17 +1,20 @@
from diffusers import DDIMScheduler, DPMSolverMultistepScheduler, KDPM2DiscreteScheduler, \ from diffusers import DDIMScheduler, DPMSolverMultistepScheduler, KDPM2DiscreteScheduler, \
KDPM2AncestralDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, \ KDPM2AncestralDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, \
HeunDiscreteScheduler, LMSDiscreteScheduler, PNDMScheduler, UniPCMultistepScheduler HeunDiscreteScheduler, LMSDiscreteScheduler, PNDMScheduler, UniPCMultistepScheduler, \
DPMSolverSinglestepScheduler
SCHEDULER_MAP = dict( SCHEDULER_MAP = dict(
ddim=(DDIMScheduler, dict(cpu_only=False)), ddim=(DDIMScheduler, dict()),
dpmpp_2=(DPMSolverMultistepScheduler, dict(cpu_only=False)), k_lms=(LMSDiscreteScheduler, dict()),
k_dpm_2=(KDPM2DiscreteScheduler, dict(cpu_only=False)), plms=(PNDMScheduler, dict()),
k_dpm_2_a=(KDPM2AncestralDiscreteScheduler, dict(cpu_only=False)), k_euler=(EulerDiscreteScheduler, dict(use_karras_sigmas=False)),
k_dpmpp_2=(DPMSolverMultistepScheduler, dict(cpu_only=False)), euler_karras=(EulerDiscreteScheduler, dict(use_karras_sigmas=True)),
k_euler=(EulerDiscreteScheduler, dict(cpu_only=False)), k_euler_a=(EulerAncestralDiscreteScheduler, dict()),
k_euler_a=(EulerAncestralDiscreteScheduler, dict(cpu_only=False)), k_dpm_2=(KDPM2DiscreteScheduler, dict()),
k_heun=(HeunDiscreteScheduler, dict(cpu_only=False)), k_dpm_2_a=(KDPM2AncestralDiscreteScheduler, dict()),
k_lms=(LMSDiscreteScheduler, dict(cpu_only=False)), dpmpp_2s=(DPMSolverSinglestepScheduler, dict()),
plms=(PNDMScheduler, dict(cpu_only=False)), k_dpmpp_2=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=False)),
k_dpmpp_2_karras=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=True)),
k_heun=(HeunDiscreteScheduler, dict()),
unipc=(UniPCMultistepScheduler, dict(cpu_only=True)) unipc=(UniPCMultistepScheduler, dict(cpu_only=True))
) )

View File

@ -4,18 +4,18 @@ from .parse_seed_weights import parse_seed_weights
SAMPLER_CHOICES = [ SAMPLER_CHOICES = [
"ddim", "ddim",
"k_dpm_2_a",
"k_dpm_2",
"k_dpmpp_2_a",
"k_dpmpp_2",
"k_euler_a",
"k_euler",
"k_heun",
"k_lms", "k_lms",
"plms", "plms",
# diffusers: "k_heun",
"pndm", "k_euler",
"unipc" "euler_karras",
"k_euler_a",
"k_dpm_2",
"k_dpm_2_a",
"dpmpp_2s",
"k_dpmpp_2",
"k_dpmpp_2_karras",
"unipc",
] ]

View File

@ -2,15 +2,17 @@
export const DIFFUSERS_SCHEDULERS: Array<string> = [ export const DIFFUSERS_SCHEDULERS: Array<string> = [
'ddim', 'ddim',
'plms',
'k_lms', 'k_lms',
'dpmpp_2', 'plms',
'k_heun',
'k_euler',
'euler_karras',
'k_euler_a',
'k_dpm_2', 'k_dpm_2',
'k_dpm_2_a', 'k_dpm_2_a',
'dpmpp_2s',
'k_dpmpp_2', 'k_dpmpp_2',
'k_euler', 'k_dpmpp_2_karras',
'k_euler_a',
'k_heun',
'unipc', 'unipc',
]; ];

View File

@ -47,15 +47,18 @@ export type CommonGeneratedImageMetadata = {
postprocessing: null | Array<ESRGANMetadata | FacetoolMetadata>; postprocessing: null | Array<ESRGANMetadata | FacetoolMetadata>;
sampler: sampler:
| 'ddim' | 'ddim'
| 'k_dpm_2_a'
| 'k_dpm_2'
| 'k_dpmpp_2_a'
| 'k_dpmpp_2'
| 'k_euler_a'
| 'k_euler'
| 'k_heun'
| 'k_lms' | 'k_lms'
| 'plms'; | 'plms'
| 'k_heun'
| 'k_euler'
| 'euler_karras'
| 'k_euler_a'
| 'k_dpm_2'
| 'k_dpm_2_a'
| 'dpmpp_2s'
| 'k_dpmpp_2'
| 'k_dpmpp_2_karras'
| 'unipc';
prompt: Prompt; prompt: Prompt;
seed: number; seed: number;
variations: SeedWeights; variations: SeedWeights;