diffusers: enable DPMSolver++ scheduler

This commit is contained in:
Kevin Turner 2022-12-07 19:00:23 -08:00
parent 9bcb3b1bf7
commit 30a8d4c2b3

View File

@ -19,6 +19,7 @@ import transformers
from PIL import Image, ImageOps
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from diffusers.schedulers.scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
from diffusers.schedulers.scheduling_euler_discrete import EulerDiscreteScheduler
from diffusers.schedulers.scheduling_ipndm import IPNDMScheduler
@ -1005,19 +1006,14 @@ class Generate:
def _set_scheduler(self):
default = self.model.scheduler
higher_order_samplers = [
'k_dpm_2',
'k_dpm_2_a',
'k_heun',
'plms', # Its first step is like Heun
]
scheduler_map = dict(
ddim=DDIMScheduler,
ipndm=IPNDMScheduler,
k_euler=EulerDiscreteScheduler,
k_euler_a=EulerAncestralDiscreteScheduler,
k_lms=LMSDiscreteScheduler,
pndm=PNDMScheduler,
plms=PNDMScheduler,
k_dpmpp_2=DPMSolverMultistepScheduler,
)
if self.sampler_name in scheduler_map:
@ -1027,11 +1023,6 @@ class Generate:
self.model_cache.model_name_or_path(self.model_name),
subfolder="scheduler"
)
elif self.sampler_name in higher_order_samplers:
msg = (f'>> Unsupported Sampler: {self.sampler_name} '
f'— diffusers does not yet support higher-order samplers, '
f'Defaulting to {default}')
self.sampler = default
else:
msg = (f'>> Unsupported Sampler: {self.sampler_name} '
f'Defaulting to {default}')