InvokeAI/invokeai/app/invocations/control_adapter.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

185 lines
7.6 KiB
Python
Raw Normal View History

from builtins import bool, float
from typing import Dict, List, Literal, Optional, Union
from pydantic import BaseModel, Field, validator
from invokeai.app.invocations.primitives import ImageField
from ...backend.model_management import BaseModelType
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
FieldDescriptions,
InputField,
Input,
InvocationContext,
OutputField,
UIType,
2023-08-31 01:29:06 +00:00
invocation,
invocation_output,
)
CONTROL_ADAPTER_TYPES = Literal["ControlNet", "IP-Adapter", "T2I-Adapter"]
CONTROLNET_MODE_VALUES = Literal["balanced", "more_prompt", "more_control", "unbalanced"]
CONTROLNET_RESIZE_VALUES = Literal[
"just_resize",
"crop_resize",
"fill_resize",
"just_resize_simple",
]
2023-09-04 23:37:12 +00:00
class ControlNetModelField(BaseModel):
"""ControlNet model field"""
model_name: str = Field(description="Name of the ControlNet model")
base_model: BaseModelType = Field(description="Base model")
class ControlField(BaseModel):
control_type: CONTROL_ADAPTER_TYPES = Field(default="ControlNet", description="The type of control adapter")
image: ImageField = Field(description="The control image")
# control_model and ip_adapter_models are both optional
# but must be on the two present
# if control_type == "ControlNet", then mus be control_model
# if control_type == "IP-Adapter", then must be ip_adapter_model
control_model: Optional[ControlNetModelField] = Field(description="The ControlNet model to use")
ip_adapter_model: Optional[str] = Field(description="The IP-Adapter model to use")
image_encoder_model: Optional[str] = Field(description="The clip_image_encoder model to use")
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
)
end_step_percent: float = Field(
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
)
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
@validator("control_weight")
def validate_control_weight(cls, v):
"""Validate that all control weights in the valid range"""
if isinstance(v, list):
for i in v:
if i < -1 or i > 2:
raise ValueError("Control weights must be within -1 to 2 range")
else:
if v < -1 or v > 2:
raise ValueError("Control weights must be within -1 to 2 range")
return v
2023-09-04 23:37:12 +00:00
2023-08-31 01:29:06 +00:00
@invocation_output("control_output")
class ControlOutput(BaseInvocationOutput):
"""node output for ControlNet info"""
type: Literal["control_output"] = "control_output"
# Outputs
control: ControlField = OutputField(description=FieldDescriptions.control)
2023-08-31 01:29:06 +00:00
@invocation("controlnet", title="ControlNet", tags=["controlnet"], category="controlnet")
class ControlNetInvocation(BaseInvocation):
"""Collects ControlNet info to pass to other nodes"""
type: Literal["controlnet"] = "controlnet"
# Inputs
image: ImageField = InputField(description="The control image")
control_model: ControlNetModelField = InputField(
default="lllyasviel/sd-controlnet-canny", description=FieldDescriptions.controlnet_model, input=Input.Direct
)
control_weight: Union[float, List[float]] = InputField(
default=1.0, description="The weight given to the ControlNet", ui_type=UIType.Float
)
begin_step_percent: float = InputField(
default=0, ge=-1, le=2, description="When the ControlNet is first applied (% of total steps)"
)
end_step_percent: float = InputField(
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
)
control_mode: CONTROLNET_MODE_VALUES = InputField(default="balanced", description="The control mode used")
resize_mode: CONTROLNET_RESIZE_VALUES = InputField(default="just_resize", description="The resize mode used")
def invoke(self, context: InvocationContext) -> ControlOutput:
return ControlOutput(
control=ControlField(
control_type="ControlNet",
image=self.image,
control_model=self.control_model,
# ip_adapter_model is currently optional
# must be either a control_model or ip_adapter_model
# ip_adapter_model=None,
control_weight=self.control_weight,
begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent,
control_mode=self.control_mode,
resize_mode=self.resize_mode,
),
)
2023-09-04 23:37:12 +00:00
IP_ADAPTER_MODELS = Literal[
"models_ip_adapter/models/ip-adapter_sd15.bin",
"models_ip_adapter/models/ip-adapter-plus_sd15.bin",
"models_ip_adapter/models/ip-adapter-plus-face_sd15.bin",
2023-09-04 23:37:12 +00:00
"models_ip_adapter/sdxl_models/ip-adapter_sdxl.bin",
]
IP_ADAPTER_IMAGE_ENCODER_MODELS = Literal[
"models_ip_adapter/models/image_encoder/",
"./models_ip_adapter/models/image_encoder/",
2023-09-04 23:37:12 +00:00
"models_ip_adapter/sdxl_models/image_encoder/",
]
2023-09-04 23:37:12 +00:00
@invocation("ipadapter", title="IP-Adapter", tags=["ipadapter"], category="ipadapter")
class IPAdapterInvocation(BaseInvocation):
"""Collects IP-Adapter info to pass to other nodes"""
type: Literal["ipadapter"] = "ipadapter"
# Inputs
image: ImageField = InputField(description="The control image")
2023-09-04 23:37:12 +00:00
# control_model: ControlNetModelField = InputField(
# default="lllyasviel/sd-controlnet-canny", description=FieldDescriptions.controlnet_model, input=Input.Direct
2023-09-04 23:37:12 +00:00
# )
ip_adapter_model: IP_ADAPTER_MODELS = InputField(
default="./models_ip_adapter/models/ip-adapter_sd15.bin", description="The IP-Adapter model"
)
image_encoder_model: IP_ADAPTER_IMAGE_ENCODER_MODELS = InputField(
2023-09-04 23:37:12 +00:00
default="./models_ip_adapter/models/image_encoder/", description="The image encoder model"
)
control_weight: Union[float, List[float]] = InputField(
default=1.0, description="The weight given to the ControlNet", ui_type=UIType.Float
)
# begin_step_percent: float = InputField(
# default=0, ge=-1, le=2, description="When the ControlNet is first applied (% of total steps)"
# )
# end_step_percent: float = InputField(
# default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
# )
# control_mode: CONTROLNET_MODE_VALUES = InputField(default="balanced", description="The control mode used")
# resize_mode: CONTROLNET_RESIZE_VALUES = InputField(default="just_resize", description="The resize mode used")
def invoke(self, context: InvocationContext) -> ControlOutput:
return ControlOutput(
control=ControlField(
control_type="IP-Adapter",
image=self.image,
# control_model is currently optional
# must be either a control_model or ip_adapter_model
# control_model=None,
ip_adapter_model=self.ip_adapter_model,
image_encoder_model=self.image_encoder_model,
control_weight=self.control_weight,
# rest are currently ignored
2023-09-04 23:37:12 +00:00
# begin_step_percent=self.begin_step_percent,
# end_step_percent=self.end_step_percent,
# control_mode=self.control_mode,
# resize_mode=self.resize_mode,
),
)