fix: denoise latents accepts CFG lists as input

This commit is contained in:
dunkeroni 2024-04-27 14:40:52 -04:00 committed by Kent Keirsey
parent 241a1fdb57
commit 71c3197eab

View File

@ -25,7 +25,7 @@ from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.schedulers import DPMSolverSDEScheduler from diffusers.schedulers import DPMSolverSDEScheduler
from diffusers.schedulers import SchedulerMixin as Scheduler from diffusers.schedulers import SchedulerMixin as Scheduler
from PIL import Image, ImageFilter from PIL import Image, ImageFilter
from pydantic import field_validator from pydantic import ValidationInfo, field_validator
from torchvision.transforms.functional import resize as tv_resize from torchvision.transforms.functional import resize as tv_resize
from transformers import CLIPVisionModelWithProjection from transformers import CLIPVisionModelWithProjection
@ -341,7 +341,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
) )
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps) steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
cfg_scale: Union[float, List[float]] = InputField( cfg_scale: Union[float, List[float]] = InputField(
default=7.5, ge=1, description=FieldDescriptions.cfg_scale, title="CFG Scale" default=7.5, description=FieldDescriptions.cfg_scale, title="CFG Scale"
) )
denoising_start: float = InputField( denoising_start: float = InputField(
default=0.0, default=0.0,
@ -397,12 +397,14 @@ class DenoiseLatentsInvocation(BaseInvocation):
) )
@field_validator("cfg_scale") @field_validator("cfg_scale")
def ge_one(cls, v: Union[List[float], float]) -> Union[List[float], float]: def ge_one(cls, v: Union[List[float], float], info: ValidationInfo) -> Union[List[float], float]:
"""validate that all cfg_scale values are >= 1""" """validate that all cfg_scale values are >= 1"""
if isinstance(v, list): if isinstance(v, list):
for i in v: for i in v:
if i < 1: if i < 1:
raise ValueError("cfg_scale must be greater than 1") raise ValueError("cfg_scale must be greater than 1")
if len(v) != info.data["steps"]:
raise ValueError("cfg_scale (list) must have the same length as the number of steps")
else: else:
if v < 1: if v < 1:
raise ValueError("cfg_scale must be greater than 1") raise ValueError("cfg_scale must be greater than 1")