diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index f16a8e36ae..00c3fa74f6 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -30,6 +30,7 @@ from invokeai.app.invocations.primitives import ImageField, ImageOutput from invokeai.app.invocations.util import validate_begin_end_step, validate_weights from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin from invokeai.app.shared.fields import FieldDescriptions +from invokeai.backend.image_util.depth_anything import DepthAnythingDetector from ...backend.model_management import BaseModelType from .baseinvocation import ( @@ -602,3 +603,33 @@ class ColorMapImageProcessorInvocation(ImageProcessorInvocation): color_map = cv2.resize(color_map, (width, height), interpolation=cv2.INTER_NEAREST) color_map = Image.fromarray(color_map) return color_map + + +DEPTH_ANYTHING_MODEL_SIZES = Literal["large", "base", "small"] + + +@invocation( + "depth_anything_image_processor", + title="Depth Anything Processor", + tags=["controlnet", "depth", "depth anything"], + category="controlnet", + version="1.0.0", +) +class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation): + """Generates a depth map based on the Depth Anything algorithm""" + + model_size: DEPTH_ANYTHING_MODEL_SIZES = InputField( + default="small", description="The size of the depth model to use" + ) + resolution: int = InputField(default=512, ge=64, multiple_of=64, description=FieldDescriptions.image_res) + offload: bool = InputField(default=False) + + def run_processor(self, image): + depth_anything_detector = DepthAnythingDetector() + depth_anything_detector.load_model(model_size=self.model_size) + + if image.mode == "RGBA": + image = image.convert("RGB") + + processed_image = depth_anything_detector(image=image, resolution=self.resolution, offload=self.offload) + return processed_image diff --git a/invokeai/backend/image_util/depth_anything/__init__.py b/invokeai/backend/image_util/depth_anything/__init__.py new file mode 100644 index 0000000000..fcd600b99e --- /dev/null +++ b/invokeai/backend/image_util/depth_anything/__init__.py @@ -0,0 +1,109 @@ +import pathlib +from typing import Literal, Union + +import cv2 +import numpy as np +import torch +import torch.nn.functional as F +from einops import repeat +from PIL import Image +from torchvision.transforms import Compose + +from invokeai.app.services.config.config_default import InvokeAIAppConfig +from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2 +from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize +from invokeai.backend.util.devices import choose_torch_device +from invokeai.backend.util.util import download_with_progress_bar + +config = InvokeAIAppConfig.get_config() + +DEPTH_ANYTHING_MODELS = { + "large": { + "url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true", + "local": "any/annotators/depth_anything/depth_anything_vitl14.pth", + }, + "base": { + "url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitb14.pth?download=true", + "local": "any/annotators/depth_anything/depth_anything_vitb14.pth", + }, + "small": { + "url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vits14.pth?download=true", + "local": "any/annotators/depth_anything/depth_anything_vits14.pth", + }, +} + + +transform = Compose( + [ + Resize( + width=518, + height=518, + resize_target=False, + keep_aspect_ratio=True, + ensure_multiple_of=14, + resize_method="lower_bound", + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + PrepareForNet(), + ] +) + + +class DepthAnythingDetector: + def __init__(self) -> None: + self.model = None + self.model_size: Union[Literal["large", "base", "small"], None] = None + + def load_model(self, model_size=Literal["large", "base", "small"]): + DEPTH_ANYTHING_MODEL_PATH = pathlib.Path(config.models_path / DEPTH_ANYTHING_MODELS[model_size]["local"]) + if not DEPTH_ANYTHING_MODEL_PATH.exists(): + download_with_progress_bar(DEPTH_ANYTHING_MODELS[model_size]["url"], DEPTH_ANYTHING_MODEL_PATH) + + if not self.model or model_size != self.model_size: + del self.model + self.model_size = model_size + + match self.model_size: + case "small": + self.model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384]) + case "base": + self.model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768]) + case "large": + self.model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024]) + case _: + raise TypeError("Not a supported model") + + self.model.load_state_dict(torch.load(DEPTH_ANYTHING_MODEL_PATH.as_posix(), map_location="cpu")) + self.model.eval() + + self.model.to(choose_torch_device()) + return self.model + + def to(self, device): + self.model.to(device) + return self + + def __call__(self, image, resolution=512, offload=False): + image = np.array(image, dtype=np.uint8) + image = image[:, :, ::-1] / 255.0 + + image_height, image_width = image.shape[:2] + image = transform({"image": image})["image"] + image = torch.from_numpy(image).unsqueeze(0).to(choose_torch_device()) + + with torch.no_grad(): + depth = self.model(image) + depth = F.interpolate(depth[None], (image_height, image_width), mode="bilinear", align_corners=False)[0, 0] + depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0 + + depth_map = repeat(depth, "h w -> h w 3").cpu().numpy().astype(np.uint8) + depth_map = Image.fromarray(depth_map) + + new_height = int(image_height * (resolution / image_width)) + depth_map = depth_map.resize((resolution, new_height)) + + if offload: + del self.model + + return depth_map diff --git a/invokeai/backend/image_util/depth_anything/model/blocks.py b/invokeai/backend/image_util/depth_anything/model/blocks.py new file mode 100644 index 0000000000..4534f52237 --- /dev/null +++ b/invokeai/backend/image_util/depth_anything/model/blocks.py @@ -0,0 +1,145 @@ +import torch.nn as nn + + +def _make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + if len(in_shape) >= 4: + out_shape4 = out_shape + + if expand: + out_shape1 = out_shape + out_shape2 = out_shape * 2 + out_shape3 = out_shape * 4 + if len(in_shape) >= 4: + out_shape4 = out_shape * 8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + if len(in_shape) >= 4: + scratch.layer4_rn = nn.Conv2d( + in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + + return scratch + + +class ResidualConvUnit(nn.Module): + """Residual convolution module.""" + + def __init__(self, features, activation, bn): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups = 1 + + self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) + + self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) + + if self.bn: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block.""" + + def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, size=None): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.groups = 1 + + self.expand = expand + out_features = features + if self.expand: + out_features = features // 2 + + self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) + + self.resConfUnit1 = ResidualConvUnit(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + self.size = size + + def forward(self, *xs, size=None): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + + output = self.resConfUnit2(output) + + if (size is None) and (self.size is None): + modifier = {"scale_factor": 2} + elif size is None: + modifier = {"size": self.size} + else: + modifier = {"size": size} + + output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners) + + output = self.out_conv(output) + + return output diff --git a/invokeai/backend/image_util/depth_anything/model/dpt.py b/invokeai/backend/image_util/depth_anything/model/dpt.py new file mode 100644 index 0000000000..e1101b3c39 --- /dev/null +++ b/invokeai/backend/image_util/depth_anything/model/dpt.py @@ -0,0 +1,183 @@ +from pathlib import Path + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .blocks import FeatureFusionBlock, _make_scratch + +torchhub_path = Path(__file__).parent.parent / "torchhub" + + +def _make_fusion_block(features, use_bn, size=None): + return FeatureFusionBlock( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + size=size, + ) + + +class DPTHead(nn.Module): + def __init__(self, nclass, in_channels, features, out_channels, use_bn=False, use_clstoken=False): + super(DPTHead, self).__init__() + + self.nclass = nclass + self.use_clstoken = use_clstoken + + self.projects = nn.ModuleList( + [ + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channel, + kernel_size=1, + stride=1, + padding=0, + ) + for out_channel in out_channels + ] + ) + + self.resize_layers = nn.ModuleList( + [ + nn.ConvTranspose2d( + in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0 + ), + nn.ConvTranspose2d( + in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0 + ), + nn.Identity(), + nn.Conv2d( + in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1 + ), + ] + ) + + if use_clstoken: + self.readout_projects = nn.ModuleList() + for _ in range(len(self.projects)): + self.readout_projects.append(nn.Sequential(nn.Linear(2 * in_channels, in_channels), nn.GELU())) + + self.scratch = _make_scratch( + out_channels, + features, + groups=1, + expand=False, + ) + + self.scratch.stem_transpose = None + + self.scratch.refinenet1 = _make_fusion_block(features, use_bn) + self.scratch.refinenet2 = _make_fusion_block(features, use_bn) + self.scratch.refinenet3 = _make_fusion_block(features, use_bn) + self.scratch.refinenet4 = _make_fusion_block(features, use_bn) + + head_features_1 = features + head_features_2 = 32 + + if nclass > 1: + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(head_features_1, nclass, kernel_size=1, stride=1, padding=0), + ) + else: + self.scratch.output_conv1 = nn.Conv2d( + head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1 + ) + + self.scratch.output_conv2 = nn.Sequential( + nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True), + nn.Identity(), + ) + + def forward(self, out_features, patch_h, patch_w): + out = [] + for i, x in enumerate(out_features): + if self.use_clstoken: + x, cls_token = x[0], x[1] + readout = cls_token.unsqueeze(1).expand_as(x) + x = self.readout_projects[i](torch.cat((x, readout), -1)) + else: + x = x[0] + + x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w)) + + x = self.projects[i](x) + x = self.resize_layers[i](x) + + out.append(x) + + layer_1, layer_2, layer_3, layer_4 = out + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:]) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:]) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:]) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv1(path_1) + out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True) + out = self.scratch.output_conv2(out) + + return out + + +class DPT_DINOv2(nn.Module): + def __init__( + self, + features, + out_channels, + encoder="vitl", + use_bn=False, + use_clstoken=False, + ): + super(DPT_DINOv2, self).__init__() + + assert encoder in ["vits", "vitb", "vitl"] + + # # in case the Internet connection is not stable, please load the DINOv2 locally + # if use_local: + # self.pretrained = torch.hub.load( + # torchhub_path / "facebookresearch_dinov2_main", + # "dinov2_{:}14".format(encoder), + # source="local", + # pretrained=False, + # ) + # else: + # self.pretrained = torch.hub.load( + # "facebookresearch/dinov2", + # "dinov2_{:}14".format(encoder), + # ) + + self.pretrained = torch.hub.load( + "facebookresearch/dinov2", + "dinov2_{:}14".format(encoder), + ) + + dim = self.pretrained.blocks[0].attn.qkv.in_features + + self.depth_head = DPTHead(1, dim, features, out_channels=out_channels, use_bn=use_bn, use_clstoken=use_clstoken) + + def forward(self, x): + h, w = x.shape[-2:] + + features = self.pretrained.get_intermediate_layers(x, 4, return_class_token=True) + + patch_h, patch_w = h // 14, w // 14 + + depth = self.depth_head(features, patch_h, patch_w) + depth = F.interpolate(depth, size=(h, w), mode="bilinear", align_corners=True) + depth = F.relu(depth) + + return depth.squeeze(1) diff --git a/invokeai/backend/image_util/depth_anything/utilities/util.py b/invokeai/backend/image_util/depth_anything/utilities/util.py new file mode 100644 index 0000000000..5362ef6c3e --- /dev/null +++ b/invokeai/backend/image_util/depth_anything/utilities/util.py @@ -0,0 +1,227 @@ +import math + +import cv2 +import numpy as np +import torch +import torch.nn.functional as F + + +def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): + """Rezise the sample to ensure the given size. Keeps aspect ratio. + + Args: + sample (dict): sample + size (tuple): image size + + Returns: + tuple: new size + """ + shape = list(sample["disparity"].shape) + + if shape[0] >= size[0] and shape[1] >= size[1]: + return sample + + scale = [0, 0] + scale[0] = size[0] / shape[0] + scale[1] = size[1] / shape[1] + + scale = max(scale) + + shape[0] = math.ceil(scale * shape[0]) + shape[1] = math.ceil(scale * shape[1]) + + # resize + sample["image"] = cv2.resize(sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method) + + sample["disparity"] = cv2.resize(sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST) + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + tuple(shape[::-1]), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return tuple(shape) + + +class Resize(object): + """Resize sample to given size (width, height).""" + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + image_interpolation_method=cv2.INTER_AREA, + ): + """Init. + + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller + than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + self.__width = width + self.__height = height + + self.__resize_target = resize_target + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + self.__image_interpolation_method = image_interpolation_method + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError(f"resize_method {self.__resize_method} not implemented") + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height) + new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height) + new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError(f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def __call__(self, sample): + width, height = self.get_size(sample["image"].shape[1], sample["image"].shape[0]) + + # resize sample + sample["image"] = cv2.resize( + sample["image"], + (width, height), + interpolation=self.__image_interpolation_method, + ) + + if self.__resize_target: + if "disparity" in sample: + sample["disparity"] = cv2.resize( + sample["disparity"], + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + + if "depth" in sample: + sample["depth"] = cv2.resize(sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST) + + if "semseg_mask" in sample: + # sample["semseg_mask"] = cv2.resize( + # sample["semseg_mask"], (width, height), interpolation=cv2.INTER_NEAREST + # ) + sample["semseg_mask"] = F.interpolate( + torch.from_numpy(sample["semseg_mask"]).float()[None, None, ...], (height, width), mode="nearest" + ).numpy()[0, 0] + + if "mask" in sample: + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + # sample["mask"] = sample["mask"].astype(bool) + + # print(sample['image'].shape, sample['depth'].shape) + return sample + + +class NormalizeImage(object): + """Normlize image by given mean and std.""" + + def __init__(self, mean, std): + self.__mean = mean + self.__std = std + + def __call__(self, sample): + sample["image"] = (sample["image"] - self.__mean) / self.__std + + return sample + + +class PrepareForNet(object): + """Prepare sample for usage as network input.""" + + def __init__(self): + pass + + def __call__(self, sample): + image = np.transpose(sample["image"], (2, 0, 1)) + sample["image"] = np.ascontiguousarray(image).astype(np.float32) + + if "mask" in sample: + sample["mask"] = sample["mask"].astype(np.float32) + sample["mask"] = np.ascontiguousarray(sample["mask"]) + + if "depth" in sample: + depth = sample["depth"].astype(np.float32) + sample["depth"] = np.ascontiguousarray(depth) + + if "semseg_mask" in sample: + sample["semseg_mask"] = sample["semseg_mask"].astype(np.float32) + sample["semseg_mask"] = np.ascontiguousarray(sample["semseg_mask"]) + + return sample diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index a4c6591802..6b9dbaf7bc 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -224,6 +224,7 @@ "amult": "a_mult", "autoConfigure": "Auto configure processor", "balanced": "Balanced", + "base": "Base", "beginEndStepPercent": "Begin / End Step Percentage", "bgth": "bg_th", "canny": "Canny", @@ -237,6 +238,8 @@ "controlMode": "Control Mode", "crop": "Crop", "delete": "Delete", + "depthAnything": "Depth Anything", + "depthAnythingDescription": "Depth map generation using the Depth Anything technique", "depthMidas": "Depth (Midas)", "depthMidasDescription": "Depth map generation using Midas", "depthZoe": "Depth (Zoe)", @@ -256,6 +259,7 @@ "colorMapTileSize": "Tile Size", "importImageFromCanvas": "Import Image From Canvas", "importMaskFromCanvas": "Import Mask From Canvas", + "large": "Large", "lineart": "Lineart", "lineartAnime": "Lineart Anime", "lineartAnimeDescription": "Anime-style lineart processing", @@ -268,6 +272,7 @@ "minConfidence": "Min Confidence", "mlsd": "M-LSD", "mlsdDescription": "Minimalist Line Segment Detector", + "modelSize": "Model Size", "none": "None", "noneDescription": "No processing applied", "normalBae": "Normal BAE", @@ -288,6 +293,7 @@ "selectModel": "Select a model", "setControlImageDimensions": "Set Control Image Dimensions To W/H", "showAdvanced": "Show Advanced", + "small": "Small", "toggleControlNet": "Toggle this ControlNet", "w": "W", "weight": "Weight", diff --git a/invokeai/frontend/web/src/features/controlAdapters/components/ControlAdapterProcessorComponent.tsx b/invokeai/frontend/web/src/features/controlAdapters/components/ControlAdapterProcessorComponent.tsx index 79ef4c0a0a..45c4b7fac9 100644 --- a/invokeai/frontend/web/src/features/controlAdapters/components/ControlAdapterProcessorComponent.tsx +++ b/invokeai/frontend/web/src/features/controlAdapters/components/ControlAdapterProcessorComponent.tsx @@ -5,6 +5,7 @@ import { memo } from 'react'; import CannyProcessor from './processors/CannyProcessor'; import ColorMapProcessor from './processors/ColorMapProcessor'; import ContentShuffleProcessor from './processors/ContentShuffleProcessor'; +import DepthAnyThingProcessor from './processors/DepthAnyThingProcessor'; import HedProcessor from './processors/HedProcessor'; import LineartAnimeProcessor from './processors/LineartAnimeProcessor'; import LineartProcessor from './processors/LineartProcessor'; @@ -48,6 +49,16 @@ const ControlAdapterProcessorComponent = ({ id }: Props) => { ); } + if (processorNode.type === 'depth_anything_image_processor') { + return ( + + ); + } + if (processorNode.type === 'hed_image_processor') { return ( { + const { controlNetId, processorNode, isEnabled } = props; + const { model_size, resolution } = processorNode; + const processorChanged = useProcessorNodeChanged(); + + const { t } = useTranslation(); + + const handleModelSizeChange = useCallback( + (v) => { + if (!isDepthAnythingModelSize(v?.value)) { + return; + } + processorChanged(controlNetId, { + model_size: v.value, + }); + }, + [controlNetId, processorChanged] + ); + + const options: { label: string; value: DepthAnythingModelSize }[] = useMemo( + () => [ + { label: t('controlnet.small'), value: 'small' }, + { label: t('controlnet.base'), value: 'base' }, + { label: t('controlnet.large'), value: 'large' }, + ], + [t] + ); + + const value = useMemo( + () => options.filter((o) => o.value === model_size)[0], + [options, model_size] + ); + + const handleResolutionChange = useCallback( + (v: number) => { + processorChanged(controlNetId, { resolution: v }); + }, + [controlNetId, processorChanged] + ); + + const handleResolutionDefaultChange = useCallback(() => { + processorChanged(controlNetId, { resolution: 512 }); + }, [controlNetId, processorChanged]); + + return ( + + + {t('controlnet.modelSize')} + + + + {t('controlnet.imageResolution')} + + + + + ); +}; + +export default memo(DepthAnythingProcessor); diff --git a/invokeai/frontend/web/src/features/controlAdapters/store/constants.ts b/invokeai/frontend/web/src/features/controlAdapters/store/constants.ts index b7647c9f0d..ddebffcfde 100644 --- a/invokeai/frontend/web/src/features/controlAdapters/store/constants.ts +++ b/invokeai/frontend/web/src/features/controlAdapters/store/constants.ts @@ -83,6 +83,22 @@ export const CONTROLNET_PROCESSORS: ControlNetProcessorsDict = { f: 256, }, }, + depth_anything_image_processor: { + type: 'depth_anything_image_processor', + get label() { + return i18n.t('controlnet.depthAnything'); + }, + get description() { + return i18n.t('controlnet.depthAnythingDescription'); + }, + default: { + id: 'depth_anything_image_processor', + type: 'depth_anything_image_processor', + model_size: 'small', + resolution: 512, + offload: false, + }, + }, hed_image_processor: { type: 'hed_image_processor', get label() { @@ -245,7 +261,7 @@ export const CONTROLNET_MODEL_DEFAULT_PROCESSORS: { } = { canny: 'canny_image_processor', mlsd: 'mlsd_image_processor', - depth: 'midas_depth_image_processor', + depth: 'depth_anything_image_processor', bae: 'normalbae_image_processor', sketch: 'pidi_image_processor', scribble: 'lineart_image_processor', diff --git a/invokeai/frontend/web/src/features/controlAdapters/store/types.ts b/invokeai/frontend/web/src/features/controlAdapters/store/types.ts index 8d391a9c08..87366a443d 100644 --- a/invokeai/frontend/web/src/features/controlAdapters/store/types.ts +++ b/invokeai/frontend/web/src/features/controlAdapters/store/types.ts @@ -10,6 +10,7 @@ import type { CannyImageProcessorInvocation, ColorMapImageProcessorInvocation, ContentShuffleImageProcessorInvocation, + DepthAnythingImageProcessorInvocation, HedImageProcessorInvocation, LineartAnimeImageProcessorInvocation, LineartImageProcessorInvocation, @@ -31,6 +32,7 @@ export type ControlAdapterProcessorNode = | CannyImageProcessorInvocation | ColorMapImageProcessorInvocation | ContentShuffleImageProcessorInvocation + | DepthAnythingImageProcessorInvocation | HedImageProcessorInvocation | LineartAnimeImageProcessorInvocation | LineartImageProcessorInvocation @@ -73,6 +75,20 @@ export type RequiredContentShuffleImageProcessorInvocation = O.Required< 'type' | 'detect_resolution' | 'image_resolution' | 'w' | 'h' | 'f' >; +/** + * The DepthAnything processor node, with parameters flagged as required + */ +export type RequiredDepthAnythingImageProcessorInvocation = O.Required< + DepthAnythingImageProcessorInvocation, + 'type' | 'model_size' | 'resolution' | 'offload' +>; + +export const zDepthAnythingModelSize = z.enum(['large', 'base', 'small']); +export type DepthAnythingModelSize = z.infer; +export const isDepthAnythingModelSize = ( + v: unknown +): v is DepthAnythingModelSize => zDepthAnythingModelSize.safeParse(v).success; + /** * The HED processor node, with parameters flagged as required */ @@ -161,6 +177,7 @@ export type RequiredControlAdapterProcessorNode = | RequiredCannyImageProcessorInvocation | RequiredColorMapImageProcessorInvocation | RequiredContentShuffleImageProcessorInvocation + | RequiredDepthAnythingImageProcessorInvocation | RequiredHedImageProcessorInvocation | RequiredLineartAnimeImageProcessorInvocation | RequiredLineartImageProcessorInvocation @@ -219,6 +236,22 @@ export const isContentShuffleImageProcessorInvocation = ( return false; }; +/** + * Type guard for DepthAnythingImageProcessorInvocation + */ +export const isDepthAnythingImageProcessorInvocation = ( + obj: unknown +): obj is DepthAnythingImageProcessorInvocation => { + if ( + isObject(obj) && + 'type' in obj && + obj.type === 'depth_anything_image_processor' + ) { + return true; + } + return false; +}; + /** * Type guard for HedImageprocessorInvocation */ diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index a7f60d6cdf..a4f358640d 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -2938,7 +2938,7 @@ export type components = { /** * Fp32 * @description Whether or not to use full float32 precision - * @default true + * @default false */ fp32?: boolean; /** @@ -3199,6 +3199,57 @@ export type components = { */ type: "denoise_mask_output"; }; + /** + * Depth Anything Processor + * @description Generates a depth map based on the Depth Anything algorithm + */ + DepthAnythingImageProcessorInvocation: { + /** @description Optional metadata to be saved with the image */ + metadata?: components["schemas"]["MetadataField"] | null; + /** + * Id + * @description The id of this instance of an invocation. Must be unique among all instances of invocations. + */ + id: string; + /** + * Is Intermediate + * @description Whether or not this is an intermediate invocation. + * @default false + */ + is_intermediate?: boolean; + /** + * Use Cache + * @description Whether or not to use the cache + * @default true + */ + use_cache?: boolean; + /** @description The image to process */ + image?: components["schemas"]["ImageField"]; + /** + * Model Size + * @description The size of the depth model to use + * @default small + * @enum {string} + */ + model_size?: "large" | "base" | "small"; + /** + * Resolution + * @description Pixel resolution for output image + * @default 512 + */ + resolution?: number; + /** + * Offload + * @default false + */ + offload?: boolean; + /** + * type + * @default depth_anything_image_processor + * @constant + */ + type: "depth_anything_image_processor"; + }; /** * Divide Integers * @description Divides two numbers @@ -4073,7 +4124,7 @@ export type components = { * @description The nodes in this graph */ nodes?: { - [key: string]: components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["OnnxModelLoaderInvocation"] | components["schemas"]["CalculateImageTilesEvenSplitInvocation"] | components["schemas"]["FloatInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["ONNXPromptInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["StringInvocation"] | components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ONNXTextToLatentsInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["UnsharpMaskInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["ColorMapImageProcessorInvocation"] | components["schemas"]["SDXLLoraLoaderInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["ONNXLatentsToImageInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["LinearUIOutputInvocation"] | components["schemas"]["HedImageProcessorInvocation"]; + [key: string]: components["schemas"]["LatentConsistencyInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["LinearUIOutputInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["ONNXLatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["CalculateImageTilesEvenSplitInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["StringInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["ONNXPromptInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["FloatInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["SDXLLoraLoaderInvocation"] | components["schemas"]["OnnxModelLoaderInvocation"] | components["schemas"]["UnsharpMaskInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["ColorMapImageProcessorInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["ONNXTextToLatentsInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["DepthAnythingImageProcessorInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["IdealSizeInvocation"] | components["schemas"]["LeresImageProcessorInvocation"]; }; /** * Edges @@ -4110,7 +4161,7 @@ export type components = { * @description The results of node executions */ results: { - [key: string]: components["schemas"]["ColorOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["ONNXModelLoaderOutput"] | components["schemas"]["VAEOutput"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["ClipSkipInvocationOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["String2Output"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["SDXLLoraLoaderOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["DenoiseMaskOutput"]; + [key: string]: components["schemas"]["ColorCollectionOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["VAEOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["SDXLLoraLoaderOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["ClipSkipInvocationOutput"] | components["schemas"]["ONNXModelLoaderOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["String2Output"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["ImageOutput"]; }; /** * Errors @@ -4500,6 +4551,77 @@ export type components = { */ type: "ip_adapter_output"; }; + /** + * Ideal Size + * @description Calculates the ideal size for generation to avoid duplication + */ + IdealSizeInvocation: { + /** + * Id + * @description The id of this instance of an invocation. Must be unique among all instances of invocations. + */ + id: string; + /** + * Is Intermediate + * @description Whether or not this is an intermediate invocation. + * @default false + */ + is_intermediate?: boolean; + /** + * Use Cache + * @description Whether or not to use the cache + * @default true + */ + use_cache?: boolean; + /** + * Width + * @description Final image width + * @default 1024 + */ + width?: number; + /** + * Height + * @description Final image height + * @default 576 + */ + height?: number; + /** @description UNet (scheduler, LoRAs) */ + unet?: components["schemas"]["UNetField"]; + /** + * Multiplier + * @description Amount to multiply the model's dimensions by when calculating the ideal size (may result in initial generation artifacts if too large) + * @default 1 + */ + multiplier?: number; + /** + * type + * @default ideal_size + * @constant + */ + type: "ideal_size"; + }; + /** + * IdealSizeOutput + * @description Base class for invocations that output an image + */ + IdealSizeOutput: { + /** + * Width + * @description The ideal width of the image (in pixels) + */ + width: number; + /** + * Height + * @description The ideal height of the image (in pixels) + */ + height: number; + /** + * type + * @default ideal_size_output + * @constant + */ + type: "ideal_size_output"; + }; /** * Blur Image * @description Blurs an image @@ -5403,7 +5525,7 @@ export type components = { /** * Fp32 * @description Whether or not to use full float32 precision - * @default true + * @default false */ fp32?: boolean; /** @@ -5911,6 +6033,96 @@ export type components = { */ type: "infill_lama"; }; + /** + * Latent Consistency MonoNode + * @description Wrapper node around diffusers LatentConsistencyTxt2ImgPipeline + */ + LatentConsistencyInvocation: { + /** + * Id + * @description The id of this instance of an invocation. Must be unique among all instances of invocations. + */ + id: string; + /** + * Is Intermediate + * @description Whether or not this is an intermediate invocation. + * @default false + */ + is_intermediate?: boolean; + /** + * Use Cache + * @description Whether or not to use the cache + * @default true + */ + use_cache?: boolean; + /** + * Prompt + * @description The prompt to use + */ + prompt?: string; + /** + * Num Inference Steps + * @description The number of inference steps to use, 4-8 recommended + * @default 8 + */ + num_inference_steps?: number; + /** + * Guidance Scale + * @description The guidance scale to use + * @default 8 + */ + guidance_scale?: number; + /** + * Batches + * @description The number of batches to use + * @default 1 + */ + batches?: number; + /** + * Images Per Batch + * @description The number of images per batch to use + * @default 1 + */ + images_per_batch?: number; + /** + * Seeds + * @description List of noise seeds to use + */ + seeds?: number[]; + /** + * Lcm Origin Steps + * @description The lcm origin steps to use + * @default 50 + */ + lcm_origin_steps?: number; + /** + * Width + * @description The width to use + * @default 512 + */ + width?: number; + /** + * Height + * @description The height to use + * @default 512 + */ + height?: number; + /** + * Precision + * @description floating point precision + * @default fp16 + * @enum {string} + */ + precision?: "fp16" | "fp32"; + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"]; + /** + * type + * @default latent_consistency_mononode + * @constant + */ + type: "latent_consistency_mononode"; + }; /** * Latents Collection Primitive * @description A collection of latents tensor primitive values @@ -6070,7 +6282,7 @@ export type components = { /** * Fp32 * @description Whether or not to use full float32 precision - * @default true + * @default false */ fp32?: boolean; /** @@ -11290,42 +11502,42 @@ export type components = { * @enum {string} */ UIType: "SDXLMainModelField" | "SDXLRefinerModelField" | "ONNXModelField" | "VAEModelField" | "LoRAModelField" | "ControlNetModelField" | "IPAdapterModelField" | "SchedulerField" | "AnyField" | "CollectionField" | "CollectionItemField" | "DEPRECATED_Boolean" | "DEPRECATED_Color" | "DEPRECATED_Conditioning" | "DEPRECATED_Control" | "DEPRECATED_Float" | "DEPRECATED_Image" | "DEPRECATED_Integer" | "DEPRECATED_Latents" | "DEPRECATED_String" | "DEPRECATED_BooleanCollection" | "DEPRECATED_ColorCollection" | "DEPRECATED_ConditioningCollection" | "DEPRECATED_ControlCollection" | "DEPRECATED_FloatCollection" | "DEPRECATED_ImageCollection" | "DEPRECATED_IntegerCollection" | "DEPRECATED_LatentsCollection" | "DEPRECATED_StringCollection" | "DEPRECATED_BooleanPolymorphic" | "DEPRECATED_ColorPolymorphic" | "DEPRECATED_ConditioningPolymorphic" | "DEPRECATED_ControlPolymorphic" | "DEPRECATED_FloatPolymorphic" | "DEPRECATED_ImagePolymorphic" | "DEPRECATED_IntegerPolymorphic" | "DEPRECATED_LatentsPolymorphic" | "DEPRECATED_StringPolymorphic" | "DEPRECATED_MainModel" | "DEPRECATED_UNet" | "DEPRECATED_Vae" | "DEPRECATED_CLIP" | "DEPRECATED_Collection" | "DEPRECATED_CollectionItem" | "DEPRECATED_Enum" | "DEPRECATED_WorkflowField" | "DEPRECATED_IsIntermediate" | "DEPRECATED_BoardField" | "DEPRECATED_MetadataItem" | "DEPRECATED_MetadataItemCollection" | "DEPRECATED_MetadataItemPolymorphic" | "DEPRECATED_MetadataDict"; - /** - * StableDiffusion1ModelFormat - * @description An enumeration. - * @enum {string} - */ - StableDiffusion1ModelFormat: "checkpoint" | "diffusers"; /** * StableDiffusionXLModelFormat * @description An enumeration. * @enum {string} */ StableDiffusionXLModelFormat: "checkpoint" | "diffusers"; - /** - * T2IAdapterModelFormat - * @description An enumeration. - * @enum {string} - */ - T2IAdapterModelFormat: "diffusers"; /** * CLIPVisionModelFormat * @description An enumeration. * @enum {string} */ CLIPVisionModelFormat: "diffusers"; - /** - * IPAdapterModelFormat - * @description An enumeration. - * @enum {string} - */ - IPAdapterModelFormat: "invokeai"; /** * ControlNetModelFormat * @description An enumeration. * @enum {string} */ ControlNetModelFormat: "checkpoint" | "diffusers"; + /** + * T2IAdapterModelFormat + * @description An enumeration. + * @enum {string} + */ + T2IAdapterModelFormat: "diffusers"; + /** + * StableDiffusion1ModelFormat + * @description An enumeration. + * @enum {string} + */ + StableDiffusion1ModelFormat: "checkpoint" | "diffusers"; + /** + * StableDiffusion2ModelFormat + * @description An enumeration. + * @enum {string} + */ + StableDiffusion2ModelFormat: "checkpoint" | "diffusers"; /** * StableDiffusionOnnxModelFormat * @description An enumeration. @@ -11333,11 +11545,11 @@ export type components = { */ StableDiffusionOnnxModelFormat: "olive" | "onnx"; /** - * StableDiffusion2ModelFormat + * IPAdapterModelFormat * @description An enumeration. * @enum {string} */ - StableDiffusion2ModelFormat: "checkpoint" | "diffusers"; + IPAdapterModelFormat: "invokeai"; }; responses: never; parameters: never; diff --git a/invokeai/frontend/web/src/services/api/types.ts b/invokeai/frontend/web/src/services/api/types.ts index 03d9153361..871a7f5a2e 100644 --- a/invokeai/frontend/web/src/services/api/types.ts +++ b/invokeai/frontend/web/src/services/api/types.ts @@ -164,6 +164,8 @@ export type ColorMapImageProcessorInvocation = s['ColorMapImageProcessorInvocation']; export type ContentShuffleImageProcessorInvocation = s['ContentShuffleImageProcessorInvocation']; +export type DepthAnythingImageProcessorInvocation = + s['DepthAnythingImageProcessorInvocation']; export type HedImageProcessorInvocation = s['HedImageProcessorInvocation']; export type LineartAnimeImageProcessorInvocation = s['LineartAnimeImageProcessorInvocation'];