feat: Add Depth Anything PreProcessor (#5548)

## What type of PR is this? (check all applicable)

- [x] Feature

## Have you discussed this change with the InvokeAI team?
- [x] Yes

## Have you updated all relevant documentation?
- [x] No


## Description

- This adds the newly released Depth Anything to InvokeAI. A new node
`Depth Anything Processor` has been added to generate depth maps using
this new technique. https://depth-anything.github.io

- All related checkpoints will be downloaded automatically on first
boot. The `DinoV2` models will be loaded to your torch cache dir and the
checkpoints pertaining to Depth Anything will be downloaded to
`any/annotators/depth_anything`.

- Alternatively you can find the checkpoints here and download them to
that folder:
https://huggingface.co/spaces/LiheYoung/Depth-Anything/tree/main/checkpoints

- This depth map can be used with any depth ControlNet model out there
but the folks at DepthAnything have also released a custom fine tuned
ControlNet model. From my limited testing, I still prefer the original
depth model because this one seems to be producing weird artifacts. Not
sure if that is a specific problem to Invoke or just the model itself.
I'll test more later. Place these in your controlnet folder like your
other ControlNets. You can get that here:
https://huggingface.co/spaces/LiheYoung/Depth-Anything/tree/main/checkpoints_controlnet

- Also available in the LinearUI

- DepthAnything has three models `large`, `base` and `small` -- I've
defaulted the processor to small but a user can change to the large
model if they wish to do so. Small is way faster but obviously somewhat
of a lesser quality.

- DepthAnything is now the default processor for depth controlnet
models.

## Screenshots


![opera_o3jHnWxVRi](https://github.com/invoke-ai/InvokeAI/assets/54517381/573c66f3-1492-45b0-b6df-25756f5e1d1a)

## Merge Plan

DO NOT MERGE YET. Test it first and I'm sure the model caching can be
done better. Coz I don't think I've done that at all. I would appreciate
if @brandonrising or @lstein or anyone can take a look at that part of
it.
This commit is contained in:
blessedcoolant 2024-01-24 19:14:34 +05:30 committed by GitHub
commit 68da5c6d22
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 1111 additions and 26 deletions

View File

@ -30,6 +30,7 @@ from invokeai.app.invocations.primitives import ImageField, ImageOutput
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
from invokeai.app.shared.fields import FieldDescriptions
from invokeai.backend.image_util.depth_anything import DepthAnythingDetector
from ...backend.model_management import BaseModelType
from .baseinvocation import (
@ -602,3 +603,33 @@ class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
color_map = cv2.resize(color_map, (width, height), interpolation=cv2.INTER_NEAREST)
color_map = Image.fromarray(color_map)
return color_map
DEPTH_ANYTHING_MODEL_SIZES = Literal["large", "base", "small"]
@invocation(
"depth_anything_image_processor",
title="Depth Anything Processor",
tags=["controlnet", "depth", "depth anything"],
category="controlnet",
version="1.0.0",
)
class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
"""Generates a depth map based on the Depth Anything algorithm"""
model_size: DEPTH_ANYTHING_MODEL_SIZES = InputField(
default="small", description="The size of the depth model to use"
)
resolution: int = InputField(default=512, ge=64, multiple_of=64, description=FieldDescriptions.image_res)
offload: bool = InputField(default=False)
def run_processor(self, image):
depth_anything_detector = DepthAnythingDetector()
depth_anything_detector.load_model(model_size=self.model_size)
if image.mode == "RGBA":
image = image.convert("RGB")
processed_image = depth_anything_detector(image=image, resolution=self.resolution, offload=self.offload)
return processed_image

View File

@ -0,0 +1,109 @@
import pathlib
from typing import Literal, Union
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from einops import repeat
from PIL import Image
from torchvision.transforms import Compose
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.util import download_with_progress_bar
config = InvokeAIAppConfig.get_config()
DEPTH_ANYTHING_MODELS = {
"large": {
"url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true",
"local": "any/annotators/depth_anything/depth_anything_vitl14.pth",
},
"base": {
"url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitb14.pth?download=true",
"local": "any/annotators/depth_anything/depth_anything_vitb14.pth",
},
"small": {
"url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vits14.pth?download=true",
"local": "any/annotators/depth_anything/depth_anything_vits14.pth",
},
}
transform = Compose(
[
Resize(
width=518,
height=518,
resize_target=False,
keep_aspect_ratio=True,
ensure_multiple_of=14,
resize_method="lower_bound",
image_interpolation_method=cv2.INTER_CUBIC,
),
NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
PrepareForNet(),
]
)
class DepthAnythingDetector:
def __init__(self) -> None:
self.model = None
self.model_size: Union[Literal["large", "base", "small"], None] = None
def load_model(self, model_size=Literal["large", "base", "small"]):
DEPTH_ANYTHING_MODEL_PATH = pathlib.Path(config.models_path / DEPTH_ANYTHING_MODELS[model_size]["local"])
if not DEPTH_ANYTHING_MODEL_PATH.exists():
download_with_progress_bar(DEPTH_ANYTHING_MODELS[model_size]["url"], DEPTH_ANYTHING_MODEL_PATH)
if not self.model or model_size != self.model_size:
del self.model
self.model_size = model_size
match self.model_size:
case "small":
self.model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384])
case "base":
self.model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768])
case "large":
self.model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024])
case _:
raise TypeError("Not a supported model")
self.model.load_state_dict(torch.load(DEPTH_ANYTHING_MODEL_PATH.as_posix(), map_location="cpu"))
self.model.eval()
self.model.to(choose_torch_device())
return self.model
def to(self, device):
self.model.to(device)
return self
def __call__(self, image, resolution=512, offload=False):
image = np.array(image, dtype=np.uint8)
image = image[:, :, ::-1] / 255.0
image_height, image_width = image.shape[:2]
image = transform({"image": image})["image"]
image = torch.from_numpy(image).unsqueeze(0).to(choose_torch_device())
with torch.no_grad():
depth = self.model(image)
depth = F.interpolate(depth[None], (image_height, image_width), mode="bilinear", align_corners=False)[0, 0]
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
depth_map = repeat(depth, "h w -> h w 3").cpu().numpy().astype(np.uint8)
depth_map = Image.fromarray(depth_map)
new_height = int(image_height * (resolution / image_width))
depth_map = depth_map.resize((resolution, new_height))
if offload:
del self.model
return depth_map

View File

@ -0,0 +1,145 @@
import torch.nn as nn
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
scratch = nn.Module()
out_shape1 = out_shape
out_shape2 = out_shape
out_shape3 = out_shape
if len(in_shape) >= 4:
out_shape4 = out_shape
if expand:
out_shape1 = out_shape
out_shape2 = out_shape * 2
out_shape3 = out_shape * 4
if len(in_shape) >= 4:
out_shape4 = out_shape * 8
scratch.layer1_rn = nn.Conv2d(
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
scratch.layer2_rn = nn.Conv2d(
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
scratch.layer3_rn = nn.Conv2d(
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
if len(in_shape) >= 4:
scratch.layer4_rn = nn.Conv2d(
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
return scratch
class ResidualConvUnit(nn.Module):
"""Residual convolution module."""
def __init__(self, features, activation, bn):
"""Init.
Args:
features (int): number of features
"""
super().__init__()
self.bn = bn
self.groups = 1
self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
if self.bn:
self.bn1 = nn.BatchNorm2d(features)
self.bn2 = nn.BatchNorm2d(features)
self.activation = activation
self.skip_add = nn.quantized.FloatFunctional()
def forward(self, x):
"""Forward pass.
Args:
x (tensor): input
Returns:
tensor: output
"""
out = self.activation(x)
out = self.conv1(out)
if self.bn:
out = self.bn1(out)
out = self.activation(out)
out = self.conv2(out)
if self.bn:
out = self.bn2(out)
if self.groups > 1:
out = self.conv_merge(out)
return self.skip_add.add(out, x)
class FeatureFusionBlock(nn.Module):
"""Feature fusion block."""
def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, size=None):
"""Init.
Args:
features (int): number of features
"""
super(FeatureFusionBlock, self).__init__()
self.deconv = deconv
self.align_corners = align_corners
self.groups = 1
self.expand = expand
out_features = features
if self.expand:
out_features = features // 2
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
self.skip_add = nn.quantized.FloatFunctional()
self.size = size
def forward(self, *xs, size=None):
"""Forward pass.
Returns:
tensor: output
"""
output = xs[0]
if len(xs) == 2:
res = self.resConfUnit1(xs[1])
output = self.skip_add.add(output, res)
output = self.resConfUnit2(output)
if (size is None) and (self.size is None):
modifier = {"scale_factor": 2}
elif size is None:
modifier = {"size": self.size}
else:
modifier = {"size": size}
output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
output = self.out_conv(output)
return output

View File

@ -0,0 +1,183 @@
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from .blocks import FeatureFusionBlock, _make_scratch
torchhub_path = Path(__file__).parent.parent / "torchhub"
def _make_fusion_block(features, use_bn, size=None):
return FeatureFusionBlock(
features,
nn.ReLU(False),
deconv=False,
bn=use_bn,
expand=False,
align_corners=True,
size=size,
)
class DPTHead(nn.Module):
def __init__(self, nclass, in_channels, features, out_channels, use_bn=False, use_clstoken=False):
super(DPTHead, self).__init__()
self.nclass = nclass
self.use_clstoken = use_clstoken
self.projects = nn.ModuleList(
[
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channel,
kernel_size=1,
stride=1,
padding=0,
)
for out_channel in out_channels
]
)
self.resize_layers = nn.ModuleList(
[
nn.ConvTranspose2d(
in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
),
nn.ConvTranspose2d(
in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
),
nn.Identity(),
nn.Conv2d(
in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
),
]
)
if use_clstoken:
self.readout_projects = nn.ModuleList()
for _ in range(len(self.projects)):
self.readout_projects.append(nn.Sequential(nn.Linear(2 * in_channels, in_channels), nn.GELU()))
self.scratch = _make_scratch(
out_channels,
features,
groups=1,
expand=False,
)
self.scratch.stem_transpose = None
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
head_features_1 = features
head_features_2 = 32
if nclass > 1:
self.scratch.output_conv = nn.Sequential(
nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1),
nn.ReLU(True),
nn.Conv2d(head_features_1, nclass, kernel_size=1, stride=1, padding=0),
)
else:
self.scratch.output_conv1 = nn.Conv2d(
head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1
)
self.scratch.output_conv2 = nn.Sequential(
nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
nn.ReLU(True),
nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
nn.ReLU(True),
nn.Identity(),
)
def forward(self, out_features, patch_h, patch_w):
out = []
for i, x in enumerate(out_features):
if self.use_clstoken:
x, cls_token = x[0], x[1]
readout = cls_token.unsqueeze(1).expand_as(x)
x = self.readout_projects[i](torch.cat((x, readout), -1))
else:
x = x[0]
x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
x = self.projects[i](x)
x = self.resize_layers[i](x)
out.append(x)
layer_1, layer_2, layer_3, layer_4 = out
layer_1_rn = self.scratch.layer1_rn(layer_1)
layer_2_rn = self.scratch.layer2_rn(layer_2)
layer_3_rn = self.scratch.layer3_rn(layer_3)
layer_4_rn = self.scratch.layer4_rn(layer_4)
path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
out = self.scratch.output_conv1(path_1)
out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
out = self.scratch.output_conv2(out)
return out
class DPT_DINOv2(nn.Module):
def __init__(
self,
features,
out_channels,
encoder="vitl",
use_bn=False,
use_clstoken=False,
):
super(DPT_DINOv2, self).__init__()
assert encoder in ["vits", "vitb", "vitl"]
# # in case the Internet connection is not stable, please load the DINOv2 locally
# if use_local:
# self.pretrained = torch.hub.load(
# torchhub_path / "facebookresearch_dinov2_main",
# "dinov2_{:}14".format(encoder),
# source="local",
# pretrained=False,
# )
# else:
# self.pretrained = torch.hub.load(
# "facebookresearch/dinov2",
# "dinov2_{:}14".format(encoder),
# )
self.pretrained = torch.hub.load(
"facebookresearch/dinov2",
"dinov2_{:}14".format(encoder),
)
dim = self.pretrained.blocks[0].attn.qkv.in_features
self.depth_head = DPTHead(1, dim, features, out_channels=out_channels, use_bn=use_bn, use_clstoken=use_clstoken)
def forward(self, x):
h, w = x.shape[-2:]
features = self.pretrained.get_intermediate_layers(x, 4, return_class_token=True)
patch_h, patch_w = h // 14, w // 14
depth = self.depth_head(features, patch_h, patch_w)
depth = F.interpolate(depth, size=(h, w), mode="bilinear", align_corners=True)
depth = F.relu(depth)
return depth.squeeze(1)

View File

@ -0,0 +1,227 @@
import math
import cv2
import numpy as np
import torch
import torch.nn.functional as F
def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
"""Rezise the sample to ensure the given size. Keeps aspect ratio.
Args:
sample (dict): sample
size (tuple): image size
Returns:
tuple: new size
"""
shape = list(sample["disparity"].shape)
if shape[0] >= size[0] and shape[1] >= size[1]:
return sample
scale = [0, 0]
scale[0] = size[0] / shape[0]
scale[1] = size[1] / shape[1]
scale = max(scale)
shape[0] = math.ceil(scale * shape[0])
shape[1] = math.ceil(scale * shape[1])
# resize
sample["image"] = cv2.resize(sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method)
sample["disparity"] = cv2.resize(sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST)
sample["mask"] = cv2.resize(
sample["mask"].astype(np.float32),
tuple(shape[::-1]),
interpolation=cv2.INTER_NEAREST,
)
sample["mask"] = sample["mask"].astype(bool)
return tuple(shape)
class Resize(object):
"""Resize sample to given size (width, height)."""
def __init__(
self,
width,
height,
resize_target=True,
keep_aspect_ratio=False,
ensure_multiple_of=1,
resize_method="lower_bound",
image_interpolation_method=cv2.INTER_AREA,
):
"""Init.
Args:
width (int): desired output width
height (int): desired output height
resize_target (bool, optional):
True: Resize the full sample (image, mask, target).
False: Resize image only.
Defaults to True.
keep_aspect_ratio (bool, optional):
True: Keep the aspect ratio of the input sample.
Output sample might not have the given width and height, and
resize behaviour depends on the parameter 'resize_method'.
Defaults to False.
ensure_multiple_of (int, optional):
Output width and height is constrained to be multiple of this parameter.
Defaults to 1.
resize_method (str, optional):
"lower_bound": Output will be at least as large as the given size.
"upper_bound": Output will be at max as large as the given size. (Output size might be smaller
than given size.)
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
Defaults to "lower_bound".
"""
self.__width = width
self.__height = height
self.__resize_target = resize_target
self.__keep_aspect_ratio = keep_aspect_ratio
self.__multiple_of = ensure_multiple_of
self.__resize_method = resize_method
self.__image_interpolation_method = image_interpolation_method
def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
if max_val is not None and y > max_val:
y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
if y < min_val:
y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
return y
def get_size(self, width, height):
# determine new height and width
scale_height = self.__height / height
scale_width = self.__width / width
if self.__keep_aspect_ratio:
if self.__resize_method == "lower_bound":
# scale such that output size is lower bound
if scale_width > scale_height:
# fit width
scale_height = scale_width
else:
# fit height
scale_width = scale_height
elif self.__resize_method == "upper_bound":
# scale such that output size is upper bound
if scale_width < scale_height:
# fit width
scale_height = scale_width
else:
# fit height
scale_width = scale_height
elif self.__resize_method == "minimal":
# scale as least as possbile
if abs(1 - scale_width) < abs(1 - scale_height):
# fit width
scale_height = scale_width
else:
# fit height
scale_width = scale_height
else:
raise ValueError(f"resize_method {self.__resize_method} not implemented")
if self.__resize_method == "lower_bound":
new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height)
new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width)
elif self.__resize_method == "upper_bound":
new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height)
new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width)
elif self.__resize_method == "minimal":
new_height = self.constrain_to_multiple_of(scale_height * height)
new_width = self.constrain_to_multiple_of(scale_width * width)
else:
raise ValueError(f"resize_method {self.__resize_method} not implemented")
return (new_width, new_height)
def __call__(self, sample):
width, height = self.get_size(sample["image"].shape[1], sample["image"].shape[0])
# resize sample
sample["image"] = cv2.resize(
sample["image"],
(width, height),
interpolation=self.__image_interpolation_method,
)
if self.__resize_target:
if "disparity" in sample:
sample["disparity"] = cv2.resize(
sample["disparity"],
(width, height),
interpolation=cv2.INTER_NEAREST,
)
if "depth" in sample:
sample["depth"] = cv2.resize(sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST)
if "semseg_mask" in sample:
# sample["semseg_mask"] = cv2.resize(
# sample["semseg_mask"], (width, height), interpolation=cv2.INTER_NEAREST
# )
sample["semseg_mask"] = F.interpolate(
torch.from_numpy(sample["semseg_mask"]).float()[None, None, ...], (height, width), mode="nearest"
).numpy()[0, 0]
if "mask" in sample:
sample["mask"] = cv2.resize(
sample["mask"].astype(np.float32),
(width, height),
interpolation=cv2.INTER_NEAREST,
)
# sample["mask"] = sample["mask"].astype(bool)
# print(sample['image'].shape, sample['depth'].shape)
return sample
class NormalizeImage(object):
"""Normlize image by given mean and std."""
def __init__(self, mean, std):
self.__mean = mean
self.__std = std
def __call__(self, sample):
sample["image"] = (sample["image"] - self.__mean) / self.__std
return sample
class PrepareForNet(object):
"""Prepare sample for usage as network input."""
def __init__(self):
pass
def __call__(self, sample):
image = np.transpose(sample["image"], (2, 0, 1))
sample["image"] = np.ascontiguousarray(image).astype(np.float32)
if "mask" in sample:
sample["mask"] = sample["mask"].astype(np.float32)
sample["mask"] = np.ascontiguousarray(sample["mask"])
if "depth" in sample:
depth = sample["depth"].astype(np.float32)
sample["depth"] = np.ascontiguousarray(depth)
if "semseg_mask" in sample:
sample["semseg_mask"] = sample["semseg_mask"].astype(np.float32)
sample["semseg_mask"] = np.ascontiguousarray(sample["semseg_mask"])
return sample

View File

@ -224,6 +224,7 @@
"amult": "a_mult",
"autoConfigure": "Auto configure processor",
"balanced": "Balanced",
"base": "Base",
"beginEndStepPercent": "Begin / End Step Percentage",
"bgth": "bg_th",
"canny": "Canny",
@ -237,6 +238,8 @@
"controlMode": "Control Mode",
"crop": "Crop",
"delete": "Delete",
"depthAnything": "Depth Anything",
"depthAnythingDescription": "Depth map generation using the Depth Anything technique",
"depthMidas": "Depth (Midas)",
"depthMidasDescription": "Depth map generation using Midas",
"depthZoe": "Depth (Zoe)",
@ -256,6 +259,7 @@
"colorMapTileSize": "Tile Size",
"importImageFromCanvas": "Import Image From Canvas",
"importMaskFromCanvas": "Import Mask From Canvas",
"large": "Large",
"lineart": "Lineart",
"lineartAnime": "Lineart Anime",
"lineartAnimeDescription": "Anime-style lineart processing",
@ -268,6 +272,7 @@
"minConfidence": "Min Confidence",
"mlsd": "M-LSD",
"mlsdDescription": "Minimalist Line Segment Detector",
"modelSize": "Model Size",
"none": "None",
"noneDescription": "No processing applied",
"normalBae": "Normal BAE",
@ -288,6 +293,7 @@
"selectModel": "Select a model",
"setControlImageDimensions": "Set Control Image Dimensions To W/H",
"showAdvanced": "Show Advanced",
"small": "Small",
"toggleControlNet": "Toggle this ControlNet",
"w": "W",
"weight": "Weight",

View File

@ -5,6 +5,7 @@ import { memo } from 'react';
import CannyProcessor from './processors/CannyProcessor';
import ColorMapProcessor from './processors/ColorMapProcessor';
import ContentShuffleProcessor from './processors/ContentShuffleProcessor';
import DepthAnyThingProcessor from './processors/DepthAnyThingProcessor';
import HedProcessor from './processors/HedProcessor';
import LineartAnimeProcessor from './processors/LineartAnimeProcessor';
import LineartProcessor from './processors/LineartProcessor';
@ -48,6 +49,16 @@ const ControlAdapterProcessorComponent = ({ id }: Props) => {
);
}
if (processorNode.type === 'depth_anything_image_processor') {
return (
<DepthAnyThingProcessor
controlNetId={id}
processorNode={processorNode}
isEnabled={isEnabled}
/>
);
}
if (processorNode.type === 'hed_image_processor') {
return (
<HedProcessor

View File

@ -0,0 +1,110 @@
import type { ComboboxOnChange } from '@invoke-ai/ui';
import {
Combobox,
CompositeNumberInput,
CompositeSlider,
FormControl,
FormLabel,
} from '@invoke-ai/ui';
import { useProcessorNodeChanged } from 'features/controlAdapters/components/hooks/useProcessorNodeChanged';
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
import type {
DepthAnythingModelSize,
RequiredDepthAnythingImageProcessorInvocation,
} from 'features/controlAdapters/store/types';
import { isDepthAnythingModelSize } from 'features/controlAdapters/store/types';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import ProcessorWrapper from './common/ProcessorWrapper';
const DEFAULTS = CONTROLNET_PROCESSORS.midas_depth_image_processor
.default as RequiredDepthAnythingImageProcessorInvocation;
type Props = {
controlNetId: string;
processorNode: RequiredDepthAnythingImageProcessorInvocation;
isEnabled: boolean;
};
const DepthAnythingProcessor = (props: Props) => {
const { controlNetId, processorNode, isEnabled } = props;
const { model_size, resolution } = processorNode;
const processorChanged = useProcessorNodeChanged();
const { t } = useTranslation();
const handleModelSizeChange = useCallback<ComboboxOnChange>(
(v) => {
if (!isDepthAnythingModelSize(v?.value)) {
return;
}
processorChanged(controlNetId, {
model_size: v.value,
});
},
[controlNetId, processorChanged]
);
const options: { label: string; value: DepthAnythingModelSize }[] = useMemo(
() => [
{ label: t('controlnet.small'), value: 'small' },
{ label: t('controlnet.base'), value: 'base' },
{ label: t('controlnet.large'), value: 'large' },
],
[t]
);
const value = useMemo(
() => options.filter((o) => o.value === model_size)[0],
[options, model_size]
);
const handleResolutionChange = useCallback(
(v: number) => {
processorChanged(controlNetId, { resolution: v });
},
[controlNetId, processorChanged]
);
const handleResolutionDefaultChange = useCallback(() => {
processorChanged(controlNetId, { resolution: 512 });
}, [controlNetId, processorChanged]);
return (
<ProcessorWrapper>
<FormControl isDisabled={!isEnabled}>
<FormLabel>{t('controlnet.modelSize')}</FormLabel>
<Combobox
value={value}
defaultInputValue={DEFAULTS.model_size}
options={options}
onChange={handleModelSizeChange}
/>
</FormControl>
<FormControl isDisabled={!isEnabled}>
<FormLabel>{t('controlnet.imageResolution')}</FormLabel>
<CompositeSlider
value={resolution}
onChange={handleResolutionChange}
defaultValue={DEFAULTS.resolution}
min={64}
max={4096}
step={64}
marks
onReset={handleResolutionDefaultChange}
/>
<CompositeNumberInput
value={resolution}
onChange={handleResolutionChange}
defaultValue={DEFAULTS.resolution}
min={64}
max={4096}
step={64}
/>
</FormControl>
</ProcessorWrapper>
);
};
export default memo(DepthAnythingProcessor);

View File

@ -83,6 +83,22 @@ export const CONTROLNET_PROCESSORS: ControlNetProcessorsDict = {
f: 256,
},
},
depth_anything_image_processor: {
type: 'depth_anything_image_processor',
get label() {
return i18n.t('controlnet.depthAnything');
},
get description() {
return i18n.t('controlnet.depthAnythingDescription');
},
default: {
id: 'depth_anything_image_processor',
type: 'depth_anything_image_processor',
model_size: 'small',
resolution: 512,
offload: false,
},
},
hed_image_processor: {
type: 'hed_image_processor',
get label() {
@ -245,7 +261,7 @@ export const CONTROLNET_MODEL_DEFAULT_PROCESSORS: {
} = {
canny: 'canny_image_processor',
mlsd: 'mlsd_image_processor',
depth: 'midas_depth_image_processor',
depth: 'depth_anything_image_processor',
bae: 'normalbae_image_processor',
sketch: 'pidi_image_processor',
scribble: 'lineart_image_processor',

View File

@ -10,6 +10,7 @@ import type {
CannyImageProcessorInvocation,
ColorMapImageProcessorInvocation,
ContentShuffleImageProcessorInvocation,
DepthAnythingImageProcessorInvocation,
HedImageProcessorInvocation,
LineartAnimeImageProcessorInvocation,
LineartImageProcessorInvocation,
@ -31,6 +32,7 @@ export type ControlAdapterProcessorNode =
| CannyImageProcessorInvocation
| ColorMapImageProcessorInvocation
| ContentShuffleImageProcessorInvocation
| DepthAnythingImageProcessorInvocation
| HedImageProcessorInvocation
| LineartAnimeImageProcessorInvocation
| LineartImageProcessorInvocation
@ -73,6 +75,20 @@ export type RequiredContentShuffleImageProcessorInvocation = O.Required<
'type' | 'detect_resolution' | 'image_resolution' | 'w' | 'h' | 'f'
>;
/**
* The DepthAnything processor node, with parameters flagged as required
*/
export type RequiredDepthAnythingImageProcessorInvocation = O.Required<
DepthAnythingImageProcessorInvocation,
'type' | 'model_size' | 'resolution' | 'offload'
>;
export const zDepthAnythingModelSize = z.enum(['large', 'base', 'small']);
export type DepthAnythingModelSize = z.infer<typeof zDepthAnythingModelSize>;
export const isDepthAnythingModelSize = (
v: unknown
): v is DepthAnythingModelSize => zDepthAnythingModelSize.safeParse(v).success;
/**
* The HED processor node, with parameters flagged as required
*/
@ -161,6 +177,7 @@ export type RequiredControlAdapterProcessorNode =
| RequiredCannyImageProcessorInvocation
| RequiredColorMapImageProcessorInvocation
| RequiredContentShuffleImageProcessorInvocation
| RequiredDepthAnythingImageProcessorInvocation
| RequiredHedImageProcessorInvocation
| RequiredLineartAnimeImageProcessorInvocation
| RequiredLineartImageProcessorInvocation
@ -219,6 +236,22 @@ export const isContentShuffleImageProcessorInvocation = (
return false;
};
/**
* Type guard for DepthAnythingImageProcessorInvocation
*/
export const isDepthAnythingImageProcessorInvocation = (
obj: unknown
): obj is DepthAnythingImageProcessorInvocation => {
if (
isObject(obj) &&
'type' in obj &&
obj.type === 'depth_anything_image_processor'
) {
return true;
}
return false;
};
/**
* Type guard for HedImageprocessorInvocation
*/

File diff suppressed because one or more lines are too long

View File

@ -164,6 +164,8 @@ export type ColorMapImageProcessorInvocation =
s['ColorMapImageProcessorInvocation'];
export type ContentShuffleImageProcessorInvocation =
s['ContentShuffleImageProcessorInvocation'];
export type DepthAnythingImageProcessorInvocation =
s['DepthAnythingImageProcessorInvocation'];
export type HedImageProcessorInvocation = s['HedImageProcessorInvocation'];
export type LineartAnimeImageProcessorInvocation =
s['LineartAnimeImageProcessorInvocation'];