mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): add control_type validation & fix types
This commit is contained in:
parent
30ab81b6bb
commit
bd15874cf6
@ -1,7 +1,7 @@
|
|||||||
from builtins import float
|
from builtins import float
|
||||||
from typing import List, Literal, Optional, Union
|
from typing import List, Literal, Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, validator
|
from pydantic import BaseModel, Field, root_validator, validator
|
||||||
|
|
||||||
from invokeai.app.invocations.primitives import ImageField
|
from invokeai.app.invocations.primitives import ImageField
|
||||||
|
|
||||||
@ -40,13 +40,9 @@ class ControlNetModelField(BaseModel):
|
|||||||
class ControlField(BaseModel):
|
class ControlField(BaseModel):
|
||||||
control_type: CONTROL_ADAPTER_TYPES = Field(default="ControlNet", description="The type of control adapter")
|
control_type: CONTROL_ADAPTER_TYPES = Field(default="ControlNet", description="The type of control adapter")
|
||||||
image: ImageField = Field(description="The control image")
|
image: ImageField = Field(description="The control image")
|
||||||
# control_model and ip_adapter_models are both optional
|
control_model: Optional[ControlNetModelField] = Field(default=None, description="The ControlNet model to use")
|
||||||
# but must be on the two present
|
ip_adapter_model: Optional[str] = Field(default=None, description="The IP-Adapter model to use")
|
||||||
# if control_type == "ControlNet", then must be control_model
|
image_encoder_model: Optional[str] = Field(default=None, description="The clip_image_encoder model to use")
|
||||||
# if control_type == "IP-Adapter", then must be ip_adapter_model
|
|
||||||
control_model: Optional[ControlNetModelField] = Field(description="The ControlNet model to use")
|
|
||||||
ip_adapter_model: Optional[str] = Field(description="The IP-Adapter model to use")
|
|
||||||
image_encoder_model: Optional[str] = Field(description="The clip_image_encoder model to use")
|
|
||||||
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
||||||
begin_step_percent: float = Field(
|
begin_step_percent: float = Field(
|
||||||
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
||||||
@ -57,6 +53,17 @@ class ControlField(BaseModel):
|
|||||||
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
|
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
|
||||||
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
|
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
|
||||||
|
|
||||||
|
@root_validator
|
||||||
|
def validate_control_model(cls, values):
|
||||||
|
"""Validate that an appropriate type of control model is provided"""
|
||||||
|
if values["control_type"] == "ControlNet":
|
||||||
|
if values.get("control_model") is None:
|
||||||
|
raise ValueError('ControlNet control_type requires "control_model" be provided')
|
||||||
|
elif values["control_type"] == "IP-Adapter":
|
||||||
|
if values.get("ip_adapter_model") is None:
|
||||||
|
raise ValueError('IP-Adapter control_type requires "ip_adapter_model" be provided')
|
||||||
|
return values
|
||||||
|
|
||||||
@validator("control_weight")
|
@validator("control_weight")
|
||||||
def validate_control_weight(cls, v):
|
def validate_control_weight(cls, v):
|
||||||
"""Validate that all control weights in the valid range"""
|
"""Validate that all control weights in the valid range"""
|
||||||
|
Loading…
Reference in New Issue
Block a user