Add dpmpp_sde and dpmpp_2m_sde schedulers(with karras)

This commit is contained in:
Sergey Borisov 2023-06-18 23:38:15 +03:00
parent f312e1448f
commit f3d9797ebe
8 changed files with 35 additions and 6 deletions

View File

@ -7,7 +7,7 @@ import einops
from pydantic import BaseModel, Field, validator from pydantic import BaseModel, Field, validator
import torch import torch
from diffusers import ControlNetModel from diffusers import ControlNetModel, DPMSolverMultistepScheduler
from diffusers.image_processor import VaeImageProcessor from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers import SchedulerMixin as Scheduler from diffusers.schedulers import SchedulerMixin as Scheduler
@ -222,6 +222,15 @@ class TextToLatentsInvocation(BaseInvocation):
c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name) c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name)
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name) uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
custom_args = dict(
eta=0.0, #ddim_eta
)
if type(scheduler) is DPMSolverMultistepScheduler and scheduler.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
custom_args.update(
generator=torch.Generator(device=uc.device).manual_seed(0),
)
conditioning_data = ConditioningData( conditioning_data = ConditioningData(
unconditioned_embeddings=uc, unconditioned_embeddings=uc,
text_embeddings=c, text_embeddings=c,
@ -233,7 +242,7 @@ class TextToLatentsInvocation(BaseInvocation):
h_symmetry_time_pct=None,#h_symmetry_time_pct, h_symmetry_time_pct=None,#h_symmetry_time_pct,
v_symmetry_time_pct=None#v_symmetry_time_pct, v_symmetry_time_pct=None#v_symmetry_time_pct,
), ),
).add_scheduler_args_if_applicable(scheduler, eta=0.0)#ddim_eta) ).add_scheduler_args_if_applicable(scheduler, **custom_args)
return conditioning_data return conditioning_data
def create_pipeline(self, unet, scheduler) -> StableDiffusionGeneratorPipeline: def create_pipeline(self, unet, scheduler) -> StableDiffusionGeneratorPipeline:

View File

@ -22,6 +22,10 @@ SAMPLER_CHOICES = [
"dpmpp_2s_k", "dpmpp_2s_k",
"dpmpp_2m", "dpmpp_2m",
"dpmpp_2m_k", "dpmpp_2m_k",
"dpmpp_2m_sde",
"dpmpp_2m_sde_k",
"dpmpp_sde",
"dpmpp_sde_k",
"unipc", "unipc",
] ]

View File

@ -1,7 +1,7 @@
from diffusers import DDIMScheduler, DPMSolverMultistepScheduler, KDPM2DiscreteScheduler, \ from diffusers import DDIMScheduler, DPMSolverMultistepScheduler, KDPM2DiscreteScheduler, \
KDPM2AncestralDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, \ KDPM2AncestralDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, \
HeunDiscreteScheduler, LMSDiscreteScheduler, PNDMScheduler, UniPCMultistepScheduler, \ HeunDiscreteScheduler, LMSDiscreteScheduler, PNDMScheduler, UniPCMultistepScheduler, \
DPMSolverSinglestepScheduler, DEISMultistepScheduler, DDPMScheduler DPMSolverSinglestepScheduler, DEISMultistepScheduler, DDPMScheduler, DPMSolverSDEScheduler
SCHEDULER_MAP = dict( SCHEDULER_MAP = dict(
ddim=(DDIMScheduler, dict()), ddim=(DDIMScheduler, dict()),
@ -21,5 +21,9 @@ SCHEDULER_MAP = dict(
dpmpp_2s_k=(DPMSolverSinglestepScheduler, dict(use_karras_sigmas=True)), dpmpp_2s_k=(DPMSolverSinglestepScheduler, dict(use_karras_sigmas=True)),
dpmpp_2m=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=False)), dpmpp_2m=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=False)),
dpmpp_2m_k=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=True)), dpmpp_2m_k=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=True)),
dpmpp_2m_sde=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=False, algorithm_type='sde-dpmsolver++')),
dpmpp_2m_sde_k=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=True, algorithm_type='sde-dpmsolver++')),
dpmpp_sde=(DPMSolverSDEScheduler, dict(use_karras_sigmas=False, noise_sampler_seed=0)),
dpmpp_sde_k=(DPMSolverSDEScheduler, dict(use_karras_sigmas=True, noise_sampler_seed=0)),
unipc=(UniPCMultistepScheduler, dict(cpu_only=True)) unipc=(UniPCMultistepScheduler, dict(cpu_only=True))
) )

View File

@ -20,6 +20,10 @@ SAMPLER_CHOICES = [
"dpmpp_2s_k", "dpmpp_2s_k",
"dpmpp_2m", "dpmpp_2m",
"dpmpp_2m_k", "dpmpp_2m_k",
"dpmpp_2m_sde",
"dpmpp_2m_sde_k",
"dpmpp_sde",
"dpmpp_sde_k",
"unipc", "unipc",
] ]

View File

@ -9,6 +9,8 @@ export const SCHEDULER_NAMES_AS_CONST = [
'ddpm', 'ddpm',
'dpmpp_2s', 'dpmpp_2s',
'dpmpp_2m', 'dpmpp_2m',
'dpmpp_2m_sde',
'dpmpp_sde',
'heun', 'heun',
'kdpm_2', 'kdpm_2',
'lms', 'lms',
@ -17,6 +19,8 @@ export const SCHEDULER_NAMES_AS_CONST = [
'euler_k', 'euler_k',
'dpmpp_2s_k', 'dpmpp_2s_k',
'dpmpp_2m_k', 'dpmpp_2m_k',
'dpmpp_2m_sde_k',
'dpmpp_sde_k',
'heun_k', 'heun_k',
'lms_k', 'lms_k',
'euler_a', 'euler_a',
@ -32,16 +36,20 @@ export const SCHEDULER_LABEL_MAP: Record<SchedulerParam, string> = {
deis: 'DEIS', deis: 'DEIS',
ddim: 'DDIM', ddim: 'DDIM',
ddpm: 'DDPM', ddpm: 'DDPM',
dpmpp_sde: 'DPM++ SDE',
dpmpp_2s: 'DPM++ 2S', dpmpp_2s: 'DPM++ 2S',
dpmpp_2m: 'DPM++ 2M', dpmpp_2m: 'DPM++ 2M',
dpmpp_2m_sde: 'DPM++ 2M SDE',
heun: 'Heun', heun: 'Heun',
kdpm_2: 'KDPM 2', kdpm_2: 'KDPM 2',
lms: 'LMS', lms: 'LMS',
pndm: 'PNDM', pndm: 'PNDM',
unipc: 'UniPC', unipc: 'UniPC',
euler_k: 'Euler Karras', euler_k: 'Euler Karras',
dpmpp_sde_k: 'DPM++ SDE Karras',
dpmpp_2s_k: 'DPM++ 2S Karras', dpmpp_2s_k: 'DPM++ 2S Karras',
dpmpp_2m_k: 'DPM++ 2M Karras', dpmpp_2m_k: 'DPM++ 2M Karras',
dpmpp_2m_sde_k: 'DPM++ 2M SDE Karras',
heun_k: 'Heun Karras', heun_k: 'Heun Karras',
lms_k: 'LMS Karras', lms_k: 'LMS Karras',
euler_a: 'Euler Ancestral', euler_a: 'Euler Ancestral',

View File

@ -45,7 +45,7 @@ export type InpaintInvocation = {
/** /**
* The scheduler to use * The scheduler to use
*/ */
scheduler?: 'ddim' | 'ddpm' | 'deis' | 'lms' | 'lms_k' | 'pndm' | 'heun' | 'heun_k' | 'euler' | 'euler_k' | 'euler_a' | 'kdpm_2' | 'kdpm_2_a' | 'dpmpp_2s' | 'dpmpp_2s_k' | 'dpmpp_2m' | 'dpmpp_2m_k' | 'unipc'; scheduler?: 'ddim' | 'ddpm' | 'deis' | 'lms' | 'lms_k' | 'pndm' | 'heun' | 'heun_k' | 'euler' | 'euler_k' | 'euler_a' | 'kdpm_2' | 'kdpm_2_a' | 'dpmpp_2s' | 'dpmpp_2s_k' | 'dpmpp_2m' | 'dpmpp_2m_k' | 'dpmpp_2m_sde' | 'dpmpp_2m_sde_k' | 'dpmpp_sde' | 'dpmpp_sde_k' | 'unipc';
/** /**
* The model to use (currently ignored) * The model to use (currently ignored)
*/ */

View File

@ -42,7 +42,7 @@ export type LatentsToLatentsInvocation = {
/** /**
* The scheduler to use * The scheduler to use
*/ */
scheduler?: 'ddim' | 'ddpm' | 'deis' | 'lms' | 'lms_k' | 'pndm' | 'heun' | 'heun_k' | 'euler' | 'euler_k' | 'euler_a' | 'kdpm_2' | 'kdpm_2_a' | 'dpmpp_2s' | 'dpmpp_2s_k' | 'dpmpp_2m' | 'dpmpp_2m_k' | 'unipc'; scheduler?: 'ddim' | 'ddpm' | 'deis' | 'lms' | 'lms_k' | 'pndm' | 'heun' | 'heun_k' | 'euler' | 'euler_k' | 'euler_a' | 'kdpm_2' | 'kdpm_2_a' | 'dpmpp_2s' | 'dpmpp_2s_k' | 'dpmpp_2m' | 'dpmpp_2m_k' | 'dpmpp_2m_sde' | 'dpmpp_2m_sde_k' | 'dpmpp_sde' | 'dpmpp_sde_k' | 'unipc';
/** /**
* The model to use (currently ignored) * The model to use (currently ignored)
*/ */

View File

@ -42,7 +42,7 @@ export type TextToLatentsInvocation = {
/** /**
* The scheduler to use * The scheduler to use
*/ */
scheduler?: 'ddim' | 'ddpm' | 'deis' | 'lms' | 'lms_k' | 'pndm' | 'heun' | 'heun_k' | 'euler' | 'euler_k' | 'euler_a' | 'kdpm_2' | 'kdpm_2_a' | 'dpmpp_2s' | 'dpmpp_2s_k' | 'dpmpp_2m' | 'dpmpp_2m_k' | 'unipc'; scheduler?: 'ddim' | 'ddpm' | 'deis' | 'lms' | 'lms_k' | 'pndm' | 'heun' | 'heun_k' | 'euler' | 'euler_k' | 'euler_a' | 'kdpm_2' | 'kdpm_2_a' | 'dpmpp_2s' | 'dpmpp_2s_k' | 'dpmpp_2m' | 'dpmpp_2m_k' | 'dpmpp_2m_sde' | 'dpmpp_2m_sde_k' | 'dpmpp_sde' | 'dpmpp_sde_k' | 'unipc';
/** /**
* The model to use (currently ignored) * The model to use (currently ignored)
*/ */