diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index baf558ac24..8cfe35598d 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -1,10 +1,11 @@ -# InvokeAI nodes for ControlNet image preprocessors +# Invocations for ControlNet image preprocessors # initial implementation by Gregg Helt, 2023 # heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux from builtins import float, bool +import cv2 import numpy as np -from typing import Literal, Optional, Union, List +from typing import Literal, Optional, Union, List, Dict from PIL import Image, ImageFilter, ImageOps from pydantic import BaseModel, Field, validator @@ -29,8 +30,13 @@ from controlnet_aux import ( ContentShuffleDetector, ZoeDetector, MediapipeFaceDetector, + SamDetector, + LeresDetector, ) +from controlnet_aux.util import HWC3, ade_palette + + from .image import ImageOutput, PILInvocationConfig CONTROLNET_DEFAULT_MODELS = [ @@ -95,6 +101,9 @@ CONTROLNET_DEFAULT_MODELS = [ CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)] CONTROLNET_MODE_VALUES = Literal[tuple(["balanced", "more_prompt", "more_control", "unbalanced"])] +# crop and fill options not ready yet +# CONTROLNET_RESIZE_VALUES = Literal[tuple(["just_resize", "crop_resize", "fill_resize"])] + class ControlField(BaseModel): image: ImageField = Field(default=None, description="The control image") @@ -105,7 +114,8 @@ class ControlField(BaseModel): description="When the ControlNet is first applied (% of total steps)") end_step_percent: float = Field(default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)") - control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The contorl mode to use") + control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use") + # resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use") @validator("control_weight") def abs_le_one(cls, v): @@ -180,7 +190,7 @@ class ControlNetInvocation(BaseInvocation): ), ) -# TODO: move image processors to separate file (image_analysis.py + class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig): """Base class for invocations that preprocess images for ControlNet""" @@ -452,6 +462,104 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation, PILInvocationCo # fmt: on def run_processor(self, image): + # MediaPipeFaceDetector throws an error if image has alpha channel + # so convert to RGB if needed + if image.mode == 'RGBA': + image = image.convert('RGB') mediapipe_face_processor = MediapipeFaceDetector() processed_image = mediapipe_face_processor(image, max_faces=self.max_faces, min_confidence=self.min_confidence) return processed_image + +class LeresImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): + """Applies leres processing to image""" + # fmt: off + type: Literal["leres_image_processor"] = "leres_image_processor" + # Inputs + thr_a: float = Field(default=0, description="Leres parameter `thr_a`") + thr_b: float = Field(default=0, description="Leres parameter `thr_b`") + boost: bool = Field(default=False, description="Whether to use boost mode") + detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection") + image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image") + # fmt: on + + def run_processor(self, image): + leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators") + processed_image = leres_processor(image, + thr_a=self.thr_a, + thr_b=self.thr_b, + boost=self.boost, + detect_resolution=self.detect_resolution, + image_resolution=self.image_resolution) + return processed_image + + +class TileResamplerProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): + + # fmt: off + type: Literal["tile_image_processor"] = "tile_image_processor" + # Inputs + #res: int = Field(default=512, ge=0, le=1024, description="The pixel resolution for each tile") + down_sampling_rate: float = Field(default=1.0, ge=1.0, le=8.0, description="Down sampling rate") + # fmt: on + + # tile_resample copied from sd-webui-controlnet/scripts/processor.py + def tile_resample(self, + np_img: np.ndarray, + res=512, # never used? + down_sampling_rate=1.0, + ): + np_img = HWC3(np_img) + if down_sampling_rate < 1.1: + return np_img + H, W, C = np_img.shape + H = int(float(H) / float(down_sampling_rate)) + W = int(float(W) / float(down_sampling_rate)) + np_img = cv2.resize(np_img, (W, H), interpolation=cv2.INTER_AREA) + return np_img + + def run_processor(self, img): + np_img = np.array(img, dtype=np.uint8) + processed_np_image = self.tile_resample(np_img, + #res=self.tile_size, + down_sampling_rate=self.down_sampling_rate + ) + processed_image = Image.fromarray(processed_np_image) + return processed_image + + + + +class SegmentAnythingProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): + """Applies segment anything processing to image""" + # fmt: off + type: Literal["segment_anything_processor"] = "segment_anything_processor" + # fmt: on + + def run_processor(self, image): + # segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints") + segment_anything_processor = SamDetectorReproducibleColors.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints") + np_img = np.array(image, dtype=np.uint8) + processed_image = segment_anything_processor(np_img) + return processed_image + +class SamDetectorReproducibleColors(SamDetector): + + # overriding SamDetector.show_anns() method to use reproducible colors for segmentation image + # base class show_anns() method randomizes colors, + # which seems to also lead to non-reproducible image generation + # so using ADE20k color palette instead + def show_anns(self, anns: List[Dict]): + if len(anns) == 0: + return + sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) + h, w = anns[0]['segmentation'].shape + final_img = Image.fromarray(np.zeros((h, w, 3), dtype=np.uint8), mode="RGB") + palette = ade_palette() + for i, ann in enumerate(sorted_anns): + m = ann['segmentation'] + img = np.empty((m.shape[0], m.shape[1], 3), dtype=np.uint8) + # doing modulo just in case number of annotated regions exceeds number of colors in palette + ann_color = palette[i % len(palette)] + img[:, :] = ann_color + final_img.paste(Image.fromarray(img, mode="RGB"), (0, 0), Image.fromarray(np.uint8(m * 255))) + return np.array(final_img, dtype=np.uint8) diff --git a/pyproject.toml b/pyproject.toml index 03396312ac..6e5b8f4e22 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,7 @@ dependencies = [ "click", "clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip", "compel>=1.2.1", - "controlnet-aux>=0.0.4", + "controlnet-aux>=0.0.6", "timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26 "datasets", "diffusers[torch]~=0.17.1",