InvokeAI/invokeai/app/invocations/control_adapter.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

188 lines
8.0 KiB
Python
Raw Normal View History

from builtins import float
from typing import List, Literal, Optional, Union
from pydantic import BaseModel, Field, root_validator, validator
from invokeai.app.invocations.primitives import ImageField
from ...backend.model_management import BaseModelType
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
FieldDescriptions,
Input,
InputField,
InvocationContext,
OutputField,
UIType,
2023-08-31 01:29:06 +00:00
invocation,
invocation_output,
)
CONTROL_ADAPTER_TYPES = Literal["ControlNet", "IP-Adapter", "T2I-Adapter"]
CONTROLNET_MODE_VALUES = Literal["balanced", "more_prompt", "more_control", "unbalanced"]
CONTROLNET_RESIZE_VALUES = Literal[
"just_resize",
"crop_resize",
"fill_resize",
"just_resize_simple",
]
2023-09-04 23:37:12 +00:00
class ControlNetModelField(BaseModel):
"""ControlNet model field"""
model_name: str = Field(description="Name of the ControlNet model")
base_model: BaseModelType = Field(description="Base model")
class ControlField(BaseModel):
control_type: CONTROL_ADAPTER_TYPES = Field(default="ControlNet", description="The type of control adapter")
image: ImageField = Field(description="The control image")
control_model: Optional[ControlNetModelField] = Field(default=None, description="The ControlNet model to use")
ip_adapter_model: Optional[str] = Field(default=None, description="The IP-Adapter model to use")
image_encoder_model: Optional[str] = Field(default=None, description="The clip_image_encoder model to use")
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
)
end_step_percent: float = Field(
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
)
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
@root_validator
def validate_control_model(cls, values):
"""Validate that an appropriate type of control model is provided"""
if values["control_type"] == "ControlNet":
if values.get("control_model") is None:
raise ValueError('ControlNet control_type requires "control_model" be provided')
elif values["control_type"] == "IP-Adapter":
if values.get("ip_adapter_model") is None:
raise ValueError('IP-Adapter control_type requires "ip_adapter_model" be provided')
return values
@validator("control_weight")
def validate_control_weight(cls, v):
"""Validate that all control weights in the valid range"""
if isinstance(v, list):
for i in v:
if i < -1 or i > 2:
raise ValueError("Control weights must be within -1 to 2 range")
else:
if v < -1 or v > 2:
raise ValueError("Control weights must be within -1 to 2 range")
return v
2023-09-04 23:37:12 +00:00
2023-08-31 01:29:06 +00:00
@invocation_output("control_output")
class ControlOutput(BaseInvocationOutput):
"""node output for ControlNet info"""
# Outputs
control: ControlField = OutputField(description=FieldDescriptions.control)
@invocation("controlnet", title="ControlNet", tags=["controlnet"], category="controlnet", version="1.0.0")
class ControlNetInvocation(BaseInvocation):
"""Collects ControlNet info to pass to other nodes"""
# Inputs
image: ImageField = InputField(description="The control image")
control_model: ControlNetModelField = InputField(
default="lllyasviel/sd-controlnet-canny", description=FieldDescriptions.controlnet_model, input=Input.Direct
)
control_weight: Union[float, List[float]] = InputField(
default=1.0, description="The weight given to the ControlNet", ui_type=UIType.Float
)
begin_step_percent: float = InputField(
default=0, ge=-1, le=2, description="When the ControlNet is first applied (% of total steps)"
)
end_step_percent: float = InputField(
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
)
control_mode: CONTROLNET_MODE_VALUES = InputField(default="balanced", description="The control mode used")
resize_mode: CONTROLNET_RESIZE_VALUES = InputField(default="just_resize", description="The resize mode used")
def invoke(self, context: InvocationContext) -> ControlOutput:
return ControlOutput(
control=ControlField(
control_type="ControlNet",
image=self.image,
control_model=self.control_model,
# ip_adapter_model is currently optional
# must be either a control_model or ip_adapter_model
# ip_adapter_model=None,
control_weight=self.control_weight,
begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent,
control_mode=self.control_mode,
resize_mode=self.resize_mode,
),
)
2023-09-04 23:37:12 +00:00
IP_ADAPTER_MODELS = Literal[
"models/core/ip_adapters/sd-1/ip-adapter_sd15.bin",
"models/core/ip_adapters/sd-1/ip-adapter-plus_sd15.bin",
"models/core/ip_adapters/sd-1/ip-adapter-plus-face_sd15.bin",
"models/core/ip_adapters/sdxl/ip-adapter_sdxl.bin",
]
IP_ADAPTER_IMAGE_ENCODER_MODELS = Literal[
"models/core/ip_adapters/sd-1/image_encoder/", "models/core/ip_adapters/sdxl/image_encoder"
]
2023-09-04 23:37:12 +00:00
@invocation("ipadapter", title="IP-Adapter", tags=["ipadapter"], category="ipadapter", version="1.0.0")
class IPAdapterInvocation(BaseInvocation):
"""Collects IP-Adapter info to pass to other nodes"""
# Inputs
image: ImageField = InputField(description="The control image")
2023-09-04 23:37:12 +00:00
# control_model: ControlNetModelField = InputField(
# default="lllyasviel/sd-controlnet-canny", description=FieldDescriptions.controlnet_model, input=Input.Direct
2023-09-04 23:37:12 +00:00
# )
ip_adapter_model: IP_ADAPTER_MODELS = InputField(
default="models/core/ip_adapters/sd-1/ip-adapter_sd15.bin", description="The IP-Adapter model"
2023-09-04 23:37:12 +00:00
)
image_encoder_model: IP_ADAPTER_IMAGE_ENCODER_MODELS = InputField(
default="models/core/ip_adapters/sd-1/image_encoder/", description="The image encoder model"
2023-09-04 23:37:12 +00:00
)
control_weight: Union[float, List[float]] = InputField(
default=1.0, description="The weight given to the ControlNet", ui_type=UIType.Float
)
# begin_step_percent: float = InputField(
# default=0, ge=-1, le=2, description="When the ControlNet is first applied (% of total steps)"
# )
# end_step_percent: float = InputField(
# default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
# )
# control_mode: CONTROLNET_MODE_VALUES = InputField(default="balanced", description="The control mode used")
# resize_mode: CONTROLNET_RESIZE_VALUES = InputField(default="just_resize", description="The resize mode used")
def invoke(self, context: InvocationContext) -> ControlOutput:
return ControlOutput(
control=ControlField(
control_type="IP-Adapter",
image=self.image,
# control_model is currently optional
# must be either a control_model or ip_adapter_model
# control_model=None,
ip_adapter_model=(
context.services.configuration.get_config().root_dir / self.ip_adapter_model
).as_posix(),
image_encoder_model=(
context.services.configuration.get_config().root_dir / self.image_encoder_model
).as_posix(),
control_weight=self.control_weight,
# rest are currently ignored
2023-09-04 23:37:12 +00:00
# begin_step_percent=self.begin_step_percent,
# end_step_percent=self.end_step_percent,
# control_mode=self.control_mode,
# resize_mode=self.resize_mode,
),
)