diffusers: enable DPMSolver++ scheduler

This commit is contained in:
Kevin Turner 2022-12-07 19:00:23 -08:00
parent 9bcb3b1bf7
commit 30a8d4c2b3

View File

@ -19,6 +19,7 @@ import transformers
from PIL import Image, ImageOps from PIL import Image, ImageOps
from diffusers.pipeline_utils import DiffusionPipeline from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.schedulers.scheduling_ddim import DDIMScheduler from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from diffusers.schedulers.scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
from diffusers.schedulers.scheduling_euler_discrete import EulerDiscreteScheduler from diffusers.schedulers.scheduling_euler_discrete import EulerDiscreteScheduler
from diffusers.schedulers.scheduling_ipndm import IPNDMScheduler from diffusers.schedulers.scheduling_ipndm import IPNDMScheduler
@ -1005,19 +1006,14 @@ class Generate:
def _set_scheduler(self): def _set_scheduler(self):
default = self.model.scheduler default = self.model.scheduler
higher_order_samplers = [
'k_dpm_2',
'k_dpm_2_a',
'k_heun',
'plms', # Its first step is like Heun
]
scheduler_map = dict( scheduler_map = dict(
ddim=DDIMScheduler, ddim=DDIMScheduler,
ipndm=IPNDMScheduler, ipndm=IPNDMScheduler,
k_euler=EulerDiscreteScheduler, k_euler=EulerDiscreteScheduler,
k_euler_a=EulerAncestralDiscreteScheduler, k_euler_a=EulerAncestralDiscreteScheduler,
k_lms=LMSDiscreteScheduler, k_lms=LMSDiscreteScheduler,
pndm=PNDMScheduler, plms=PNDMScheduler,
k_dpmpp_2=DPMSolverMultistepScheduler,
) )
if self.sampler_name in scheduler_map: if self.sampler_name in scheduler_map:
@ -1027,11 +1023,6 @@ class Generate:
self.model_cache.model_name_or_path(self.model_name), self.model_cache.model_name_or_path(self.model_name),
subfolder="scheduler" subfolder="scheduler"
) )
elif self.sampler_name in higher_order_samplers:
msg = (f'>> Unsupported Sampler: {self.sampler_name} '
f'— diffusers does not yet support higher-order samplers, '
f'Defaulting to {default}')
self.sampler = default
else: else:
msg = (f'>> Unsupported Sampler: {self.sampler_name} ' msg = (f'>> Unsupported Sampler: {self.sampler_name} '
f'Defaulting to {default}') f'Defaulting to {default}')