mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
diffusers: upgrade to diffusers 0.10, add Heun scheduler
This commit is contained in:
parent
30a8d4c2b3
commit
9199d698f8
@ -4,7 +4,7 @@
|
|||||||
--trusted-host https://download.pytorch.org
|
--trusted-host https://download.pytorch.org
|
||||||
accelerate~=0.14
|
accelerate~=0.14
|
||||||
albumentations
|
albumentations
|
||||||
diffusers[torch]~=0.9
|
diffusers[torch]~=0.10
|
||||||
einops
|
einops
|
||||||
eventlet
|
eventlet
|
||||||
flask_cors
|
flask_cors
|
||||||
|
@ -17,6 +17,7 @@ import skimage
|
|||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
|
from diffusers import HeunDiscreteScheduler
|
||||||
from diffusers.pipeline_utils import DiffusionPipeline
|
from diffusers.pipeline_utils import DiffusionPipeline
|
||||||
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||||
from diffusers.schedulers.scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
|
from diffusers.schedulers.scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
|
||||||
@ -1008,12 +1009,17 @@ class Generate:
|
|||||||
|
|
||||||
scheduler_map = dict(
|
scheduler_map = dict(
|
||||||
ddim=DDIMScheduler,
|
ddim=DDIMScheduler,
|
||||||
|
dpmpp_2=DPMSolverMultistepScheduler,
|
||||||
ipndm=IPNDMScheduler,
|
ipndm=IPNDMScheduler,
|
||||||
|
# DPMSolverMultistepScheduler is technically not `k_` anything, as it is neither
|
||||||
|
# the k-diffusers implementation nor included in EDM (Karras 2022), but we can
|
||||||
|
# provide an alias for compatibility.
|
||||||
|
k_dpmpp_2=DPMSolverMultistepScheduler,
|
||||||
k_euler=EulerDiscreteScheduler,
|
k_euler=EulerDiscreteScheduler,
|
||||||
k_euler_a=EulerAncestralDiscreteScheduler,
|
k_euler_a=EulerAncestralDiscreteScheduler,
|
||||||
|
k_heun=HeunDiscreteScheduler,
|
||||||
k_lms=LMSDiscreteScheduler,
|
k_lms=LMSDiscreteScheduler,
|
||||||
plms=PNDMScheduler,
|
plms=PNDMScheduler,
|
||||||
k_dpmpp_2=DPMSolverMultistepScheduler,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.sampler_name in scheduler_map:
|
if self.sampler_name in scheduler_map:
|
||||||
|
@ -407,7 +407,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
batch_size = initial_latents.size(0)
|
batch_size = initial_latents.size(0)
|
||||||
img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
|
img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
|
||||||
img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
|
img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||||
timesteps = img2img_pipeline.get_timesteps(num_inference_steps, strength, device=device)
|
timesteps, _ = img2img_pipeline.get_timesteps(num_inference_steps, strength, device=device)
|
||||||
latent_timestep = timesteps[:1].repeat(batch_size)
|
latent_timestep = timesteps[:1].repeat(batch_size)
|
||||||
latents = self.noise_latents_for_time(initial_latents, latent_timestep, noise_func=noise_func)
|
latents = self.noise_latents_for_time(initial_latents, latent_timestep, noise_func=noise_func)
|
||||||
|
|
||||||
@ -454,7 +454,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
|
|
||||||
img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
|
img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
|
||||||
img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
|
img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||||
timesteps = img2img_pipeline.get_timesteps(num_inference_steps, strength, device=device)
|
timesteps, _ = img2img_pipeline.get_timesteps(num_inference_steps, strength, device=device)
|
||||||
|
|
||||||
assert img2img_pipeline.scheduler is self.scheduler
|
assert img2img_pipeline.scheduler is self.scheduler
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user