mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix: changed validation to not error on connection
This commit is contained in:
parent
71c3197eab
commit
f262b9032d
@ -25,7 +25,7 @@ from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||
from diffusers.schedulers import DPMSolverSDEScheduler
|
||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
from PIL import Image, ImageFilter
|
||||
from pydantic import ValidationInfo, field_validator
|
||||
from pydantic import field_validator
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
from transformers import CLIPVisionModelWithProjection
|
||||
|
||||
@ -397,14 +397,12 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
@field_validator("cfg_scale")
|
||||
def ge_one(cls, v: Union[List[float], float], info: ValidationInfo) -> Union[List[float], float]:
|
||||
def ge_one(cls, v: Union[List[float], float]) -> Union[List[float], float]:
|
||||
"""validate that all cfg_scale values are >= 1"""
|
||||
if isinstance(v, list):
|
||||
for i in v:
|
||||
if i < 1:
|
||||
raise ValueError("cfg_scale must be greater than 1")
|
||||
if len(v) != info.data["steps"]:
|
||||
raise ValueError("cfg_scale (list) must have the same length as the number of steps")
|
||||
else:
|
||||
if v < 1:
|
||||
raise ValueError("cfg_scale must be greater than 1")
|
||||
@ -565,6 +563,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
dtype=unet.dtype,
|
||||
)
|
||||
|
||||
if isinstance(self.cfg_scale, list):
|
||||
assert (
|
||||
len(self.cfg_scale) == self.steps
|
||||
), "cfg_scale (list) must have the same length as the number of steps"
|
||||
|
||||
conditioning_data = TextConditioningData(
|
||||
uncond_text=uncond_text_embedding,
|
||||
cond_text=cond_text_embedding,
|
||||
|
Loading…
Reference in New Issue
Block a user