fix: change eta only for TCD Scheduler

This commit is contained in:
blessedcoolant 2024-05-01 12:30:06 +05:30
parent 1bdcbe3284
commit dce8b88aaf

View File

@ -11,7 +11,6 @@ import numpy.typing as npt
import torch import torch
import torchvision import torchvision
import torchvision.transforms as T import torchvision.transforms as T
from diffusers import AutoencoderKL, AutoencoderTiny
from diffusers.configuration_utils import ConfigMixin from diffusers.configuration_utils import ConfigMixin
from diffusers.image_processor import VaeImageProcessor from diffusers.image_processor import VaeImageProcessor
from diffusers.models.adapter import T2IAdapter from diffusers.models.adapter import T2IAdapter
@ -21,9 +20,12 @@ from diffusers.models.attention_processor import (
LoRAXFormersAttnProcessor, LoRAXFormersAttnProcessor,
XFormersAttnProcessor, XFormersAttnProcessor,
) )
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.schedulers import DPMSolverSDEScheduler from diffusers.schedulers.scheduling_dpmsolver_sde import DPMSolverSDEScheduler
from diffusers.schedulers import SchedulerMixin as Scheduler from diffusers.schedulers.scheduling_tcd import TCDScheduler
from diffusers.schedulers.scheduling_utils import SchedulerMixin as Scheduler
from PIL import Image, ImageFilter from PIL import Image, ImageFilter
from pydantic import field_validator from pydantic import field_validator
from torchvision.transforms.functional import resize as tv_resize from torchvision.transforms.functional import resize as tv_resize
@ -861,7 +863,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
# all bits flipped. I don't know the original rationale for this, but now we must keep it like this for # all bits flipped. I don't know the original rationale for this, but now we must keep it like this for
# reproducibility. # reproducibility.
scheduler_step_kwargs.update({"generator": torch.Generator(device=device).manual_seed(seed ^ 0xFFFFFFFF)}) scheduler_step_kwargs.update({"generator": torch.Generator(device=device).manual_seed(seed ^ 0xFFFFFFFF)})
if "eta" in scheduler_step_signature.parameters: if isinstance(scheduler, TCDScheduler):
scheduler_step_kwargs.update({"eta": 1.0}) scheduler_step_kwargs.update({"eta": 1.0})
return num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs return num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs