mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
diffusers: enable DPMSolver++ scheduler
This commit is contained in:
parent
9bcb3b1bf7
commit
30a8d4c2b3
@ -19,6 +19,7 @@ import transformers
|
|||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
from diffusers.pipeline_utils import DiffusionPipeline
|
from diffusers.pipeline_utils import DiffusionPipeline
|
||||||
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||||
|
from diffusers.schedulers.scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
|
||||||
from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
|
from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
|
||||||
from diffusers.schedulers.scheduling_euler_discrete import EulerDiscreteScheduler
|
from diffusers.schedulers.scheduling_euler_discrete import EulerDiscreteScheduler
|
||||||
from diffusers.schedulers.scheduling_ipndm import IPNDMScheduler
|
from diffusers.schedulers.scheduling_ipndm import IPNDMScheduler
|
||||||
@ -1005,19 +1006,14 @@ class Generate:
|
|||||||
def _set_scheduler(self):
|
def _set_scheduler(self):
|
||||||
default = self.model.scheduler
|
default = self.model.scheduler
|
||||||
|
|
||||||
higher_order_samplers = [
|
|
||||||
'k_dpm_2',
|
|
||||||
'k_dpm_2_a',
|
|
||||||
'k_heun',
|
|
||||||
'plms', # Its first step is like Heun
|
|
||||||
]
|
|
||||||
scheduler_map = dict(
|
scheduler_map = dict(
|
||||||
ddim=DDIMScheduler,
|
ddim=DDIMScheduler,
|
||||||
ipndm=IPNDMScheduler,
|
ipndm=IPNDMScheduler,
|
||||||
k_euler=EulerDiscreteScheduler,
|
k_euler=EulerDiscreteScheduler,
|
||||||
k_euler_a=EulerAncestralDiscreteScheduler,
|
k_euler_a=EulerAncestralDiscreteScheduler,
|
||||||
k_lms=LMSDiscreteScheduler,
|
k_lms=LMSDiscreteScheduler,
|
||||||
pndm=PNDMScheduler,
|
plms=PNDMScheduler,
|
||||||
|
k_dpmpp_2=DPMSolverMultistepScheduler,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.sampler_name in scheduler_map:
|
if self.sampler_name in scheduler_map:
|
||||||
@ -1027,11 +1023,6 @@ class Generate:
|
|||||||
self.model_cache.model_name_or_path(self.model_name),
|
self.model_cache.model_name_or_path(self.model_name),
|
||||||
subfolder="scheduler"
|
subfolder="scheduler"
|
||||||
)
|
)
|
||||||
elif self.sampler_name in higher_order_samplers:
|
|
||||||
msg = (f'>> Unsupported Sampler: {self.sampler_name} '
|
|
||||||
f'— diffusers does not yet support higher-order samplers, '
|
|
||||||
f'Defaulting to {default}')
|
|
||||||
self.sampler = default
|
|
||||||
else:
|
else:
|
||||||
msg = (f'>> Unsupported Sampler: {self.sampler_name} '
|
msg = (f'>> Unsupported Sampler: {self.sampler_name} '
|
||||||
f'Defaulting to {default}')
|
f'Defaulting to {default}')
|
||||||
|
Loading…
Reference in New Issue
Block a user