wip: segment anything node

This commit is contained in:
blessedcoolant 2024-07-26 00:43:26 +05:30
parent e579be0118
commit b20c70c588
29 changed files with 264 additions and 482 deletions

View File

@ -0,0 +1,76 @@
from typing import Dict, cast
import torch
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import ImageField, InputField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.grounding_segment_anything.gsa import GroundingSegmentAnythingDetector
from invokeai.backend.util.devices import TorchDevice
GROUNDING_SEGMENT_ANYTHING_MODELS = {
"groundingdino_swint_ogc": "https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth",
"segment_anything_vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
}
@invocation(
"segment_anything",
title="Segment Anything",
tags=["grounding_dino", "segment", "anything"],
category="image",
version="1.0.0",
)
class SegmentAnythingInvocation(BaseInvocation):
"""Automatically generate masks from an image using GroundingDINO & Segment Anything"""
image: ImageField = InputField(description="The image to process")
prompt: str = InputField(default="", description="Keywords to segment", title="Prompt")
box_threshold: float = InputField(
default=0.5, ge=0, le=1, description="Threshold of box detection", title="Box Threshold"
)
text_threshold: float = InputField(
default=0.5, ge=0, le=1, description="Threshold of text detection", title="Text Threshold"
)
nms_threshold: float = InputField(
default=0.8, ge=0, le=1, description="Threshold of nms detection", title="NMS Threshold"
)
def invoke(self, context: InvocationContext) -> ImageOutput:
input_image = context.images.get_pil(self.image.image_name)
grounding_dino_model = context.models.load_remote_model(
GROUNDING_SEGMENT_ANYTHING_MODELS["groundingdino_swint_ogc"]
)
segment_anything_model = context.models.load_remote_model(
GROUNDING_SEGMENT_ANYTHING_MODELS["segment_anything_vit_h"]
)
with (
grounding_dino_model.model_on_device() as (_, grounding_dino_state_dict),
segment_anything_model.model_on_device() as (_, segment_anything_state_dict),
):
if not grounding_dino_state_dict or not segment_anything_state_dict:
raise RuntimeError("Unable to load segmentation models")
grounding_dino = GroundingSegmentAnythingDetector.build_grounding_dino(
cast(Dict[str, torch.Tensor], grounding_dino_state_dict)
)
segment_anything = GroundingSegmentAnythingDetector.build_segment_anything(
cast(Dict[str, torch.Tensor], segment_anything_state_dict), TorchDevice.choose_torch_device()
)
detector = GroundingSegmentAnythingDetector(grounding_dino, segment_anything)
mask = detector.predict(
input_image, self.prompt, self.box_threshold, self.text_threshold, self.nms_threshold
)
image_dto = context.images.save(mask)
"""Builds an ImageOutput and its ImageField"""
processed_image_field = ImageField(image_name=image_dto.image_name)
return ImageOutput(
image=processed_image_field,
width=input_image.width,
height=input_image.height,
)

View File

@ -10,8 +10,8 @@ import torch
import torchvision.transforms as T
import torchvision.transforms.functional as F
from groundingdino.util.box_ops import box_xyxy_to_cxcywh
from groundingdino.util.misc import interpolate
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.box_ops import box_xyxy_to_cxcywh
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.misc import interpolate
def crop(image, target, region):
@ -58,9 +58,7 @@ def crop(image, target, region):
if os.environ.get("IPDB_SHILONG_DEBUG", None) == "INFO":
# for debug and visualization only.
if "strings_positive" in target:
target["strings_positive"] = [
_i for _i, _j in zip(target["strings_positive"], keep) if _j
]
target["strings_positive"] = [_i for _i, _j in zip(target["strings_positive"], keep, strict=False) if _j]
return cropped_image, target
@ -73,9 +71,7 @@ def hflip(image, target):
target = target.copy()
if "boxes" in target:
boxes = target["boxes"]
boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor(
[w, 0, w, 0]
)
boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0])
target["boxes"] = boxes
if "masks" in target:
@ -119,15 +115,13 @@ def resize(image, target, size, max_size=None):
if target is None:
return rescaled_image, None
ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size))
ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size, strict=False))
ratio_width, ratio_height = ratios
target = target.copy()
if "boxes" in target:
boxes = target["boxes"]
scaled_boxes = boxes * torch.as_tensor(
[ratio_width, ratio_height, ratio_width, ratio_height]
)
scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height])
target["boxes"] = scaled_boxes
if "area" in target:
@ -139,9 +133,7 @@ def resize(image, target, size, max_size=None):
target["size"] = torch.tensor([h, w])
if "masks" in target:
target["masks"] = (
interpolate(target["masks"][:, None].float(), size, mode="nearest")[:, 0] > 0.5
)
target["masks"] = interpolate(target["masks"][:, None].float(), size, mode="nearest")[:, 0] > 0.5
return rescaled_image, target
@ -192,11 +184,7 @@ class RandomSizeCrop(object):
h = random.randint(self.min_size, min(img.height, self.max_size))
region = T.RandomCrop.get_params(img, [h, w])
result_img, result_target = crop(img, target, region)
if (
not self.respect_boxes
or len(result_target["boxes"]) == init_boxes
or i == max_patience - 1
):
if not self.respect_boxes or len(result_target["boxes"]) == init_boxes or i == max_patience - 1:
return result_img, result_target
return result_img, result_target

View File

@ -12,4 +12,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# ------------------------------------------------------------------------
from .groundingdino import build_groundingdino
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.models.GroundingDINO.groundingdino import (
build_groundingdino,
)

View File

@ -24,10 +24,13 @@ import torchvision
from torch import nn
from torchvision.models._utils import IntermediateLayerGetter
from groundingdino.util.misc import NestedTensor, clean_state_dict, is_main_process
from .position_encoding import build_position_encoding
from .swin_transformer import build_swin_transformer
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.models.GroundingDINO.backbone.position_encoding import (
build_position_encoding,
)
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.models.GroundingDINO.backbone.swin_transformer import (
build_swin_transformer,
)
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.misc import NestedTensor, is_main_process
class FrozenBatchNorm2d(torch.nn.Module):
@ -80,19 +83,12 @@ class BackboneBase(nn.Module):
):
super().__init__()
for name, parameter in backbone.named_parameters():
if (
not train_backbone
or "layer2" not in name
and "layer3" not in name
and "layer4" not in name
):
if not train_backbone or "layer2" not in name and "layer3" not in name and "layer4" not in name:
parameter.requires_grad_(False)
return_layers = {}
for idx, layer_index in enumerate(return_interm_indices):
return_layers.update(
{"layer{}".format(5 - len(return_interm_indices) + idx): "{}".format(layer_index)}
)
return_layers.update({"layer{}".format(5 - len(return_interm_indices) + idx): "{}".format(layer_index)})
# if len:
# if use_stage1_feature:
@ -214,8 +210,8 @@ def build_backbone(args):
model = Joiner(backbone, position_embedding)
model.num_channels = bb_num_channels
assert isinstance(
bb_num_channels, List
), "bb_num_channels is expected to be a List but {}".format(type(bb_num_channels))
assert isinstance(bb_num_channels, List), "bb_num_channels is expected to be a List but {}".format(
type(bb_num_channels)
)
# import ipdb; ipdb.set_trace()
return model

View File

@ -24,7 +24,7 @@ import math
import torch
from torch import nn
from groundingdino.util.misc import NestedTensor
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.misc import NestedTensor
class PositionEmbeddingSine(nn.Module):
@ -65,12 +65,8 @@ class PositionEmbeddingSine(nn.Module):
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack(
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
).flatten(3)
pos_y = torch.stack(
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
).flatten(3)
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
return pos
@ -81,9 +77,7 @@ class PositionEmbeddingSineHW(nn.Module):
used by the Attention is all you need paper, generalized to work on images.
"""
def __init__(
self, num_pos_feats=64, temperatureH=10000, temperatureW=10000, normalize=False, scale=None
):
def __init__(self, num_pos_feats=64, temperatureH=10000, temperatureW=10000, normalize=False, scale=None):
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperatureH = temperatureH
@ -111,19 +105,15 @@ class PositionEmbeddingSineHW(nn.Module):
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_tx = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_tx = self.temperatureW ** (2 * (torch.div(dim_tx, 2, rounding_mode='floor')) / self.num_pos_feats)
dim_tx = self.temperatureW ** (2 * (torch.div(dim_tx, 2, rounding_mode="floor")) / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_tx
dim_ty = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_ty = self.temperatureH ** (2 * (torch.div(dim_ty, 2, rounding_mode='floor')) / self.num_pos_feats)
dim_ty = self.temperatureH ** (2 * (torch.div(dim_ty, 2, rounding_mode="floor")) / self.num_pos_feats)
pos_y = y_embed[:, :, :, None] / dim_ty
pos_x = torch.stack(
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
).flatten(3)
pos_y = torch.stack(
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
).flatten(3)
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
# import ipdb; ipdb.set_trace()

View File

@ -18,15 +18,13 @@ import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from groundingdino.util.misc import NestedTensor
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.misc import NestedTensor
class Mlp(nn.Module):
"""Multilayer perceptron."""
def __init__(
self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0
):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
@ -138,24 +136,16 @@ class WindowAttention(nn.Module):
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv = (
self.qkv(x)
.reshape(B_, N, 3, self.num_heads, C // self.num_heads)
.permute(2, 0, 3, 1, 4)
)
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = q @ k.transpose(-2, -1)
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)
].view(
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(
2, 0, 1
).contiguous() # nH, Wh*Ww, Wh*Ww
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
@ -228,9 +218,7 @@ class SwinTransformerBlock(nn.Module):
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop
)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.H = None
self.W = None
@ -266,12 +254,8 @@ class SwinTransformerBlock(nn.Module):
attn_mask = None
# partition windows
x_windows = window_partition(
shifted_x, self.window_size
) # nW*B, window_size, window_size, C
x_windows = x_windows.view(
-1, self.window_size * self.window_size, C
) # nW*B, window_size*window_size, C
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
@ -433,14 +417,10 @@ class BasicLayer(nn.Module):
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(
img_mask, self.window_size
) # nW, window_size, window_size, 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
attn_mask == 0, float(0.0)
)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
for blk in self.blocks:
blk.H, blk.W = H, W
@ -589,9 +569,7 @@ class SwinTransformer(nn.Module):
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
] # stochastic depth decay rule
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
# build layers
self.layers = nn.ModuleList()
@ -682,9 +660,7 @@ class SwinTransformer(nn.Module):
Wh, Ww = x.size(2), x.size(3)
if self.ape:
# interpolate the position embedding to the corresponding size
absolute_pos_embed = F.interpolate(
self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic"
)
absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic")
x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
else:
x = x.flatten(2).transpose(1, 2)
@ -718,9 +694,7 @@ class SwinTransformer(nn.Module):
Wh, Ww = x.size(2), x.size(3)
if self.ape:
# interpolate the position embedding to the corresponding size
absolute_pos_embed = F.interpolate(
self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic"
)
absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic")
x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
else:
x = x.flatten(2).transpose(1, 2)
@ -769,21 +743,11 @@ def build_swin_transformer(modelname, pretrain_img_size, **kw):
]
model_para_dict = {
"swin_T_224_1k": dict(
embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7
),
"swin_B_224_22k": dict(
embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=7
),
"swin_B_384_22k": dict(
embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=12
),
"swin_L_224_22k": dict(
embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=7
),
"swin_L_384_22k": dict(
embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=12
),
"swin_T_224_1k": dict(embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7),
"swin_B_224_22k": dict(embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=7),
"swin_B_384_22k": dict(embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=12),
"swin_L_224_22k": dict(embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=7),
"swin_L_384_22k": dict(embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=12),
}
kw_cgf = model_para_dict[modelname]
kw_cgf.update(kw)

View File

@ -21,8 +21,12 @@ import torch
import torch.nn.functional as F
from torch import nn
from groundingdino.util import get_tokenlizer
from groundingdino.util.misc import NestedTensor, inverse_sigmoid, nested_tensor_from_tensor_list
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util import get_tokenlizer
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.misc import (
NestedTensor,
inverse_sigmoid,
nested_tensor_from_tensor_list,
)
from ..registry import MODULE_BUILD_FUNCS
from .backbone import build_backbone

View File

@ -22,14 +22,19 @@ import torch
import torch.utils.checkpoint as checkpoint
from torch import Tensor, nn
from groundingdino.util.misc import inverse_sigmoid
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.misc import inverse_sigmoid
from .fuse_modules import BiAttentionBlock
from .ms_deform_attn import MultiScaleDeformableAttention as MSDeformAttn
from .transformer_vanilla import TransformerEncoderLayer
from .utils import (MLP, _get_activation_fn, _get_clones,
gen_encoder_output_proposals, gen_sineembed_for_position,
get_sine_pos_embed)
from .utils import (
MLP,
_get_activation_fn,
_get_clones,
gen_encoder_output_proposals,
gen_sineembed_for_position,
get_sine_pos_embed,
)
class Transformer(nn.Module):

View File

@ -1,4 +1,4 @@
from typing import List, Tuple
from typing import Dict, List, Tuple
import cv2
import numpy as np
@ -7,11 +7,11 @@ import torch
from PIL import Image
from torchvision.ops import box_convert
import groundingdino.datasets.transforms as T
from groundingdino.models import build_model
from groundingdino.util.misc import clean_state_dict
from groundingdino.util.slconfig import SLConfig
from groundingdino.util.utils import get_phrases_from_posmap
import invokeai.backend.image_util.grounding_segment_anything.groundingdino.datasets.transforms as T
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.models import build_model
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.misc import clean_state_dict
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.slconfig import SLConfig
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.utils import get_phrases_from_posmap
# ----------------------------------------------------------------------------------------------------------------------
# OLD API
@ -25,12 +25,11 @@ def preprocess_caption(caption: str) -> str:
return result + "."
def load_model(model_config_path: str, model_checkpoint_path: str, device: str = "cuda"):
def load_model(model_config_path: str, model_state_dict: Dict[str, torch.Tensor], device: str = "cuda"):
args = SLConfig.fromfile(model_config_path)
args.device = device
model = build_model(args)
checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
model.load_state_dict(clean_state_dict(model_state_dict["model"]), strict=False)
model.eval()
return model
@ -98,9 +97,9 @@ def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor
class Model:
def __init__(self, model_config_path: str, model_checkpoint_path: str, device: str = "cuda"):
def __init__(self, model_config_path: str, model_state_dict: Dict[str, torch.Tensor], device: str = "cuda"):
self.model = load_model(
model_config_path=model_config_path, model_checkpoint_path=model_checkpoint_path, device=device
model_config_path=model_config_path, model_state_dict=model_state_dict, device=device
).to(device)
self.device = device

View File

@ -1,91 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import functools
import logging
import os
import sys
from termcolor import colored
class _ColorfulFormatter(logging.Formatter):
def __init__(self, *args, **kwargs):
self._root_name = kwargs.pop("root_name") + "."
self._abbrev_name = kwargs.pop("abbrev_name", "")
if len(self._abbrev_name):
self._abbrev_name = self._abbrev_name + "."
super(_ColorfulFormatter, self).__init__(*args, **kwargs)
def formatMessage(self, record):
record.name = record.name.replace(self._root_name, self._abbrev_name)
log = super(_ColorfulFormatter, self).formatMessage(record)
if record.levelno == logging.WARNING:
prefix = colored("WARNING", "red", attrs=["blink"])
elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
prefix = colored("ERROR", "red", attrs=["blink", "underline"])
else:
return log
return prefix + " " + log
# so that calling setup_logger multiple times won't add many handlers
@functools.lru_cache()
def setup_logger(output=None, distributed_rank=0, *, color=True, name="imagenet", abbrev_name=None):
"""
Initialize the detectron2 logger and set its verbosity level to "INFO".
Args:
output (str): a file name or a directory to save log. If None, will not save log file.
If ends with ".txt" or ".log", assumed to be a file name.
Otherwise, logs will be saved to `output/log.txt`.
name (str): the root module name of this logger
Returns:
logging.Logger: a logger
"""
logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG)
logger.propagate = False
if abbrev_name is None:
abbrev_name = name
plain_formatter = logging.Formatter("[%(asctime)s.%(msecs)03d]: %(message)s", datefmt="%m/%d %H:%M:%S")
# stdout logging: master only
if distributed_rank == 0:
ch = logging.StreamHandler(stream=sys.stdout)
ch.setLevel(logging.DEBUG)
if color:
formatter = _ColorfulFormatter(
colored("[%(asctime)s.%(msecs)03d]: ", "green") + "%(message)s",
datefmt="%m/%d %H:%M:%S",
root_name=name,
abbrev_name=str(abbrev_name),
)
else:
formatter = plain_formatter
ch.setFormatter(formatter)
logger.addHandler(ch)
# file logging: all workers
if output is not None:
if output.endswith(".txt") or output.endswith(".log"):
filename = output
else:
filename = os.path.join(output, "log.txt")
if distributed_rank > 0:
filename = filename + f".rank{distributed_rank}"
os.makedirs(os.path.dirname(filename), exist_ok=True)
fh = logging.StreamHandler(_cached_log_stream(filename))
fh.setLevel(logging.DEBUG)
fh.setFormatter(plain_formatter)
logger.addHandler(fh)
return logger
# cache the opened file object, so that different calls to `setup_logger`
# with the same file name can safely write to the same file.
@functools.lru_cache(maxsize=None)
def _cached_log_stream(filename):
return open(filename, "a")

View File

@ -161,7 +161,7 @@ def all_gather_cpu(data):
dist.all_gather(tensor_list, tensor, group=cpu_group)
data_list = []
for size, tensor in zip(size_list, tensor_list):
for size, tensor in zip(size_list, tensor_list, strict=False):
tensor = torch.split(tensor, [size, max_size - size], dim=0)[0]
buffer = io.BytesIO(tensor.cpu().numpy())
obj = torch.load(buffer)
@ -210,7 +210,7 @@ def all_gather(data):
dist.all_gather(tensor_list, tensor)
data_list = []
for size, tensor in zip(size_list, tensor_list):
for size, tensor in zip(size_list, tensor_list, strict=False):
buffer = tensor.cpu().numpy().tobytes()[:size]
data_list.append(pickle.loads(buffer))
@ -240,7 +240,7 @@ def reduce_dict(input_dict, average=True):
dist.all_reduce(values)
if average:
values /= world_size
reduced_dict = {k: v for k, v in zip(names, values)}
reduced_dict = {k: v for k, v in zip(names, values, strict=False)}
return reduced_dict
@ -378,7 +378,7 @@ def get_sha():
def collate_fn(batch):
# import ipdb; ipdb.set_trace()
batch = list(zip(*batch))
batch = list(zip(*batch, strict=False))
batch[0] = nested_tensor_from_tensor_list(batch[0])
return tuple(batch)
@ -480,7 +480,7 @@ def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
device = tensor_list[0].device
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
for img, pad_img, m in zip(tensor_list, tensor, mask):
for img, pad_img, m in zip(tensor_list, tensor, mask, strict=False):
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
m[: img.shape[1], : img.shape[2]] = False
else:
@ -505,7 +505,7 @@ def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTen
padded_imgs = []
padded_masks = []
for img in tensor_list:
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape), strict=False)]
padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
padded_imgs.append(padded_img)

View File

@ -9,7 +9,7 @@ import numpy as np
import torch
from transformers import AutoTokenizer
from groundingdino.util.slconfig import SLConfig
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.slconfig import SLConfig
def slprint(x, name="x"):

View File

@ -1,44 +1,38 @@
from typing import Dict, List, Literal, Optional
import pathlib
from typing import Dict, List, Optional
import cv2
import numpy as np
import supervision as sv
import torch
import torchvision
from PIL import Image
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.inference import Model
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.build_sam import sam_model_registry
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.predictor import SamPredictor
GROUNDING_SEGMENT_ANYTHING_MODELS = {
"groundingdino_swint_ogc": "https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth",
"segment_anything_vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
}
class GroundingSegmentAnythingDetector:
def __init__(self) -> None:
self.grounding_dino_model: Optional[Model] = None
self.segment_anything_model: Optional[SamPredictor] = None
self.grounding_dino_config: str = "./groundingdino/config/GroundingDINO_SwinT_OGC.py"
self.sam_encoder: Literal["vit_h"] = "vit_h"
self.device: torch.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
def __init__(self, grounding_dino_model: Model, segment_anything_model: SamPredictor) -> None:
self.grounding_dino_model: Optional[Model] = grounding_dino_model
self.segment_anything_model: Optional[SamPredictor] = segment_anything_model
def build_grounding_dino(self):
@staticmethod
def build_grounding_dino(grounding_dino_state_dict: Dict[str, torch.Tensor]):
grounding_dino_config = pathlib.Path(
"./invokeai/backend/image_util/grounding_segment_anything/groundingdino/config/GroundingDINO_SwinT_OGC.py"
)
return Model(
model_config_path=self.grounding_dino_config,
model_checkpoint_path="./checkpoints/groundingdino_swint_ogc.pth",
model_state_dict=grounding_dino_state_dict,
model_config_path=grounding_dino_config.as_posix(),
)
def build_segment_anything(self):
sam = sam_model_registry[self.sam_encoder](checkpoint="./checkpoints/sam_vit_h_4b8939.pth")
sam.to(device=self.device)
@staticmethod
def build_segment_anything(segment_anything_state_dict: Dict[str, torch.Tensor], device: torch.device):
sam = sam_model_registry["vit_h"](checkpoint=segment_anything_state_dict)
sam.to(device=device)
return SamPredictor(sam)
def build_grounding_sam(self):
self.grounding_dino_model = self.build_grounding_dino()
self.segment_anything_model = self.build_segment_anything()
def detect_objects(
self,
image: np.ndarray,
@ -77,20 +71,18 @@ class GroundingSegmentAnythingDetector:
def predict(
self,
image: str,
image: Image.Image,
prompt: str,
box_threshold: float = 0.5,
text_threshold: float = 0.5,
nms_threshold: float = 0.8,
):
if not self.grounding_dino_model or not self.segment_anything_model:
self.build_grounding_sam()
image = cv2.imread(image)
open_cv_image = np.array(image)
open_cv_image = open_cv_image[:, :, ::-1].copy()
prompts = prompt.split(",")
detections = self.detect_objects(image, prompts, box_threshold, text_threshold, nms_threshold)
segments = self.segment_detections(image, detections, prompts)
detections = self.detect_objects(open_cv_image, prompts, box_threshold, text_threshold, nms_threshold)
segments = self.segment_detections(open_cv_image, detections, prompts)
if len(segments) > 0:
combined_mask = np.zeros_like(list(segments.values())[0])
@ -98,15 +90,6 @@ class GroundingSegmentAnythingDetector:
combined_mask = np.logical_or(combined_mask, mask)
mask_preview = (combined_mask * 255).astype(np.uint8)
else:
mask_preview = np.zeros(image.shape, np.uint8)
mask_preview = np.zeros(open_cv_image.shape, np.uint8)
cv2.imwrite("mask.png", mask_preview)
if __name__ == "__main__":
gsa = GroundingSegmentAnythingDetector()
image = "./assets/image.webp"
while True:
prompt = input("Segment: ")
gsa.predict(image, prompt, 0.5, 0.5, 0.8)
return Image.fromarray(mask_preview)

View File

@ -4,13 +4,22 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from .automatic_mask_generator import SamAutomaticMaskGenerator
from .build_sam import build_sam, build_sam_vit_b, build_sam_vit_h, build_sam_vit_l, sam_model_registry
from .build_sam_hq import (
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.automatic_mask_generator import (
SamAutomaticMaskGenerator,
) # noqa
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.build_sam import ( # noqa
build_sam,
build_sam_vit_b,
build_sam_vit_h,
build_sam_vit_l,
sam_model_registry,
)
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.build_sam_hq import ( # noqa
build_sam_hq,
build_sam_hq_vit_b,
build_sam_hq_vit_h,
build_sam_hq_vit_l,
sam_hq_model_registry,
)
from .predictor import SamPredictor
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.predictor import SamPredictor # noqa

View File

@ -10,9 +10,9 @@ import numpy as np
import torch
from torchvision.ops.boxes import batched_nms, box_area # type: ignore
from .modeling import Sam
from .predictor import SamPredictor
from .utils.amg import (
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling import Sam
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.predictor import SamPredictor
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.utils.amg import (
MaskData,
area_from_rle,
batch_iterator,

View File

@ -8,7 +8,13 @@ from functools import partial
import torch
from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling import (
ImageEncoderViT,
MaskDecoder,
PromptEncoder,
Sam,
TwoWayTransformer,
)
def build_sam_vit_h(checkpoint=None):
@ -101,7 +107,5 @@ def _build_sam(
)
sam.eval()
if checkpoint is not None:
with open(checkpoint, "rb") as f:
state_dict = torch.load(f)
sam.load_state_dict(state_dict)
sam.load_state_dict(checkpoint)
return sam

View File

@ -8,7 +8,13 @@ from functools import partial
import torch
from .modeling import ImageEncoderViT, MaskDecoderHQ, PromptEncoder, Sam, TwoWayTransformer
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling import (
ImageEncoderViT,
MaskDecoderHQ,
PromptEncoder,
Sam,
TwoWayTransformer,
)
def build_sam_hq_vit_h(checkpoint=None):

View File

@ -4,9 +4,17 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from .image_encoder import ImageEncoderViT
from .mask_decoder import MaskDecoder
from .mask_decoder_hq import MaskDecoderHQ
from .prompt_encoder import PromptEncoder
from .sam import Sam
from .transformer import TwoWayTransformer
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.image_encoder import (
ImageEncoderViT,
)
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.mask_decoder import MaskDecoder
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.mask_decoder_hq import (
MaskDecoderHQ,
)
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.prompt_encoder import (
PromptEncoder,
)
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.sam import Sam
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.transformer import (
TwoWayTransformer,
)

View File

@ -10,7 +10,10 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from .common import LayerNorm2d, MLPBlock
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.common import (
LayerNorm2d,
MLPBlock,
)
# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa

View File

@ -10,7 +10,7 @@ import torch
from torch import nn
from torch.nn import functional as F
from .common import LayerNorm2d
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.common import LayerNorm2d
class MaskDecoder(nn.Module):

View File

@ -11,7 +11,7 @@ import torch
from torch import nn
from torch.nn import functional as F
from .common import LayerNorm2d
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.common import LayerNorm2d
class MaskDecoderHQ(nn.Module):

View File

@ -10,7 +10,7 @@ import numpy as np
import torch
from torch import nn
from .common import LayerNorm2d
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.common import LayerNorm2d
class PromptEncoder(nn.Module):

View File

@ -10,9 +10,13 @@ import torch
from torch import nn
from torch.nn import functional as F
from .image_encoder import ImageEncoderViT
from .mask_decoder import MaskDecoder
from .prompt_encoder import PromptEncoder
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.image_encoder import (
ImageEncoderViT,
)
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.mask_decoder import MaskDecoder
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.prompt_encoder import (
PromptEncoder,
)
class Sam(nn.Module):
@ -98,7 +102,7 @@ class Sam(nn.Module):
image_embeddings = self.image_encoder(input_images)
outputs = []
for image_record, curr_embedding in zip(batched_input, image_embeddings):
for image_record, curr_embedding in zip(batched_input, image_embeddings, strict=False):
if "point_coords" in image_record:
points = (image_record["point_coords"], image_record["point_labels"])
else:

View File

@ -10,7 +10,7 @@ from typing import Tuple, Type
import torch
from torch import Tensor, nn
from .common import MLPBlock
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.common import MLPBlock
class TwoWayTransformer(nn.Module):

View File

@ -9,8 +9,8 @@ from typing import Optional, Tuple
import numpy as np
import torch
from .modeling import Sam
from .utils.transforms import ResizeLongestSide
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling import Sam
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.utils.transforms import ResizeLongestSide
class SamPredictor:

View File

@ -4,14 +4,14 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch
import math
from copy import deepcopy
from itertools import product
from typing import Any, Dict, Generator, ItemsView, List, Tuple
import numpy as np
import torch
class MaskData:
"""
@ -153,9 +153,7 @@ def area_from_rle(rle: Dict[str, Any]) -> int:
return sum(rle["counts"][1::2])
def calculate_stability_score(
masks: torch.Tensor, mask_threshold: float, threshold_offset: float
) -> torch.Tensor:
def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, threshold_offset: float) -> torch.Tensor:
"""
Computes the stability score for a batch of masks. The stability
score is the IoU between the binary masks obtained by thresholding
@ -163,16 +161,8 @@ def calculate_stability_score(
"""
# One mask is always contained inside the other.
# Save memory by preventing unnecesary cast to torch.int64
intersections = (
(masks > (mask_threshold + threshold_offset))
.sum(-1, dtype=torch.int16)
.sum(-1, dtype=torch.int32)
)
unions = (
(masks > (mask_threshold - threshold_offset))
.sum(-1, dtype=torch.int16)
.sum(-1, dtype=torch.int32)
)
intersections = (masks > (mask_threshold + threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
unions = (masks > (mask_threshold - threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
return intersections / unions
@ -186,9 +176,7 @@ def build_point_grid(n_per_side: int) -> np.ndarray:
return points
def build_all_layer_point_grids(
n_per_side: int, n_layers: int, scale_per_layer: int
) -> List[np.ndarray]:
def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer: int) -> List[np.ndarray]:
"""Generates point grids for all crop layers."""
points_by_layer = []
for i in range(n_layers + 1):
@ -252,9 +240,7 @@ def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
return points + offset
def uncrop_masks(
masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int
) -> torch.Tensor:
def uncrop_masks(masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int) -> torch.Tensor:
x0, y0, x1, y1 = crop_box
if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
return masks
@ -264,9 +250,7 @@ def uncrop_masks(
return torch.nn.functional.pad(masks, pad, value=0)
def remove_small_regions(
mask: np.ndarray, area_thresh: float, mode: str
) -> Tuple[np.ndarray, bool]:
def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tuple[np.ndarray, bool]:
"""
Removes small disconnected regions and holes in a mask. Returns the
mask and an indicator of if the mask has been modified.

View File

@ -1,144 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
from torch.nn import functional as F
from typing import Tuple
from ..modeling import Sam
from .amg import calculate_stability_score
class SamOnnxModel(nn.Module):
"""
This model should not be called directly, but is used in ONNX export.
It combines the prompt encoder, mask decoder, and mask postprocessing of Sam,
with some functions modified to enable model tracing. Also supports extra
options controlling what information. See the ONNX export script for details.
"""
def __init__(
self,
model: Sam,
return_single_mask: bool,
use_stability_score: bool = False,
return_extra_metrics: bool = False,
) -> None:
super().__init__()
self.mask_decoder = model.mask_decoder
self.model = model
self.img_size = model.image_encoder.img_size
self.return_single_mask = return_single_mask
self.use_stability_score = use_stability_score
self.stability_score_offset = 1.0
self.return_extra_metrics = return_extra_metrics
@staticmethod
def resize_longest_image_size(
input_image_size: torch.Tensor, longest_side: int
) -> torch.Tensor:
input_image_size = input_image_size.to(torch.float32)
scale = longest_side / torch.max(input_image_size)
transformed_size = scale * input_image_size
transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64)
return transformed_size
def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor:
point_coords = point_coords + 0.5
point_coords = point_coords / self.img_size
point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords)
point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)
point_embedding = point_embedding * (point_labels != -1)
point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * (
point_labels == -1
)
for i in range(self.model.prompt_encoder.num_point_embeddings):
point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[
i
].weight * (point_labels == i)
return point_embedding
def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor:
mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask)
mask_embedding = mask_embedding + (
1 - has_mask_input
) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)
return mask_embedding
def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor:
masks = F.interpolate(
masks,
size=(self.img_size, self.img_size),
mode="bilinear",
align_corners=False,
)
prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size)
masks = masks[..., : int(prepadded_size[0]), : int(prepadded_size[1])]
orig_im_size = orig_im_size.to(torch.int64)
h, w = orig_im_size[0], orig_im_size[1]
masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False)
return masks
def select_masks(
self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int
) -> Tuple[torch.Tensor, torch.Tensor]:
# Determine if we should return the multiclick mask or not from the number of points.
# The reweighting is used to avoid control flow.
score_reweight = torch.tensor(
[[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)]
).to(iou_preds.device)
score = iou_preds + (num_points - 2.5) * score_reweight
best_idx = torch.argmax(score, dim=1)
masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1)
iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1)
return masks, iou_preds
@torch.no_grad()
def forward(
self,
image_embeddings: torch.Tensor,
point_coords: torch.Tensor,
point_labels: torch.Tensor,
mask_input: torch.Tensor,
has_mask_input: torch.Tensor,
orig_im_size: torch.Tensor,
):
sparse_embedding = self._embed_points(point_coords, point_labels)
dense_embedding = self._embed_masks(mask_input, has_mask_input)
masks, scores = self.model.mask_decoder.predict_masks(
image_embeddings=image_embeddings,
image_pe=self.model.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embedding,
dense_prompt_embeddings=dense_embedding,
)
if self.use_stability_score:
scores = calculate_stability_score(
masks, self.model.mask_threshold, self.stability_score_offset
)
if self.return_single_mask:
masks, scores = self.select_masks(masks, scores, point_coords.shape[1])
upscaled_masks = self.mask_postprocessing(masks, orig_im_size)
if self.return_extra_metrics:
stability_scores = calculate_stability_score(
upscaled_masks, self.model.mask_threshold, self.stability_score_offset
)
areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1)
return upscaled_masks, scores, stability_scores, areas, masks
return upscaled_masks, scores, masks

View File

@ -4,14 +4,14 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from copy import deepcopy
from typing import Tuple
import numpy as np
import torch
from torch.nn import functional as F
from torchvision.transforms.functional import resize, to_pil_image # type: ignore
from copy import deepcopy
from typing import Tuple
class ResizeLongestSide:
"""
@ -36,9 +36,7 @@ class ResizeLongestSide:
original image size in (H, W) format.
"""
old_h, old_w = original_size
new_h, new_w = self.get_preprocess_shape(
original_size[0], original_size[1], self.target_length
)
new_h, new_w = self.get_preprocess_shape(original_size[0], original_size[1], self.target_length)
coords = deepcopy(coords).astype(float)
coords[..., 0] = coords[..., 0] * (new_w / old_w)
coords[..., 1] = coords[..., 1] * (new_h / old_h)
@ -60,29 +58,21 @@ class ResizeLongestSide:
"""
# Expects an image in BCHW format. May not exactly match apply_image.
target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
return F.interpolate(
image, target_size, mode="bilinear", align_corners=False, antialias=True
)
return F.interpolate(image, target_size, mode="bilinear", align_corners=False, antialias=True)
def apply_coords_torch(
self, coords: torch.Tensor, original_size: Tuple[int, ...]
) -> torch.Tensor:
def apply_coords_torch(self, coords: torch.Tensor, original_size: Tuple[int, ...]) -> torch.Tensor:
"""
Expects a torch tensor with length 2 in the last dimension. Requires the
original image size in (H, W) format.
"""
old_h, old_w = original_size
new_h, new_w = self.get_preprocess_shape(
original_size[0], original_size[1], self.target_length
)
new_h, new_w = self.get_preprocess_shape(original_size[0], original_size[1], self.target_length)
coords = deepcopy(coords).to(torch.float)
coords[..., 0] = coords[..., 0] * (new_w / old_w)
coords[..., 1] = coords[..., 1] * (new_h / old_h)
return coords
def apply_boxes_torch(
self, boxes: torch.Tensor, original_size: Tuple[int, ...]
) -> torch.Tensor:
def apply_boxes_torch(self, boxes: torch.Tensor, original_size: Tuple[int, ...]) -> torch.Tensor:
"""
Expects a torch tensor with shape Bx4. Requires the original image
size in (H, W) format.

View File

@ -80,6 +80,7 @@ dependencies = [
"picklescan",
"pillow",
"prompt-toolkit",
"pycocotools",
"pympler~=1.0.1",
"pypatchmatch",
'pyperclip',
@ -90,6 +91,7 @@ dependencies = [
"scikit-image~=0.21.0",
"semver~=3.0.1",
"send2trash",
"supervision",
"test-tube~=0.7.5",
"windows-curses; sys_platform=='win32'",
]