mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactor DWOpenPose and add type hints
This commit is contained in:
parent
38df6f3702
commit
e9a20051bd
@ -39,7 +39,7 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, heuristic_resize
|
||||
from invokeai.backend.image_util.canny import get_canny_edges
|
||||
from invokeai.backend.image_util.depth_anything import DEPTH_ANYTHING_MODELS, DepthAnythingDetector
|
||||
from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector
|
||||
from invokeai.backend.image_util.dw_openpose import DWPOSE_MODELS, DWOpenposeDetector
|
||||
from invokeai.backend.image_util.hed import HEDProcessor
|
||||
from invokeai.backend.image_util.lineart import LineartProcessor
|
||||
from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor
|
||||
@ -633,7 +633,11 @@ class DWOpenposeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
|
||||
dw_openpose = DWOpenposeDetector(context)
|
||||
mm = context.models
|
||||
onnx_det = mm.download_and_cache_ckpt(DWPOSE_MODELS["yolox_l.onnx"])
|
||||
onnx_pose = mm.download_and_cache_ckpt(DWPOSE_MODELS["dw-ll_ucoco_384.onnx"])
|
||||
|
||||
dw_openpose = DWOpenposeDetector(onnx_det=onnx_det, onnx_pose=onnx_pose)
|
||||
processed_image = dw_openpose(
|
||||
image,
|
||||
draw_face=self.draw_face,
|
||||
|
@ -1,31 +1,53 @@
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from controlnet_aux.util import resize_image
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.image_util.dw_openpose.utils import draw_bodypose, draw_facepose, draw_handpose
|
||||
from invokeai.backend.image_util.dw_openpose.utils import NDArrayInt, draw_bodypose, draw_facepose, draw_handpose
|
||||
from invokeai.backend.image_util.dw_openpose.wholebody import Wholebody
|
||||
|
||||
DWPOSE_MODELS = {
|
||||
"yolox_l.onnx": "https://huggingface.co/yzd-v/DWPose/resolve/main/yolox_l.onnx?download=true",
|
||||
"dw-ll_ucoco_384.onnx": "https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.onnx?download=true",
|
||||
}
|
||||
|
||||
def draw_pose(pose, H, W, draw_face=True, draw_body=True, draw_hands=True, resolution=512):
|
||||
|
||||
def draw_pose(
|
||||
pose: Dict[str, NDArrayInt | Dict[str, NDArrayInt]],
|
||||
H: int,
|
||||
W: int,
|
||||
draw_face: bool = True,
|
||||
draw_body: bool = True,
|
||||
draw_hands: bool = True,
|
||||
resolution: int = 512,
|
||||
) -> Image.Image:
|
||||
bodies = pose["bodies"]
|
||||
faces = pose["faces"]
|
||||
hands = pose["hands"]
|
||||
|
||||
assert isinstance(bodies, dict)
|
||||
candidate = bodies["candidate"]
|
||||
|
||||
assert isinstance(bodies, dict)
|
||||
subset = bodies["subset"]
|
||||
|
||||
canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8)
|
||||
|
||||
if draw_body:
|
||||
canvas = draw_bodypose(canvas, candidate, subset)
|
||||
|
||||
if draw_hands:
|
||||
assert isinstance(hands, np.ndarray)
|
||||
canvas = draw_handpose(canvas, hands)
|
||||
|
||||
if draw_face:
|
||||
canvas = draw_facepose(canvas, faces)
|
||||
assert isinstance(hands, np.ndarray)
|
||||
canvas = draw_facepose(canvas, faces) # type: ignore
|
||||
|
||||
dwpose_image = resize_image(
|
||||
dwpose_image: Image.Image = resize_image(
|
||||
canvas,
|
||||
resolution,
|
||||
)
|
||||
@ -40,11 +62,16 @@ class DWOpenposeDetector:
|
||||
Credits: https://github.com/IDEA-Research/DWPose
|
||||
"""
|
||||
|
||||
def __init__(self, context: InvocationContext) -> None:
|
||||
self.pose_estimation = Wholebody(context)
|
||||
def __init__(self, onnx_det: Path, onnx_pose: Path) -> None:
|
||||
self.pose_estimation = Wholebody(onnx_det=onnx_det, onnx_pose=onnx_pose)
|
||||
|
||||
def __call__(
|
||||
self, image: Image.Image, draw_face=False, draw_body=True, draw_hands=False, resolution=512
|
||||
self,
|
||||
image: Image.Image,
|
||||
draw_face: bool = False,
|
||||
draw_body: bool = True,
|
||||
draw_hands: bool = False,
|
||||
resolution: int = 512,
|
||||
) -> Image.Image:
|
||||
np_image = np.array(image)
|
||||
H, W, C = np_image.shape
|
||||
@ -80,3 +107,6 @@ class DWOpenposeDetector:
|
||||
return draw_pose(
|
||||
pose, H, W, draw_face=draw_face, draw_hands=draw_hands, draw_body=draw_body, resolution=resolution
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["DWPOSE_MODELS", "DWOpenposeDetector"]
|
||||
|
@ -5,11 +5,13 @@ import math
|
||||
import cv2
|
||||
import matplotlib
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
eps = 0.01
|
||||
NDArrayInt = npt.NDArray[np.uint8]
|
||||
|
||||
|
||||
def draw_bodypose(canvas, candidate, subset):
|
||||
def draw_bodypose(canvas: NDArrayInt, candidate: NDArrayInt, subset: NDArrayInt) -> NDArrayInt:
|
||||
H, W, C = canvas.shape
|
||||
candidate = np.array(candidate)
|
||||
subset = np.array(subset)
|
||||
@ -88,7 +90,7 @@ def draw_bodypose(canvas, candidate, subset):
|
||||
return canvas
|
||||
|
||||
|
||||
def draw_handpose(canvas, all_hand_peaks):
|
||||
def draw_handpose(canvas: NDArrayInt, all_hand_peaks: NDArrayInt) -> NDArrayInt:
|
||||
H, W, C = canvas.shape
|
||||
|
||||
edges = [
|
||||
@ -142,7 +144,7 @@ def draw_handpose(canvas, all_hand_peaks):
|
||||
return canvas
|
||||
|
||||
|
||||
def draw_facepose(canvas, all_lmks):
|
||||
def draw_facepose(canvas: NDArrayInt, all_lmks: NDArrayInt) -> NDArrayInt:
|
||||
H, W, C = canvas.shape
|
||||
for lmks in all_lmks:
|
||||
lmks = np.array(lmks)
|
||||
|
@ -2,33 +2,26 @@
|
||||
# Modified pathing to suit Invoke
|
||||
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
from .onnxdet import inference_detector
|
||||
from .onnxpose import inference_pose
|
||||
|
||||
DWPOSE_MODELS = {
|
||||
"yolox_l.onnx": "https://huggingface.co/yzd-v/DWPose/resolve/main/yolox_l.onnx?download=true",
|
||||
"dw-ll_ucoco_384.onnx": "https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.onnx?download=true",
|
||||
}
|
||||
|
||||
config = get_config()
|
||||
|
||||
|
||||
class Wholebody:
|
||||
def __init__(self, context: InvocationContext):
|
||||
def __init__(self, onnx_det: Path, onnx_pose: Path):
|
||||
device = TorchDevice.choose_torch_device()
|
||||
|
||||
providers = ["CUDAExecutionProvider"] if device.type == "cuda" else ["CPUExecutionProvider"]
|
||||
|
||||
onnx_det = context.models.download_and_cache_ckpt(DWPOSE_MODELS["yolox_l.onnx"])
|
||||
onnx_pose = context.models.download_and_cache_ckpt(DWPOSE_MODELS["dw-ll_ucoco_384.onnx"])
|
||||
|
||||
self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers)
|
||||
self.session_pose = ort.InferenceSession(path_or_bytes=onnx_pose, providers=providers)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user