fix(nodes): fix constraints/validation for controlnet

- Fix `weight` and `begin_step_percent`, the constraints were mixed up
- Add model validatort to ensure `begin_step_percent < end_step_percent`
- Bump version
This commit is contained in:
psychedelicious 2024-01-02 11:13:49 +11:00 committed by Kent Keirsey
parent d256d93a2a
commit 2700d0e769
4 changed files with 90 additions and 22 deletions

View File

@ -24,9 +24,10 @@ from controlnet_aux import (
) )
from controlnet_aux.util import HWC3, ade_palette from controlnet_aux.util import HWC3, ade_palette
from PIL import Image from PIL import Image
from pydantic import BaseModel, ConfigDict, Field, field_validator from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from invokeai.app.invocations.primitives import ImageField, ImageOutput from invokeai.app.invocations.primitives import ImageField, ImageOutput
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
from invokeai.app.shared.fields import FieldDescriptions from invokeai.app.shared.fields import FieldDescriptions
@ -75,17 +76,16 @@ class ControlField(BaseModel):
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use") resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
@field_validator("control_weight") @field_validator("control_weight")
@classmethod
def validate_control_weight(cls, v): def validate_control_weight(cls, v):
"""Validate that all control weights in the valid range""" validate_weights(v)
if isinstance(v, list):
for i in v:
if i < -1 or i > 2:
raise ValueError("Control weights must be within -1 to 2 range")
else:
if v < -1 or v > 2:
raise ValueError("Control weights must be within -1 to 2 range")
return v return v
@model_validator(mode="after")
def validate_begin_end_step_percent(self):
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
return self
@invocation_output("control_output") @invocation_output("control_output")
class ControlOutput(BaseInvocationOutput): class ControlOutput(BaseInvocationOutput):
@ -95,17 +95,17 @@ class ControlOutput(BaseInvocationOutput):
control: ControlField = OutputField(description=FieldDescriptions.control) control: ControlField = OutputField(description=FieldDescriptions.control)
@invocation("controlnet", title="ControlNet", tags=["controlnet"], category="controlnet", version="1.1.0") @invocation("controlnet", title="ControlNet", tags=["controlnet"], category="controlnet", version="1.1.1")
class ControlNetInvocation(BaseInvocation): class ControlNetInvocation(BaseInvocation):
"""Collects ControlNet info to pass to other nodes""" """Collects ControlNet info to pass to other nodes"""
image: ImageField = InputField(description="The control image") image: ImageField = InputField(description="The control image")
control_model: ControlNetModelField = InputField(description=FieldDescriptions.controlnet_model, input=Input.Direct) control_model: ControlNetModelField = InputField(description=FieldDescriptions.controlnet_model, input=Input.Direct)
control_weight: Union[float, List[float]] = InputField( control_weight: Union[float, List[float]] = InputField(
default=1.0, description="The weight given to the ControlNet" default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"
) )
begin_step_percent: float = InputField( begin_step_percent: float = InputField(
default=0, ge=-1, le=2, description="When the ControlNet is first applied (% of total steps)" default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
) )
end_step_percent: float = InputField( end_step_percent: float = InputField(
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)" default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
@ -113,6 +113,17 @@ class ControlNetInvocation(BaseInvocation):
control_mode: CONTROLNET_MODE_VALUES = InputField(default="balanced", description="The control mode used") control_mode: CONTROLNET_MODE_VALUES = InputField(default="balanced", description="The control mode used")
resize_mode: CONTROLNET_RESIZE_VALUES = InputField(default="just_resize", description="The resize mode used") resize_mode: CONTROLNET_RESIZE_VALUES = InputField(default="just_resize", description="The resize mode used")
@field_validator("control_weight")
@classmethod
def validate_control_weight(cls, v):
validate_weights(v)
return v
@model_validator(mode="after")
def validate_begin_end_step_percent(self) -> "ControlNetInvocation":
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
return self
def invoke(self, context: InvocationContext) -> ControlOutput: def invoke(self, context: InvocationContext) -> ControlOutput:
return ControlOutput( return ControlOutput(
control=ControlField( control=ControlField(

View File

@ -2,7 +2,7 @@ import os
from builtins import float from builtins import float
from typing import List, Union from typing import List, Union
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from invokeai.app.invocations.baseinvocation import ( from invokeai.app.invocations.baseinvocation import (
BaseInvocation, BaseInvocation,
@ -15,6 +15,7 @@ from invokeai.app.invocations.baseinvocation import (
invocation_output, invocation_output,
) )
from invokeai.app.invocations.primitives import ImageField from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.shared.fields import FieldDescriptions from invokeai.app.shared.fields import FieldDescriptions
from invokeai.backend.model_management.models.base import BaseModelType, ModelType from invokeai.backend.model_management.models.base import BaseModelType, ModelType
from invokeai.backend.model_management.models.ip_adapter import get_ip_adapter_image_encoder_model_id from invokeai.backend.model_management.models.ip_adapter import get_ip_adapter_image_encoder_model_id
@ -39,7 +40,6 @@ class IPAdapterField(BaseModel):
ip_adapter_model: IPAdapterModelField = Field(description="The IP-Adapter model to use.") ip_adapter_model: IPAdapterModelField = Field(description="The IP-Adapter model to use.")
image_encoder_model: CLIPVisionModelField = Field(description="The name of the CLIP image encoder model.") image_encoder_model: CLIPVisionModelField = Field(description="The name of the CLIP image encoder model.")
weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet") weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
# weight: float = Field(default=1.0, ge=0, description="The weight of the IP-Adapter.")
begin_step_percent: float = Field( begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)" default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
) )
@ -47,6 +47,17 @@ class IPAdapterField(BaseModel):
default=1, ge=0, le=1, description="When the IP-Adapter is last applied (% of total steps)" default=1, ge=0, le=1, description="When the IP-Adapter is last applied (% of total steps)"
) )
@field_validator("weight")
@classmethod
def validate_ip_adapter_weight(cls, v):
validate_weights(v)
return v
@model_validator(mode="after")
def validate_begin_end_step_percent(self):
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
return self
@invocation_output("ip_adapter_output") @invocation_output("ip_adapter_output")
class IPAdapterOutput(BaseInvocationOutput): class IPAdapterOutput(BaseInvocationOutput):
@ -54,7 +65,7 @@ class IPAdapterOutput(BaseInvocationOutput):
ip_adapter: IPAdapterField = OutputField(description=FieldDescriptions.ip_adapter, title="IP-Adapter") ip_adapter: IPAdapterField = OutputField(description=FieldDescriptions.ip_adapter, title="IP-Adapter")
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.1.0") @invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.1.1")
class IPAdapterInvocation(BaseInvocation): class IPAdapterInvocation(BaseInvocation):
"""Collects IP-Adapter info to pass to other nodes.""" """Collects IP-Adapter info to pass to other nodes."""
@ -64,18 +75,27 @@ class IPAdapterInvocation(BaseInvocation):
description="The IP-Adapter model.", title="IP-Adapter Model", input=Input.Direct, ui_order=-1 description="The IP-Adapter model.", title="IP-Adapter Model", input=Input.Direct, ui_order=-1
) )
# weight: float = InputField(default=1.0, description="The weight of the IP-Adapter.", ui_type=UIType.Float)
weight: Union[float, List[float]] = InputField( weight: Union[float, List[float]] = InputField(
default=1, ge=-1, description="The weight given to the IP-Adapter", title="Weight" default=1, description="The weight given to the IP-Adapter", title="Weight"
) )
begin_step_percent: float = InputField( begin_step_percent: float = InputField(
default=0, ge=-1, le=2, description="When the IP-Adapter is first applied (% of total steps)" default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
) )
end_step_percent: float = InputField( end_step_percent: float = InputField(
default=1, ge=0, le=1, description="When the IP-Adapter is last applied (% of total steps)" default=1, ge=0, le=1, description="When the IP-Adapter is last applied (% of total steps)"
) )
@field_validator("weight")
@classmethod
def validate_ip_adapter_weight(cls, v):
validate_weights(v)
return v
@model_validator(mode="after")
def validate_begin_end_step_percent(self):
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
return self
def invoke(self, context: InvocationContext) -> IPAdapterOutput: def invoke(self, context: InvocationContext) -> IPAdapterOutput:
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model. # Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
ip_adapter_info = context.services.model_manager.model_info( ip_adapter_info = context.services.model_manager.model_info(

View File

@ -1,6 +1,6 @@
from typing import Union from typing import Union
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from invokeai.app.invocations.baseinvocation import ( from invokeai.app.invocations.baseinvocation import (
BaseInvocation, BaseInvocation,
@ -14,6 +14,7 @@ from invokeai.app.invocations.baseinvocation import (
) )
from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES
from invokeai.app.invocations.primitives import ImageField from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.shared.fields import FieldDescriptions from invokeai.app.shared.fields import FieldDescriptions
from invokeai.backend.model_management.models.base import BaseModelType from invokeai.backend.model_management.models.base import BaseModelType
@ -37,6 +38,17 @@ class T2IAdapterField(BaseModel):
) )
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use") resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
@field_validator("weight")
@classmethod
def validate_ip_adapter_weight(cls, v):
validate_weights(v)
return v
@model_validator(mode="after")
def validate_begin_end_step_percent(self):
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
return self
@invocation_output("t2i_adapter_output") @invocation_output("t2i_adapter_output")
class T2IAdapterOutput(BaseInvocationOutput): class T2IAdapterOutput(BaseInvocationOutput):
@ -44,7 +56,7 @@ class T2IAdapterOutput(BaseInvocationOutput):
@invocation( @invocation(
"t2i_adapter", title="T2I-Adapter", tags=["t2i_adapter", "control"], category="t2i_adapter", version="1.0.0" "t2i_adapter", title="T2I-Adapter", tags=["t2i_adapter", "control"], category="t2i_adapter", version="1.0.1"
) )
class T2IAdapterInvocation(BaseInvocation): class T2IAdapterInvocation(BaseInvocation):
"""Collects T2I-Adapter info to pass to other nodes.""" """Collects T2I-Adapter info to pass to other nodes."""
@ -61,7 +73,7 @@ class T2IAdapterInvocation(BaseInvocation):
default=1, ge=0, description="The weight given to the T2I-Adapter", title="Weight" default=1, ge=0, description="The weight given to the T2I-Adapter", title="Weight"
) )
begin_step_percent: float = InputField( begin_step_percent: float = InputField(
default=0, ge=-1, le=2, description="When the T2I-Adapter is first applied (% of total steps)" default=0, ge=0, le=1, description="When the T2I-Adapter is first applied (% of total steps)"
) )
end_step_percent: float = InputField( end_step_percent: float = InputField(
default=1, ge=0, le=1, description="When the T2I-Adapter is last applied (% of total steps)" default=1, ge=0, le=1, description="When the T2I-Adapter is last applied (% of total steps)"
@ -71,6 +83,17 @@ class T2IAdapterInvocation(BaseInvocation):
description="The resize mode applied to the T2I-Adapter input image so that it matches the target output size.", description="The resize mode applied to the T2I-Adapter input image so that it matches the target output size.",
) )
@field_validator("weight")
@classmethod
def validate_ip_adapter_weight(cls, v):
validate_weights(v)
return v
@model_validator(mode="after")
def validate_begin_end_step_percent(self):
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
return self
def invoke(self, context: InvocationContext) -> T2IAdapterOutput: def invoke(self, context: InvocationContext) -> T2IAdapterOutput:
return T2IAdapterOutput( return T2IAdapterOutput(
t2i_adapter=T2IAdapterField( t2i_adapter=T2IAdapterField(

View File

@ -0,0 +1,14 @@
from typing import Union
def validate_weights(weights: Union[float, list[float]]) -> None:
"""Validate that all control weights in the valid range"""
to_validate = weights if isinstance(weights, list) else [weights]
if any(i < -1 or i > 2 for i in to_validate):
raise ValueError("Control weights must be within -1 to 2 range")
def validate_begin_end_step(begin_step_percent: float, end_step_percent: float) -> None:
"""Validate that begin_step_percent is less than end_step_percent"""
if begin_step_percent >= end_step_percent:
raise ValueError("Begin step percent must be less than or equal to end step percent")