fix(nodes): fix constraints/validation for controlnet

- Fix `weight` and `begin_step_percent`, the constraints were mixed up
- Add model validatort to ensure `begin_step_percent < end_step_percent`
- Bump version
This commit is contained in:
psychedelicious 2024-01-02 11:13:49 +11:00 committed by Kent Keirsey
parent d256d93a2a
commit 2700d0e769
4 changed files with 90 additions and 22 deletions

View File

@ -24,9 +24,10 @@ from controlnet_aux import (
)
from controlnet_aux.util import HWC3, ade_palette
from PIL import Image
from pydantic import BaseModel, ConfigDict, Field, field_validator
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from invokeai.app.invocations.primitives import ImageField, ImageOutput
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
from invokeai.app.shared.fields import FieldDescriptions
@ -75,17 +76,16 @@ class ControlField(BaseModel):
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
@field_validator("control_weight")
@classmethod
def validate_control_weight(cls, v):
"""Validate that all control weights in the valid range"""
if isinstance(v, list):
for i in v:
if i < -1 or i > 2:
raise ValueError("Control weights must be within -1 to 2 range")
else:
if v < -1 or v > 2:
raise ValueError("Control weights must be within -1 to 2 range")
validate_weights(v)
return v
@model_validator(mode="after")
def validate_begin_end_step_percent(self):
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
return self
@invocation_output("control_output")
class ControlOutput(BaseInvocationOutput):
@ -95,17 +95,17 @@ class ControlOutput(BaseInvocationOutput):
control: ControlField = OutputField(description=FieldDescriptions.control)
@invocation("controlnet", title="ControlNet", tags=["controlnet"], category="controlnet", version="1.1.0")
@invocation("controlnet", title="ControlNet", tags=["controlnet"], category="controlnet", version="1.1.1")
class ControlNetInvocation(BaseInvocation):
"""Collects ControlNet info to pass to other nodes"""
image: ImageField = InputField(description="The control image")
control_model: ControlNetModelField = InputField(description=FieldDescriptions.controlnet_model, input=Input.Direct)
control_weight: Union[float, List[float]] = InputField(
default=1.0, description="The weight given to the ControlNet"
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"
)
begin_step_percent: float = InputField(
default=0, ge=-1, le=2, description="When the ControlNet is first applied (% of total steps)"
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
)
end_step_percent: float = InputField(
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
@ -113,6 +113,17 @@ class ControlNetInvocation(BaseInvocation):
control_mode: CONTROLNET_MODE_VALUES = InputField(default="balanced", description="The control mode used")
resize_mode: CONTROLNET_RESIZE_VALUES = InputField(default="just_resize", description="The resize mode used")
@field_validator("control_weight")
@classmethod
def validate_control_weight(cls, v):
validate_weights(v)
return v
@model_validator(mode="after")
def validate_begin_end_step_percent(self) -> "ControlNetInvocation":
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
return self
def invoke(self, context: InvocationContext) -> ControlOutput:
return ControlOutput(
control=ControlField(

View File

@ -2,7 +2,7 @@ import os
from builtins import float
from typing import List, Union
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
@ -15,6 +15,7 @@ from invokeai.app.invocations.baseinvocation import (
invocation_output,
)
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.shared.fields import FieldDescriptions
from invokeai.backend.model_management.models.base import BaseModelType, ModelType
from invokeai.backend.model_management.models.ip_adapter import get_ip_adapter_image_encoder_model_id
@ -39,7 +40,6 @@ class IPAdapterField(BaseModel):
ip_adapter_model: IPAdapterModelField = Field(description="The IP-Adapter model to use.")
image_encoder_model: CLIPVisionModelField = Field(description="The name of the CLIP image encoder model.")
weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
# weight: float = Field(default=1.0, ge=0, description="The weight of the IP-Adapter.")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
)
@ -47,6 +47,17 @@ class IPAdapterField(BaseModel):
default=1, ge=0, le=1, description="When the IP-Adapter is last applied (% of total steps)"
)
@field_validator("weight")
@classmethod
def validate_ip_adapter_weight(cls, v):
validate_weights(v)
return v
@model_validator(mode="after")
def validate_begin_end_step_percent(self):
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
return self
@invocation_output("ip_adapter_output")
class IPAdapterOutput(BaseInvocationOutput):
@ -54,7 +65,7 @@ class IPAdapterOutput(BaseInvocationOutput):
ip_adapter: IPAdapterField = OutputField(description=FieldDescriptions.ip_adapter, title="IP-Adapter")
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.1.0")
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.1.1")
class IPAdapterInvocation(BaseInvocation):
"""Collects IP-Adapter info to pass to other nodes."""
@ -64,18 +75,27 @@ class IPAdapterInvocation(BaseInvocation):
description="The IP-Adapter model.", title="IP-Adapter Model", input=Input.Direct, ui_order=-1
)
# weight: float = InputField(default=1.0, description="The weight of the IP-Adapter.", ui_type=UIType.Float)
weight: Union[float, List[float]] = InputField(
default=1, ge=-1, description="The weight given to the IP-Adapter", title="Weight"
default=1, description="The weight given to the IP-Adapter", title="Weight"
)
begin_step_percent: float = InputField(
default=0, ge=-1, le=2, description="When the IP-Adapter is first applied (% of total steps)"
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
)
end_step_percent: float = InputField(
default=1, ge=0, le=1, description="When the IP-Adapter is last applied (% of total steps)"
)
@field_validator("weight")
@classmethod
def validate_ip_adapter_weight(cls, v):
validate_weights(v)
return v
@model_validator(mode="after")
def validate_begin_end_step_percent(self):
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
return self
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
ip_adapter_info = context.services.model_manager.model_info(

View File

@ -1,6 +1,6 @@
from typing import Union
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
@ -14,6 +14,7 @@ from invokeai.app.invocations.baseinvocation import (
)
from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.shared.fields import FieldDescriptions
from invokeai.backend.model_management.models.base import BaseModelType
@ -37,6 +38,17 @@ class T2IAdapterField(BaseModel):
)
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
@field_validator("weight")
@classmethod
def validate_ip_adapter_weight(cls, v):
validate_weights(v)
return v
@model_validator(mode="after")
def validate_begin_end_step_percent(self):
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
return self
@invocation_output("t2i_adapter_output")
class T2IAdapterOutput(BaseInvocationOutput):
@ -44,7 +56,7 @@ class T2IAdapterOutput(BaseInvocationOutput):
@invocation(
"t2i_adapter", title="T2I-Adapter", tags=["t2i_adapter", "control"], category="t2i_adapter", version="1.0.0"
"t2i_adapter", title="T2I-Adapter", tags=["t2i_adapter", "control"], category="t2i_adapter", version="1.0.1"
)
class T2IAdapterInvocation(BaseInvocation):
"""Collects T2I-Adapter info to pass to other nodes."""
@ -61,7 +73,7 @@ class T2IAdapterInvocation(BaseInvocation):
default=1, ge=0, description="The weight given to the T2I-Adapter", title="Weight"
)
begin_step_percent: float = InputField(
default=0, ge=-1, le=2, description="When the T2I-Adapter is first applied (% of total steps)"
default=0, ge=0, le=1, description="When the T2I-Adapter is first applied (% of total steps)"
)
end_step_percent: float = InputField(
default=1, ge=0, le=1, description="When the T2I-Adapter is last applied (% of total steps)"
@ -71,6 +83,17 @@ class T2IAdapterInvocation(BaseInvocation):
description="The resize mode applied to the T2I-Adapter input image so that it matches the target output size.",
)
@field_validator("weight")
@classmethod
def validate_ip_adapter_weight(cls, v):
validate_weights(v)
return v
@model_validator(mode="after")
def validate_begin_end_step_percent(self):
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
return self
def invoke(self, context: InvocationContext) -> T2IAdapterOutput:
return T2IAdapterOutput(
t2i_adapter=T2IAdapterField(

View File

@ -0,0 +1,14 @@
from typing import Union
def validate_weights(weights: Union[float, list[float]]) -> None:
"""Validate that all control weights in the valid range"""
to_validate = weights if isinstance(weights, list) else [weights]
if any(i < -1 or i > 2 for i in to_validate):
raise ValueError("Control weights must be within -1 to 2 range")
def validate_begin_end_step(begin_step_percent: float, end_step_percent: float) -> None:
"""Validate that begin_step_percent is less than end_step_percent"""
if begin_step_percent >= end_step_percent:
raise ValueError("Begin step percent must be less than or equal to end step percent")