Added TileResampler ControlNet preprocessor node.

Also fixes to SegmentAnything ControlNet preprocessor node.
This commit is contained in:
user1 2023-06-26 04:27:26 -07:00
parent 10e8389fa4
commit 873c18bc4b

View File

@ -1,8 +1,9 @@
# InvokeAI nodes for ControlNet image preprocessors # Invocations for ControlNet image preprocessors
# initial implementation by Gregg Helt, 2023 # initial implementation by Gregg Helt, 2023
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux # heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
from builtins import float, bool from builtins import float, bool
import cv2
import numpy as np import numpy as np
from typing import Literal, Optional, Union, List, Dict from typing import Literal, Optional, Union, List, Dict
from PIL import Image, ImageFilter, ImageOps from PIL import Image, ImageFilter, ImageOps
@ -33,7 +34,7 @@ from controlnet_aux import (
# LeresDetector, # LeresDetector,
) )
from controlnet_aux.util import ade_palette from controlnet_aux.util import HWC3, ade_palette
from .image import ImageOutput, PILInvocationConfig from .image import ImageOutput, PILInvocationConfig
@ -483,6 +484,43 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
# image_resolution=self.image_resolution) # image_resolution=self.image_resolution)
# return processed_image # return processed_image
class TileResamplerProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
# fmt: off
type: Literal["tile_image_processor"] = "tile_image_processor"
# Inputs
#res: int = Field(default=512, ge=0, le=1024, description="The pixel resolution for each tile")
down_sampling_rate: float = Field(default=1.0, ge=1.0, le=8.0, description="Down sampling rate")
# fmt: on
# tile_resample copied from sd-webui-controlnet/scripts/processor.py
def tile_resample(self,
np_img: np.ndarray,
res=512, # never used?
down_sampling_rate=1.0,
):
np_img = HWC3(np_img)
if down_sampling_rate < 1.1:
return np_img
H, W, C = np_img.shape
H = int(float(H) / float(down_sampling_rate))
W = int(float(W) / float(down_sampling_rate))
np_img = cv2.resize(np_img, (W, H), interpolation=cv2.INTER_AREA)
return np_img
def run_processor(self, img):
np_img = np.array(img, dtype=np.uint8)
processed_np_image = self.tile_resample(np_img,
#res=self.tile_size,
down_sampling_rate=self.down_sampling_rate
)
processed_image = Image.fromarray(processed_np_image)
return processed_image
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): class SegmentAnythingProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
"""Applies segment anything processing to image""" """Applies segment anything processing to image"""
# fmt: off # fmt: off
@ -492,7 +530,8 @@ class SegmentAnythingProcessorInvocation(ImageProcessorInvocation, PILInvocation
def run_processor(self, image): def run_processor(self, image):
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints") # segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
segment_anything_processor = SamDetectorReproducibleColors.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints") segment_anything_processor = SamDetectorReproducibleColors.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
processed_image = segment_anything_processor(image) np_img = np.array(image, dtype=np.uint8)
processed_image = segment_anything_processor(np_img)
return processed_image return processed_image
class SamDetectorReproducibleColors(SamDetector): class SamDetectorReproducibleColors(SamDetector):