feat(nodes): add control_type validation & fix types

This commit is contained in:
psychedelicious 2023-09-05 12:24:54 +10:00
parent 30ab81b6bb
commit bd15874cf6

View File

@ -1,7 +1,7 @@
from builtins import float
from typing import List, Literal, Optional, Union
from pydantic import BaseModel, Field, validator
from pydantic import BaseModel, Field, root_validator, validator
from invokeai.app.invocations.primitives import ImageField
@ -40,13 +40,9 @@ class ControlNetModelField(BaseModel):
class ControlField(BaseModel):
control_type: CONTROL_ADAPTER_TYPES = Field(default="ControlNet", description="The type of control adapter")
image: ImageField = Field(description="The control image")
# control_model and ip_adapter_models are both optional
# but must be on the two present
# if control_type == "ControlNet", then must be control_model
# if control_type == "IP-Adapter", then must be ip_adapter_model
control_model: Optional[ControlNetModelField] = Field(description="The ControlNet model to use")
ip_adapter_model: Optional[str] = Field(description="The IP-Adapter model to use")
image_encoder_model: Optional[str] = Field(description="The clip_image_encoder model to use")
control_model: Optional[ControlNetModelField] = Field(default=None, description="The ControlNet model to use")
ip_adapter_model: Optional[str] = Field(default=None, description="The IP-Adapter model to use")
image_encoder_model: Optional[str] = Field(default=None, description="The clip_image_encoder model to use")
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
@ -57,6 +53,17 @@ class ControlField(BaseModel):
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
@root_validator
def validate_control_model(cls, values):
"""Validate that an appropriate type of control model is provided"""
if values["control_type"] == "ControlNet":
if values.get("control_model") is None:
raise ValueError('ControlNet control_type requires "control_model" be provided')
elif values["control_type"] == "IP-Adapter":
if values.get("ip_adapter_model") is None:
raise ValueError('IP-Adapter control_type requires "ip_adapter_model" be provided')
return values
@validator("control_weight")
def validate_control_weight(cls, v):
"""Validate that all control weights in the valid range"""