From bd15874cf6fe60b2ae852b306678b40737ad3d56 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 5 Sep 2023 12:24:54 +1000 Subject: [PATCH] feat(nodes): add control_type validation & fix types --- invokeai/app/invocations/control_adapter.py | 23 ++++++++++++++------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/invokeai/app/invocations/control_adapter.py b/invokeai/app/invocations/control_adapter.py index 34561623d5..cbf3d78318 100644 --- a/invokeai/app/invocations/control_adapter.py +++ b/invokeai/app/invocations/control_adapter.py @@ -1,7 +1,7 @@ from builtins import float from typing import List, Literal, Optional, Union -from pydantic import BaseModel, Field, validator +from pydantic import BaseModel, Field, root_validator, validator from invokeai.app.invocations.primitives import ImageField @@ -40,13 +40,9 @@ class ControlNetModelField(BaseModel): class ControlField(BaseModel): control_type: CONTROL_ADAPTER_TYPES = Field(default="ControlNet", description="The type of control adapter") image: ImageField = Field(description="The control image") - # control_model and ip_adapter_models are both optional - # but must be on the two present - # if control_type == "ControlNet", then must be control_model - # if control_type == "IP-Adapter", then must be ip_adapter_model - control_model: Optional[ControlNetModelField] = Field(description="The ControlNet model to use") - ip_adapter_model: Optional[str] = Field(description="The IP-Adapter model to use") - image_encoder_model: Optional[str] = Field(description="The clip_image_encoder model to use") + control_model: Optional[ControlNetModelField] = Field(default=None, description="The ControlNet model to use") + ip_adapter_model: Optional[str] = Field(default=None, description="The IP-Adapter model to use") + image_encoder_model: Optional[str] = Field(default=None, description="The clip_image_encoder model to use") control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet") begin_step_percent: float = Field( default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)" @@ -57,6 +53,17 @@ class ControlField(BaseModel): control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use") resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use") + @root_validator + def validate_control_model(cls, values): + """Validate that an appropriate type of control model is provided""" + if values["control_type"] == "ControlNet": + if values.get("control_model") is None: + raise ValueError('ControlNet control_type requires "control_model" be provided') + elif values["control_type"] == "IP-Adapter": + if values.get("ip_adapter_model") is None: + raise ValueError('IP-Adapter control_type requires "ip_adapter_model" be provided') + return values + @validator("control_weight") def validate_control_weight(cls, v): """Validate that all control weights in the valid range"""