feat(nodes): add control_type validation & fix types

This commit is contained in:
psychedelicious 2023-09-05 12:24:54 +10:00
parent 30ab81b6bb
commit bd15874cf6

View File

@ -1,7 +1,7 @@
from builtins import float from builtins import float
from typing import List, Literal, Optional, Union from typing import List, Literal, Optional, Union
from pydantic import BaseModel, Field, validator from pydantic import BaseModel, Field, root_validator, validator
from invokeai.app.invocations.primitives import ImageField from invokeai.app.invocations.primitives import ImageField
@ -40,13 +40,9 @@ class ControlNetModelField(BaseModel):
class ControlField(BaseModel): class ControlField(BaseModel):
control_type: CONTROL_ADAPTER_TYPES = Field(default="ControlNet", description="The type of control adapter") control_type: CONTROL_ADAPTER_TYPES = Field(default="ControlNet", description="The type of control adapter")
image: ImageField = Field(description="The control image") image: ImageField = Field(description="The control image")
# control_model and ip_adapter_models are both optional control_model: Optional[ControlNetModelField] = Field(default=None, description="The ControlNet model to use")
# but must be on the two present ip_adapter_model: Optional[str] = Field(default=None, description="The IP-Adapter model to use")
# if control_type == "ControlNet", then must be control_model image_encoder_model: Optional[str] = Field(default=None, description="The clip_image_encoder model to use")
# if control_type == "IP-Adapter", then must be ip_adapter_model
control_model: Optional[ControlNetModelField] = Field(description="The ControlNet model to use")
ip_adapter_model: Optional[str] = Field(description="The IP-Adapter model to use")
image_encoder_model: Optional[str] = Field(description="The clip_image_encoder model to use")
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet") control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field( begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)" default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
@ -57,6 +53,17 @@ class ControlField(BaseModel):
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use") control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use") resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
@root_validator
def validate_control_model(cls, values):
"""Validate that an appropriate type of control model is provided"""
if values["control_type"] == "ControlNet":
if values.get("control_model") is None:
raise ValueError('ControlNet control_type requires "control_model" be provided')
elif values["control_type"] == "IP-Adapter":
if values.get("ip_adapter_model") is None:
raise ValueError('IP-Adapter control_type requires "ip_adapter_model" be provided')
return values
@validator("control_weight") @validator("control_weight")
def validate_control_weight(cls, v): def validate_control_weight(cls, v):
"""Validate that all control weights in the valid range""" """Validate that all control weights in the valid range"""