mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): add control_type validation & fix types
This commit is contained in:
parent
30ab81b6bb
commit
bd15874cf6
@ -1,7 +1,7 @@
|
||||
from builtins import float
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from pydantic import BaseModel, Field, root_validator, validator
|
||||
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
|
||||
@ -40,13 +40,9 @@ class ControlNetModelField(BaseModel):
|
||||
class ControlField(BaseModel):
|
||||
control_type: CONTROL_ADAPTER_TYPES = Field(default="ControlNet", description="The type of control adapter")
|
||||
image: ImageField = Field(description="The control image")
|
||||
# control_model and ip_adapter_models are both optional
|
||||
# but must be on the two present
|
||||
# if control_type == "ControlNet", then must be control_model
|
||||
# if control_type == "IP-Adapter", then must be ip_adapter_model
|
||||
control_model: Optional[ControlNetModelField] = Field(description="The ControlNet model to use")
|
||||
ip_adapter_model: Optional[str] = Field(description="The IP-Adapter model to use")
|
||||
image_encoder_model: Optional[str] = Field(description="The clip_image_encoder model to use")
|
||||
control_model: Optional[ControlNetModelField] = Field(default=None, description="The ControlNet model to use")
|
||||
ip_adapter_model: Optional[str] = Field(default=None, description="The IP-Adapter model to use")
|
||||
image_encoder_model: Optional[str] = Field(default=None, description="The clip_image_encoder model to use")
|
||||
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
||||
begin_step_percent: float = Field(
|
||||
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
||||
@ -57,6 +53,17 @@ class ControlField(BaseModel):
|
||||
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
|
||||
|
||||
@root_validator
|
||||
def validate_control_model(cls, values):
|
||||
"""Validate that an appropriate type of control model is provided"""
|
||||
if values["control_type"] == "ControlNet":
|
||||
if values.get("control_model") is None:
|
||||
raise ValueError('ControlNet control_type requires "control_model" be provided')
|
||||
elif values["control_type"] == "IP-Adapter":
|
||||
if values.get("ip_adapter_model") is None:
|
||||
raise ValueError('IP-Adapter control_type requires "ip_adapter_model" be provided')
|
||||
return values
|
||||
|
||||
@validator("control_weight")
|
||||
def validate_control_weight(cls, v):
|
||||
"""Validate that all control weights in the valid range"""
|
||||
|
Loading…
Reference in New Issue
Block a user