# Invocations for ControlNet image preprocessors # initial implementation by Gregg Helt, 2023 # heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux from builtins import bool, float from typing import Dict, List, Literal, Optional, Union import cv2 import numpy as np from controlnet_aux import (CannyDetector, ContentShuffleDetector, HEDdetector, LeresDetector, LineartAnimeDetector, LineartDetector, MediapipeFaceDetector, MidasDetector, MLSDdetector, NormalBaeDetector, OpenposeDetector, PidiNetDetector, SamDetector, ZoeDetector) from controlnet_aux.util import HWC3, ade_palette from PIL import Image from pydantic import BaseModel, Field, validator from ...backend.model_management import BaseModelType, ModelType from ..models.image import ImageCategory, ImageField, ResourceOrigin from .baseinvocation import (BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext) from ..models.image import ImageOutput, PILInvocationConfig CONTROLNET_DEFAULT_MODELS = [ ########################################### # lllyasviel sd v1.5, ControlNet v1.0 models ############################################## "lllyasviel/sd-controlnet-canny", "lllyasviel/sd-controlnet-depth", "lllyasviel/sd-controlnet-hed", "lllyasviel/sd-controlnet-seg", "lllyasviel/sd-controlnet-openpose", "lllyasviel/sd-controlnet-scribble", "lllyasviel/sd-controlnet-normal", "lllyasviel/sd-controlnet-mlsd", ############################################# # lllyasviel sd v1.5, ControlNet v1.1 models ############################################# "lllyasviel/control_v11p_sd15_canny", "lllyasviel/control_v11p_sd15_openpose", "lllyasviel/control_v11p_sd15_seg", # "lllyasviel/control_v11p_sd15_depth", # broken "lllyasviel/control_v11f1p_sd15_depth", "lllyasviel/control_v11p_sd15_normalbae", "lllyasviel/control_v11p_sd15_scribble", "lllyasviel/control_v11p_sd15_mlsd", "lllyasviel/control_v11p_sd15_softedge", "lllyasviel/control_v11p_sd15s2_lineart_anime", "lllyasviel/control_v11p_sd15_lineart", "lllyasviel/control_v11p_sd15_inpaint", # "lllyasviel/control_v11u_sd15_tile", # problem (temporary?) with huffingface "lllyasviel/control_v11u_sd15_tile", # so for now replace "lllyasviel/control_v11f1e_sd15_tile", "lllyasviel/control_v11e_sd15_shuffle", "lllyasviel/control_v11e_sd15_ip2p", "lllyasviel/control_v11f1e_sd15_tile", ################################################# # thibaud sd v2.1 models (ControlNet v1.0? or v1.1? ################################################## "thibaud/controlnet-sd21-openpose-diffusers", "thibaud/controlnet-sd21-canny-diffusers", "thibaud/controlnet-sd21-depth-diffusers", "thibaud/controlnet-sd21-scribble-diffusers", "thibaud/controlnet-sd21-hed-diffusers", "thibaud/controlnet-sd21-zoedepth-diffusers", "thibaud/controlnet-sd21-color-diffusers", "thibaud/controlnet-sd21-openposev2-diffusers", "thibaud/controlnet-sd21-lineart-diffusers", "thibaud/controlnet-sd21-normalbae-diffusers", "thibaud/controlnet-sd21-ade20k-diffusers", ############################################## # ControlNetMediaPipeface, ControlNet v1.1 ############################################## # ["CrucibleAI/ControlNetMediaPipeFace", "diffusion_sd15"], # SD 1.5 # diffusion_sd15 needs to be passed to from_pretrained() as subfolder arg # hacked t2l to split to model & subfolder if format is "model,subfolder" "CrucibleAI/ControlNetMediaPipeFace,diffusion_sd15", # SD 1.5 "CrucibleAI/ControlNetMediaPipeFace", # SD 2.1? ] CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)] CONTROLNET_MODE_VALUES = Literal[tuple( ["balanced", "more_prompt", "more_control", "unbalanced"])] CONTROLNET_RESIZE_VALUES = Literal[tuple( ["just_resize", "crop_resize", "fill_resize", "just_resize_simple",])] class ControlNetModelField(BaseModel): """ControlNet model field""" model_name: str = Field(description="Name of the ControlNet model") base_model: BaseModelType = Field(description="Base model") class ControlField(BaseModel): image: ImageField = Field(default=None, description="The control image") control_model: Optional[ControlNetModelField] = Field( default=None, description="The ControlNet model to use") # control_weight: Optional[float] = Field(default=1, description="weight given to controlnet") control_weight: Union[float, List[float]] = Field( default=1, description="The weight given to the ControlNet") begin_step_percent: float = Field( default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)") end_step_percent: float = Field( default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)") control_mode: CONTROLNET_MODE_VALUES = Field( default="balanced", description="The control mode to use") resize_mode: CONTROLNET_RESIZE_VALUES = Field( default="just_resize", description="The resize mode to use") @validator("control_weight") def validate_control_weight(cls, v): """Validate that all control weights in the valid range""" if isinstance(v, list): for i in v: if i < -1 or i > 2: raise ValueError( 'Control weights must be within -1 to 2 range') else: if v < -1 or v > 2: raise ValueError('Control weights must be within -1 to 2 range') return v class Config: schema_extra = { "required": ["image", "control_model", "control_weight", "begin_step_percent", "end_step_percent"], "ui": { "type_hints": { "control_weight": "float", "control_model": "controlnet_model", # "control_weight": "number", } } } class ControlOutput(BaseInvocationOutput): """node output for ControlNet info""" # fmt: off type: Literal["control_output"] = "control_output" control: ControlField = Field(default=None, description="The control info") # fmt: on class ControlNetInvocation(BaseInvocation): """Collects ControlNet info to pass to other nodes""" # fmt: off type: Literal["controlnet"] = "controlnet" # Inputs image: ImageField = Field(default=None, description="The control image") control_model: ControlNetModelField = Field(default="lllyasviel/sd-controlnet-canny", description="control model used") control_weight: Union[float, List[float]] = Field(default=1.0, description="The weight given to the ControlNet") begin_step_percent: float = Field(default=0, ge=-1, le=2, description="When the ControlNet is first applied (% of total steps)") end_step_percent: float = Field(default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)") control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode used") resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode used") # fmt: on class Config(InvocationConfig): schema_extra = { "ui": { "title": "ControlNet", "tags": ["controlnet", "latents"], "type_hints": { "model": "model", "control": "control", # "cfg_scale": "float", "cfg_scale": "number", "control_weight": "float", } }, } def invoke(self, context: InvocationContext) -> ControlOutput: return ControlOutput( control=ControlField( image=self.image, control_model=self.control_model, control_weight=self.control_weight, begin_step_percent=self.begin_step_percent, end_step_percent=self.end_step_percent, control_mode=self.control_mode, resize_mode=self.resize_mode, ), ) class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig): """Base class for invocations that preprocess images for ControlNet""" # fmt: off type: Literal["image_processor"] = "image_processor" # Inputs image: ImageField = Field(default=None, description="The image to process") # fmt: on class Config(InvocationConfig): schema_extra = { "ui": { "title": "Image Processor", "tags": ["image", "processor"] }, } def run_processor(self, image): # superclass just passes through image without processing return image def invoke(self, context: InvocationContext) -> ImageOutput: raw_image = context.services.images.get_pil_image(self.image.image_name) # image type should be PIL.PngImagePlugin.PngImageFile ? processed_image = self.run_processor(raw_image) # FIXME: what happened to image metadata? # metadata = context.services.metadata.build_metadata( # session_id=context.graph_execution_state_id, node=self # ) # currently can't see processed image in node UI without a showImage node, # so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery image_dto = context.services.images.create( image=processed_image, image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.CONTROL, session_id=context.graph_execution_state_id, node_id=self.id, is_intermediate=self.is_intermediate ) """Builds an ImageOutput and its ImageField""" processed_image_field = ImageField(image_name=image_dto.image_name) return ImageOutput( image=processed_image_field, # width=processed_image.width, width=image_dto.width, # height=processed_image.height, height=image_dto.height, # mode=processed_image.mode, ) class CannyImageProcessorInvocation( ImageProcessorInvocation, PILInvocationConfig): """Canny edge detection for ControlNet""" # fmt: off type: Literal["canny_image_processor"] = "canny_image_processor" # Input low_threshold: int = Field(default=100, ge=0, le=255, description="The low threshold of the Canny pixel gradient (0-255)") high_threshold: int = Field(default=200, ge=0, le=255, description="The high threshold of the Canny pixel gradient (0-255)") # fmt: on class Config(InvocationConfig): schema_extra = { "ui": { "title": "Canny Processor", "tags": ["controlnet", "canny", "image", "processor"] }, } def run_processor(self, image): canny_processor = CannyDetector() processed_image = canny_processor( image, self.low_threshold, self.high_threshold) return processed_image class HedImageProcessorInvocation( ImageProcessorInvocation, PILInvocationConfig): """Applies HED edge detection to image""" # fmt: off type: Literal["hed_image_processor"] = "hed_image_processor" # Inputs detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection") image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image") # safe not supported in controlnet_aux v0.0.3 # safe: bool = Field(default=False, description="whether to use safe mode") scribble: bool = Field(default=False, description="Whether to use scribble mode") # fmt: on class Config(InvocationConfig): schema_extra = { "ui": { "title": "Softedge(HED) Processor", "tags": ["controlnet", "softedge", "hed", "image", "processor"] }, } def run_processor(self, image): hed_processor = HEDdetector.from_pretrained("lllyasviel/Annotators") processed_image = hed_processor(image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution, # safe not supported in controlnet_aux v0.0.3 # safe=self.safe, scribble=self.scribble, ) return processed_image class LineartImageProcessorInvocation( ImageProcessorInvocation, PILInvocationConfig): """Applies line art processing to image""" # fmt: off type: Literal["lineart_image_processor"] = "lineart_image_processor" # Inputs detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection") image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image") coarse: bool = Field(default=False, description="Whether to use coarse mode") # fmt: on class Config(InvocationConfig): schema_extra = { "ui": { "title": "Lineart Processor", "tags": ["controlnet", "lineart", "image", "processor"] }, } def run_processor(self, image): lineart_processor = LineartDetector.from_pretrained( "lllyasviel/Annotators") processed_image = lineart_processor( image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution, coarse=self.coarse) return processed_image class LineartAnimeImageProcessorInvocation( ImageProcessorInvocation, PILInvocationConfig): """Applies line art anime processing to image""" # fmt: off type: Literal["lineart_anime_image_processor"] = "lineart_anime_image_processor" # Inputs detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection") image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image") # fmt: on class Config(InvocationConfig): schema_extra = { "ui": { "title": "Lineart Anime Processor", "tags": ["controlnet", "lineart", "anime", "image", "processor"] }, } def run_processor(self, image): processor = LineartAnimeDetector.from_pretrained( "lllyasviel/Annotators") processed_image = processor(image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution, ) return processed_image class OpenposeImageProcessorInvocation( ImageProcessorInvocation, PILInvocationConfig): """Applies Openpose processing to image""" # fmt: off type: Literal["openpose_image_processor"] = "openpose_image_processor" # Inputs hand_and_face: bool = Field(default=False, description="Whether to use hands and face mode") detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection") image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image") # fmt: on class Config(InvocationConfig): schema_extra = { "ui": { "title": "Openpose Processor", "tags": ["controlnet", "openpose", "image", "processor"] }, } def run_processor(self, image): openpose_processor = OpenposeDetector.from_pretrained( "lllyasviel/Annotators") processed_image = openpose_processor( image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution, hand_and_face=self.hand_and_face,) return processed_image class MidasDepthImageProcessorInvocation( ImageProcessorInvocation, PILInvocationConfig): """Applies Midas depth processing to image""" # fmt: off type: Literal["midas_depth_image_processor"] = "midas_depth_image_processor" # Inputs a_mult: float = Field(default=2.0, ge=0, description="Midas parameter `a_mult` (a = a_mult * PI)") bg_th: float = Field(default=0.1, ge=0, description="Midas parameter `bg_th`") # depth_and_normal not supported in controlnet_aux v0.0.3 # depth_and_normal: bool = Field(default=False, description="whether to use depth and normal mode") # fmt: on class Config(InvocationConfig): schema_extra = { "ui": { "title": "Midas (Depth) Processor", "tags": ["controlnet", "midas", "depth", "image", "processor"] }, } def run_processor(self, image): midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators") processed_image = midas_processor(image, a=np.pi * self.a_mult, bg_th=self.bg_th, # dept_and_normal not supported in controlnet_aux v0.0.3 # depth_and_normal=self.depth_and_normal, ) return processed_image class NormalbaeImageProcessorInvocation( ImageProcessorInvocation, PILInvocationConfig): """Applies NormalBae processing to image""" # fmt: off type: Literal["normalbae_image_processor"] = "normalbae_image_processor" # Inputs detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection") image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image") # fmt: on class Config(InvocationConfig): schema_extra = { "ui": { "title": "Normal BAE Processor", "tags": ["controlnet", "normal", "bae", "image", "processor"] }, } def run_processor(self, image): normalbae_processor = NormalBaeDetector.from_pretrained( "lllyasviel/Annotators") processed_image = normalbae_processor( image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution) return processed_image class MlsdImageProcessorInvocation( ImageProcessorInvocation, PILInvocationConfig): """Applies MLSD processing to image""" # fmt: off type: Literal["mlsd_image_processor"] = "mlsd_image_processor" # Inputs detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection") image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image") thr_v: float = Field(default=0.1, ge=0, description="MLSD parameter `thr_v`") thr_d: float = Field(default=0.1, ge=0, description="MLSD parameter `thr_d`") # fmt: on class Config(InvocationConfig): schema_extra = { "ui": { "title": "MLSD Processor", "tags": ["controlnet", "mlsd", "image", "processor"] }, } def run_processor(self, image): mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators") processed_image = mlsd_processor( image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution, thr_v=self.thr_v, thr_d=self.thr_d) return processed_image class PidiImageProcessorInvocation( ImageProcessorInvocation, PILInvocationConfig): """Applies PIDI processing to image""" # fmt: off type: Literal["pidi_image_processor"] = "pidi_image_processor" # Inputs detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection") image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image") safe: bool = Field(default=False, description="Whether to use safe mode") scribble: bool = Field(default=False, description="Whether to use scribble mode") # fmt: on class Config(InvocationConfig): schema_extra = { "ui": { "title": "PIDI Processor", "tags": ["controlnet", "pidi", "image", "processor"] }, } def run_processor(self, image): pidi_processor = PidiNetDetector.from_pretrained( "lllyasviel/Annotators") processed_image = pidi_processor( image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution, safe=self.safe, scribble=self.scribble) return processed_image class ContentShuffleImageProcessorInvocation( ImageProcessorInvocation, PILInvocationConfig): """Applies content shuffle processing to image""" # fmt: off type: Literal["content_shuffle_image_processor"] = "content_shuffle_image_processor" # Inputs detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection") image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image") h: Optional[int] = Field(default=512, ge=0, description="Content shuffle `h` parameter") w: Optional[int] = Field(default=512, ge=0, description="Content shuffle `w` parameter") f: Optional[int] = Field(default=256, ge=0, description="Content shuffle `f` parameter") # fmt: on class Config(InvocationConfig): schema_extra = { "ui": { "title": "Content Shuffle Processor", "tags": ["controlnet", "contentshuffle", "image", "processor"] }, } def run_processor(self, image): content_shuffle_processor = ContentShuffleDetector() processed_image = content_shuffle_processor(image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution, h=self.h, w=self.w, f=self.f ) return processed_image # should work with controlnet_aux >= 0.0.4 and timm <= 0.6.13 class ZoeDepthImageProcessorInvocation( ImageProcessorInvocation, PILInvocationConfig): """Applies Zoe depth processing to image""" # fmt: off type: Literal["zoe_depth_image_processor"] = "zoe_depth_image_processor" # fmt: on class Config(InvocationConfig): schema_extra = { "ui": { "title": "Zoe (Depth) Processor", "tags": ["controlnet", "zoe", "depth", "image", "processor"] }, } def run_processor(self, image): zoe_depth_processor = ZoeDetector.from_pretrained( "lllyasviel/Annotators") processed_image = zoe_depth_processor(image) return processed_image class MediapipeFaceProcessorInvocation( ImageProcessorInvocation, PILInvocationConfig): """Applies mediapipe face processing to image""" # fmt: off type: Literal["mediapipe_face_processor"] = "mediapipe_face_processor" # Inputs max_faces: int = Field(default=1, ge=1, description="Maximum number of faces to detect") min_confidence: float = Field(default=0.5, ge=0, le=1, description="Minimum confidence for face detection") # fmt: on class Config(InvocationConfig): schema_extra = { "ui": { "title": "Mediapipe Processor", "tags": ["controlnet", "mediapipe", "image", "processor"] }, } def run_processor(self, image): # MediaPipeFaceDetector throws an error if image has alpha channel # so convert to RGB if needed if image.mode == 'RGBA': image = image.convert('RGB') mediapipe_face_processor = MediapipeFaceDetector() processed_image = mediapipe_face_processor( image, max_faces=self.max_faces, min_confidence=self.min_confidence) return processed_image class LeresImageProcessorInvocation( ImageProcessorInvocation, PILInvocationConfig): """Applies leres processing to image""" # fmt: off type: Literal["leres_image_processor"] = "leres_image_processor" # Inputs thr_a: float = Field(default=0, description="Leres parameter `thr_a`") thr_b: float = Field(default=0, description="Leres parameter `thr_b`") boost: bool = Field(default=False, description="Whether to use boost mode") detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection") image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image") # fmt: on class Config(InvocationConfig): schema_extra = { "ui": { "title": "Leres (Depth) Processor", "tags": ["controlnet", "leres", "depth", "image", "processor"] }, } def run_processor(self, image): leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators") processed_image = leres_processor( image, thr_a=self.thr_a, thr_b=self.thr_b, boost=self.boost, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution) return processed_image class TileResamplerProcessorInvocation( ImageProcessorInvocation, PILInvocationConfig): # fmt: off type: Literal["tile_image_processor"] = "tile_image_processor" # Inputs #res: int = Field(default=512, ge=0, le=1024, description="The pixel resolution for each tile") down_sampling_rate: float = Field(default=1.0, ge=1.0, le=8.0, description="Down sampling rate") # fmt: on class Config(InvocationConfig): schema_extra = { "ui": { "title": "Tile Resample Processor", "tags": ["controlnet", "tile", "resample", "image", "processor"] }, } # tile_resample copied from sd-webui-controlnet/scripts/processor.py def tile_resample(self, np_img: np.ndarray, res=512, # never used? down_sampling_rate=1.0, ): np_img = HWC3(np_img) if down_sampling_rate < 1.1: return np_img H, W, C = np_img.shape H = int(float(H) / float(down_sampling_rate)) W = int(float(W) / float(down_sampling_rate)) np_img = cv2.resize(np_img, (W, H), interpolation=cv2.INTER_AREA) return np_img def run_processor(self, img): np_img = np.array(img, dtype=np.uint8) processed_np_image = self.tile_resample(np_img, # res=self.tile_size, down_sampling_rate=self.down_sampling_rate ) processed_image = Image.fromarray(processed_np_image) return processed_image class SegmentAnythingProcessorInvocation( ImageProcessorInvocation, PILInvocationConfig): """Applies segment anything processing to image""" # fmt: off type: Literal["segment_anything_processor"] = "segment_anything_processor" # fmt: on class Config(InvocationConfig): schema_extra = {"ui": {"title": "Segment Anything Processor", "tags": [ "controlnet", "segment", "anything", "sam", "image", "processor"]}, } def run_processor(self, image): # segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints") segment_anything_processor = SamDetectorReproducibleColors.from_pretrained( "ybelkada/segment-anything", subfolder="checkpoints") np_img = np.array(image, dtype=np.uint8) processed_image = segment_anything_processor(np_img) return processed_image class SamDetectorReproducibleColors(SamDetector): # overriding SamDetector.show_anns() method to use reproducible colors for segmentation image # base class show_anns() method randomizes colors, # which seems to also lead to non-reproducible image generation # so using ADE20k color palette instead def show_anns(self, anns: List[Dict]): if len(anns) == 0: return sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) h, w = anns[0]['segmentation'].shape final_img = Image.fromarray( np.zeros((h, w, 3), dtype=np.uint8), mode="RGB") palette = ade_palette() for i, ann in enumerate(sorted_anns): m = ann['segmentation'] img = np.empty((m.shape[0], m.shape[1], 3), dtype=np.uint8) # doing modulo just in case number of annotated regions exceeds number of colors in palette ann_color = palette[i % len(palette)] img[:, :] = ann_color final_img.paste( Image.fromarray(img, mode="RGB"), (0, 0), Image.fromarray(np.uint8(m * 255))) return np.array(final_img, dtype=np.uint8)