diff --git a/invokeai/app/invocations/control_adapter.py b/invokeai/app/invocations/control_adapter.py new file mode 100644 index 0000000000..189c4baed5 --- /dev/null +++ b/invokeai/app/invocations/control_adapter.py @@ -0,0 +1,109 @@ +from builtins import bool, float +from typing import Dict, List, Literal, Optional, Union + +from pydantic import BaseModel, Field, validator +from invokeai.app.invocations.primitives import ImageField + +from ...backend.model_management import BaseModelType + +from .baseinvocation import ( + BaseInvocation, + BaseInvocationOutput, + FieldDescriptions, + InputField, + Input, + InvocationContext, + OutputField, + UIType, + tags, + title, +) + + +CONTROLNET_MODE_VALUES = Literal["balanced", "more_prompt", "more_control", "unbalanced"] +CONTROLNET_RESIZE_VALUES = Literal[ + "just_resize", + "crop_resize", + "fill_resize", + "just_resize_simple", +] + + +class ControlNetModelField(BaseModel): + """ControlNet model field""" + + model_name: str = Field(description="Name of the ControlNet model") + base_model: BaseModelType = Field(description="Base model") + + +class ControlField(BaseModel): + image: ImageField = Field(description="The control image") + control_model: ControlNetModelField = Field(description="The ControlNet model to use") + control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet") + begin_step_percent: float = Field( + default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)" + ) + end_step_percent: float = Field( + default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)" + ) + control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use") + resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use") + + @validator("control_weight") + def validate_control_weight(cls, v): + """Validate that all control weights in the valid range""" + if isinstance(v, list): + for i in v: + if i < -1 or i > 2: + raise ValueError("Control weights must be within -1 to 2 range") + else: + if v < -1 or v > 2: + raise ValueError("Control weights must be within -1 to 2 range") + return v + + +class ControlOutput(BaseInvocationOutput): + """node output for ControlNet info""" + + type: Literal["control_output"] = "control_output" + + # Outputs + control: ControlField = OutputField(description=FieldDescriptions.control) + + +@title("ControlNet") +@tags("controlnet") +class ControlNetInvocation(BaseInvocation): + """Collects ControlNet info to pass to other nodes""" + + type: Literal["controlnet"] = "controlnet" + + # Inputs + image: ImageField = InputField(description="The control image") + control_model: ControlNetModelField = InputField( + default="lllyasviel/sd-controlnet-canny", description=FieldDescriptions.controlnet_model, input=Input.Direct + ) + control_weight: Union[float, List[float]] = InputField( + default=1.0, description="The weight given to the ControlNet", ui_type=UIType.Float + ) + begin_step_percent: float = InputField( + default=0, ge=-1, le=2, description="When the ControlNet is first applied (% of total steps)" + ) + end_step_percent: float = InputField( + default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)" + ) + control_mode: CONTROLNET_MODE_VALUES = InputField(default="balanced", description="The control mode used") + resize_mode: CONTROLNET_RESIZE_VALUES = InputField(default="just_resize", description="The resize mode used") + + def invoke(self, context: InvocationContext) -> ControlOutput: + return ControlOutput( + control=ControlField( + image=self.image, + control_model=self.control_model, + control_weight=self.control_weight, + begin_step_percent=self.begin_step_percent, + end_step_percent=self.end_step_percent, + control_mode=self.control_mode, + resize_mode=self.resize_mode, + ), + )