mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat: Add Depth Anything PreProcessor
This commit is contained in:
parent
2aed6e2dba
commit
8f5e2cbcc7
@ -30,6 +30,7 @@ from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
|||||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||||
from invokeai.app.shared.fields import FieldDescriptions
|
from invokeai.app.shared.fields import FieldDescriptions
|
||||||
|
from invokeai.backend.image_util.depth_anything import DepthAnythingDetector
|
||||||
|
|
||||||
from ...backend.model_management import BaseModelType
|
from ...backend.model_management import BaseModelType
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
@ -602,3 +603,32 @@ class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
color_map = cv2.resize(color_map, (width, height), interpolation=cv2.INTER_NEAREST)
|
color_map = cv2.resize(color_map, (width, height), interpolation=cv2.INTER_NEAREST)
|
||||||
color_map = Image.fromarray(color_map)
|
color_map = Image.fromarray(color_map)
|
||||||
return color_map
|
return color_map
|
||||||
|
|
||||||
|
|
||||||
|
DEPTH_ANYTHING_MODEL_SIZES = Literal["large", "base", "small"]
|
||||||
|
|
||||||
|
|
||||||
|
@invocation(
|
||||||
|
"depth_anything_image_processor",
|
||||||
|
title="Depth Anything Processor",
|
||||||
|
tags=["controlnet", "depth", "depth anything"],
|
||||||
|
category="controlnet",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
|
class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
|
"""Generates a depth map based on the Depth Anything algorithm"""
|
||||||
|
|
||||||
|
model_size: DEPTH_ANYTHING_MODEL_SIZES = InputField(
|
||||||
|
default="large", description="The size of the depth model to use"
|
||||||
|
)
|
||||||
|
offload: bool = InputField(default=False)
|
||||||
|
|
||||||
|
def run_processor(self, image):
|
||||||
|
depth_anything_detector = DepthAnythingDetector()
|
||||||
|
depth_anything_detector.load_model(model_size=self.model_size)
|
||||||
|
|
||||||
|
if image.mode == "RGBA":
|
||||||
|
image = image.convert("RGB")
|
||||||
|
|
||||||
|
processed_image = depth_anything_detector(image=image, offload=self.offload)
|
||||||
|
return processed_image
|
||||||
|
107
invokeai/backend/image_util/depth_anything/__init__.py
Normal file
107
invokeai/backend/image_util/depth_anything/__init__.py
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
import pathlib
|
||||||
|
from typing import Literal, Union
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from einops import repeat
|
||||||
|
from PIL import Image
|
||||||
|
from torchvision.transforms import Compose
|
||||||
|
|
||||||
|
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||||
|
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
|
||||||
|
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize
|
||||||
|
from invokeai.backend.util.devices import choose_torch_device
|
||||||
|
from invokeai.backend.util.util import download_with_progress_bar
|
||||||
|
|
||||||
|
config = InvokeAIAppConfig.get_config()
|
||||||
|
|
||||||
|
DEPTH_ANYTHING_MODELS = {
|
||||||
|
"large": {
|
||||||
|
"url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true",
|
||||||
|
"local": "sd-1/controlnet/annotator/depth_anything/depth_anything_vitl14.pth",
|
||||||
|
},
|
||||||
|
"base": {
|
||||||
|
"url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitb14.pth?download=true",
|
||||||
|
"local": "sd-1/controlnet/annotator/depth_anything/depth_anything_vitb14.pth",
|
||||||
|
},
|
||||||
|
"small": {
|
||||||
|
"url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vits14.pth?download=true",
|
||||||
|
"local": "sd-1/controlnet/annotator/depth_anything/depth_anything_vits14.pth",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
transform = Compose(
|
||||||
|
[
|
||||||
|
Resize(
|
||||||
|
width=518,
|
||||||
|
height=518,
|
||||||
|
resize_target=False,
|
||||||
|
keep_aspect_ratio=True,
|
||||||
|
ensure_multiple_of=14,
|
||||||
|
resize_method="lower_bound",
|
||||||
|
image_interpolation_method=cv2.INTER_CUBIC,
|
||||||
|
),
|
||||||
|
NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||||
|
PrepareForNet(),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DepthAnythingDetector:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.model = None
|
||||||
|
self.model_size: Union[Literal["large", "base", "small"], None] = None
|
||||||
|
|
||||||
|
def load_model(self, model_size=Literal["large", "base", "small"]):
|
||||||
|
DEPTH_ANYTHING_MODEL_PATH = pathlib.Path(config.models_path / DEPTH_ANYTHING_MODELS[model_size]["local"])
|
||||||
|
if not DEPTH_ANYTHING_MODEL_PATH.exists():
|
||||||
|
download_with_progress_bar(DEPTH_ANYTHING_MODELS[model_size]["url"], DEPTH_ANYTHING_MODEL_PATH)
|
||||||
|
|
||||||
|
if not self.model or model_size != self.model_size:
|
||||||
|
del self.model
|
||||||
|
self.model_size = model_size
|
||||||
|
|
||||||
|
if self.model_size == "small":
|
||||||
|
self.model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384], localhub=True)
|
||||||
|
if self.model_size == "base":
|
||||||
|
self.model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768], localhub=True)
|
||||||
|
if self.model_size == "large":
|
||||||
|
self.model = DPT_DINOv2(
|
||||||
|
encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024], localhub=True
|
||||||
|
)
|
||||||
|
|
||||||
|
self.model.load_state_dict(torch.load(DEPTH_ANYTHING_MODEL_PATH.as_posix(), map_location="cpu"))
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
self.model.to(choose_torch_device())
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
def to(self, device):
|
||||||
|
self.model.to(device)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __call__(self, image, offload=False):
|
||||||
|
image = np.array(image, dtype=np.uint8)
|
||||||
|
original_width, original_height = image.shape[:2]
|
||||||
|
image = image[:, :, ::-1] / 255.0
|
||||||
|
|
||||||
|
image_width, image_height = image.shape[:2]
|
||||||
|
image = transform({"image": image})["image"]
|
||||||
|
image = torch.from_numpy(image).unsqueeze(0).to(choose_torch_device())
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
depth = self.model(image)
|
||||||
|
depth = F.interpolate(depth[None], (image_height, image_width), mode="bilinear", align_corners=False)[0, 0]
|
||||||
|
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
|
||||||
|
|
||||||
|
depth_map = repeat(depth, "h w -> h w 3").cpu().numpy().astype(np.uint8)
|
||||||
|
depth_map = Image.fromarray(depth_map)
|
||||||
|
depth_map = depth_map.resize((original_height, original_width))
|
||||||
|
|
||||||
|
if offload:
|
||||||
|
del self.model
|
||||||
|
|
||||||
|
return depth_map
|
145
invokeai/backend/image_util/depth_anything/model/blocks.py
Normal file
145
invokeai/backend/image_util/depth_anything/model/blocks.py
Normal file
@ -0,0 +1,145 @@
|
|||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
|
||||||
|
scratch = nn.Module()
|
||||||
|
|
||||||
|
out_shape1 = out_shape
|
||||||
|
out_shape2 = out_shape
|
||||||
|
out_shape3 = out_shape
|
||||||
|
if len(in_shape) >= 4:
|
||||||
|
out_shape4 = out_shape
|
||||||
|
|
||||||
|
if expand:
|
||||||
|
out_shape1 = out_shape
|
||||||
|
out_shape2 = out_shape * 2
|
||||||
|
out_shape3 = out_shape * 4
|
||||||
|
if len(in_shape) >= 4:
|
||||||
|
out_shape4 = out_shape * 8
|
||||||
|
|
||||||
|
scratch.layer1_rn = nn.Conv2d(
|
||||||
|
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
||||||
|
)
|
||||||
|
scratch.layer2_rn = nn.Conv2d(
|
||||||
|
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
||||||
|
)
|
||||||
|
scratch.layer3_rn = nn.Conv2d(
|
||||||
|
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
||||||
|
)
|
||||||
|
if len(in_shape) >= 4:
|
||||||
|
scratch.layer4_rn = nn.Conv2d(
|
||||||
|
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
||||||
|
)
|
||||||
|
|
||||||
|
return scratch
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualConvUnit(nn.Module):
|
||||||
|
"""Residual convolution module."""
|
||||||
|
|
||||||
|
def __init__(self, features, activation, bn):
|
||||||
|
"""Init.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
features (int): number of features
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.bn = bn
|
||||||
|
|
||||||
|
self.groups = 1
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
|
||||||
|
|
||||||
|
self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
|
||||||
|
|
||||||
|
if self.bn:
|
||||||
|
self.bn1 = nn.BatchNorm2d(features)
|
||||||
|
self.bn2 = nn.BatchNorm2d(features)
|
||||||
|
|
||||||
|
self.activation = activation
|
||||||
|
|
||||||
|
self.skip_add = nn.quantized.FloatFunctional()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""Forward pass.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (tensor): input
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tensor: output
|
||||||
|
"""
|
||||||
|
|
||||||
|
out = self.activation(x)
|
||||||
|
out = self.conv1(out)
|
||||||
|
if self.bn:
|
||||||
|
out = self.bn1(out)
|
||||||
|
|
||||||
|
out = self.activation(out)
|
||||||
|
out = self.conv2(out)
|
||||||
|
if self.bn:
|
||||||
|
out = self.bn2(out)
|
||||||
|
|
||||||
|
if self.groups > 1:
|
||||||
|
out = self.conv_merge(out)
|
||||||
|
|
||||||
|
return self.skip_add.add(out, x)
|
||||||
|
|
||||||
|
|
||||||
|
class FeatureFusionBlock(nn.Module):
|
||||||
|
"""Feature fusion block."""
|
||||||
|
|
||||||
|
def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, size=None):
|
||||||
|
"""Init.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
features (int): number of features
|
||||||
|
"""
|
||||||
|
super(FeatureFusionBlock, self).__init__()
|
||||||
|
|
||||||
|
self.deconv = deconv
|
||||||
|
self.align_corners = align_corners
|
||||||
|
|
||||||
|
self.groups = 1
|
||||||
|
|
||||||
|
self.expand = expand
|
||||||
|
out_features = features
|
||||||
|
if self.expand:
|
||||||
|
out_features = features // 2
|
||||||
|
|
||||||
|
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
|
||||||
|
|
||||||
|
self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
|
||||||
|
self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
|
||||||
|
|
||||||
|
self.skip_add = nn.quantized.FloatFunctional()
|
||||||
|
|
||||||
|
self.size = size
|
||||||
|
|
||||||
|
def forward(self, *xs, size=None):
|
||||||
|
"""Forward pass.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tensor: output
|
||||||
|
"""
|
||||||
|
output = xs[0]
|
||||||
|
|
||||||
|
if len(xs) == 2:
|
||||||
|
res = self.resConfUnit1(xs[1])
|
||||||
|
output = self.skip_add.add(output, res)
|
||||||
|
|
||||||
|
output = self.resConfUnit2(output)
|
||||||
|
|
||||||
|
if (size is None) and (self.size is None):
|
||||||
|
modifier = {"scale_factor": 2}
|
||||||
|
elif size is None:
|
||||||
|
modifier = {"size": self.size}
|
||||||
|
else:
|
||||||
|
modifier = {"size": size}
|
||||||
|
|
||||||
|
output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
|
||||||
|
|
||||||
|
output = self.out_conv(output)
|
||||||
|
|
||||||
|
return output
|
186
invokeai/backend/image_util/depth_anything/model/dpt.py
Normal file
186
invokeai/backend/image_util/depth_anything/model/dpt.py
Normal file
@ -0,0 +1,186 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from .blocks import FeatureFusionBlock, _make_scratch
|
||||||
|
|
||||||
|
torchhub_path = Path(__file__).parent.parent / "torchhub"
|
||||||
|
|
||||||
|
|
||||||
|
def _make_fusion_block(features, use_bn, size=None):
|
||||||
|
return FeatureFusionBlock(
|
||||||
|
features,
|
||||||
|
nn.ReLU(False),
|
||||||
|
deconv=False,
|
||||||
|
bn=use_bn,
|
||||||
|
expand=False,
|
||||||
|
align_corners=True,
|
||||||
|
size=size,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DPTHead(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, nclass, in_channels, features=256, use_bn=False, out_channels=[256, 512, 1024, 1024], use_clstoken=False
|
||||||
|
):
|
||||||
|
super(DPTHead, self).__init__()
|
||||||
|
|
||||||
|
self.nclass = nclass
|
||||||
|
self.use_clstoken = use_clstoken
|
||||||
|
|
||||||
|
self.projects = nn.ModuleList(
|
||||||
|
[
|
||||||
|
nn.Conv2d(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channel,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
)
|
||||||
|
for out_channel in out_channels
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.resize_layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
nn.ConvTranspose2d(
|
||||||
|
in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
|
||||||
|
),
|
||||||
|
nn.ConvTranspose2d(
|
||||||
|
in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
|
||||||
|
),
|
||||||
|
nn.Identity(),
|
||||||
|
nn.Conv2d(
|
||||||
|
in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_clstoken:
|
||||||
|
self.readout_projects = nn.ModuleList()
|
||||||
|
for _ in range(len(self.projects)):
|
||||||
|
self.readout_projects.append(nn.Sequential(nn.Linear(2 * in_channels, in_channels), nn.GELU()))
|
||||||
|
|
||||||
|
self.scratch = _make_scratch(
|
||||||
|
out_channels,
|
||||||
|
features,
|
||||||
|
groups=1,
|
||||||
|
expand=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.scratch.stem_transpose = None
|
||||||
|
|
||||||
|
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
|
||||||
|
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
|
||||||
|
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
|
||||||
|
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
|
||||||
|
|
||||||
|
head_features_1 = features
|
||||||
|
head_features_2 = 32
|
||||||
|
|
||||||
|
if nclass > 1:
|
||||||
|
self.scratch.output_conv = nn.Sequential(
|
||||||
|
nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1),
|
||||||
|
nn.ReLU(True),
|
||||||
|
nn.Conv2d(head_features_1, nclass, kernel_size=1, stride=1, padding=0),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.scratch.output_conv1 = nn.Conv2d(
|
||||||
|
head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1
|
||||||
|
)
|
||||||
|
|
||||||
|
self.scratch.output_conv2 = nn.Sequential(
|
||||||
|
nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
|
||||||
|
nn.ReLU(True),
|
||||||
|
nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
|
||||||
|
nn.ReLU(True),
|
||||||
|
nn.Identity(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, out_features, patch_h, patch_w):
|
||||||
|
out = []
|
||||||
|
for i, x in enumerate(out_features):
|
||||||
|
if self.use_clstoken:
|
||||||
|
x, cls_token = x[0], x[1]
|
||||||
|
readout = cls_token.unsqueeze(1).expand_as(x)
|
||||||
|
x = self.readout_projects[i](torch.cat((x, readout), -1))
|
||||||
|
else:
|
||||||
|
x = x[0]
|
||||||
|
|
||||||
|
x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
|
||||||
|
|
||||||
|
x = self.projects[i](x)
|
||||||
|
x = self.resize_layers[i](x)
|
||||||
|
|
||||||
|
out.append(x)
|
||||||
|
|
||||||
|
layer_1, layer_2, layer_3, layer_4 = out
|
||||||
|
|
||||||
|
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
||||||
|
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
||||||
|
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
||||||
|
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
||||||
|
|
||||||
|
path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
|
||||||
|
path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
|
||||||
|
path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
|
||||||
|
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
||||||
|
|
||||||
|
out = self.scratch.output_conv1(path_1)
|
||||||
|
out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
|
||||||
|
out = self.scratch.output_conv2(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class DPT_DINOv2(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
encoder="vitl",
|
||||||
|
features=256,
|
||||||
|
out_channels=[256, 512, 1024, 1024],
|
||||||
|
use_bn=False,
|
||||||
|
use_clstoken=False,
|
||||||
|
localhub=True,
|
||||||
|
):
|
||||||
|
super(DPT_DINOv2, self).__init__()
|
||||||
|
|
||||||
|
assert encoder in ["vits", "vitb", "vitl"]
|
||||||
|
|
||||||
|
# # in case the Internet connection is not stable, please load the DINOv2 locally
|
||||||
|
# if localhub:
|
||||||
|
# self.pretrained = torch.hub.load(
|
||||||
|
# torchhub_path / "facebookresearch_dinov2_main",
|
||||||
|
# "dinov2_{:}14".format(encoder),
|
||||||
|
# source="local",
|
||||||
|
# pretrained=False,
|
||||||
|
# )
|
||||||
|
# else:
|
||||||
|
# self.pretrained = torch.hub.load(
|
||||||
|
# "facebookresearch/dinov2",
|
||||||
|
# "dinov2_{:}14".format(encoder),
|
||||||
|
# )
|
||||||
|
|
||||||
|
self.pretrained = torch.hub.load(
|
||||||
|
"facebookresearch/dinov2",
|
||||||
|
"dinov2_{:}14".format(encoder),
|
||||||
|
)
|
||||||
|
|
||||||
|
dim = self.pretrained.blocks[0].attn.qkv.in_features
|
||||||
|
|
||||||
|
self.depth_head = DPTHead(1, dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
h, w = x.shape[-2:]
|
||||||
|
|
||||||
|
features = self.pretrained.get_intermediate_layers(x, 4, return_class_token=True)
|
||||||
|
|
||||||
|
patch_h, patch_w = h // 14, w // 14
|
||||||
|
|
||||||
|
depth = self.depth_head(features, patch_h, patch_w)
|
||||||
|
depth = F.interpolate(depth, size=(h, w), mode="bilinear", align_corners=True)
|
||||||
|
depth = F.relu(depth)
|
||||||
|
|
||||||
|
return depth.squeeze(1)
|
227
invokeai/backend/image_util/depth_anything/utilities/util.py
Normal file
227
invokeai/backend/image_util/depth_anything/utilities/util.py
Normal file
@ -0,0 +1,227 @@
|
|||||||
|
import math
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
|
||||||
|
"""Rezise the sample to ensure the given size. Keeps aspect ratio.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample (dict): sample
|
||||||
|
size (tuple): image size
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: new size
|
||||||
|
"""
|
||||||
|
shape = list(sample["disparity"].shape)
|
||||||
|
|
||||||
|
if shape[0] >= size[0] and shape[1] >= size[1]:
|
||||||
|
return sample
|
||||||
|
|
||||||
|
scale = [0, 0]
|
||||||
|
scale[0] = size[0] / shape[0]
|
||||||
|
scale[1] = size[1] / shape[1]
|
||||||
|
|
||||||
|
scale = max(scale)
|
||||||
|
|
||||||
|
shape[0] = math.ceil(scale * shape[0])
|
||||||
|
shape[1] = math.ceil(scale * shape[1])
|
||||||
|
|
||||||
|
# resize
|
||||||
|
sample["image"] = cv2.resize(sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method)
|
||||||
|
|
||||||
|
sample["disparity"] = cv2.resize(sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST)
|
||||||
|
sample["mask"] = cv2.resize(
|
||||||
|
sample["mask"].astype(np.float32),
|
||||||
|
tuple(shape[::-1]),
|
||||||
|
interpolation=cv2.INTER_NEAREST,
|
||||||
|
)
|
||||||
|
sample["mask"] = sample["mask"].astype(bool)
|
||||||
|
|
||||||
|
return tuple(shape)
|
||||||
|
|
||||||
|
|
||||||
|
class Resize(object):
|
||||||
|
"""Resize sample to given size (width, height)."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
resize_target=True,
|
||||||
|
keep_aspect_ratio=False,
|
||||||
|
ensure_multiple_of=1,
|
||||||
|
resize_method="lower_bound",
|
||||||
|
image_interpolation_method=cv2.INTER_AREA,
|
||||||
|
):
|
||||||
|
"""Init.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
width (int): desired output width
|
||||||
|
height (int): desired output height
|
||||||
|
resize_target (bool, optional):
|
||||||
|
True: Resize the full sample (image, mask, target).
|
||||||
|
False: Resize image only.
|
||||||
|
Defaults to True.
|
||||||
|
keep_aspect_ratio (bool, optional):
|
||||||
|
True: Keep the aspect ratio of the input sample.
|
||||||
|
Output sample might not have the given width and height, and
|
||||||
|
resize behaviour depends on the parameter 'resize_method'.
|
||||||
|
Defaults to False.
|
||||||
|
ensure_multiple_of (int, optional):
|
||||||
|
Output width and height is constrained to be multiple of this parameter.
|
||||||
|
Defaults to 1.
|
||||||
|
resize_method (str, optional):
|
||||||
|
"lower_bound": Output will be at least as large as the given size.
|
||||||
|
"upper_bound": Output will be at max as large as the given size. (Output size might be smaller
|
||||||
|
than given size.)
|
||||||
|
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
|
||||||
|
Defaults to "lower_bound".
|
||||||
|
"""
|
||||||
|
self.__width = width
|
||||||
|
self.__height = height
|
||||||
|
|
||||||
|
self.__resize_target = resize_target
|
||||||
|
self.__keep_aspect_ratio = keep_aspect_ratio
|
||||||
|
self.__multiple_of = ensure_multiple_of
|
||||||
|
self.__resize_method = resize_method
|
||||||
|
self.__image_interpolation_method = image_interpolation_method
|
||||||
|
|
||||||
|
def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
|
||||||
|
y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
||||||
|
|
||||||
|
if max_val is not None and y > max_val:
|
||||||
|
y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
||||||
|
|
||||||
|
if y < min_val:
|
||||||
|
y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
||||||
|
|
||||||
|
return y
|
||||||
|
|
||||||
|
def get_size(self, width, height):
|
||||||
|
# determine new height and width
|
||||||
|
scale_height = self.__height / height
|
||||||
|
scale_width = self.__width / width
|
||||||
|
|
||||||
|
if self.__keep_aspect_ratio:
|
||||||
|
if self.__resize_method == "lower_bound":
|
||||||
|
# scale such that output size is lower bound
|
||||||
|
if scale_width > scale_height:
|
||||||
|
# fit width
|
||||||
|
scale_height = scale_width
|
||||||
|
else:
|
||||||
|
# fit height
|
||||||
|
scale_width = scale_height
|
||||||
|
elif self.__resize_method == "upper_bound":
|
||||||
|
# scale such that output size is upper bound
|
||||||
|
if scale_width < scale_height:
|
||||||
|
# fit width
|
||||||
|
scale_height = scale_width
|
||||||
|
else:
|
||||||
|
# fit height
|
||||||
|
scale_width = scale_height
|
||||||
|
elif self.__resize_method == "minimal":
|
||||||
|
# scale as least as possbile
|
||||||
|
if abs(1 - scale_width) < abs(1 - scale_height):
|
||||||
|
# fit width
|
||||||
|
scale_height = scale_width
|
||||||
|
else:
|
||||||
|
# fit height
|
||||||
|
scale_width = scale_height
|
||||||
|
else:
|
||||||
|
raise ValueError(f"resize_method {self.__resize_method} not implemented")
|
||||||
|
|
||||||
|
if self.__resize_method == "lower_bound":
|
||||||
|
new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height)
|
||||||
|
new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width)
|
||||||
|
elif self.__resize_method == "upper_bound":
|
||||||
|
new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height)
|
||||||
|
new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width)
|
||||||
|
elif self.__resize_method == "minimal":
|
||||||
|
new_height = self.constrain_to_multiple_of(scale_height * height)
|
||||||
|
new_width = self.constrain_to_multiple_of(scale_width * width)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"resize_method {self.__resize_method} not implemented")
|
||||||
|
|
||||||
|
return (new_width, new_height)
|
||||||
|
|
||||||
|
def __call__(self, sample):
|
||||||
|
width, height = self.get_size(sample["image"].shape[1], sample["image"].shape[0])
|
||||||
|
|
||||||
|
# resize sample
|
||||||
|
sample["image"] = cv2.resize(
|
||||||
|
sample["image"],
|
||||||
|
(width, height),
|
||||||
|
interpolation=self.__image_interpolation_method,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.__resize_target:
|
||||||
|
if "disparity" in sample:
|
||||||
|
sample["disparity"] = cv2.resize(
|
||||||
|
sample["disparity"],
|
||||||
|
(width, height),
|
||||||
|
interpolation=cv2.INTER_NEAREST,
|
||||||
|
)
|
||||||
|
|
||||||
|
if "depth" in sample:
|
||||||
|
sample["depth"] = cv2.resize(sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST)
|
||||||
|
|
||||||
|
if "semseg_mask" in sample:
|
||||||
|
# sample["semseg_mask"] = cv2.resize(
|
||||||
|
# sample["semseg_mask"], (width, height), interpolation=cv2.INTER_NEAREST
|
||||||
|
# )
|
||||||
|
sample["semseg_mask"] = F.interpolate(
|
||||||
|
torch.from_numpy(sample["semseg_mask"]).float()[None, None, ...], (height, width), mode="nearest"
|
||||||
|
).numpy()[0, 0]
|
||||||
|
|
||||||
|
if "mask" in sample:
|
||||||
|
sample["mask"] = cv2.resize(
|
||||||
|
sample["mask"].astype(np.float32),
|
||||||
|
(width, height),
|
||||||
|
interpolation=cv2.INTER_NEAREST,
|
||||||
|
)
|
||||||
|
# sample["mask"] = sample["mask"].astype(bool)
|
||||||
|
|
||||||
|
# print(sample['image'].shape, sample['depth'].shape)
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
class NormalizeImage(object):
|
||||||
|
"""Normlize image by given mean and std."""
|
||||||
|
|
||||||
|
def __init__(self, mean, std):
|
||||||
|
self.__mean = mean
|
||||||
|
self.__std = std
|
||||||
|
|
||||||
|
def __call__(self, sample):
|
||||||
|
sample["image"] = (sample["image"] - self.__mean) / self.__std
|
||||||
|
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
class PrepareForNet(object):
|
||||||
|
"""Prepare sample for usage as network input."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __call__(self, sample):
|
||||||
|
image = np.transpose(sample["image"], (2, 0, 1))
|
||||||
|
sample["image"] = np.ascontiguousarray(image).astype(np.float32)
|
||||||
|
|
||||||
|
if "mask" in sample:
|
||||||
|
sample["mask"] = sample["mask"].astype(np.float32)
|
||||||
|
sample["mask"] = np.ascontiguousarray(sample["mask"])
|
||||||
|
|
||||||
|
if "depth" in sample:
|
||||||
|
depth = sample["depth"].astype(np.float32)
|
||||||
|
sample["depth"] = np.ascontiguousarray(depth)
|
||||||
|
|
||||||
|
if "semseg_mask" in sample:
|
||||||
|
sample["semseg_mask"] = sample["semseg_mask"].astype(np.float32)
|
||||||
|
sample["semseg_mask"] = np.ascontiguousarray(sample["semseg_mask"])
|
||||||
|
|
||||||
|
return sample
|
Loading…
x
Reference in New Issue
Block a user