refactor DWOpenPose and add type hints

This commit is contained in:
Lincoln Stein 2024-05-03 18:08:53 -04:00
parent 38df6f3702
commit e9a20051bd
4 changed files with 52 additions and 23 deletions

View File

@ -39,7 +39,7 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, heuristic_resize
from invokeai.backend.image_util.canny import get_canny_edges
from invokeai.backend.image_util.depth_anything import DEPTH_ANYTHING_MODELS, DepthAnythingDetector
from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector
from invokeai.backend.image_util.dw_openpose import DWPOSE_MODELS, DWOpenposeDetector
from invokeai.backend.image_util.hed import HEDProcessor
from invokeai.backend.image_util.lineart import LineartProcessor
from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor
@ -633,7 +633,11 @@ class DWOpenposeImageProcessorInvocation(ImageProcessorInvocation):
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
dw_openpose = DWOpenposeDetector(context)
mm = context.models
onnx_det = mm.download_and_cache_ckpt(DWPOSE_MODELS["yolox_l.onnx"])
onnx_pose = mm.download_and_cache_ckpt(DWPOSE_MODELS["dw-ll_ucoco_384.onnx"])
dw_openpose = DWOpenposeDetector(onnx_det=onnx_det, onnx_pose=onnx_pose)
processed_image = dw_openpose(
image,
draw_face=self.draw_face,

View File

@ -1,31 +1,53 @@
from pathlib import Path
from typing import Dict
import numpy as np
import torch
from controlnet_aux.util import resize_image
from PIL import Image
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.dw_openpose.utils import draw_bodypose, draw_facepose, draw_handpose
from invokeai.backend.image_util.dw_openpose.utils import NDArrayInt, draw_bodypose, draw_facepose, draw_handpose
from invokeai.backend.image_util.dw_openpose.wholebody import Wholebody
DWPOSE_MODELS = {
"yolox_l.onnx": "https://huggingface.co/yzd-v/DWPose/resolve/main/yolox_l.onnx?download=true",
"dw-ll_ucoco_384.onnx": "https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.onnx?download=true",
}
def draw_pose(pose, H, W, draw_face=True, draw_body=True, draw_hands=True, resolution=512):
def draw_pose(
pose: Dict[str, NDArrayInt | Dict[str, NDArrayInt]],
H: int,
W: int,
draw_face: bool = True,
draw_body: bool = True,
draw_hands: bool = True,
resolution: int = 512,
) -> Image.Image:
bodies = pose["bodies"]
faces = pose["faces"]
hands = pose["hands"]
assert isinstance(bodies, dict)
candidate = bodies["candidate"]
assert isinstance(bodies, dict)
subset = bodies["subset"]
canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8)
if draw_body:
canvas = draw_bodypose(canvas, candidate, subset)
if draw_hands:
assert isinstance(hands, np.ndarray)
canvas = draw_handpose(canvas, hands)
if draw_face:
canvas = draw_facepose(canvas, faces)
assert isinstance(hands, np.ndarray)
canvas = draw_facepose(canvas, faces) # type: ignore
dwpose_image = resize_image(
dwpose_image: Image.Image = resize_image(
canvas,
resolution,
)
@ -40,11 +62,16 @@ class DWOpenposeDetector:
Credits: https://github.com/IDEA-Research/DWPose
"""
def __init__(self, context: InvocationContext) -> None:
self.pose_estimation = Wholebody(context)
def __init__(self, onnx_det: Path, onnx_pose: Path) -> None:
self.pose_estimation = Wholebody(onnx_det=onnx_det, onnx_pose=onnx_pose)
def __call__(
self, image: Image.Image, draw_face=False, draw_body=True, draw_hands=False, resolution=512
self,
image: Image.Image,
draw_face: bool = False,
draw_body: bool = True,
draw_hands: bool = False,
resolution: int = 512,
) -> Image.Image:
np_image = np.array(image)
H, W, C = np_image.shape
@ -80,3 +107,6 @@ class DWOpenposeDetector:
return draw_pose(
pose, H, W, draw_face=draw_face, draw_hands=draw_hands, draw_body=draw_body, resolution=resolution
)
__all__ = ["DWPOSE_MODELS", "DWOpenposeDetector"]

View File

@ -5,11 +5,13 @@ import math
import cv2
import matplotlib
import numpy as np
import numpy.typing as npt
eps = 0.01
NDArrayInt = npt.NDArray[np.uint8]
def draw_bodypose(canvas, candidate, subset):
def draw_bodypose(canvas: NDArrayInt, candidate: NDArrayInt, subset: NDArrayInt) -> NDArrayInt:
H, W, C = canvas.shape
candidate = np.array(candidate)
subset = np.array(subset)
@ -88,7 +90,7 @@ def draw_bodypose(canvas, candidate, subset):
return canvas
def draw_handpose(canvas, all_hand_peaks):
def draw_handpose(canvas: NDArrayInt, all_hand_peaks: NDArrayInt) -> NDArrayInt:
H, W, C = canvas.shape
edges = [
@ -142,7 +144,7 @@ def draw_handpose(canvas, all_hand_peaks):
return canvas
def draw_facepose(canvas, all_lmks):
def draw_facepose(canvas: NDArrayInt, all_lmks: NDArrayInt) -> NDArrayInt:
H, W, C = canvas.shape
for lmks in all_lmks:
lmks = np.array(lmks)

View File

@ -2,33 +2,26 @@
# Modified pathing to suit Invoke
from pathlib import Path
import numpy as np
import onnxruntime as ort
from invokeai.app.services.config.config_default import get_config
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.util.devices import TorchDevice
from .onnxdet import inference_detector
from .onnxpose import inference_pose
DWPOSE_MODELS = {
"yolox_l.onnx": "https://huggingface.co/yzd-v/DWPose/resolve/main/yolox_l.onnx?download=true",
"dw-ll_ucoco_384.onnx": "https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.onnx?download=true",
}
config = get_config()
class Wholebody:
def __init__(self, context: InvocationContext):
def __init__(self, onnx_det: Path, onnx_pose: Path):
device = TorchDevice.choose_torch_device()
providers = ["CUDAExecutionProvider"] if device.type == "cuda" else ["CPUExecutionProvider"]
onnx_det = context.models.download_and_cache_ckpt(DWPOSE_MODELS["yolox_l.onnx"])
onnx_pose = context.models.download_and_cache_ckpt(DWPOSE_MODELS["dw-ll_ucoco_384.onnx"])
self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers)
self.session_pose = ort.InferenceSession(path_or_bytes=onnx_pose, providers=providers)