port dw_openpose, depth_anything, and lama processors to new model download scheme

This commit is contained in:
Lincoln Stein 2024-04-12 21:05:23 -04:00 committed by Lincoln Stein
parent c140d3b1df
commit 3ead827d61
7 changed files with 72 additions and 105 deletions

View File

@ -137,7 +137,7 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
image: ImageField = InputField(description="The image to process")
def run_processor(self, image: Image.Image) -> Image.Image:
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
# superclass just passes through image without processing
return image
@ -148,7 +148,7 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
def invoke(self, context: InvocationContext) -> ImageOutput:
raw_image = self.load_image(context)
# image type should be PIL.PngImagePlugin.PngImageFile ?
processed_image = self.run_processor(raw_image)
processed_image = self.run_processor(raw_image, context)
# currently can't see processed image in node UI without a showImage node,
# so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery
@ -189,7 +189,7 @@ class CannyImageProcessorInvocation(ImageProcessorInvocation):
# Keep alpha channel for Canny processing to detect edges of transparent areas
return context.images.get_pil(self.image.image_name, "RGBA")
def run_processor(self, image: Image.Image) -> Image.Image:
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
processed_image = get_canny_edges(
image,
self.low_threshold,
@ -216,7 +216,7 @@ class HedImageProcessorInvocation(ImageProcessorInvocation):
# safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
def run_processor(self, image: Image.Image) -> Image.Image:
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
hed_processor = HEDProcessor()
processed_image = hed_processor.run(
image,
@ -243,7 +243,7 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation):
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
coarse: bool = InputField(default=False, description="Whether to use coarse mode")
def run_processor(self, image: Image.Image) -> Image.Image:
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
lineart_processor = LineartProcessor()
processed_image = lineart_processor.run(
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution, coarse=self.coarse
@ -264,7 +264,7 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
def run_processor(self, image: Image.Image) -> Image.Image:
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
processor = LineartAnimeProcessor()
processed_image = processor.run(
image,
@ -291,7 +291,8 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
# depth_and_normal not supported in controlnet_aux v0.0.3
# depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode")
def run_processor(self, image):
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
# TODO: replace from_pretrained() calls with context.models.download_and_cache() (or similar)
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
processed_image = midas_processor(
image,
@ -318,9 +319,9 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
def run_processor(self, image):
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
processed_image = normalbae_processor(
processed_image: Image.Image = normalbae_processor(
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution
)
return processed_image
@ -337,7 +338,7 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation):
thr_v: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_v`")
thr_d: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_d`")
def run_processor(self, image):
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
processed_image = mlsd_processor(
image,
@ -360,7 +361,7 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation):
safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
def run_processor(self, image):
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
processed_image = pidi_processor(
image,
@ -388,7 +389,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
w: int = InputField(default=512, ge=0, description="Content shuffle `w` parameter")
f: int = InputField(default=256, ge=0, description="Content shuffle `f` parameter")
def run_processor(self, image):
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
content_shuffle_processor = ContentShuffleDetector()
processed_image = content_shuffle_processor(
image,
@ -412,7 +413,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
"""Applies Zoe depth processing to image"""
def run_processor(self, image):
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
processed_image = zoe_depth_processor(image)
return processed_image
@ -433,7 +434,7 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
def run_processor(self, image):
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
mediapipe_face_processor = MediapipeFaceDetector()
processed_image = mediapipe_face_processor(
image,
@ -461,7 +462,7 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation):
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
def run_processor(self, image):
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators")
processed_image = leres_processor(
image,
@ -503,7 +504,7 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation):
np_img = cv2.resize(np_img, (W, H), interpolation=cv2.INTER_AREA)
return np_img
def run_processor(self, img):
def run_processor(self, img: Image.Image, context: InvocationContext) -> Image.Image:
np_img = np.array(img, dtype=np.uint8)
processed_np_image = self.tile_resample(
np_img,
@ -527,7 +528,7 @@ class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
def run_processor(self, image):
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
segment_anything_processor = SamDetectorReproducibleColors.from_pretrained(
"ybelkada/segment-anything", subfolder="checkpoints"
@ -573,7 +574,7 @@ class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
color_map_tile_size: int = InputField(default=64, ge=0, description=FieldDescriptions.tile_size)
def run_processor(self, image: Image.Image):
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
np_image = np.array(image, dtype=np.uint8)
height, width = np_image.shape[:2]
@ -608,8 +609,8 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
)
resolution: int = InputField(default=512, ge=64, multiple_of=64, description=FieldDescriptions.image_res)
def run_processor(self, image: Image.Image):
depth_anything_detector = DepthAnythingDetector()
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
depth_anything_detector = DepthAnythingDetector(context)
depth_anything_detector.load_model(model_size=self.model_size)
processed_image = depth_anything_detector(image=image, resolution=self.resolution)
@ -631,8 +632,8 @@ class DWOpenposeImageProcessorInvocation(ImageProcessorInvocation):
draw_hands: bool = InputField(default=False)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
def run_processor(self, image: Image.Image):
dw_openpose = DWOpenposeDetector()
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
dw_openpose = DWOpenposeDetector(context)
processed_image = dw_openpose(
image,
draw_face=self.draw_face,

View File

@ -38,7 +38,7 @@ class InfillImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
image: ImageField = InputField(description="The image to process")
@abstractmethod
def infill(self, image: Image.Image) -> Image.Image:
def infill(self, image: Image.Image, context: InvocationContext) -> Image.Image:
"""Infill the image with the specified method"""
pass
@ -57,7 +57,7 @@ class InfillImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
return ImageOutput.build(context.images.get_dto(self.image.image_name))
# Perform Infill action
infilled_image = self.infill(input_image)
infilled_image = self.infill(input_image, context)
# Create ImageDTO for Infilled Image
infilled_image_dto = context.images.save(image=infilled_image)
@ -75,7 +75,7 @@ class InfillColorInvocation(InfillImageProcessorInvocation):
description="The color to use to infill",
)
def infill(self, image: Image.Image):
def infill(self, image: Image.Image, context: InvocationContext):
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
infilled = Image.alpha_composite(solid_bg, image.convert("RGBA"))
infilled.paste(image, (0, 0), image.split()[-1])
@ -94,7 +94,7 @@ class InfillTileInvocation(InfillImageProcessorInvocation):
description="The seed to use for tile generation (omit for random)",
)
def infill(self, image: Image.Image):
def infill(self, image: Image.Image, context: InvocationContext):
output = infill_tile(image, seed=self.seed, tile_size=self.tile_size)
return output.infilled
@ -108,7 +108,7 @@ class InfillPatchMatchInvocation(InfillImageProcessorInvocation):
downscale: float = InputField(default=2.0, gt=0, description="Run patchmatch on downscaled image to speedup infill")
resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")
def infill(self, image: Image.Image):
def infill(self, image: Image.Image, context: InvocationContext):
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
width = int(image.width / self.downscale)
@ -132,8 +132,8 @@ class InfillPatchMatchInvocation(InfillImageProcessorInvocation):
class LaMaInfillInvocation(InfillImageProcessorInvocation):
"""Infills transparent areas of an image using the LaMa model"""
def infill(self, image: Image.Image):
lama = LaMA()
def infill(self, image: Image.Image, context: InvocationContext):
lama = LaMA(context)
return lama(image)
@ -141,7 +141,7 @@ class LaMaInfillInvocation(InfillImageProcessorInvocation):
class CV2InfillInvocation(InfillImageProcessorInvocation):
"""Infills transparent areas of an image using OpenCV Inpainting"""
def infill(self, image: Image.Image):
def infill(self, image: Image.Image, context: InvocationContext):
return cv2_inpaint(image)
@ -163,5 +163,5 @@ class MosaicInfillInvocation(InfillImageProcessorInvocation):
description="The max threshold for color",
)
def infill(self, image: Image.Image):
def infill(self, image: Image.Image, context: InvocationContext):
return infill_mosaic(image, (self.tile_width, self.tile_height), self.min_color.tuple(), self.max_color.tuple())

View File

@ -534,10 +534,10 @@ class ModelsInterface(InvocationContextInterface):
loader: Optional[Callable[[Path], Dict[str | int, Any]]] = None,
) -> LoadedModel:
"""
Load and cache the model file located at the indicated URL.
Download, cache, and Load the model file located at the indicated URL.
This will check the model download cache for the model designated
by the provided URL and download it if needed using download_and_cache_model().
by the provided URL and download it if needed using download_and_cache_ckpt().
It will then load the model into the RAM cache. If the optional loader
argument is provided, the loader will be invoked to load the model into
memory. Otherwise the method will call safetensors.torch.load_file() or

View File

@ -1,5 +1,4 @@
import pathlib
from typing import Literal, Union
from typing import Literal, Optional, Union
import cv2
import numpy as np
@ -10,7 +9,7 @@ from PIL import Image
from torchvision.transforms import Compose
from invokeai.app.services.config.config_default import get_config
from invokeai.app.util.download_with_progress import download_with_progress_bar
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize
from invokeai.backend.util.devices import choose_torch_device
@ -20,18 +19,9 @@ config = get_config()
logger = InvokeAILogger.get_logger(config=config)
DEPTH_ANYTHING_MODELS = {
"large": {
"url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true",
"local": "any/annotators/depth_anything/depth_anything_vitl14.pth",
},
"base": {
"url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitb14.pth?download=true",
"local": "any/annotators/depth_anything/depth_anything_vitb14.pth",
},
"small": {
"url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vits14.pth?download=true",
"local": "any/annotators/depth_anything/depth_anything_vits14.pth",
},
"large": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true",
"base": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitb14.pth?download=true",
"small": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vits14.pth?download=true",
}
@ -53,18 +43,14 @@ transform = Compose(
class DepthAnythingDetector:
def __init__(self) -> None:
self.model = None
def __init__(self, context: InvocationContext) -> None:
self.context = context
self.model: Optional[DPT_DINOv2] = None
self.model_size: Union[Literal["large", "base", "small"], None] = None
self.device = choose_torch_device()
def load_model(self, model_size: Literal["large", "base", "small"] = "small"):
DEPTH_ANYTHING_MODEL_PATH = config.models_path / DEPTH_ANYTHING_MODELS[model_size]["local"]
download_with_progress_bar(
pathlib.Path(DEPTH_ANYTHING_MODELS[model_size]["url"]).name,
DEPTH_ANYTHING_MODELS[model_size]["url"],
DEPTH_ANYTHING_MODEL_PATH,
)
def load_model(self, model_size: Literal["large", "base", "small"] = "small") -> DPT_DINOv2:
depth_anything_model_path = self.context.models.download_and_cache_ckpt(DEPTH_ANYTHING_MODELS[model_size])
if not self.model or model_size != self.model_size:
del self.model
@ -78,7 +64,8 @@ class DepthAnythingDetector:
case "large":
self.model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024])
self.model.load_state_dict(torch.load(DEPTH_ANYTHING_MODEL_PATH.as_posix(), map_location="cpu"))
assert self.model is not None
self.model.load_state_dict(torch.load(depth_anything_model_path.as_posix(), map_location="cpu"))
self.model.eval()
self.model.to(choose_torch_device())

View File

@ -3,6 +3,7 @@ import torch
from controlnet_aux.util import resize_image
from PIL import Image
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.dw_openpose.utils import draw_bodypose, draw_facepose, draw_handpose
from invokeai.backend.image_util.dw_openpose.wholebody import Wholebody
@ -39,8 +40,8 @@ class DWOpenposeDetector:
Credits: https://github.com/IDEA-Research/DWPose
"""
def __init__(self) -> None:
self.pose_estimation = Wholebody()
def __init__(self, context: InvocationContext) -> None:
self.pose_estimation = Wholebody(context)
def __call__(
self, image: Image.Image, draw_face=False, draw_body=True, draw_hands=False, resolution=512

View File

@ -4,44 +4,31 @@
import numpy as np
import onnxruntime as ort
import torch
from invokeai.app.services.config.config_default import get_config
from invokeai.app.util.download_with_progress import download_with_progress_bar
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.util.devices import choose_torch_device
from .onnxdet import inference_detector
from .onnxpose import inference_pose
DWPOSE_MODELS = {
"yolox_l.onnx": {
"local": "any/annotators/dwpose/yolox_l.onnx",
"url": "https://huggingface.co/yzd-v/DWPose/resolve/main/yolox_l.onnx?download=true",
},
"dw-ll_ucoco_384.onnx": {
"local": "any/annotators/dwpose/dw-ll_ucoco_384.onnx",
"url": "https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.onnx?download=true",
},
"yolox_l.onnx": "https://huggingface.co/yzd-v/DWPose/resolve/main/yolox_l.onnx?download=true",
"dw-ll_ucoco_384.onnx": "https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.onnx?download=true",
}
config = get_config()
class Wholebody:
def __init__(self):
def __init__(self, context: InvocationContext):
device = choose_torch_device()
providers = ["CUDAExecutionProvider"] if device == "cuda" else ["CPUExecutionProvider"]
providers = ["CUDAExecutionProvider"] if device == torch.device("cuda") else ["CPUExecutionProvider"]
DET_MODEL_PATH = config.models_path / DWPOSE_MODELS["yolox_l.onnx"]["local"]
download_with_progress_bar("yolox_l.onnx", DWPOSE_MODELS["yolox_l.onnx"]["url"], DET_MODEL_PATH)
POSE_MODEL_PATH = config.models_path / DWPOSE_MODELS["dw-ll_ucoco_384.onnx"]["local"]
download_with_progress_bar(
"dw-ll_ucoco_384.onnx", DWPOSE_MODELS["dw-ll_ucoco_384.onnx"]["url"], POSE_MODEL_PATH
)
onnx_det = DET_MODEL_PATH
onnx_pose = POSE_MODEL_PATH
onnx_det = context.models.download_and_cache_ckpt(DWPOSE_MODELS["yolox_l.onnx"])
onnx_pose = context.models.download_and_cache_ckpt(DWPOSE_MODELS["dw-ll_ucoco_384.onnx"])
self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers)
self.session_pose = ort.InferenceSession(path_or_bytes=onnx_pose, providers=providers)

View File

@ -1,4 +1,3 @@
import gc
from typing import Any
import numpy as np
@ -6,9 +5,7 @@ import torch
from PIL import Image
import invokeai.backend.util.logging as logger
from invokeai.app.services.config.config_default import get_config
from invokeai.app.util.download_with_progress import download_with_progress_bar
from invokeai.backend.util.devices import choose_torch_device
from invokeai.app.services.shared.invocation_context import InvocationContext
def norm_img(np_img):
@ -28,18 +25,14 @@ def load_jit_model(url_or_path, device):
class LaMA:
def __init__(self, context: InvocationContext):
self._context = context
def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any:
device = choose_torch_device()
model_location = get_config().models_path / "core/misc/lama/lama.pt"
if not model_location.exists():
download_with_progress_bar(
name="LaMa Inpainting Model",
url="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
dest_path=model_location,
)
model = load_jit_model(model_location, device)
loaded_model = self._context.models.load_ckpt_from_url(
source="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
loader=lambda path: load_jit_model(path, "cpu"),
)
image = np.asarray(input_image.convert("RGB"))
image = norm_img(image)
@ -48,20 +41,18 @@ class LaMA:
mask = np.asarray(mask)
mask = np.invert(mask)
mask = norm_img(mask)
mask = (mask > 0) * 1
image = torch.from_numpy(image).unsqueeze(0).to(device)
mask = torch.from_numpy(mask).unsqueeze(0).to(device)
with torch.inference_mode():
infilled_image = model(image, mask)
with loaded_model as model:
device = next(model.buffers()).device
image = torch.from_numpy(image).unsqueeze(0).to(device)
mask = torch.from_numpy(mask).unsqueeze(0).to(device)
with torch.inference_mode():
infilled_image = model(image, mask)
infilled_image = infilled_image[0].permute(1, 2, 0).detach().cpu().numpy()
infilled_image = np.clip(infilled_image * 255, 0, 255).astype("uint8")
infilled_image = Image.fromarray(infilled_image)
del model
gc.collect()
torch.cuda.empty_cache()
return infilled_image