From f2f4c37f19498b36ef1f3d656b3da172a9edd4d6 Mon Sep 17 00:00:00 2001 From: user1 Date: Thu, 4 May 2023 16:01:22 -0700 Subject: [PATCH] Refactored ControlNet nodes so they subclass from PreprocessedControlInvocation, and only need to override run_processor(image) (instead of reimplementing invoke()) --- .../controlnet_image_processors.py | 94 ++++++++++++------- 1 file changed, 58 insertions(+), 36 deletions(-) diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index c72d064f11..5e8ad06bc2 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -30,7 +30,7 @@ class ControlField(BaseModel): class ControlOutput(BaseInvocationOutput): - """Base class for invocations that output ControlNet info""" + """node output for ControlNet info""" # fmt: off type: Literal["control_output"] = "control_output" @@ -38,52 +38,74 @@ class ControlOutput(BaseInvocationOutput): # image: ImageField = Field(default=None, description="outputs just them image info (which is also included in control output)") # fmt: on +class PreprocessedControlInvocation(BaseInvocation, PILInvocationConfig): + """Base class for invocations that preprocess images for ControlNet""" -class CannyControlInvocation(BaseInvocation, PILInvocationConfig): + # fmt: off + type: Literal["preprocessed_control"] = "preprocessed_control" + + # Inputs + image: ImageField = Field(default=None, description="image to process") + control_model: str = Field(default=None, description="control model to use") + control_weight: float = Field(default=0.5, ge=0, le=1, description="control weight") + # begin_step_percent: float = Field(default=0, ge=0, le=1, + # description="% of total steps at which controlnet is first applied") + # end_step_percent: float = Field(default=1, ge=0, le=1, + # description="% of total steps at which controlnet is last applied") + # guess_mode: bool = Field(default=False, description="use guess mode (controlnet ignores prompt)") + # fmt: on + + # This super class handles invoke() call, which in turn calls run_processor(image) + # subclasses override run_processor instead of implementing their own invoke() + def run_processor(self, image): + # super class pass through of image + return image + + def invoke(self, context: InvocationContext) -> ControlOutput: + image = context.services.images.get( + self.image.image_type, self.image.image_name + ) + # image type should be PIL.PngImagePlugin.PngImageFile ? + processed_image = self.run_processor(image) + image_type = ImageType.INTERMEDIATE + image_name = context.services.images.create_name( + context.graph_execution_state_id, self.id + ) + metadata = context.services.metadata.build_metadata( + session_id=context.graph_execution_state_id, node=self + ) + context.services.images.save(image_type, image_name, processed_image, metadata) + + """Builds an ImageOutput and its ImageField""" + image_field = ImageField( + image_name=image_name, + image_type=image_type, + ) + return ControlOutput( + control=ControlField( + image=image_field, + control_model=self.control_model, + control_weight=self.control_weight, + ) + ) + + +class CannyControlInvocation(PreprocessedControlInvocation, PILInvocationConfig): """Canny edge detection for ControlNet""" # fmt: off type: Literal["cannycontrol"] = "cannycontrol" - # Inputs - image: ImageField = Field(default=None, description="image to process") - control_model: str = Field(default=None, description="control model to use") - control_weight: float = Field(default=0.5, ge=0, le=1, description="control weight") - # begin_step_percent: float = Field(default=0, ge=0, le=1, - # description="% of total steps at which controlnet is first applied") - # end_step_percent: float = Field(default=1, ge=0, le=1, - # description="% of total steps at which controlnet is last applied") - # guess_mode: bool = Field(default=False, description="use guess mode (controlnet ignores prompt)") - low_threshold: float = Field(default=100, ge=0, description="low threshold of Canny pixel gradient") high_threshold: float = Field(default=200, ge=0, description="high threshold of Canny pixel gradient") # fmt: on - def invoke(self, context: InvocationContext) -> ControlOutput: - image = context.services.images.get( - self.image.image_type, self.image.image_name - ) + def run_processor(self, image): + print("**** running Canny processor ****") + print("image type: ", type(image)) canny_processor = CannyDetector() processed_image = canny_processor(image, self.low_threshold, self.high_threshold) - image_type = ImageType.INTERMEDIATE - image_name = context.services.images.create_name( - context.graph_execution_state_id, self.id - ) - metadata = context.services.metadata.build_metadata( - session_id=context.graph_execution_state_id, node=self - ) - context.services.images.save(image_type, image_name, processed_image, metadata) + print("processed image type: ", type(image)) + return processed_image - """Builds an ImageOutput and its ImageField""" - image_field = ImageField( - image_name=image_name, - image_type=image_type, - ) - return ControlOutput( - control=ControlField( - image=image_field, - control_model=self.control_model, - control_weight=self.control_weight, - ) - )