diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py
index f16a8e36ae..00c3fa74f6 100644
--- a/invokeai/app/invocations/controlnet_image_processors.py
+++ b/invokeai/app/invocations/controlnet_image_processors.py
@@ -30,6 +30,7 @@ from invokeai.app.invocations.primitives import ImageField, ImageOutput
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
from invokeai.app.shared.fields import FieldDescriptions
+from invokeai.backend.image_util.depth_anything import DepthAnythingDetector
from ...backend.model_management import BaseModelType
from .baseinvocation import (
@@ -602,3 +603,33 @@ class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
color_map = cv2.resize(color_map, (width, height), interpolation=cv2.INTER_NEAREST)
color_map = Image.fromarray(color_map)
return color_map
+
+
+DEPTH_ANYTHING_MODEL_SIZES = Literal["large", "base", "small"]
+
+
+@invocation(
+ "depth_anything_image_processor",
+ title="Depth Anything Processor",
+ tags=["controlnet", "depth", "depth anything"],
+ category="controlnet",
+ version="1.0.0",
+)
+class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
+ """Generates a depth map based on the Depth Anything algorithm"""
+
+ model_size: DEPTH_ANYTHING_MODEL_SIZES = InputField(
+ default="small", description="The size of the depth model to use"
+ )
+ resolution: int = InputField(default=512, ge=64, multiple_of=64, description=FieldDescriptions.image_res)
+ offload: bool = InputField(default=False)
+
+ def run_processor(self, image):
+ depth_anything_detector = DepthAnythingDetector()
+ depth_anything_detector.load_model(model_size=self.model_size)
+
+ if image.mode == "RGBA":
+ image = image.convert("RGB")
+
+ processed_image = depth_anything_detector(image=image, resolution=self.resolution, offload=self.offload)
+ return processed_image
diff --git a/invokeai/backend/image_util/depth_anything/__init__.py b/invokeai/backend/image_util/depth_anything/__init__.py
new file mode 100644
index 0000000000..fcd600b99e
--- /dev/null
+++ b/invokeai/backend/image_util/depth_anything/__init__.py
@@ -0,0 +1,109 @@
+import pathlib
+from typing import Literal, Union
+
+import cv2
+import numpy as np
+import torch
+import torch.nn.functional as F
+from einops import repeat
+from PIL import Image
+from torchvision.transforms import Compose
+
+from invokeai.app.services.config.config_default import InvokeAIAppConfig
+from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
+from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize
+from invokeai.backend.util.devices import choose_torch_device
+from invokeai.backend.util.util import download_with_progress_bar
+
+config = InvokeAIAppConfig.get_config()
+
+DEPTH_ANYTHING_MODELS = {
+ "large": {
+ "url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true",
+ "local": "any/annotators/depth_anything/depth_anything_vitl14.pth",
+ },
+ "base": {
+ "url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitb14.pth?download=true",
+ "local": "any/annotators/depth_anything/depth_anything_vitb14.pth",
+ },
+ "small": {
+ "url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vits14.pth?download=true",
+ "local": "any/annotators/depth_anything/depth_anything_vits14.pth",
+ },
+}
+
+
+transform = Compose(
+ [
+ Resize(
+ width=518,
+ height=518,
+ resize_target=False,
+ keep_aspect_ratio=True,
+ ensure_multiple_of=14,
+ resize_method="lower_bound",
+ image_interpolation_method=cv2.INTER_CUBIC,
+ ),
+ NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
+ PrepareForNet(),
+ ]
+)
+
+
+class DepthAnythingDetector:
+ def __init__(self) -> None:
+ self.model = None
+ self.model_size: Union[Literal["large", "base", "small"], None] = None
+
+ def load_model(self, model_size=Literal["large", "base", "small"]):
+ DEPTH_ANYTHING_MODEL_PATH = pathlib.Path(config.models_path / DEPTH_ANYTHING_MODELS[model_size]["local"])
+ if not DEPTH_ANYTHING_MODEL_PATH.exists():
+ download_with_progress_bar(DEPTH_ANYTHING_MODELS[model_size]["url"], DEPTH_ANYTHING_MODEL_PATH)
+
+ if not self.model or model_size != self.model_size:
+ del self.model
+ self.model_size = model_size
+
+ match self.model_size:
+ case "small":
+ self.model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384])
+ case "base":
+ self.model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768])
+ case "large":
+ self.model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024])
+ case _:
+ raise TypeError("Not a supported model")
+
+ self.model.load_state_dict(torch.load(DEPTH_ANYTHING_MODEL_PATH.as_posix(), map_location="cpu"))
+ self.model.eval()
+
+ self.model.to(choose_torch_device())
+ return self.model
+
+ def to(self, device):
+ self.model.to(device)
+ return self
+
+ def __call__(self, image, resolution=512, offload=False):
+ image = np.array(image, dtype=np.uint8)
+ image = image[:, :, ::-1] / 255.0
+
+ image_height, image_width = image.shape[:2]
+ image = transform({"image": image})["image"]
+ image = torch.from_numpy(image).unsqueeze(0).to(choose_torch_device())
+
+ with torch.no_grad():
+ depth = self.model(image)
+ depth = F.interpolate(depth[None], (image_height, image_width), mode="bilinear", align_corners=False)[0, 0]
+ depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
+
+ depth_map = repeat(depth, "h w -> h w 3").cpu().numpy().astype(np.uint8)
+ depth_map = Image.fromarray(depth_map)
+
+ new_height = int(image_height * (resolution / image_width))
+ depth_map = depth_map.resize((resolution, new_height))
+
+ if offload:
+ del self.model
+
+ return depth_map
diff --git a/invokeai/backend/image_util/depth_anything/model/blocks.py b/invokeai/backend/image_util/depth_anything/model/blocks.py
new file mode 100644
index 0000000000..4534f52237
--- /dev/null
+++ b/invokeai/backend/image_util/depth_anything/model/blocks.py
@@ -0,0 +1,145 @@
+import torch.nn as nn
+
+
+def _make_scratch(in_shape, out_shape, groups=1, expand=False):
+ scratch = nn.Module()
+
+ out_shape1 = out_shape
+ out_shape2 = out_shape
+ out_shape3 = out_shape
+ if len(in_shape) >= 4:
+ out_shape4 = out_shape
+
+ if expand:
+ out_shape1 = out_shape
+ out_shape2 = out_shape * 2
+ out_shape3 = out_shape * 4
+ if len(in_shape) >= 4:
+ out_shape4 = out_shape * 8
+
+ scratch.layer1_rn = nn.Conv2d(
+ in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ scratch.layer2_rn = nn.Conv2d(
+ in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ scratch.layer3_rn = nn.Conv2d(
+ in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ if len(in_shape) >= 4:
+ scratch.layer4_rn = nn.Conv2d(
+ in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+
+ return scratch
+
+
+class ResidualConvUnit(nn.Module):
+ """Residual convolution module."""
+
+ def __init__(self, features, activation, bn):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super().__init__()
+
+ self.bn = bn
+
+ self.groups = 1
+
+ self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
+
+ self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
+
+ if self.bn:
+ self.bn1 = nn.BatchNorm2d(features)
+ self.bn2 = nn.BatchNorm2d(features)
+
+ self.activation = activation
+
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input
+
+ Returns:
+ tensor: output
+ """
+
+ out = self.activation(x)
+ out = self.conv1(out)
+ if self.bn:
+ out = self.bn1(out)
+
+ out = self.activation(out)
+ out = self.conv2(out)
+ if self.bn:
+ out = self.bn2(out)
+
+ if self.groups > 1:
+ out = self.conv_merge(out)
+
+ return self.skip_add.add(out, x)
+
+
+class FeatureFusionBlock(nn.Module):
+ """Feature fusion block."""
+
+ def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, size=None):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super(FeatureFusionBlock, self).__init__()
+
+ self.deconv = deconv
+ self.align_corners = align_corners
+
+ self.groups = 1
+
+ self.expand = expand
+ out_features = features
+ if self.expand:
+ out_features = features // 2
+
+ self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
+
+ self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
+ self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
+
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ self.size = size
+
+ def forward(self, *xs, size=None):
+ """Forward pass.
+
+ Returns:
+ tensor: output
+ """
+ output = xs[0]
+
+ if len(xs) == 2:
+ res = self.resConfUnit1(xs[1])
+ output = self.skip_add.add(output, res)
+
+ output = self.resConfUnit2(output)
+
+ if (size is None) and (self.size is None):
+ modifier = {"scale_factor": 2}
+ elif size is None:
+ modifier = {"size": self.size}
+ else:
+ modifier = {"size": size}
+
+ output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
+
+ output = self.out_conv(output)
+
+ return output
diff --git a/invokeai/backend/image_util/depth_anything/model/dpt.py b/invokeai/backend/image_util/depth_anything/model/dpt.py
new file mode 100644
index 0000000000..e1101b3c39
--- /dev/null
+++ b/invokeai/backend/image_util/depth_anything/model/dpt.py
@@ -0,0 +1,183 @@
+from pathlib import Path
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .blocks import FeatureFusionBlock, _make_scratch
+
+torchhub_path = Path(__file__).parent.parent / "torchhub"
+
+
+def _make_fusion_block(features, use_bn, size=None):
+ return FeatureFusionBlock(
+ features,
+ nn.ReLU(False),
+ deconv=False,
+ bn=use_bn,
+ expand=False,
+ align_corners=True,
+ size=size,
+ )
+
+
+class DPTHead(nn.Module):
+ def __init__(self, nclass, in_channels, features, out_channels, use_bn=False, use_clstoken=False):
+ super(DPTHead, self).__init__()
+
+ self.nclass = nclass
+ self.use_clstoken = use_clstoken
+
+ self.projects = nn.ModuleList(
+ [
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channel,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ )
+ for out_channel in out_channels
+ ]
+ )
+
+ self.resize_layers = nn.ModuleList(
+ [
+ nn.ConvTranspose2d(
+ in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
+ ),
+ nn.ConvTranspose2d(
+ in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
+ ),
+ nn.Identity(),
+ nn.Conv2d(
+ in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
+ ),
+ ]
+ )
+
+ if use_clstoken:
+ self.readout_projects = nn.ModuleList()
+ for _ in range(len(self.projects)):
+ self.readout_projects.append(nn.Sequential(nn.Linear(2 * in_channels, in_channels), nn.GELU()))
+
+ self.scratch = _make_scratch(
+ out_channels,
+ features,
+ groups=1,
+ expand=False,
+ )
+
+ self.scratch.stem_transpose = None
+
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
+
+ head_features_1 = features
+ head_features_2 = 32
+
+ if nclass > 1:
+ self.scratch.output_conv = nn.Sequential(
+ nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(True),
+ nn.Conv2d(head_features_1, nclass, kernel_size=1, stride=1, padding=0),
+ )
+ else:
+ self.scratch.output_conv1 = nn.Conv2d(
+ head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1
+ )
+
+ self.scratch.output_conv2 = nn.Sequential(
+ nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(True),
+ nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
+ nn.ReLU(True),
+ nn.Identity(),
+ )
+
+ def forward(self, out_features, patch_h, patch_w):
+ out = []
+ for i, x in enumerate(out_features):
+ if self.use_clstoken:
+ x, cls_token = x[0], x[1]
+ readout = cls_token.unsqueeze(1).expand_as(x)
+ x = self.readout_projects[i](torch.cat((x, readout), -1))
+ else:
+ x = x[0]
+
+ x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
+
+ x = self.projects[i](x)
+ x = self.resize_layers[i](x)
+
+ out.append(x)
+
+ layer_1, layer_2, layer_3, layer_4 = out
+
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+ path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
+
+ out = self.scratch.output_conv1(path_1)
+ out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
+ out = self.scratch.output_conv2(out)
+
+ return out
+
+
+class DPT_DINOv2(nn.Module):
+ def __init__(
+ self,
+ features,
+ out_channels,
+ encoder="vitl",
+ use_bn=False,
+ use_clstoken=False,
+ ):
+ super(DPT_DINOv2, self).__init__()
+
+ assert encoder in ["vits", "vitb", "vitl"]
+
+ # # in case the Internet connection is not stable, please load the DINOv2 locally
+ # if use_local:
+ # self.pretrained = torch.hub.load(
+ # torchhub_path / "facebookresearch_dinov2_main",
+ # "dinov2_{:}14".format(encoder),
+ # source="local",
+ # pretrained=False,
+ # )
+ # else:
+ # self.pretrained = torch.hub.load(
+ # "facebookresearch/dinov2",
+ # "dinov2_{:}14".format(encoder),
+ # )
+
+ self.pretrained = torch.hub.load(
+ "facebookresearch/dinov2",
+ "dinov2_{:}14".format(encoder),
+ )
+
+ dim = self.pretrained.blocks[0].attn.qkv.in_features
+
+ self.depth_head = DPTHead(1, dim, features, out_channels=out_channels, use_bn=use_bn, use_clstoken=use_clstoken)
+
+ def forward(self, x):
+ h, w = x.shape[-2:]
+
+ features = self.pretrained.get_intermediate_layers(x, 4, return_class_token=True)
+
+ patch_h, patch_w = h // 14, w // 14
+
+ depth = self.depth_head(features, patch_h, patch_w)
+ depth = F.interpolate(depth, size=(h, w), mode="bilinear", align_corners=True)
+ depth = F.relu(depth)
+
+ return depth.squeeze(1)
diff --git a/invokeai/backend/image_util/depth_anything/utilities/util.py b/invokeai/backend/image_util/depth_anything/utilities/util.py
new file mode 100644
index 0000000000..5362ef6c3e
--- /dev/null
+++ b/invokeai/backend/image_util/depth_anything/utilities/util.py
@@ -0,0 +1,227 @@
+import math
+
+import cv2
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+
+def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
+ """Rezise the sample to ensure the given size. Keeps aspect ratio.
+
+ Args:
+ sample (dict): sample
+ size (tuple): image size
+
+ Returns:
+ tuple: new size
+ """
+ shape = list(sample["disparity"].shape)
+
+ if shape[0] >= size[0] and shape[1] >= size[1]:
+ return sample
+
+ scale = [0, 0]
+ scale[0] = size[0] / shape[0]
+ scale[1] = size[1] / shape[1]
+
+ scale = max(scale)
+
+ shape[0] = math.ceil(scale * shape[0])
+ shape[1] = math.ceil(scale * shape[1])
+
+ # resize
+ sample["image"] = cv2.resize(sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method)
+
+ sample["disparity"] = cv2.resize(sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST)
+ sample["mask"] = cv2.resize(
+ sample["mask"].astype(np.float32),
+ tuple(shape[::-1]),
+ interpolation=cv2.INTER_NEAREST,
+ )
+ sample["mask"] = sample["mask"].astype(bool)
+
+ return tuple(shape)
+
+
+class Resize(object):
+ """Resize sample to given size (width, height)."""
+
+ def __init__(
+ self,
+ width,
+ height,
+ resize_target=True,
+ keep_aspect_ratio=False,
+ ensure_multiple_of=1,
+ resize_method="lower_bound",
+ image_interpolation_method=cv2.INTER_AREA,
+ ):
+ """Init.
+
+ Args:
+ width (int): desired output width
+ height (int): desired output height
+ resize_target (bool, optional):
+ True: Resize the full sample (image, mask, target).
+ False: Resize image only.
+ Defaults to True.
+ keep_aspect_ratio (bool, optional):
+ True: Keep the aspect ratio of the input sample.
+ Output sample might not have the given width and height, and
+ resize behaviour depends on the parameter 'resize_method'.
+ Defaults to False.
+ ensure_multiple_of (int, optional):
+ Output width and height is constrained to be multiple of this parameter.
+ Defaults to 1.
+ resize_method (str, optional):
+ "lower_bound": Output will be at least as large as the given size.
+ "upper_bound": Output will be at max as large as the given size. (Output size might be smaller
+ than given size.)
+ "minimal": Scale as least as possible. (Output size might be smaller than given size.)
+ Defaults to "lower_bound".
+ """
+ self.__width = width
+ self.__height = height
+
+ self.__resize_target = resize_target
+ self.__keep_aspect_ratio = keep_aspect_ratio
+ self.__multiple_of = ensure_multiple_of
+ self.__resize_method = resize_method
+ self.__image_interpolation_method = image_interpolation_method
+
+ def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
+ y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ if max_val is not None and y > max_val:
+ y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ if y < min_val:
+ y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ return y
+
+ def get_size(self, width, height):
+ # determine new height and width
+ scale_height = self.__height / height
+ scale_width = self.__width / width
+
+ if self.__keep_aspect_ratio:
+ if self.__resize_method == "lower_bound":
+ # scale such that output size is lower bound
+ if scale_width > scale_height:
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ elif self.__resize_method == "upper_bound":
+ # scale such that output size is upper bound
+ if scale_width < scale_height:
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ elif self.__resize_method == "minimal":
+ # scale as least as possbile
+ if abs(1 - scale_width) < abs(1 - scale_height):
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ else:
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
+
+ if self.__resize_method == "lower_bound":
+ new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height)
+ new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width)
+ elif self.__resize_method == "upper_bound":
+ new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height)
+ new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width)
+ elif self.__resize_method == "minimal":
+ new_height = self.constrain_to_multiple_of(scale_height * height)
+ new_width = self.constrain_to_multiple_of(scale_width * width)
+ else:
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
+
+ return (new_width, new_height)
+
+ def __call__(self, sample):
+ width, height = self.get_size(sample["image"].shape[1], sample["image"].shape[0])
+
+ # resize sample
+ sample["image"] = cv2.resize(
+ sample["image"],
+ (width, height),
+ interpolation=self.__image_interpolation_method,
+ )
+
+ if self.__resize_target:
+ if "disparity" in sample:
+ sample["disparity"] = cv2.resize(
+ sample["disparity"],
+ (width, height),
+ interpolation=cv2.INTER_NEAREST,
+ )
+
+ if "depth" in sample:
+ sample["depth"] = cv2.resize(sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST)
+
+ if "semseg_mask" in sample:
+ # sample["semseg_mask"] = cv2.resize(
+ # sample["semseg_mask"], (width, height), interpolation=cv2.INTER_NEAREST
+ # )
+ sample["semseg_mask"] = F.interpolate(
+ torch.from_numpy(sample["semseg_mask"]).float()[None, None, ...], (height, width), mode="nearest"
+ ).numpy()[0, 0]
+
+ if "mask" in sample:
+ sample["mask"] = cv2.resize(
+ sample["mask"].astype(np.float32),
+ (width, height),
+ interpolation=cv2.INTER_NEAREST,
+ )
+ # sample["mask"] = sample["mask"].astype(bool)
+
+ # print(sample['image'].shape, sample['depth'].shape)
+ return sample
+
+
+class NormalizeImage(object):
+ """Normlize image by given mean and std."""
+
+ def __init__(self, mean, std):
+ self.__mean = mean
+ self.__std = std
+
+ def __call__(self, sample):
+ sample["image"] = (sample["image"] - self.__mean) / self.__std
+
+ return sample
+
+
+class PrepareForNet(object):
+ """Prepare sample for usage as network input."""
+
+ def __init__(self):
+ pass
+
+ def __call__(self, sample):
+ image = np.transpose(sample["image"], (2, 0, 1))
+ sample["image"] = np.ascontiguousarray(image).astype(np.float32)
+
+ if "mask" in sample:
+ sample["mask"] = sample["mask"].astype(np.float32)
+ sample["mask"] = np.ascontiguousarray(sample["mask"])
+
+ if "depth" in sample:
+ depth = sample["depth"].astype(np.float32)
+ sample["depth"] = np.ascontiguousarray(depth)
+
+ if "semseg_mask" in sample:
+ sample["semseg_mask"] = sample["semseg_mask"].astype(np.float32)
+ sample["semseg_mask"] = np.ascontiguousarray(sample["semseg_mask"])
+
+ return sample
diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json
index a4c6591802..6b9dbaf7bc 100644
--- a/invokeai/frontend/web/public/locales/en.json
+++ b/invokeai/frontend/web/public/locales/en.json
@@ -224,6 +224,7 @@
"amult": "a_mult",
"autoConfigure": "Auto configure processor",
"balanced": "Balanced",
+ "base": "Base",
"beginEndStepPercent": "Begin / End Step Percentage",
"bgth": "bg_th",
"canny": "Canny",
@@ -237,6 +238,8 @@
"controlMode": "Control Mode",
"crop": "Crop",
"delete": "Delete",
+ "depthAnything": "Depth Anything",
+ "depthAnythingDescription": "Depth map generation using the Depth Anything technique",
"depthMidas": "Depth (Midas)",
"depthMidasDescription": "Depth map generation using Midas",
"depthZoe": "Depth (Zoe)",
@@ -256,6 +259,7 @@
"colorMapTileSize": "Tile Size",
"importImageFromCanvas": "Import Image From Canvas",
"importMaskFromCanvas": "Import Mask From Canvas",
+ "large": "Large",
"lineart": "Lineart",
"lineartAnime": "Lineart Anime",
"lineartAnimeDescription": "Anime-style lineart processing",
@@ -268,6 +272,7 @@
"minConfidence": "Min Confidence",
"mlsd": "M-LSD",
"mlsdDescription": "Minimalist Line Segment Detector",
+ "modelSize": "Model Size",
"none": "None",
"noneDescription": "No processing applied",
"normalBae": "Normal BAE",
@@ -288,6 +293,7 @@
"selectModel": "Select a model",
"setControlImageDimensions": "Set Control Image Dimensions To W/H",
"showAdvanced": "Show Advanced",
+ "small": "Small",
"toggleControlNet": "Toggle this ControlNet",
"w": "W",
"weight": "Weight",
diff --git a/invokeai/frontend/web/src/features/controlAdapters/components/ControlAdapterProcessorComponent.tsx b/invokeai/frontend/web/src/features/controlAdapters/components/ControlAdapterProcessorComponent.tsx
index 79ef4c0a0a..45c4b7fac9 100644
--- a/invokeai/frontend/web/src/features/controlAdapters/components/ControlAdapterProcessorComponent.tsx
+++ b/invokeai/frontend/web/src/features/controlAdapters/components/ControlAdapterProcessorComponent.tsx
@@ -5,6 +5,7 @@ import { memo } from 'react';
import CannyProcessor from './processors/CannyProcessor';
import ColorMapProcessor from './processors/ColorMapProcessor';
import ContentShuffleProcessor from './processors/ContentShuffleProcessor';
+import DepthAnyThingProcessor from './processors/DepthAnyThingProcessor';
import HedProcessor from './processors/HedProcessor';
import LineartAnimeProcessor from './processors/LineartAnimeProcessor';
import LineartProcessor from './processors/LineartProcessor';
@@ -48,6 +49,16 @@ const ControlAdapterProcessorComponent = ({ id }: Props) => {
);
}
+ if (processorNode.type === 'depth_anything_image_processor') {
+ return (
+
+ );
+ }
+
if (processorNode.type === 'hed_image_processor') {
return (
{
+ const { controlNetId, processorNode, isEnabled } = props;
+ const { model_size, resolution } = processorNode;
+ const processorChanged = useProcessorNodeChanged();
+
+ const { t } = useTranslation();
+
+ const handleModelSizeChange = useCallback(
+ (v) => {
+ if (!isDepthAnythingModelSize(v?.value)) {
+ return;
+ }
+ processorChanged(controlNetId, {
+ model_size: v.value,
+ });
+ },
+ [controlNetId, processorChanged]
+ );
+
+ const options: { label: string; value: DepthAnythingModelSize }[] = useMemo(
+ () => [
+ { label: t('controlnet.small'), value: 'small' },
+ { label: t('controlnet.base'), value: 'base' },
+ { label: t('controlnet.large'), value: 'large' },
+ ],
+ [t]
+ );
+
+ const value = useMemo(
+ () => options.filter((o) => o.value === model_size)[0],
+ [options, model_size]
+ );
+
+ const handleResolutionChange = useCallback(
+ (v: number) => {
+ processorChanged(controlNetId, { resolution: v });
+ },
+ [controlNetId, processorChanged]
+ );
+
+ const handleResolutionDefaultChange = useCallback(() => {
+ processorChanged(controlNetId, { resolution: 512 });
+ }, [controlNetId, processorChanged]);
+
+ return (
+
+
+ {t('controlnet.modelSize')}
+
+
+
+ {t('controlnet.imageResolution')}
+
+
+
+
+ );
+};
+
+export default memo(DepthAnythingProcessor);
diff --git a/invokeai/frontend/web/src/features/controlAdapters/store/constants.ts b/invokeai/frontend/web/src/features/controlAdapters/store/constants.ts
index b7647c9f0d..ddebffcfde 100644
--- a/invokeai/frontend/web/src/features/controlAdapters/store/constants.ts
+++ b/invokeai/frontend/web/src/features/controlAdapters/store/constants.ts
@@ -83,6 +83,22 @@ export const CONTROLNET_PROCESSORS: ControlNetProcessorsDict = {
f: 256,
},
},
+ depth_anything_image_processor: {
+ type: 'depth_anything_image_processor',
+ get label() {
+ return i18n.t('controlnet.depthAnything');
+ },
+ get description() {
+ return i18n.t('controlnet.depthAnythingDescription');
+ },
+ default: {
+ id: 'depth_anything_image_processor',
+ type: 'depth_anything_image_processor',
+ model_size: 'small',
+ resolution: 512,
+ offload: false,
+ },
+ },
hed_image_processor: {
type: 'hed_image_processor',
get label() {
@@ -245,7 +261,7 @@ export const CONTROLNET_MODEL_DEFAULT_PROCESSORS: {
} = {
canny: 'canny_image_processor',
mlsd: 'mlsd_image_processor',
- depth: 'midas_depth_image_processor',
+ depth: 'depth_anything_image_processor',
bae: 'normalbae_image_processor',
sketch: 'pidi_image_processor',
scribble: 'lineart_image_processor',
diff --git a/invokeai/frontend/web/src/features/controlAdapters/store/types.ts b/invokeai/frontend/web/src/features/controlAdapters/store/types.ts
index 8d391a9c08..87366a443d 100644
--- a/invokeai/frontend/web/src/features/controlAdapters/store/types.ts
+++ b/invokeai/frontend/web/src/features/controlAdapters/store/types.ts
@@ -10,6 +10,7 @@ import type {
CannyImageProcessorInvocation,
ColorMapImageProcessorInvocation,
ContentShuffleImageProcessorInvocation,
+ DepthAnythingImageProcessorInvocation,
HedImageProcessorInvocation,
LineartAnimeImageProcessorInvocation,
LineartImageProcessorInvocation,
@@ -31,6 +32,7 @@ export type ControlAdapterProcessorNode =
| CannyImageProcessorInvocation
| ColorMapImageProcessorInvocation
| ContentShuffleImageProcessorInvocation
+ | DepthAnythingImageProcessorInvocation
| HedImageProcessorInvocation
| LineartAnimeImageProcessorInvocation
| LineartImageProcessorInvocation
@@ -73,6 +75,20 @@ export type RequiredContentShuffleImageProcessorInvocation = O.Required<
'type' | 'detect_resolution' | 'image_resolution' | 'w' | 'h' | 'f'
>;
+/**
+ * The DepthAnything processor node, with parameters flagged as required
+ */
+export type RequiredDepthAnythingImageProcessorInvocation = O.Required<
+ DepthAnythingImageProcessorInvocation,
+ 'type' | 'model_size' | 'resolution' | 'offload'
+>;
+
+export const zDepthAnythingModelSize = z.enum(['large', 'base', 'small']);
+export type DepthAnythingModelSize = z.infer;
+export const isDepthAnythingModelSize = (
+ v: unknown
+): v is DepthAnythingModelSize => zDepthAnythingModelSize.safeParse(v).success;
+
/**
* The HED processor node, with parameters flagged as required
*/
@@ -161,6 +177,7 @@ export type RequiredControlAdapterProcessorNode =
| RequiredCannyImageProcessorInvocation
| RequiredColorMapImageProcessorInvocation
| RequiredContentShuffleImageProcessorInvocation
+ | RequiredDepthAnythingImageProcessorInvocation
| RequiredHedImageProcessorInvocation
| RequiredLineartAnimeImageProcessorInvocation
| RequiredLineartImageProcessorInvocation
@@ -219,6 +236,22 @@ export const isContentShuffleImageProcessorInvocation = (
return false;
};
+/**
+ * Type guard for DepthAnythingImageProcessorInvocation
+ */
+export const isDepthAnythingImageProcessorInvocation = (
+ obj: unknown
+): obj is DepthAnythingImageProcessorInvocation => {
+ if (
+ isObject(obj) &&
+ 'type' in obj &&
+ obj.type === 'depth_anything_image_processor'
+ ) {
+ return true;
+ }
+ return false;
+};
+
/**
* Type guard for HedImageprocessorInvocation
*/
diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts
index a7f60d6cdf..a4f358640d 100644
--- a/invokeai/frontend/web/src/services/api/schema.ts
+++ b/invokeai/frontend/web/src/services/api/schema.ts
@@ -2938,7 +2938,7 @@ export type components = {
/**
* Fp32
* @description Whether or not to use full float32 precision
- * @default true
+ * @default false
*/
fp32?: boolean;
/**
@@ -3199,6 +3199,57 @@ export type components = {
*/
type: "denoise_mask_output";
};
+ /**
+ * Depth Anything Processor
+ * @description Generates a depth map based on the Depth Anything algorithm
+ */
+ DepthAnythingImageProcessorInvocation: {
+ /** @description Optional metadata to be saved with the image */
+ metadata?: components["schemas"]["MetadataField"] | null;
+ /**
+ * Id
+ * @description The id of this instance of an invocation. Must be unique among all instances of invocations.
+ */
+ id: string;
+ /**
+ * Is Intermediate
+ * @description Whether or not this is an intermediate invocation.
+ * @default false
+ */
+ is_intermediate?: boolean;
+ /**
+ * Use Cache
+ * @description Whether or not to use the cache
+ * @default true
+ */
+ use_cache?: boolean;
+ /** @description The image to process */
+ image?: components["schemas"]["ImageField"];
+ /**
+ * Model Size
+ * @description The size of the depth model to use
+ * @default small
+ * @enum {string}
+ */
+ model_size?: "large" | "base" | "small";
+ /**
+ * Resolution
+ * @description Pixel resolution for output image
+ * @default 512
+ */
+ resolution?: number;
+ /**
+ * Offload
+ * @default false
+ */
+ offload?: boolean;
+ /**
+ * type
+ * @default depth_anything_image_processor
+ * @constant
+ */
+ type: "depth_anything_image_processor";
+ };
/**
* Divide Integers
* @description Divides two numbers
@@ -4073,7 +4124,7 @@ export type components = {
* @description The nodes in this graph
*/
nodes?: {
- [key: string]: components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["OnnxModelLoaderInvocation"] | components["schemas"]["CalculateImageTilesEvenSplitInvocation"] | components["schemas"]["FloatInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["ONNXPromptInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["StringInvocation"] | components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ONNXTextToLatentsInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["UnsharpMaskInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["ColorMapImageProcessorInvocation"] | components["schemas"]["SDXLLoraLoaderInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["ONNXLatentsToImageInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["LinearUIOutputInvocation"] | components["schemas"]["HedImageProcessorInvocation"];
+ [key: string]: components["schemas"]["LatentConsistencyInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["LinearUIOutputInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["ONNXLatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["CalculateImageTilesEvenSplitInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["StringInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["ONNXPromptInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["FloatInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["SDXLLoraLoaderInvocation"] | components["schemas"]["OnnxModelLoaderInvocation"] | components["schemas"]["UnsharpMaskInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["ColorMapImageProcessorInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["ONNXTextToLatentsInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["DepthAnythingImageProcessorInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["IdealSizeInvocation"] | components["schemas"]["LeresImageProcessorInvocation"];
};
/**
* Edges
@@ -4110,7 +4161,7 @@ export type components = {
* @description The results of node executions
*/
results: {
- [key: string]: components["schemas"]["ColorOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["ONNXModelLoaderOutput"] | components["schemas"]["VAEOutput"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["ClipSkipInvocationOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["String2Output"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["SDXLLoraLoaderOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["DenoiseMaskOutput"];
+ [key: string]: components["schemas"]["ColorCollectionOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["VAEOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["SDXLLoraLoaderOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["ClipSkipInvocationOutput"] | components["schemas"]["ONNXModelLoaderOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["String2Output"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["ImageOutput"];
};
/**
* Errors
@@ -4500,6 +4551,77 @@ export type components = {
*/
type: "ip_adapter_output";
};
+ /**
+ * Ideal Size
+ * @description Calculates the ideal size for generation to avoid duplication
+ */
+ IdealSizeInvocation: {
+ /**
+ * Id
+ * @description The id of this instance of an invocation. Must be unique among all instances of invocations.
+ */
+ id: string;
+ /**
+ * Is Intermediate
+ * @description Whether or not this is an intermediate invocation.
+ * @default false
+ */
+ is_intermediate?: boolean;
+ /**
+ * Use Cache
+ * @description Whether or not to use the cache
+ * @default true
+ */
+ use_cache?: boolean;
+ /**
+ * Width
+ * @description Final image width
+ * @default 1024
+ */
+ width?: number;
+ /**
+ * Height
+ * @description Final image height
+ * @default 576
+ */
+ height?: number;
+ /** @description UNet (scheduler, LoRAs) */
+ unet?: components["schemas"]["UNetField"];
+ /**
+ * Multiplier
+ * @description Amount to multiply the model's dimensions by when calculating the ideal size (may result in initial generation artifacts if too large)
+ * @default 1
+ */
+ multiplier?: number;
+ /**
+ * type
+ * @default ideal_size
+ * @constant
+ */
+ type: "ideal_size";
+ };
+ /**
+ * IdealSizeOutput
+ * @description Base class for invocations that output an image
+ */
+ IdealSizeOutput: {
+ /**
+ * Width
+ * @description The ideal width of the image (in pixels)
+ */
+ width: number;
+ /**
+ * Height
+ * @description The ideal height of the image (in pixels)
+ */
+ height: number;
+ /**
+ * type
+ * @default ideal_size_output
+ * @constant
+ */
+ type: "ideal_size_output";
+ };
/**
* Blur Image
* @description Blurs an image
@@ -5403,7 +5525,7 @@ export type components = {
/**
* Fp32
* @description Whether or not to use full float32 precision
- * @default true
+ * @default false
*/
fp32?: boolean;
/**
@@ -5911,6 +6033,96 @@ export type components = {
*/
type: "infill_lama";
};
+ /**
+ * Latent Consistency MonoNode
+ * @description Wrapper node around diffusers LatentConsistencyTxt2ImgPipeline
+ */
+ LatentConsistencyInvocation: {
+ /**
+ * Id
+ * @description The id of this instance of an invocation. Must be unique among all instances of invocations.
+ */
+ id: string;
+ /**
+ * Is Intermediate
+ * @description Whether or not this is an intermediate invocation.
+ * @default false
+ */
+ is_intermediate?: boolean;
+ /**
+ * Use Cache
+ * @description Whether or not to use the cache
+ * @default true
+ */
+ use_cache?: boolean;
+ /**
+ * Prompt
+ * @description The prompt to use
+ */
+ prompt?: string;
+ /**
+ * Num Inference Steps
+ * @description The number of inference steps to use, 4-8 recommended
+ * @default 8
+ */
+ num_inference_steps?: number;
+ /**
+ * Guidance Scale
+ * @description The guidance scale to use
+ * @default 8
+ */
+ guidance_scale?: number;
+ /**
+ * Batches
+ * @description The number of batches to use
+ * @default 1
+ */
+ batches?: number;
+ /**
+ * Images Per Batch
+ * @description The number of images per batch to use
+ * @default 1
+ */
+ images_per_batch?: number;
+ /**
+ * Seeds
+ * @description List of noise seeds to use
+ */
+ seeds?: number[];
+ /**
+ * Lcm Origin Steps
+ * @description The lcm origin steps to use
+ * @default 50
+ */
+ lcm_origin_steps?: number;
+ /**
+ * Width
+ * @description The width to use
+ * @default 512
+ */
+ width?: number;
+ /**
+ * Height
+ * @description The height to use
+ * @default 512
+ */
+ height?: number;
+ /**
+ * Precision
+ * @description floating point precision
+ * @default fp16
+ * @enum {string}
+ */
+ precision?: "fp16" | "fp32";
+ /** @description The board to save the image to */
+ board?: components["schemas"]["BoardField"];
+ /**
+ * type
+ * @default latent_consistency_mononode
+ * @constant
+ */
+ type: "latent_consistency_mononode";
+ };
/**
* Latents Collection Primitive
* @description A collection of latents tensor primitive values
@@ -6070,7 +6282,7 @@ export type components = {
/**
* Fp32
* @description Whether or not to use full float32 precision
- * @default true
+ * @default false
*/
fp32?: boolean;
/**
@@ -11290,42 +11502,42 @@ export type components = {
* @enum {string}
*/
UIType: "SDXLMainModelField" | "SDXLRefinerModelField" | "ONNXModelField" | "VAEModelField" | "LoRAModelField" | "ControlNetModelField" | "IPAdapterModelField" | "SchedulerField" | "AnyField" | "CollectionField" | "CollectionItemField" | "DEPRECATED_Boolean" | "DEPRECATED_Color" | "DEPRECATED_Conditioning" | "DEPRECATED_Control" | "DEPRECATED_Float" | "DEPRECATED_Image" | "DEPRECATED_Integer" | "DEPRECATED_Latents" | "DEPRECATED_String" | "DEPRECATED_BooleanCollection" | "DEPRECATED_ColorCollection" | "DEPRECATED_ConditioningCollection" | "DEPRECATED_ControlCollection" | "DEPRECATED_FloatCollection" | "DEPRECATED_ImageCollection" | "DEPRECATED_IntegerCollection" | "DEPRECATED_LatentsCollection" | "DEPRECATED_StringCollection" | "DEPRECATED_BooleanPolymorphic" | "DEPRECATED_ColorPolymorphic" | "DEPRECATED_ConditioningPolymorphic" | "DEPRECATED_ControlPolymorphic" | "DEPRECATED_FloatPolymorphic" | "DEPRECATED_ImagePolymorphic" | "DEPRECATED_IntegerPolymorphic" | "DEPRECATED_LatentsPolymorphic" | "DEPRECATED_StringPolymorphic" | "DEPRECATED_MainModel" | "DEPRECATED_UNet" | "DEPRECATED_Vae" | "DEPRECATED_CLIP" | "DEPRECATED_Collection" | "DEPRECATED_CollectionItem" | "DEPRECATED_Enum" | "DEPRECATED_WorkflowField" | "DEPRECATED_IsIntermediate" | "DEPRECATED_BoardField" | "DEPRECATED_MetadataItem" | "DEPRECATED_MetadataItemCollection" | "DEPRECATED_MetadataItemPolymorphic" | "DEPRECATED_MetadataDict";
- /**
- * StableDiffusion1ModelFormat
- * @description An enumeration.
- * @enum {string}
- */
- StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusionXLModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusionXLModelFormat: "checkpoint" | "diffusers";
- /**
- * T2IAdapterModelFormat
- * @description An enumeration.
- * @enum {string}
- */
- T2IAdapterModelFormat: "diffusers";
/**
* CLIPVisionModelFormat
* @description An enumeration.
* @enum {string}
*/
CLIPVisionModelFormat: "diffusers";
- /**
- * IPAdapterModelFormat
- * @description An enumeration.
- * @enum {string}
- */
- IPAdapterModelFormat: "invokeai";
/**
* ControlNetModelFormat
* @description An enumeration.
* @enum {string}
*/
ControlNetModelFormat: "checkpoint" | "diffusers";
+ /**
+ * T2IAdapterModelFormat
+ * @description An enumeration.
+ * @enum {string}
+ */
+ T2IAdapterModelFormat: "diffusers";
+ /**
+ * StableDiffusion1ModelFormat
+ * @description An enumeration.
+ * @enum {string}
+ */
+ StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
+ /**
+ * StableDiffusion2ModelFormat
+ * @description An enumeration.
+ * @enum {string}
+ */
+ StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusionOnnxModelFormat
* @description An enumeration.
@@ -11333,11 +11545,11 @@ export type components = {
*/
StableDiffusionOnnxModelFormat: "olive" | "onnx";
/**
- * StableDiffusion2ModelFormat
+ * IPAdapterModelFormat
* @description An enumeration.
* @enum {string}
*/
- StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
+ IPAdapterModelFormat: "invokeai";
};
responses: never;
parameters: never;
diff --git a/invokeai/frontend/web/src/services/api/types.ts b/invokeai/frontend/web/src/services/api/types.ts
index 03d9153361..871a7f5a2e 100644
--- a/invokeai/frontend/web/src/services/api/types.ts
+++ b/invokeai/frontend/web/src/services/api/types.ts
@@ -164,6 +164,8 @@ export type ColorMapImageProcessorInvocation =
s['ColorMapImageProcessorInvocation'];
export type ContentShuffleImageProcessorInvocation =
s['ContentShuffleImageProcessorInvocation'];
+export type DepthAnythingImageProcessorInvocation =
+ s['DepthAnythingImageProcessorInvocation'];
export type HedImageProcessorInvocation = s['HedImageProcessorInvocation'];
export type LineartAnimeImageProcessorInvocation =
s['LineartAnimeImageProcessorInvocation'];