port dw_openpose, depth_anything, and lama processors to new model download scheme

This commit is contained in:
Lincoln Stein
2024-04-12 21:05:23 -04:00
parent 3a26c7bb9e
commit 41b909cbe3
7 changed files with 72 additions and 105 deletions

View File

@ -1,5 +1,4 @@
import pathlib
from typing import Literal, Union
from typing import Literal, Optional, Union
import cv2
import numpy as np
@ -10,7 +9,7 @@ from PIL import Image
from torchvision.transforms import Compose
from invokeai.app.services.config.config_default import get_config
from invokeai.app.util.download_with_progress import download_with_progress_bar
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize
from invokeai.backend.util.devices import choose_torch_device
@ -20,18 +19,9 @@ config = get_config()
logger = InvokeAILogger.get_logger(config=config)
DEPTH_ANYTHING_MODELS = {
"large": {
"url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true",
"local": "any/annotators/depth_anything/depth_anything_vitl14.pth",
},
"base": {
"url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitb14.pth?download=true",
"local": "any/annotators/depth_anything/depth_anything_vitb14.pth",
},
"small": {
"url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vits14.pth?download=true",
"local": "any/annotators/depth_anything/depth_anything_vits14.pth",
},
"large": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true",
"base": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitb14.pth?download=true",
"small": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vits14.pth?download=true",
}
@ -53,18 +43,14 @@ transform = Compose(
class DepthAnythingDetector:
def __init__(self) -> None:
self.model = None
def __init__(self, context: InvocationContext) -> None:
self.context = context
self.model: Optional[DPT_DINOv2] = None
self.model_size: Union[Literal["large", "base", "small"], None] = None
self.device = choose_torch_device()
def load_model(self, model_size: Literal["large", "base", "small"] = "small"):
DEPTH_ANYTHING_MODEL_PATH = config.models_path / DEPTH_ANYTHING_MODELS[model_size]["local"]
download_with_progress_bar(
pathlib.Path(DEPTH_ANYTHING_MODELS[model_size]["url"]).name,
DEPTH_ANYTHING_MODELS[model_size]["url"],
DEPTH_ANYTHING_MODEL_PATH,
)
def load_model(self, model_size: Literal["large", "base", "small"] = "small") -> DPT_DINOv2:
depth_anything_model_path = self.context.models.download_and_cache_ckpt(DEPTH_ANYTHING_MODELS[model_size])
if not self.model or model_size != self.model_size:
del self.model
@ -78,7 +64,8 @@ class DepthAnythingDetector:
case "large":
self.model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024])
self.model.load_state_dict(torch.load(DEPTH_ANYTHING_MODEL_PATH.as_posix(), map_location="cpu"))
assert self.model is not None
self.model.load_state_dict(torch.load(depth_anything_model_path.as_posix(), map_location="cpu"))
self.model.eval()
self.model.to(choose_torch_device())

View File

@ -3,6 +3,7 @@ import torch
from controlnet_aux.util import resize_image
from PIL import Image
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.dw_openpose.utils import draw_bodypose, draw_facepose, draw_handpose
from invokeai.backend.image_util.dw_openpose.wholebody import Wholebody
@ -39,8 +40,8 @@ class DWOpenposeDetector:
Credits: https://github.com/IDEA-Research/DWPose
"""
def __init__(self) -> None:
self.pose_estimation = Wholebody()
def __init__(self, context: InvocationContext) -> None:
self.pose_estimation = Wholebody(context)
def __call__(
self, image: Image.Image, draw_face=False, draw_body=True, draw_hands=False, resolution=512

View File

@ -4,44 +4,31 @@
import numpy as np
import onnxruntime as ort
import torch
from invokeai.app.services.config.config_default import get_config
from invokeai.app.util.download_with_progress import download_with_progress_bar
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.util.devices import choose_torch_device
from .onnxdet import inference_detector
from .onnxpose import inference_pose
DWPOSE_MODELS = {
"yolox_l.onnx": {
"local": "any/annotators/dwpose/yolox_l.onnx",
"url": "https://huggingface.co/yzd-v/DWPose/resolve/main/yolox_l.onnx?download=true",
},
"dw-ll_ucoco_384.onnx": {
"local": "any/annotators/dwpose/dw-ll_ucoco_384.onnx",
"url": "https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.onnx?download=true",
},
"yolox_l.onnx": "https://huggingface.co/yzd-v/DWPose/resolve/main/yolox_l.onnx?download=true",
"dw-ll_ucoco_384.onnx": "https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.onnx?download=true",
}
config = get_config()
class Wholebody:
def __init__(self):
def __init__(self, context: InvocationContext):
device = choose_torch_device()
providers = ["CUDAExecutionProvider"] if device == "cuda" else ["CPUExecutionProvider"]
providers = ["CUDAExecutionProvider"] if device == torch.device("cuda") else ["CPUExecutionProvider"]
DET_MODEL_PATH = config.models_path / DWPOSE_MODELS["yolox_l.onnx"]["local"]
download_with_progress_bar("yolox_l.onnx", DWPOSE_MODELS["yolox_l.onnx"]["url"], DET_MODEL_PATH)
POSE_MODEL_PATH = config.models_path / DWPOSE_MODELS["dw-ll_ucoco_384.onnx"]["local"]
download_with_progress_bar(
"dw-ll_ucoco_384.onnx", DWPOSE_MODELS["dw-ll_ucoco_384.onnx"]["url"], POSE_MODEL_PATH
)
onnx_det = DET_MODEL_PATH
onnx_pose = POSE_MODEL_PATH
onnx_det = context.models.download_and_cache_ckpt(DWPOSE_MODELS["yolox_l.onnx"])
onnx_pose = context.models.download_and_cache_ckpt(DWPOSE_MODELS["dw-ll_ucoco_384.onnx"])
self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers)
self.session_pose = ort.InferenceSession(path_or_bytes=onnx_pose, providers=providers)

View File

@ -1,4 +1,3 @@
import gc
from typing import Any
import numpy as np
@ -6,9 +5,7 @@ import torch
from PIL import Image
import invokeai.backend.util.logging as logger
from invokeai.app.services.config.config_default import get_config
from invokeai.app.util.download_with_progress import download_with_progress_bar
from invokeai.backend.util.devices import choose_torch_device
from invokeai.app.services.shared.invocation_context import InvocationContext
def norm_img(np_img):
@ -28,18 +25,14 @@ def load_jit_model(url_or_path, device):
class LaMA:
def __init__(self, context: InvocationContext):
self._context = context
def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any:
device = choose_torch_device()
model_location = get_config().models_path / "core/misc/lama/lama.pt"
if not model_location.exists():
download_with_progress_bar(
name="LaMa Inpainting Model",
url="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
dest_path=model_location,
)
model = load_jit_model(model_location, device)
loaded_model = self._context.models.load_ckpt_from_url(
source="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
loader=lambda path: load_jit_model(path, "cpu"),
)
image = np.asarray(input_image.convert("RGB"))
image = norm_img(image)
@ -48,20 +41,18 @@ class LaMA:
mask = np.asarray(mask)
mask = np.invert(mask)
mask = norm_img(mask)
mask = (mask > 0) * 1
image = torch.from_numpy(image).unsqueeze(0).to(device)
mask = torch.from_numpy(mask).unsqueeze(0).to(device)
with torch.inference_mode():
infilled_image = model(image, mask)
with loaded_model as model:
device = next(model.buffers()).device
image = torch.from_numpy(image).unsqueeze(0).to(device)
mask = torch.from_numpy(mask).unsqueeze(0).to(device)
with torch.inference_mode():
infilled_image = model(image, mask)
infilled_image = infilled_image[0].permute(1, 2, 0).detach().cpu().numpy()
infilled_image = np.clip(infilled_image * 255, 0, 255).astype("uint8")
infilled_image = Image.fromarray(infilled_image)
del model
gc.collect()
torch.cuda.empty_cache()
return infilled_image