# Invocations for ControlNet image preprocessors # heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux from builtins import bool, float from typing import Dict, List, Literal, Optional, Union import cv2 import numpy as np from controlnet_aux import ( CannyDetector, ContentShuffleDetector, HEDdetector, LeresDetector, LineartAnimeDetector, LineartDetector, MediapipeFaceDetector, MidasDetector, MLSDdetector, NormalBaeDetector, OpenposeDetector, PidiNetDetector, SamDetector, ZoeDetector, ) from controlnet_aux.util import HWC3, ade_palette from PIL import Image from invokeai.app.invocations.primitives import ImageField, ImageOutput from ..models.image import ImageCategory, ResourceOrigin from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, FieldDescriptions, Input, InputField, InvocationContext, OutputField, UIType, invocation, ) @invocation( "image_processor", title="Base Image Processor", tags=["controlnet"], category="controlnet", version="1.0.0" ) class ImageProcessorInvocation(BaseInvocation): """Base class for invocations that preprocess images for ControlNet""" image: ImageField = InputField(description="The image to process") def run_processor(self, image): # superclass just passes through image without processing return image def invoke(self, context: InvocationContext) -> ImageOutput: raw_image = context.services.images.get_pil_image(self.image.image_name) # image type should be PIL.PngImagePlugin.PngImageFile ? processed_image = self.run_processor(raw_image) # currently can't see processed image in node UI without a showImage node, # so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery image_dto = context.services.images.create( image=processed_image, image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.CONTROL, session_id=context.graph_execution_state_id, node_id=self.id, is_intermediate=self.is_intermediate, workflow=self.workflow, ) """Builds an ImageOutput and its ImageField""" processed_image_field = ImageField(image_name=image_dto.image_name) return ImageOutput( image=processed_image_field, # width=processed_image.width, width=image_dto.width, # height=processed_image.height, height=image_dto.height, # mode=processed_image.mode, ) @invocation( "canny_image_processor", title="Canny Processor", tags=["controlnet", "canny"], category="controlnet", version="1.0.0", ) class CannyImageProcessorInvocation(ImageProcessorInvocation): """Canny edge detection for ControlNet""" low_threshold: int = InputField( default=100, ge=0, le=255, description="The low threshold of the Canny pixel gradient (0-255)" ) high_threshold: int = InputField( default=200, ge=0, le=255, description="The high threshold of the Canny pixel gradient (0-255)" ) def run_processor(self, image): canny_processor = CannyDetector() processed_image = canny_processor(image, self.low_threshold, self.high_threshold) return processed_image @invocation( "hed_image_processor", title="HED (softedge) Processor", tags=["controlnet", "hed", "softedge"], category="controlnet", version="1.0.0", ) class HedImageProcessorInvocation(ImageProcessorInvocation): """Applies HED edge detection to image""" detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res) image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) # safe not supported in controlnet_aux v0.0.3 # safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode) scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode) def run_processor(self, image): hed_processor = HEDdetector.from_pretrained("lllyasviel/Annotators") processed_image = hed_processor( image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution, # safe not supported in controlnet_aux v0.0.3 # safe=self.safe, scribble=self.scribble, ) return processed_image @invocation( "lineart_image_processor", title="Lineart Processor", tags=["controlnet", "lineart"], category="controlnet", version="1.0.0", ) class LineartImageProcessorInvocation(ImageProcessorInvocation): """Applies line art processing to image""" detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res) image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) coarse: bool = InputField(default=False, description="Whether to use coarse mode") def run_processor(self, image): lineart_processor = LineartDetector.from_pretrained("lllyasviel/Annotators") processed_image = lineart_processor( image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution, coarse=self.coarse ) return processed_image @invocation( "lineart_anime_image_processor", title="Lineart Anime Processor", tags=["controlnet", "lineart", "anime"], category="controlnet", version="1.0.0", ) class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation): """Applies line art anime processing to image""" detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res) image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) def run_processor(self, image): processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators") processed_image = processor( image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution, ) return processed_image @invocation( "openpose_image_processor", title="Openpose Processor", tags=["controlnet", "openpose", "pose"], category="controlnet", version="1.0.0", ) class OpenposeImageProcessorInvocation(ImageProcessorInvocation): """Applies Openpose processing to image""" hand_and_face: bool = InputField(default=False, description="Whether to use hands and face mode") detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res) image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) def run_processor(self, image): openpose_processor = OpenposeDetector.from_pretrained("lllyasviel/Annotators") processed_image = openpose_processor( image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution, hand_and_face=self.hand_and_face, ) return processed_image @invocation( "midas_depth_image_processor", title="Midas Depth Processor", tags=["controlnet", "midas"], category="controlnet", version="1.0.0", ) class MidasDepthImageProcessorInvocation(ImageProcessorInvocation): """Applies Midas depth processing to image""" a_mult: float = InputField(default=2.0, ge=0, description="Midas parameter `a_mult` (a = a_mult * PI)") bg_th: float = InputField(default=0.1, ge=0, description="Midas parameter `bg_th`") # depth_and_normal not supported in controlnet_aux v0.0.3 # depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode") def run_processor(self, image): midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators") processed_image = midas_processor( image, a=np.pi * self.a_mult, bg_th=self.bg_th, # dept_and_normal not supported in controlnet_aux v0.0.3 # depth_and_normal=self.depth_and_normal, ) return processed_image @invocation( "normalbae_image_processor", title="Normal BAE Processor", tags=["controlnet"], category="controlnet", version="1.0.0", ) class NormalbaeImageProcessorInvocation(ImageProcessorInvocation): """Applies NormalBae processing to image""" detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res) image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) def run_processor(self, image): normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators") processed_image = normalbae_processor( image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution ) return processed_image @invocation( "mlsd_image_processor", title="MLSD Processor", tags=["controlnet", "mlsd"], category="controlnet", version="1.0.0" ) class MlsdImageProcessorInvocation(ImageProcessorInvocation): """Applies MLSD processing to image""" detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res) image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) thr_v: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_v`") thr_d: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_d`") def run_processor(self, image): mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators") processed_image = mlsd_processor( image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution, thr_v=self.thr_v, thr_d=self.thr_d, ) return processed_image @invocation( "pidi_image_processor", title="PIDI Processor", tags=["controlnet", "pidi"], category="controlnet", version="1.0.0" ) class PidiImageProcessorInvocation(ImageProcessorInvocation): """Applies PIDI processing to image""" detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res) image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode) scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode) def run_processor(self, image): pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators") processed_image = pidi_processor( image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution, safe=self.safe, scribble=self.scribble, ) return processed_image @invocation( "content_shuffle_image_processor", title="Content Shuffle Processor", tags=["controlnet", "contentshuffle"], category="controlnet", version="1.0.0", ) class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation): """Applies content shuffle processing to image""" detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res) image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) h: Optional[int] = InputField(default=512, ge=0, description="Content shuffle `h` parameter") w: Optional[int] = InputField(default=512, ge=0, description="Content shuffle `w` parameter") f: Optional[int] = InputField(default=256, ge=0, description="Content shuffle `f` parameter") def run_processor(self, image): content_shuffle_processor = ContentShuffleDetector() processed_image = content_shuffle_processor( image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution, h=self.h, w=self.w, f=self.f, ) return processed_image # should work with controlnet_aux >= 0.0.4 and timm <= 0.6.13 @invocation( "zoe_depth_image_processor", title="Zoe (Depth) Processor", tags=["controlnet", "zoe", "depth"], category="controlnet", version="1.0.0", ) class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation): """Applies Zoe depth processing to image""" def run_processor(self, image): zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators") processed_image = zoe_depth_processor(image) return processed_image @invocation( "mediapipe_face_processor", title="Mediapipe Face Processor", tags=["controlnet", "mediapipe", "face"], category="controlnet", version="1.0.0", ) class MediapipeFaceProcessorInvocation(ImageProcessorInvocation): """Applies mediapipe face processing to image""" max_faces: int = InputField(default=1, ge=1, description="Maximum number of faces to detect") min_confidence: float = InputField(default=0.5, ge=0, le=1, description="Minimum confidence for face detection") def run_processor(self, image): # MediaPipeFaceDetector throws an error if image has alpha channel # so convert to RGB if needed if image.mode == "RGBA": image = image.convert("RGB") mediapipe_face_processor = MediapipeFaceDetector() processed_image = mediapipe_face_processor(image, max_faces=self.max_faces, min_confidence=self.min_confidence) return processed_image @invocation( "leres_image_processor", title="Leres (Depth) Processor", tags=["controlnet", "leres", "depth"], category="controlnet", version="1.0.0", ) class LeresImageProcessorInvocation(ImageProcessorInvocation): """Applies leres processing to image""" thr_a: float = InputField(default=0, description="Leres parameter `thr_a`") thr_b: float = InputField(default=0, description="Leres parameter `thr_b`") boost: bool = InputField(default=False, description="Whether to use boost mode") detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res) image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) def run_processor(self, image): leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators") processed_image = leres_processor( image, thr_a=self.thr_a, thr_b=self.thr_b, boost=self.boost, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution, ) return processed_image @invocation( "tile_image_processor", title="Tile Resample Processor", tags=["controlnet", "tile"], category="controlnet", version="1.0.0", ) class TileResamplerProcessorInvocation(ImageProcessorInvocation): """Tile resampler processor""" # res: int = InputField(default=512, ge=0, le=1024, description="The pixel resolution for each tile") down_sampling_rate: float = InputField(default=1.0, ge=1.0, le=8.0, description="Down sampling rate") # tile_resample copied from sd-webui-controlnet/scripts/processor.py def tile_resample( self, np_img: np.ndarray, res=512, # never used? down_sampling_rate=1.0, ): np_img = HWC3(np_img) if down_sampling_rate < 1.1: return np_img H, W, C = np_img.shape H = int(float(H) / float(down_sampling_rate)) W = int(float(W) / float(down_sampling_rate)) np_img = cv2.resize(np_img, (W, H), interpolation=cv2.INTER_AREA) return np_img def run_processor(self, img): np_img = np.array(img, dtype=np.uint8) processed_np_image = self.tile_resample( np_img, # res=self.tile_size, down_sampling_rate=self.down_sampling_rate, ) processed_image = Image.fromarray(processed_np_image) return processed_image @invocation( "segment_anything_processor", title="Segment Anything Processor", tags=["controlnet", "segmentanything"], category="controlnet", version="1.0.0", ) class SegmentAnythingProcessorInvocation(ImageProcessorInvocation): """Applies segment anything processing to image""" def run_processor(self, image): # segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints") segment_anything_processor = SamDetectorReproducibleColors.from_pretrained( "ybelkada/segment-anything", subfolder="checkpoints" ) np_img = np.array(image, dtype=np.uint8) processed_image = segment_anything_processor(np_img) return processed_image class SamDetectorReproducibleColors(SamDetector): # overriding SamDetector.show_anns() method to use reproducible colors for segmentation image # base class show_anns() method randomizes colors, # which seems to also lead to non-reproducible image generation # so using ADE20k color palette instead def show_anns(self, anns: List[Dict]): if len(anns) == 0: return sorted_anns = sorted(anns, key=(lambda x: x["area"]), reverse=True) h, w = anns[0]["segmentation"].shape final_img = Image.fromarray(np.zeros((h, w, 3), dtype=np.uint8), mode="RGB") palette = ade_palette() for i, ann in enumerate(sorted_anns): m = ann["segmentation"] img = np.empty((m.shape[0], m.shape[1], 3), dtype=np.uint8) # doing modulo just in case number of annotated regions exceeds number of colors in palette ann_color = palette[i % len(palette)] img[:, :] = ann_color final_img.paste(Image.fromarray(img, mode="RGB"), (0, 0), Image.fromarray(np.uint8(m * 255))) return np.array(final_img, dtype=np.uint8)