mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix: denoise latents accepts CFG lists as input
This commit is contained in:
parent
241a1fdb57
commit
71c3197eab
@ -25,7 +25,7 @@ from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
|||||||
from diffusers.schedulers import DPMSolverSDEScheduler
|
from diffusers.schedulers import DPMSolverSDEScheduler
|
||||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||||
from PIL import Image, ImageFilter
|
from PIL import Image, ImageFilter
|
||||||
from pydantic import field_validator
|
from pydantic import ValidationInfo, field_validator
|
||||||
from torchvision.transforms.functional import resize as tv_resize
|
from torchvision.transforms.functional import resize as tv_resize
|
||||||
from transformers import CLIPVisionModelWithProjection
|
from transformers import CLIPVisionModelWithProjection
|
||||||
|
|
||||||
@ -341,7 +341,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
|
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
|
||||||
cfg_scale: Union[float, List[float]] = InputField(
|
cfg_scale: Union[float, List[float]] = InputField(
|
||||||
default=7.5, ge=1, description=FieldDescriptions.cfg_scale, title="CFG Scale"
|
default=7.5, description=FieldDescriptions.cfg_scale, title="CFG Scale"
|
||||||
)
|
)
|
||||||
denoising_start: float = InputField(
|
denoising_start: float = InputField(
|
||||||
default=0.0,
|
default=0.0,
|
||||||
@ -397,12 +397,14 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@field_validator("cfg_scale")
|
@field_validator("cfg_scale")
|
||||||
def ge_one(cls, v: Union[List[float], float]) -> Union[List[float], float]:
|
def ge_one(cls, v: Union[List[float], float], info: ValidationInfo) -> Union[List[float], float]:
|
||||||
"""validate that all cfg_scale values are >= 1"""
|
"""validate that all cfg_scale values are >= 1"""
|
||||||
if isinstance(v, list):
|
if isinstance(v, list):
|
||||||
for i in v:
|
for i in v:
|
||||||
if i < 1:
|
if i < 1:
|
||||||
raise ValueError("cfg_scale must be greater than 1")
|
raise ValueError("cfg_scale must be greater than 1")
|
||||||
|
if len(v) != info.data["steps"]:
|
||||||
|
raise ValueError("cfg_scale (list) must have the same length as the number of steps")
|
||||||
else:
|
else:
|
||||||
if v < 1:
|
if v < 1:
|
||||||
raise ValueError("cfg_scale must be greater than 1")
|
raise ValueError("cfg_scale must be greater than 1")
|
||||||
|
Loading…
Reference in New Issue
Block a user