feat(worker): add image_resolution as option for all cnet procesors

This commit is contained in:
maryhipp 2024-03-18 13:20:46 -04:00 committed by psychedelicious
parent b25850a585
commit ed0f9f7d66

View File

@ -176,6 +176,7 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
class CannyImageProcessorInvocation(ImageProcessorInvocation): class CannyImageProcessorInvocation(ImageProcessorInvocation):
"""Canny edge detection for ControlNet""" """Canny edge detection for ControlNet"""
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
low_threshold: int = InputField( low_threshold: int = InputField(
default=100, ge=0, le=255, description="The low threshold of the Canny pixel gradient (0-255)" default=100, ge=0, le=255, description="The low threshold of the Canny pixel gradient (0-255)"
) )
@ -189,7 +190,7 @@ class CannyImageProcessorInvocation(ImageProcessorInvocation):
def run_processor(self, image): def run_processor(self, image):
canny_processor = CannyDetector() canny_processor = CannyDetector()
processed_image = canny_processor(image, self.low_threshold, self.high_threshold) processed_image = canny_processor(image, self.low_threshold, self.high_threshold, image_resolution=self.image_resolution,)
return processed_image return processed_image
@ -279,6 +280,7 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
a_mult: float = InputField(default=2.0, ge=0, description="Midas parameter `a_mult` (a = a_mult * PI)") a_mult: float = InputField(default=2.0, ge=0, description="Midas parameter `a_mult` (a = a_mult * PI)")
bg_th: float = InputField(default=0.1, ge=0, description="Midas parameter `bg_th`") bg_th: float = InputField(default=0.1, ge=0, description="Midas parameter `bg_th`")
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
# depth_and_normal not supported in controlnet_aux v0.0.3 # depth_and_normal not supported in controlnet_aux v0.0.3
# depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode") # depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode")
@ -288,6 +290,7 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
image, image,
a=np.pi * self.a_mult, a=np.pi * self.a_mult,
bg_th=self.bg_th, bg_th=self.bg_th,
image_resolution=self.image_resolution
# dept_and_normal not supported in controlnet_aux v0.0.3 # dept_and_normal not supported in controlnet_aux v0.0.3
# depth_and_normal=self.depth_and_normal, # depth_and_normal=self.depth_and_normal,
) )
@ -401,9 +404,11 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation): class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
"""Applies Zoe depth processing to image""" """Applies Zoe depth processing to image"""
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
def run_processor(self, image): def run_processor(self, image):
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators") zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
processed_image = zoe_depth_processor(image) processed_image = zoe_depth_processor(image, image_resolution=self.image_resolution)
return processed_image return processed_image
@ -419,10 +424,11 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
max_faces: int = InputField(default=1, ge=1, description="Maximum number of faces to detect") max_faces: int = InputField(default=1, ge=1, description="Maximum number of faces to detect")
min_confidence: float = InputField(default=0.5, ge=0, le=1, description="Minimum confidence for face detection") min_confidence: float = InputField(default=0.5, ge=0, le=1, description="Minimum confidence for face detection")
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
def run_processor(self, image): def run_processor(self, image):
mediapipe_face_processor = MediapipeFaceDetector() mediapipe_face_processor = MediapipeFaceDetector()
processed_image = mediapipe_face_processor(image, max_faces=self.max_faces, min_confidence=self.min_confidence) processed_image = mediapipe_face_processor(image, max_faces=self.max_faces, min_confidence=self.min_confidence, image_resolution=self.image_resolution)
return processed_image return processed_image
@ -504,6 +510,7 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation):
) )
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation): class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
"""Applies segment anything processing to image""" """Applies segment anything processing to image"""
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
def run_processor(self, image): def run_processor(self, image):
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints") # segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
@ -511,7 +518,7 @@ class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
"ybelkada/segment-anything", subfolder="checkpoints" "ybelkada/segment-anything", subfolder="checkpoints"
) )
np_img = np.array(image, dtype=np.uint8) np_img = np.array(image, dtype=np.uint8)
processed_image = segment_anything_processor(np_img) processed_image = segment_anything_processor(np_img, image_resolution=self.image_resolution)
return processed_image return processed_image