Revert ControlNetInvocation changes.

This commit is contained in:
Ryan Dick
2023-09-06 19:30:30 -04:00
parent 46c9dcb113
commit cdbf40c9b2
5 changed files with 118 additions and 115 deletions

View File

@ -1,104 +0,0 @@
from builtins import float
from typing import List, Literal, Optional, Union
from pydantic import BaseModel, Field, root_validator, validator
from invokeai.app.invocations.primitives import ImageField
from ...backend.model_management import BaseModelType
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
FieldDescriptions,
Input,
InputField,
InvocationContext,
OutputField,
UIType,
invocation,
invocation_output,
)
CONTROLNET_MODE_VALUES = Literal["balanced", "more_prompt", "more_control", "unbalanced"]
CONTROLNET_RESIZE_VALUES = Literal[
"just_resize",
"crop_resize",
"fill_resize",
"just_resize_simple",
]
class ControlNetModelField(BaseModel):
"""ControlNet model field"""
model_name: str = Field(description="Name of the ControlNet model")
base_model: BaseModelType = Field(description="Base model")
class ControlField(BaseModel):
image: ImageField = Field(description="The control image")
control_model: ControlNetModelField = Field(description="The ControlNet model to use")
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
)
end_step_percent: float = Field(
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
)
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
@validator("control_weight")
def validate_control_weight(cls, v):
"""Validate that all control weights in the valid range"""
if isinstance(v, list):
for i in v:
if i < -1 or i > 2:
raise ValueError("Control weights must be within -1 to 2 range")
else:
if v < -1 or v > 2:
raise ValueError("Control weights must be within -1 to 2 range")
return v
@invocation_output("control_output")
class ControlOutput(BaseInvocationOutput):
"""node output for ControlNet info"""
# Outputs
control: ControlField = OutputField(description=FieldDescriptions.control)
@invocation("controlnet", title="ControlNet", tags=["controlnet"], category="controlnet", version="1.0.0")
class ControlNetInvocation(BaseInvocation):
"""Collects ControlNet info to pass to other nodes"""
# Inputs
image: ImageField = InputField(description="The control image")
control_model: ControlNetModelField = InputField(
default="lllyasviel/sd-controlnet-canny", description=FieldDescriptions.controlnet_model, input=Input.Direct
)
control_weight: Union[float, List[float]] = InputField(
default=1.0, description="The weight given to the ControlNet", ui_type=UIType.Float
)
begin_step_percent: float = InputField(
default=0, ge=-1, le=2, description="When the ControlNet is first applied (% of total steps)"
)
end_step_percent: float = InputField(
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
)
control_mode: CONTROLNET_MODE_VALUES = InputField(default="balanced", description="The control mode used")
resize_mode: CONTROLNET_RESIZE_VALUES = InputField(default="just_resize", description="The resize mode used")
def invoke(self, context: InvocationContext) -> ControlOutput:
return ControlOutput(
control=ControlField(
image=self.image,
control_model=self.control_model,
control_weight=self.control_weight,
begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent,
control_mode=self.control_mode,
resize_mode=self.resize_mode,
),
)

View File

@ -1,7 +1,8 @@
# Invocations for ControlNet image preprocessors # Invocations for ControlNet image preprocessors
# initial implementation by Gregg Helt, 2023
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux # heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
from builtins import bool, float from builtins import bool, float
from typing import Dict, List, Optional from typing import Dict, List, Literal, Optional, Union
import cv2 import cv2
import numpy as np import numpy as np
@ -23,11 +24,105 @@ from controlnet_aux import (
) )
from controlnet_aux.util import HWC3, ade_palette from controlnet_aux.util import HWC3, ade_palette
from PIL import Image from PIL import Image
from pydantic import BaseModel, Field, validator
from invokeai.app.invocations.primitives import ImageField, ImageOutput from invokeai.app.invocations.primitives import ImageField, ImageOutput
from ...backend.model_management import BaseModelType
from ..models.image import ImageCategory, ResourceOrigin from ..models.image import ImageCategory, ResourceOrigin
from .baseinvocation import BaseInvocation, FieldDescriptions, InputField, InvocationContext, invocation from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
FieldDescriptions,
Input,
InputField,
InvocationContext,
OutputField,
UIType,
invocation,
invocation_output,
)
CONTROLNET_MODE_VALUES = Literal["balanced", "more_prompt", "more_control", "unbalanced"]
CONTROLNET_RESIZE_VALUES = Literal[
"just_resize",
"crop_resize",
"fill_resize",
"just_resize_simple",
]
class ControlNetModelField(BaseModel):
"""ControlNet model field"""
model_name: str = Field(description="Name of the ControlNet model")
base_model: BaseModelType = Field(description="Base model")
class ControlField(BaseModel):
image: ImageField = Field(description="The control image")
control_model: ControlNetModelField = Field(description="The ControlNet model to use")
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
)
end_step_percent: float = Field(
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
)
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
@validator("control_weight")
def validate_control_weight(cls, v):
"""Validate that all control weights in the valid range"""
if isinstance(v, list):
for i in v:
if i < -1 or i > 2:
raise ValueError("Control weights must be within -1 to 2 range")
else:
if v < -1 or v > 2:
raise ValueError("Control weights must be within -1 to 2 range")
return v
@invocation_output("control_output")
class ControlOutput(BaseInvocationOutput):
"""node output for ControlNet info"""
# Outputs
control: ControlField = OutputField(description=FieldDescriptions.control)
@invocation("controlnet", title="ControlNet", tags=["controlnet"], category="controlnet", version="1.0.0")
class ControlNetInvocation(BaseInvocation):
"""Collects ControlNet info to pass to other nodes"""
image: ImageField = InputField(description="The control image")
control_model: ControlNetModelField = InputField(description=FieldDescriptions.controlnet_model, input=Input.Direct)
control_weight: Union[float, List[float]] = InputField(
default=1.0, description="The weight given to the ControlNet", ui_type=UIType.Float
)
begin_step_percent: float = InputField(
default=0, ge=-1, le=2, description="When the ControlNet is first applied (% of total steps)"
)
end_step_percent: float = InputField(
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
)
control_mode: CONTROLNET_MODE_VALUES = InputField(default="balanced", description="The control mode used")
resize_mode: CONTROLNET_RESIZE_VALUES = InputField(default="just_resize", description="The resize mode used")
def invoke(self, context: InvocationContext) -> ControlOutput:
return ControlOutput(
control=ControlField(
image=self.image,
control_model=self.control_model,
control_weight=self.control_weight,
begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent,
control_mode=self.control_mode,
resize_mode=self.resize_mode,
),
)
@invocation( @invocation(

View File

@ -64,7 +64,7 @@ from .baseinvocation import (
invocation_output, invocation_output,
) )
from .compel import ConditioningField from .compel import ConditioningField
from .control_adapter import ControlField from .controlnet_image_processors import ControlField
from .model import ModelInfo, UNetField, VaeField from .model import ModelInfo, UNetField, VaeField
DEFAULT_PRECISION = choose_precision(choose_torch_device()) DEFAULT_PRECISION = choose_precision(choose_torch_device())

View File

@ -11,7 +11,7 @@ from invokeai.app.invocations.baseinvocation import (
invocation, invocation,
invocation_output, invocation_output,
) )
from invokeai.app.invocations.control_adapter import ControlField from invokeai.app.invocations.controlnet_image_processors import ControlField
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull from invokeai.app.util.model_exclude_null import BaseModelExcludeNull

View File

@ -13,7 +13,12 @@ from pydantic import BaseModel, Field, validator
from tqdm import tqdm from tqdm import tqdm
from invokeai.app.invocations.metadata import CoreMetadata from invokeai.app.invocations.metadata import CoreMetadata
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, ImageField, ImageOutput from invokeai.app.invocations.primitives import (
ConditioningField,
ConditioningOutput,
ImageField,
ImageOutput,
)
from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend import BaseModelType, ModelType, SubModelType from invokeai.backend import BaseModelType, ModelType, SubModelType
@ -25,8 +30,8 @@ from .baseinvocation import (
BaseInvocation, BaseInvocation,
BaseInvocationOutput, BaseInvocationOutput,
FieldDescriptions, FieldDescriptions,
InputField,
Input, Input,
InputField,
InvocationContext, InvocationContext,
OutputField, OutputField,
UIComponent, UIComponent,
@ -34,8 +39,14 @@ from .baseinvocation import (
invocation, invocation,
invocation_output, invocation_output,
) )
from .control_adapter import ControlField from .controlnet_image_processors import ControlField
from .latent import SAMPLER_NAME_VALUES, LatentsField, LatentsOutput, build_latents_output, get_scheduler from .latent import (
SAMPLER_NAME_VALUES,
LatentsField,
LatentsOutput,
build_latents_output,
get_scheduler,
)
from .model import ClipField, ModelInfo, UNetField, VaeField from .model import ClipField, ModelInfo, UNetField, VaeField
ORT_TO_NP_TYPE = { ORT_TO_NP_TYPE = {
@ -95,9 +106,10 @@ class ONNXPromptInvocation(BaseInvocation):
print(f'Warn: trigger: "{trigger}" not found') print(f'Warn: trigger: "{trigger}" not found')
if loras or ti_list: if loras or ti_list:
text_encoder.release_session() text_encoder.release_session()
with ONNXModelPatcher.apply_lora_text_encoder(text_encoder, loras), ONNXModelPatcher.apply_ti( with (
orig_tokenizer, text_encoder, ti_list ONNXModelPatcher.apply_lora_text_encoder(text_encoder, loras),
) as (tokenizer, ti_manager): ONNXModelPatcher.apply_ti(orig_tokenizer, text_encoder, ti_list) as (tokenizer, ti_manager),
):
text_encoder.create_session() text_encoder.create_session()
# copy from # copy from