Compare commits

...

4 Commits

2 changed files with 34 additions and 2 deletions

View File

@ -6,9 +6,11 @@ from typing import Dict, List, Literal, Optional, Union
import cv2
import numpy as np
import torch
from controlnet_aux import (
CannyDetector,
ContentShuffleDetector,
DWposeDetector,
HEDdetector,
LeresDetector,
LineartAnimeDetector,
@ -125,7 +127,7 @@ class ControlNetInvocation(BaseInvocation):
@invocation(
"image_processor", title="Base Image Processor", tags=["controlnet"], category="controlnet", version="1.0.0"
"image_processor", title="Base Image Processorwp", tags=["controlnet"], category="controlnet", version="1.0.0"
)
class ImageProcessorInvocation(BaseInvocation):
"""Base class for invocations that preprocess images for ControlNet"""
@ -589,3 +591,29 @@ class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
color_map = cv2.resize(color_map, (width, height), interpolation=cv2.INTER_NEAREST)
color_map = Image.fromarray(color_map)
return color_map
@invocation(
"dwpose_image_processor",
title="DWPose Processor",
tags=["controlnet", "dwpose", "pose"],
category="controlnet",
version="1.0.0",
)
class DWPoseImageProcessorInvocation(ImageProcessorInvocation):
"""Applies DW-Pose processing to image"""
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
def run_processor(self, image):
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# for now, executing DWPose processing on CPU only
device = "cpu"
dwpose_processor = DWposeDetector(device=device)
processed_image = dwpose_processor(
image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
)
return processed_image

View File

@ -37,7 +37,7 @@ dependencies = [
"click",
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
"compel~=2.0.2",
"controlnet-aux>=0.0.6",
"controlnet-aux>=0.0.7",
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
"datasets",
"diffusers[torch]~=0.21.0",
@ -52,6 +52,10 @@ dependencies = [
"invisible-watermark~=0.2.0", # needed to install SDXL base and refiner using their repo_ids
"matplotlib", # needed for plotting of Penner easing functions
"mediapipe", # needed for "mediapipeface" controlnet model
"mmcv>=2.0.1",
"mmdet>=3.1.0",
"mmengine",
"mmpose>=1.1.0",
"numpy",
"npyscreen",
"omegaconf",