diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index 933c32c908..83b0e90892 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -6,9 +6,11 @@ from typing import Dict, List, Literal, Optional, Union import cv2 import numpy as np +import torch from controlnet_aux import ( CannyDetector, ContentShuffleDetector, + DWposeDetector, HEDdetector, LeresDetector, LineartAnimeDetector, @@ -589,3 +591,27 @@ class ColorMapImageProcessorInvocation(ImageProcessorInvocation): color_map = cv2.resize(color_map, (width, height), interpolation=cv2.INTER_NEAREST) color_map = Image.fromarray(color_map) return color_map + + +@invocation( + "dwpose_image_processor", + title="DWPose Processor", + tags=["controlnet", "dwpose", "pose"], + category="controlnet", + version="1.0.0", +) +class DWPoseImageProcessorInvocation(ImageProcessorInvocation): + """Applies DW-Pose processing to image""" + + detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res) + image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) + + def run_processor(self, image): + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + dwpose_processor = DWposeDetector(device=device) + processed_image = dwpose_processor( + image, + detect_resolution=self.detect_resolution, + image_resolution=self.image_resolution, + ) + return processed_image