diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index 5e8ad06bc2..9a841af7c8 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -9,12 +9,24 @@ from .baseinvocation import ( InvocationConfig, ) -from controlnet_aux import CannyDetector +from controlnet_aux import ( + CannyDetector, + HEDdetector, + LineartDetector, + LineartAnimeDetector, + MidasDetector, + MLSDdetector, + NormalBaeDetector, + OpenposeDetector, + PidiNetDetector, + ContentShuffleDetector, + # StyleShuffleDetector, + ZoeDetector) + from .image import ImageOutput, build_image_output, PILInvocationConfig class ControlField(BaseModel): - image: ImageField = Field(default=None, description="processed image") # width: Optional[int] = Field(default=None, description="The width of the image in pixels") # height: Optional[int] = Field(default=None, description="The height of the image in pixels") @@ -38,66 +50,69 @@ class ControlOutput(BaseInvocationOutput): # image: ImageField = Field(default=None, description="outputs just them image info (which is also included in control output)") # fmt: on + class PreprocessedControlInvocation(BaseInvocation, PILInvocationConfig): - """Base class for invocations that preprocess images for ControlNet""" + """Base class for invocations that preprocess images for ControlNet""" - # fmt: off - type: Literal["preprocessed_control"] = "preprocessed_control" + # fmt: off + type: Literal["preprocessed_control"] = "preprocessed_control" - # Inputs - image: ImageField = Field(default=None, description="image to process") - control_model: str = Field(default=None, description="control model to use") - control_weight: float = Field(default=0.5, ge=0, le=1, description="control weight") - # begin_step_percent: float = Field(default=0, ge=0, le=1, - # description="% of total steps at which controlnet is first applied") - # end_step_percent: float = Field(default=1, ge=0, le=1, - # description="% of total steps at which controlnet is last applied") - # guess_mode: bool = Field(default=False, description="use guess mode (controlnet ignores prompt)") - # fmt: on + # Inputs + image: ImageField = Field(default=None, description="image to process") + control_model: str = Field(default=None, description="control model to use") + control_weight: float = Field(default=0.5, ge=0, le=1, description="control weight") - # This super class handles invoke() call, which in turn calls run_processor(image) - # subclasses override run_processor instead of implementing their own invoke() - def run_processor(self, image): - # super class pass through of image - return image + # begin_step_percent: float = Field(default=0, ge=0, le=1, + # description="% of total steps at which controlnet is first applied") + # end_step_percent: float = Field(default=1, ge=0, le=1, + # description="% of total steps at which controlnet is last applied") + # guess_mode: bool = Field(default=False, description="use guess mode (controlnet ignores prompt)") + # fmt: on - def invoke(self, context: InvocationContext) -> ControlOutput: - image = context.services.images.get( - self.image.image_type, self.image.image_name - ) - # image type should be PIL.PngImagePlugin.PngImageFile ? - processed_image = self.run_processor(image) - image_type = ImageType.INTERMEDIATE - image_name = context.services.images.create_name( - context.graph_execution_state_id, self.id - ) - metadata = context.services.metadata.build_metadata( - session_id=context.graph_execution_state_id, node=self - ) - context.services.images.save(image_type, image_name, processed_image, metadata) + # This super class handles invoke() call, which in turn calls run_processor(image) + # subclasses override run_processor instead of implementing their own invoke() + def run_processor(self, image): + # superclass just passes through image without processing + return image - """Builds an ImageOutput and its ImageField""" - image_field = ImageField( - image_name=image_name, - image_type=image_type, - ) - return ControlOutput( - control=ControlField( - image=image_field, - control_model=self.control_model, - control_weight=self.control_weight, - ) - ) + def invoke(self, context: InvocationContext) -> ControlOutput: + image = context.services.images.get( + self.image.image_type, self.image.image_name + ) + # image type should be PIL.PngImagePlugin.PngImageFile ? + processed_image = self.run_processor(image) + image_type = ImageType.INTERMEDIATE + image_name = context.services.images.create_name( + context.graph_execution_state_id, self.id + ) + metadata = context.services.metadata.build_metadata( + session_id=context.graph_execution_state_id, node=self + ) + context.services.images.save(image_type, image_name, processed_image, metadata) + + """Builds an ImageOutput and its ImageField""" + image_field = ImageField( + image_name=image_name, + image_type=image_type, + ) + return ControlOutput( + control=ControlField( + image=image_field, + control_model=self.control_model, + control_weight=self.control_weight, + ) + ) class CannyControlInvocation(PreprocessedControlInvocation, PILInvocationConfig): """Canny edge detection for ControlNet""" # fmt: off - type: Literal["cannycontrol"] = "cannycontrol" + type: Literal["canny_control"] = "canny_control" # Inputs low_threshold: float = Field(default=100, ge=0, description="low threshold of Canny pixel gradient") high_threshold: float = Field(default=200, ge=0, description="high threshold of Canny pixel gradient") + # fmt: on def run_processor(self, image):