fix: changed validation to not error on connection

This commit is contained in:
dunkeroni 2024-04-27 15:12:06 -04:00 committed by Kent Keirsey
parent 71c3197eab
commit f262b9032d

View File

@ -25,7 +25,7 @@ from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.schedulers import DPMSolverSDEScheduler from diffusers.schedulers import DPMSolverSDEScheduler
from diffusers.schedulers import SchedulerMixin as Scheduler from diffusers.schedulers import SchedulerMixin as Scheduler
from PIL import Image, ImageFilter from PIL import Image, ImageFilter
from pydantic import ValidationInfo, field_validator from pydantic import field_validator
from torchvision.transforms.functional import resize as tv_resize from torchvision.transforms.functional import resize as tv_resize
from transformers import CLIPVisionModelWithProjection from transformers import CLIPVisionModelWithProjection
@ -397,14 +397,12 @@ class DenoiseLatentsInvocation(BaseInvocation):
) )
@field_validator("cfg_scale") @field_validator("cfg_scale")
def ge_one(cls, v: Union[List[float], float], info: ValidationInfo) -> Union[List[float], float]: def ge_one(cls, v: Union[List[float], float]) -> Union[List[float], float]:
"""validate that all cfg_scale values are >= 1""" """validate that all cfg_scale values are >= 1"""
if isinstance(v, list): if isinstance(v, list):
for i in v: for i in v:
if i < 1: if i < 1:
raise ValueError("cfg_scale must be greater than 1") raise ValueError("cfg_scale must be greater than 1")
if len(v) != info.data["steps"]:
raise ValueError("cfg_scale (list) must have the same length as the number of steps")
else: else:
if v < 1: if v < 1:
raise ValueError("cfg_scale must be greater than 1") raise ValueError("cfg_scale must be greater than 1")
@ -565,6 +563,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
dtype=unet.dtype, dtype=unet.dtype,
) )
if isinstance(self.cfg_scale, list):
assert (
len(self.cfg_scale) == self.steps
), "cfg_scale (list) must have the same length as the number of steps"
conditioning_data = TextConditioningData( conditioning_data = TextConditioningData(
uncond_text=uncond_text_embedding, uncond_text=uncond_text_embedding,
cond_text=cond_text_embedding, cond_text=cond_text_embedding,