wip: segment anything node

This commit is contained in:
blessedcoolant 2024-07-26 00:43:26 +05:30
parent e579be0118
commit b20c70c588
29 changed files with 264 additions and 482 deletions

View File

@ -0,0 +1,76 @@
from typing import Dict, cast
import torch
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import ImageField, InputField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.grounding_segment_anything.gsa import GroundingSegmentAnythingDetector
from invokeai.backend.util.devices import TorchDevice
GROUNDING_SEGMENT_ANYTHING_MODELS = {
"groundingdino_swint_ogc": "https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth",
"segment_anything_vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
}
@invocation(
"segment_anything",
title="Segment Anything",
tags=["grounding_dino", "segment", "anything"],
category="image",
version="1.0.0",
)
class SegmentAnythingInvocation(BaseInvocation):
"""Automatically generate masks from an image using GroundingDINO & Segment Anything"""
image: ImageField = InputField(description="The image to process")
prompt: str = InputField(default="", description="Keywords to segment", title="Prompt")
box_threshold: float = InputField(
default=0.5, ge=0, le=1, description="Threshold of box detection", title="Box Threshold"
)
text_threshold: float = InputField(
default=0.5, ge=0, le=1, description="Threshold of text detection", title="Text Threshold"
)
nms_threshold: float = InputField(
default=0.8, ge=0, le=1, description="Threshold of nms detection", title="NMS Threshold"
)
def invoke(self, context: InvocationContext) -> ImageOutput:
input_image = context.images.get_pil(self.image.image_name)
grounding_dino_model = context.models.load_remote_model(
GROUNDING_SEGMENT_ANYTHING_MODELS["groundingdino_swint_ogc"]
)
segment_anything_model = context.models.load_remote_model(
GROUNDING_SEGMENT_ANYTHING_MODELS["segment_anything_vit_h"]
)
with (
grounding_dino_model.model_on_device() as (_, grounding_dino_state_dict),
segment_anything_model.model_on_device() as (_, segment_anything_state_dict),
):
if not grounding_dino_state_dict or not segment_anything_state_dict:
raise RuntimeError("Unable to load segmentation models")
grounding_dino = GroundingSegmentAnythingDetector.build_grounding_dino(
cast(Dict[str, torch.Tensor], grounding_dino_state_dict)
)
segment_anything = GroundingSegmentAnythingDetector.build_segment_anything(
cast(Dict[str, torch.Tensor], segment_anything_state_dict), TorchDevice.choose_torch_device()
)
detector = GroundingSegmentAnythingDetector(grounding_dino, segment_anything)
mask = detector.predict(
input_image, self.prompt, self.box_threshold, self.text_threshold, self.nms_threshold
)
image_dto = context.images.save(mask)
"""Builds an ImageOutput and its ImageField"""
processed_image_field = ImageField(image_name=image_dto.image_name)
return ImageOutput(
image=processed_image_field,
width=input_image.width,
height=input_image.height,
)

View File

@ -10,8 +10,8 @@ import torch
import torchvision.transforms as T import torchvision.transforms as T
import torchvision.transforms.functional as F import torchvision.transforms.functional as F
from groundingdino.util.box_ops import box_xyxy_to_cxcywh from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.box_ops import box_xyxy_to_cxcywh
from groundingdino.util.misc import interpolate from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.misc import interpolate
def crop(image, target, region): def crop(image, target, region):
@ -58,9 +58,7 @@ def crop(image, target, region):
if os.environ.get("IPDB_SHILONG_DEBUG", None) == "INFO": if os.environ.get("IPDB_SHILONG_DEBUG", None) == "INFO":
# for debug and visualization only. # for debug and visualization only.
if "strings_positive" in target: if "strings_positive" in target:
target["strings_positive"] = [ target["strings_positive"] = [_i for _i, _j in zip(target["strings_positive"], keep, strict=False) if _j]
_i for _i, _j in zip(target["strings_positive"], keep) if _j
]
return cropped_image, target return cropped_image, target
@ -73,9 +71,7 @@ def hflip(image, target):
target = target.copy() target = target.copy()
if "boxes" in target: if "boxes" in target:
boxes = target["boxes"] boxes = target["boxes"]
boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor( boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0])
[w, 0, w, 0]
)
target["boxes"] = boxes target["boxes"] = boxes
if "masks" in target: if "masks" in target:
@ -119,15 +115,13 @@ def resize(image, target, size, max_size=None):
if target is None: if target is None:
return rescaled_image, None return rescaled_image, None
ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size)) ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size, strict=False))
ratio_width, ratio_height = ratios ratio_width, ratio_height = ratios
target = target.copy() target = target.copy()
if "boxes" in target: if "boxes" in target:
boxes = target["boxes"] boxes = target["boxes"]
scaled_boxes = boxes * torch.as_tensor( scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height])
[ratio_width, ratio_height, ratio_width, ratio_height]
)
target["boxes"] = scaled_boxes target["boxes"] = scaled_boxes
if "area" in target: if "area" in target:
@ -139,9 +133,7 @@ def resize(image, target, size, max_size=None):
target["size"] = torch.tensor([h, w]) target["size"] = torch.tensor([h, w])
if "masks" in target: if "masks" in target:
target["masks"] = ( target["masks"] = interpolate(target["masks"][:, None].float(), size, mode="nearest")[:, 0] > 0.5
interpolate(target["masks"][:, None].float(), size, mode="nearest")[:, 0] > 0.5
)
return rescaled_image, target return rescaled_image, target
@ -192,11 +184,7 @@ class RandomSizeCrop(object):
h = random.randint(self.min_size, min(img.height, self.max_size)) h = random.randint(self.min_size, min(img.height, self.max_size))
region = T.RandomCrop.get_params(img, [h, w]) region = T.RandomCrop.get_params(img, [h, w])
result_img, result_target = crop(img, target, region) result_img, result_target = crop(img, target, region)
if ( if not self.respect_boxes or len(result_target["boxes"]) == init_boxes or i == max_patience - 1:
not self.respect_boxes
or len(result_target["boxes"]) == init_boxes
or i == max_patience - 1
):
return result_img, result_target return result_img, result_target
return result_img, result_target return result_img, result_target

View File

@ -12,4 +12,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# ------------------------------------------------------------------------ # ------------------------------------------------------------------------
from .groundingdino import build_groundingdino from invokeai.backend.image_util.grounding_segment_anything.groundingdino.models.GroundingDINO.groundingdino import (
build_groundingdino,
)

View File

@ -24,10 +24,13 @@ import torchvision
from torch import nn from torch import nn
from torchvision.models._utils import IntermediateLayerGetter from torchvision.models._utils import IntermediateLayerGetter
from groundingdino.util.misc import NestedTensor, clean_state_dict, is_main_process from invokeai.backend.image_util.grounding_segment_anything.groundingdino.models.GroundingDINO.backbone.position_encoding import (
build_position_encoding,
from .position_encoding import build_position_encoding )
from .swin_transformer import build_swin_transformer from invokeai.backend.image_util.grounding_segment_anything.groundingdino.models.GroundingDINO.backbone.swin_transformer import (
build_swin_transformer,
)
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.misc import NestedTensor, is_main_process
class FrozenBatchNorm2d(torch.nn.Module): class FrozenBatchNorm2d(torch.nn.Module):
@ -80,19 +83,12 @@ class BackboneBase(nn.Module):
): ):
super().__init__() super().__init__()
for name, parameter in backbone.named_parameters(): for name, parameter in backbone.named_parameters():
if ( if not train_backbone or "layer2" not in name and "layer3" not in name and "layer4" not in name:
not train_backbone
or "layer2" not in name
and "layer3" not in name
and "layer4" not in name
):
parameter.requires_grad_(False) parameter.requires_grad_(False)
return_layers = {} return_layers = {}
for idx, layer_index in enumerate(return_interm_indices): for idx, layer_index in enumerate(return_interm_indices):
return_layers.update( return_layers.update({"layer{}".format(5 - len(return_interm_indices) + idx): "{}".format(layer_index)})
{"layer{}".format(5 - len(return_interm_indices) + idx): "{}".format(layer_index)}
)
# if len: # if len:
# if use_stage1_feature: # if use_stage1_feature:
@ -214,8 +210,8 @@ def build_backbone(args):
model = Joiner(backbone, position_embedding) model = Joiner(backbone, position_embedding)
model.num_channels = bb_num_channels model.num_channels = bb_num_channels
assert isinstance( assert isinstance(bb_num_channels, List), "bb_num_channels is expected to be a List but {}".format(
bb_num_channels, List type(bb_num_channels)
), "bb_num_channels is expected to be a List but {}".format(type(bb_num_channels)) )
# import ipdb; ipdb.set_trace() # import ipdb; ipdb.set_trace()
return model return model

View File

@ -24,7 +24,7 @@ import math
import torch import torch
from torch import nn from torch import nn
from groundingdino.util.misc import NestedTensor from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.misc import NestedTensor
class PositionEmbeddingSine(nn.Module): class PositionEmbeddingSine(nn.Module):
@ -65,12 +65,8 @@ class PositionEmbeddingSine(nn.Module):
pos_x = x_embed[:, :, :, None] / dim_t pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack( pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
).flatten(3)
pos_y = torch.stack(
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
return pos return pos
@ -81,9 +77,7 @@ class PositionEmbeddingSineHW(nn.Module):
used by the Attention is all you need paper, generalized to work on images. used by the Attention is all you need paper, generalized to work on images.
""" """
def __init__( def __init__(self, num_pos_feats=64, temperatureH=10000, temperatureW=10000, normalize=False, scale=None):
self, num_pos_feats=64, temperatureH=10000, temperatureW=10000, normalize=False, scale=None
):
super().__init__() super().__init__()
self.num_pos_feats = num_pos_feats self.num_pos_feats = num_pos_feats
self.temperatureH = temperatureH self.temperatureH = temperatureH
@ -111,19 +105,15 @@ class PositionEmbeddingSineHW(nn.Module):
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_tx = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) dim_tx = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_tx = self.temperatureW ** (2 * (torch.div(dim_tx, 2, rounding_mode='floor')) / self.num_pos_feats) dim_tx = self.temperatureW ** (2 * (torch.div(dim_tx, 2, rounding_mode="floor")) / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_tx pos_x = x_embed[:, :, :, None] / dim_tx
dim_ty = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) dim_ty = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_ty = self.temperatureH ** (2 * (torch.div(dim_ty, 2, rounding_mode='floor')) / self.num_pos_feats) dim_ty = self.temperatureH ** (2 * (torch.div(dim_ty, 2, rounding_mode="floor")) / self.num_pos_feats)
pos_y = y_embed[:, :, :, None] / dim_ty pos_y = y_embed[:, :, :, None] / dim_ty
pos_x = torch.stack( pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
).flatten(3)
pos_y = torch.stack(
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
# import ipdb; ipdb.set_trace() # import ipdb; ipdb.set_trace()

View File

@ -18,15 +18,13 @@ import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath, to_2tuple, trunc_normal_ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from groundingdino.util.misc import NestedTensor from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.misc import NestedTensor
class Mlp(nn.Module): class Mlp(nn.Module):
"""Multilayer perceptron.""" """Multilayer perceptron."""
def __init__( def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0):
self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0
):
super().__init__() super().__init__()
out_features = out_features or in_features out_features = out_features or in_features
hidden_features = hidden_features or in_features hidden_features = hidden_features or in_features
@ -138,24 +136,16 @@ class WindowAttention(nn.Module):
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
""" """
B_, N, C = x.shape B_, N, C = x.shape
qkv = ( qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
self.qkv(x)
.reshape(B_, N, 3, self.num_heads, C // self.num_heads)
.permute(2, 0, 3, 1, 4)
)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale q = q * self.scale
attn = q @ k.transpose(-2, -1) attn = q @ k.transpose(-2, -1)
relative_position_bias = self.relative_position_bias_table[ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.relative_position_index.view(-1)
].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1 self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
) # Wh*Ww,Wh*Ww,nH ) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute( relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
2, 0, 1
).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0) attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None: if mask is not None:
@ -228,9 +218,7 @@ class SwinTransformerBlock(nn.Module):
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim) self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio) mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp( self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop
)
self.H = None self.H = None
self.W = None self.W = None
@ -266,12 +254,8 @@ class SwinTransformerBlock(nn.Module):
attn_mask = None attn_mask = None
# partition windows # partition windows
x_windows = window_partition( x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
shifted_x, self.window_size x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
) # nW*B, window_size, window_size, C
x_windows = x_windows.view(
-1, self.window_size * self.window_size, C
) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA # W-MSA/SW-MSA
attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
@ -433,14 +417,10 @@ class BasicLayer(nn.Module):
img_mask[:, h, w, :] = cnt img_mask[:, h, w, :] = cnt
cnt += 1 cnt += 1
mask_windows = window_partition( mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
img_mask, self.window_size
) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size) mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill( attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
attn_mask == 0, float(0.0)
)
for blk in self.blocks: for blk in self.blocks:
blk.H, blk.W = H, W blk.H, blk.W = H, W
@ -589,9 +569,7 @@ class SwinTransformer(nn.Module):
self.pos_drop = nn.Dropout(p=drop_rate) self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth # stochastic depth
dpr = [ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
] # stochastic depth decay rule
# build layers # build layers
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
@ -682,9 +660,7 @@ class SwinTransformer(nn.Module):
Wh, Ww = x.size(2), x.size(3) Wh, Ww = x.size(2), x.size(3)
if self.ape: if self.ape:
# interpolate the position embedding to the corresponding size # interpolate the position embedding to the corresponding size
absolute_pos_embed = F.interpolate( absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic")
self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic"
)
x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
else: else:
x = x.flatten(2).transpose(1, 2) x = x.flatten(2).transpose(1, 2)
@ -718,9 +694,7 @@ class SwinTransformer(nn.Module):
Wh, Ww = x.size(2), x.size(3) Wh, Ww = x.size(2), x.size(3)
if self.ape: if self.ape:
# interpolate the position embedding to the corresponding size # interpolate the position embedding to the corresponding size
absolute_pos_embed = F.interpolate( absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic")
self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic"
)
x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
else: else:
x = x.flatten(2).transpose(1, 2) x = x.flatten(2).transpose(1, 2)
@ -769,21 +743,11 @@ def build_swin_transformer(modelname, pretrain_img_size, **kw):
] ]
model_para_dict = { model_para_dict = {
"swin_T_224_1k": dict( "swin_T_224_1k": dict(embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7),
embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7 "swin_B_224_22k": dict(embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=7),
), "swin_B_384_22k": dict(embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=12),
"swin_B_224_22k": dict( "swin_L_224_22k": dict(embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=7),
embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=7 "swin_L_384_22k": dict(embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=12),
),
"swin_B_384_22k": dict(
embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=12
),
"swin_L_224_22k": dict(
embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=7
),
"swin_L_384_22k": dict(
embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=12
),
} }
kw_cgf = model_para_dict[modelname] kw_cgf = model_para_dict[modelname]
kw_cgf.update(kw) kw_cgf.update(kw)

View File

@ -21,8 +21,12 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from groundingdino.util import get_tokenlizer from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util import get_tokenlizer
from groundingdino.util.misc import NestedTensor, inverse_sigmoid, nested_tensor_from_tensor_list from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.misc import (
NestedTensor,
inverse_sigmoid,
nested_tensor_from_tensor_list,
)
from ..registry import MODULE_BUILD_FUNCS from ..registry import MODULE_BUILD_FUNCS
from .backbone import build_backbone from .backbone import build_backbone

View File

@ -22,14 +22,19 @@ import torch
import torch.utils.checkpoint as checkpoint import torch.utils.checkpoint as checkpoint
from torch import Tensor, nn from torch import Tensor, nn
from groundingdino.util.misc import inverse_sigmoid from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.misc import inverse_sigmoid
from .fuse_modules import BiAttentionBlock from .fuse_modules import BiAttentionBlock
from .ms_deform_attn import MultiScaleDeformableAttention as MSDeformAttn from .ms_deform_attn import MultiScaleDeformableAttention as MSDeformAttn
from .transformer_vanilla import TransformerEncoderLayer from .transformer_vanilla import TransformerEncoderLayer
from .utils import (MLP, _get_activation_fn, _get_clones, from .utils import (
gen_encoder_output_proposals, gen_sineembed_for_position, MLP,
get_sine_pos_embed) _get_activation_fn,
_get_clones,
gen_encoder_output_proposals,
gen_sineembed_for_position,
get_sine_pos_embed,
)
class Transformer(nn.Module): class Transformer(nn.Module):

View File

@ -1,4 +1,4 @@
from typing import List, Tuple from typing import Dict, List, Tuple
import cv2 import cv2
import numpy as np import numpy as np
@ -7,11 +7,11 @@ import torch
from PIL import Image from PIL import Image
from torchvision.ops import box_convert from torchvision.ops import box_convert
import groundingdino.datasets.transforms as T import invokeai.backend.image_util.grounding_segment_anything.groundingdino.datasets.transforms as T
from groundingdino.models import build_model from invokeai.backend.image_util.grounding_segment_anything.groundingdino.models import build_model
from groundingdino.util.misc import clean_state_dict from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.misc import clean_state_dict
from groundingdino.util.slconfig import SLConfig from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.slconfig import SLConfig
from groundingdino.util.utils import get_phrases_from_posmap from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.utils import get_phrases_from_posmap
# ---------------------------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------------------------
# OLD API # OLD API
@ -25,12 +25,11 @@ def preprocess_caption(caption: str) -> str:
return result + "." return result + "."
def load_model(model_config_path: str, model_checkpoint_path: str, device: str = "cuda"): def load_model(model_config_path: str, model_state_dict: Dict[str, torch.Tensor], device: str = "cuda"):
args = SLConfig.fromfile(model_config_path) args = SLConfig.fromfile(model_config_path)
args.device = device args.device = device
model = build_model(args) model = build_model(args)
checkpoint = torch.load(model_checkpoint_path, map_location="cpu") model.load_state_dict(clean_state_dict(model_state_dict["model"]), strict=False)
model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
model.eval() model.eval()
return model return model
@ -98,9 +97,9 @@ def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor
class Model: class Model:
def __init__(self, model_config_path: str, model_checkpoint_path: str, device: str = "cuda"): def __init__(self, model_config_path: str, model_state_dict: Dict[str, torch.Tensor], device: str = "cuda"):
self.model = load_model( self.model = load_model(
model_config_path=model_config_path, model_checkpoint_path=model_checkpoint_path, device=device model_config_path=model_config_path, model_state_dict=model_state_dict, device=device
).to(device) ).to(device)
self.device = device self.device = device

View File

@ -1,91 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import functools
import logging
import os
import sys
from termcolor import colored
class _ColorfulFormatter(logging.Formatter):
def __init__(self, *args, **kwargs):
self._root_name = kwargs.pop("root_name") + "."
self._abbrev_name = kwargs.pop("abbrev_name", "")
if len(self._abbrev_name):
self._abbrev_name = self._abbrev_name + "."
super(_ColorfulFormatter, self).__init__(*args, **kwargs)
def formatMessage(self, record):
record.name = record.name.replace(self._root_name, self._abbrev_name)
log = super(_ColorfulFormatter, self).formatMessage(record)
if record.levelno == logging.WARNING:
prefix = colored("WARNING", "red", attrs=["blink"])
elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
prefix = colored("ERROR", "red", attrs=["blink", "underline"])
else:
return log
return prefix + " " + log
# so that calling setup_logger multiple times won't add many handlers
@functools.lru_cache()
def setup_logger(output=None, distributed_rank=0, *, color=True, name="imagenet", abbrev_name=None):
"""
Initialize the detectron2 logger and set its verbosity level to "INFO".
Args:
output (str): a file name or a directory to save log. If None, will not save log file.
If ends with ".txt" or ".log", assumed to be a file name.
Otherwise, logs will be saved to `output/log.txt`.
name (str): the root module name of this logger
Returns:
logging.Logger: a logger
"""
logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG)
logger.propagate = False
if abbrev_name is None:
abbrev_name = name
plain_formatter = logging.Formatter("[%(asctime)s.%(msecs)03d]: %(message)s", datefmt="%m/%d %H:%M:%S")
# stdout logging: master only
if distributed_rank == 0:
ch = logging.StreamHandler(stream=sys.stdout)
ch.setLevel(logging.DEBUG)
if color:
formatter = _ColorfulFormatter(
colored("[%(asctime)s.%(msecs)03d]: ", "green") + "%(message)s",
datefmt="%m/%d %H:%M:%S",
root_name=name,
abbrev_name=str(abbrev_name),
)
else:
formatter = plain_formatter
ch.setFormatter(formatter)
logger.addHandler(ch)
# file logging: all workers
if output is not None:
if output.endswith(".txt") or output.endswith(".log"):
filename = output
else:
filename = os.path.join(output, "log.txt")
if distributed_rank > 0:
filename = filename + f".rank{distributed_rank}"
os.makedirs(os.path.dirname(filename), exist_ok=True)
fh = logging.StreamHandler(_cached_log_stream(filename))
fh.setLevel(logging.DEBUG)
fh.setFormatter(plain_formatter)
logger.addHandler(fh)
return logger
# cache the opened file object, so that different calls to `setup_logger`
# with the same file name can safely write to the same file.
@functools.lru_cache(maxsize=None)
def _cached_log_stream(filename):
return open(filename, "a")

View File

@ -161,7 +161,7 @@ def all_gather_cpu(data):
dist.all_gather(tensor_list, tensor, group=cpu_group) dist.all_gather(tensor_list, tensor, group=cpu_group)
data_list = [] data_list = []
for size, tensor in zip(size_list, tensor_list): for size, tensor in zip(size_list, tensor_list, strict=False):
tensor = torch.split(tensor, [size, max_size - size], dim=0)[0] tensor = torch.split(tensor, [size, max_size - size], dim=0)[0]
buffer = io.BytesIO(tensor.cpu().numpy()) buffer = io.BytesIO(tensor.cpu().numpy())
obj = torch.load(buffer) obj = torch.load(buffer)
@ -210,7 +210,7 @@ def all_gather(data):
dist.all_gather(tensor_list, tensor) dist.all_gather(tensor_list, tensor)
data_list = [] data_list = []
for size, tensor in zip(size_list, tensor_list): for size, tensor in zip(size_list, tensor_list, strict=False):
buffer = tensor.cpu().numpy().tobytes()[:size] buffer = tensor.cpu().numpy().tobytes()[:size]
data_list.append(pickle.loads(buffer)) data_list.append(pickle.loads(buffer))
@ -240,7 +240,7 @@ def reduce_dict(input_dict, average=True):
dist.all_reduce(values) dist.all_reduce(values)
if average: if average:
values /= world_size values /= world_size
reduced_dict = {k: v for k, v in zip(names, values)} reduced_dict = {k: v for k, v in zip(names, values, strict=False)}
return reduced_dict return reduced_dict
@ -378,7 +378,7 @@ def get_sha():
def collate_fn(batch): def collate_fn(batch):
# import ipdb; ipdb.set_trace() # import ipdb; ipdb.set_trace()
batch = list(zip(*batch)) batch = list(zip(*batch, strict=False))
batch[0] = nested_tensor_from_tensor_list(batch[0]) batch[0] = nested_tensor_from_tensor_list(batch[0])
return tuple(batch) return tuple(batch)
@ -480,7 +480,7 @@ def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
device = tensor_list[0].device device = tensor_list[0].device
tensor = torch.zeros(batch_shape, dtype=dtype, device=device) tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
mask = torch.ones((b, h, w), dtype=torch.bool, device=device) mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
for img, pad_img, m in zip(tensor_list, tensor, mask): for img, pad_img, m in zip(tensor_list, tensor, mask, strict=False):
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
m[: img.shape[1], : img.shape[2]] = False m[: img.shape[1], : img.shape[2]] = False
else: else:
@ -505,7 +505,7 @@ def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTen
padded_imgs = [] padded_imgs = []
padded_masks = [] padded_masks = []
for img in tensor_list: for img in tensor_list:
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape), strict=False)]
padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
padded_imgs.append(padded_img) padded_imgs.append(padded_img)

View File

@ -9,7 +9,7 @@ import numpy as np
import torch import torch
from transformers import AutoTokenizer from transformers import AutoTokenizer
from groundingdino.util.slconfig import SLConfig from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.slconfig import SLConfig
def slprint(x, name="x"): def slprint(x, name="x"):

View File

@ -1,44 +1,38 @@
from typing import Dict, List, Literal, Optional import pathlib
from typing import Dict, List, Optional
import cv2
import numpy as np import numpy as np
import supervision as sv import supervision as sv
import torch import torch
import torchvision import torchvision
from PIL import Image
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.inference import Model from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.inference import Model
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.build_sam import sam_model_registry from invokeai.backend.image_util.grounding_segment_anything.segment_anything.build_sam import sam_model_registry
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.predictor import SamPredictor from invokeai.backend.image_util.grounding_segment_anything.segment_anything.predictor import SamPredictor
GROUNDING_SEGMENT_ANYTHING_MODELS = {
"groundingdino_swint_ogc": "https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth",
"segment_anything_vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
}
class GroundingSegmentAnythingDetector: class GroundingSegmentAnythingDetector:
def __init__(self) -> None: def __init__(self, grounding_dino_model: Model, segment_anything_model: SamPredictor) -> None:
self.grounding_dino_model: Optional[Model] = None self.grounding_dino_model: Optional[Model] = grounding_dino_model
self.segment_anything_model: Optional[SamPredictor] = None self.segment_anything_model: Optional[SamPredictor] = segment_anything_model
self.grounding_dino_config: str = "./groundingdino/config/GroundingDINO_SwinT_OGC.py"
self.sam_encoder: Literal["vit_h"] = "vit_h"
self.device: torch.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
def build_grounding_dino(self): @staticmethod
def build_grounding_dino(grounding_dino_state_dict: Dict[str, torch.Tensor]):
grounding_dino_config = pathlib.Path(
"./invokeai/backend/image_util/grounding_segment_anything/groundingdino/config/GroundingDINO_SwinT_OGC.py"
)
return Model( return Model(
model_config_path=self.grounding_dino_config, model_state_dict=grounding_dino_state_dict,
model_checkpoint_path="./checkpoints/groundingdino_swint_ogc.pth", model_config_path=grounding_dino_config.as_posix(),
) )
def build_segment_anything(self): @staticmethod
sam = sam_model_registry[self.sam_encoder](checkpoint="./checkpoints/sam_vit_h_4b8939.pth") def build_segment_anything(segment_anything_state_dict: Dict[str, torch.Tensor], device: torch.device):
sam.to(device=self.device) sam = sam_model_registry["vit_h"](checkpoint=segment_anything_state_dict)
sam.to(device=device)
return SamPredictor(sam) return SamPredictor(sam)
def build_grounding_sam(self):
self.grounding_dino_model = self.build_grounding_dino()
self.segment_anything_model = self.build_segment_anything()
def detect_objects( def detect_objects(
self, self,
image: np.ndarray, image: np.ndarray,
@ -77,20 +71,18 @@ class GroundingSegmentAnythingDetector:
def predict( def predict(
self, self,
image: str, image: Image.Image,
prompt: str, prompt: str,
box_threshold: float = 0.5, box_threshold: float = 0.5,
text_threshold: float = 0.5, text_threshold: float = 0.5,
nms_threshold: float = 0.8, nms_threshold: float = 0.8,
): ):
if not self.grounding_dino_model or not self.segment_anything_model: open_cv_image = np.array(image)
self.build_grounding_sam() open_cv_image = open_cv_image[:, :, ::-1].copy()
image = cv2.imread(image)
prompts = prompt.split(",") prompts = prompt.split(",")
detections = self.detect_objects(image, prompts, box_threshold, text_threshold, nms_threshold) detections = self.detect_objects(open_cv_image, prompts, box_threshold, text_threshold, nms_threshold)
segments = self.segment_detections(image, detections, prompts) segments = self.segment_detections(open_cv_image, detections, prompts)
if len(segments) > 0: if len(segments) > 0:
combined_mask = np.zeros_like(list(segments.values())[0]) combined_mask = np.zeros_like(list(segments.values())[0])
@ -98,15 +90,6 @@ class GroundingSegmentAnythingDetector:
combined_mask = np.logical_or(combined_mask, mask) combined_mask = np.logical_or(combined_mask, mask)
mask_preview = (combined_mask * 255).astype(np.uint8) mask_preview = (combined_mask * 255).astype(np.uint8)
else: else:
mask_preview = np.zeros(image.shape, np.uint8) mask_preview = np.zeros(open_cv_image.shape, np.uint8)
cv2.imwrite("mask.png", mask_preview) return Image.fromarray(mask_preview)
if __name__ == "__main__":
gsa = GroundingSegmentAnythingDetector()
image = "./assets/image.webp"
while True:
prompt = input("Segment: ")
gsa.predict(image, prompt, 0.5, 0.5, 0.8)

View File

@ -4,13 +4,22 @@
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from .automatic_mask_generator import SamAutomaticMaskGenerator
from .build_sam import build_sam, build_sam_vit_b, build_sam_vit_h, build_sam_vit_l, sam_model_registry from invokeai.backend.image_util.grounding_segment_anything.segment_anything.automatic_mask_generator import (
from .build_sam_hq import ( SamAutomaticMaskGenerator,
) # noqa
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.build_sam import ( # noqa
build_sam,
build_sam_vit_b,
build_sam_vit_h,
build_sam_vit_l,
sam_model_registry,
)
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.build_sam_hq import ( # noqa
build_sam_hq, build_sam_hq,
build_sam_hq_vit_b, build_sam_hq_vit_b,
build_sam_hq_vit_h, build_sam_hq_vit_h,
build_sam_hq_vit_l, build_sam_hq_vit_l,
sam_hq_model_registry, sam_hq_model_registry,
) )
from .predictor import SamPredictor from invokeai.backend.image_util.grounding_segment_anything.segment_anything.predictor import SamPredictor # noqa

View File

@ -10,9 +10,9 @@ import numpy as np
import torch import torch
from torchvision.ops.boxes import batched_nms, box_area # type: ignore from torchvision.ops.boxes import batched_nms, box_area # type: ignore
from .modeling import Sam from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling import Sam
from .predictor import SamPredictor from invokeai.backend.image_util.grounding_segment_anything.segment_anything.predictor import SamPredictor
from .utils.amg import ( from invokeai.backend.image_util.grounding_segment_anything.segment_anything.utils.amg import (
MaskData, MaskData,
area_from_rle, area_from_rle,
batch_iterator, batch_iterator,

View File

@ -8,7 +8,13 @@ from functools import partial
import torch import torch
from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling import (
ImageEncoderViT,
MaskDecoder,
PromptEncoder,
Sam,
TwoWayTransformer,
)
def build_sam_vit_h(checkpoint=None): def build_sam_vit_h(checkpoint=None):
@ -101,7 +107,5 @@ def _build_sam(
) )
sam.eval() sam.eval()
if checkpoint is not None: if checkpoint is not None:
with open(checkpoint, "rb") as f: sam.load_state_dict(checkpoint)
state_dict = torch.load(f)
sam.load_state_dict(state_dict)
return sam return sam

View File

@ -8,7 +8,13 @@ from functools import partial
import torch import torch
from .modeling import ImageEncoderViT, MaskDecoderHQ, PromptEncoder, Sam, TwoWayTransformer from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling import (
ImageEncoderViT,
MaskDecoderHQ,
PromptEncoder,
Sam,
TwoWayTransformer,
)
def build_sam_hq_vit_h(checkpoint=None): def build_sam_hq_vit_h(checkpoint=None):

View File

@ -4,9 +4,17 @@
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from .image_encoder import ImageEncoderViT from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.image_encoder import (
from .mask_decoder import MaskDecoder ImageEncoderViT,
from .mask_decoder_hq import MaskDecoderHQ )
from .prompt_encoder import PromptEncoder from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.mask_decoder import MaskDecoder
from .sam import Sam from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.mask_decoder_hq import (
from .transformer import TwoWayTransformer MaskDecoderHQ,
)
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.prompt_encoder import (
PromptEncoder,
)
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.sam import Sam
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.transformer import (
TwoWayTransformer,
)

View File

@ -10,7 +10,10 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .common import LayerNorm2d, MLPBlock from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.common import (
LayerNorm2d,
MLPBlock,
)
# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa

View File

@ -10,7 +10,7 @@ import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from .common import LayerNorm2d from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.common import LayerNorm2d
class MaskDecoder(nn.Module): class MaskDecoder(nn.Module):

View File

@ -11,7 +11,7 @@ import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from .common import LayerNorm2d from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.common import LayerNorm2d
class MaskDecoderHQ(nn.Module): class MaskDecoderHQ(nn.Module):

View File

@ -10,7 +10,7 @@ import numpy as np
import torch import torch
from torch import nn from torch import nn
from .common import LayerNorm2d from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.common import LayerNorm2d
class PromptEncoder(nn.Module): class PromptEncoder(nn.Module):

View File

@ -10,9 +10,13 @@ import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from .image_encoder import ImageEncoderViT from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.image_encoder import (
from .mask_decoder import MaskDecoder ImageEncoderViT,
from .prompt_encoder import PromptEncoder )
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.mask_decoder import MaskDecoder
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.prompt_encoder import (
PromptEncoder,
)
class Sam(nn.Module): class Sam(nn.Module):
@ -98,7 +102,7 @@ class Sam(nn.Module):
image_embeddings = self.image_encoder(input_images) image_embeddings = self.image_encoder(input_images)
outputs = [] outputs = []
for image_record, curr_embedding in zip(batched_input, image_embeddings): for image_record, curr_embedding in zip(batched_input, image_embeddings, strict=False):
if "point_coords" in image_record: if "point_coords" in image_record:
points = (image_record["point_coords"], image_record["point_labels"]) points = (image_record["point_coords"], image_record["point_labels"])
else: else:

View File

@ -10,7 +10,7 @@ from typing import Tuple, Type
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
from .common import MLPBlock from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.common import MLPBlock
class TwoWayTransformer(nn.Module): class TwoWayTransformer(nn.Module):

View File

@ -9,8 +9,8 @@ from typing import Optional, Tuple
import numpy as np import numpy as np
import torch import torch
from .modeling import Sam from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling import Sam
from .utils.transforms import ResizeLongestSide from invokeai.backend.image_util.grounding_segment_anything.segment_anything.utils.transforms import ResizeLongestSide
class SamPredictor: class SamPredictor:

View File

@ -4,14 +4,14 @@
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import numpy as np
import torch
import math import math
from copy import deepcopy from copy import deepcopy
from itertools import product from itertools import product
from typing import Any, Dict, Generator, ItemsView, List, Tuple from typing import Any, Dict, Generator, ItemsView, List, Tuple
import numpy as np
import torch
class MaskData: class MaskData:
""" """
@ -153,9 +153,7 @@ def area_from_rle(rle: Dict[str, Any]) -> int:
return sum(rle["counts"][1::2]) return sum(rle["counts"][1::2])
def calculate_stability_score( def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, threshold_offset: float) -> torch.Tensor:
masks: torch.Tensor, mask_threshold: float, threshold_offset: float
) -> torch.Tensor:
""" """
Computes the stability score for a batch of masks. The stability Computes the stability score for a batch of masks. The stability
score is the IoU between the binary masks obtained by thresholding score is the IoU between the binary masks obtained by thresholding
@ -163,16 +161,8 @@ def calculate_stability_score(
""" """
# One mask is always contained inside the other. # One mask is always contained inside the other.
# Save memory by preventing unnecesary cast to torch.int64 # Save memory by preventing unnecesary cast to torch.int64
intersections = ( intersections = (masks > (mask_threshold + threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
(masks > (mask_threshold + threshold_offset)) unions = (masks > (mask_threshold - threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
.sum(-1, dtype=torch.int16)
.sum(-1, dtype=torch.int32)
)
unions = (
(masks > (mask_threshold - threshold_offset))
.sum(-1, dtype=torch.int16)
.sum(-1, dtype=torch.int32)
)
return intersections / unions return intersections / unions
@ -186,9 +176,7 @@ def build_point_grid(n_per_side: int) -> np.ndarray:
return points return points
def build_all_layer_point_grids( def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer: int) -> List[np.ndarray]:
n_per_side: int, n_layers: int, scale_per_layer: int
) -> List[np.ndarray]:
"""Generates point grids for all crop layers.""" """Generates point grids for all crop layers."""
points_by_layer = [] points_by_layer = []
for i in range(n_layers + 1): for i in range(n_layers + 1):
@ -252,9 +240,7 @@ def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
return points + offset return points + offset
def uncrop_masks( def uncrop_masks(masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int) -> torch.Tensor:
masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int
) -> torch.Tensor:
x0, y0, x1, y1 = crop_box x0, y0, x1, y1 = crop_box
if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h: if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
return masks return masks
@ -264,9 +250,7 @@ def uncrop_masks(
return torch.nn.functional.pad(masks, pad, value=0) return torch.nn.functional.pad(masks, pad, value=0)
def remove_small_regions( def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tuple[np.ndarray, bool]:
mask: np.ndarray, area_thresh: float, mode: str
) -> Tuple[np.ndarray, bool]:
""" """
Removes small disconnected regions and holes in a mask. Returns the Removes small disconnected regions and holes in a mask. Returns the
mask and an indicator of if the mask has been modified. mask and an indicator of if the mask has been modified.

View File

@ -1,144 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
from torch.nn import functional as F
from typing import Tuple
from ..modeling import Sam
from .amg import calculate_stability_score
class SamOnnxModel(nn.Module):
"""
This model should not be called directly, but is used in ONNX export.
It combines the prompt encoder, mask decoder, and mask postprocessing of Sam,
with some functions modified to enable model tracing. Also supports extra
options controlling what information. See the ONNX export script for details.
"""
def __init__(
self,
model: Sam,
return_single_mask: bool,
use_stability_score: bool = False,
return_extra_metrics: bool = False,
) -> None:
super().__init__()
self.mask_decoder = model.mask_decoder
self.model = model
self.img_size = model.image_encoder.img_size
self.return_single_mask = return_single_mask
self.use_stability_score = use_stability_score
self.stability_score_offset = 1.0
self.return_extra_metrics = return_extra_metrics
@staticmethod
def resize_longest_image_size(
input_image_size: torch.Tensor, longest_side: int
) -> torch.Tensor:
input_image_size = input_image_size.to(torch.float32)
scale = longest_side / torch.max(input_image_size)
transformed_size = scale * input_image_size
transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64)
return transformed_size
def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor:
point_coords = point_coords + 0.5
point_coords = point_coords / self.img_size
point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords)
point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)
point_embedding = point_embedding * (point_labels != -1)
point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * (
point_labels == -1
)
for i in range(self.model.prompt_encoder.num_point_embeddings):
point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[
i
].weight * (point_labels == i)
return point_embedding
def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor:
mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask)
mask_embedding = mask_embedding + (
1 - has_mask_input
) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)
return mask_embedding
def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor:
masks = F.interpolate(
masks,
size=(self.img_size, self.img_size),
mode="bilinear",
align_corners=False,
)
prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size)
masks = masks[..., : int(prepadded_size[0]), : int(prepadded_size[1])]
orig_im_size = orig_im_size.to(torch.int64)
h, w = orig_im_size[0], orig_im_size[1]
masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False)
return masks
def select_masks(
self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int
) -> Tuple[torch.Tensor, torch.Tensor]:
# Determine if we should return the multiclick mask or not from the number of points.
# The reweighting is used to avoid control flow.
score_reweight = torch.tensor(
[[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)]
).to(iou_preds.device)
score = iou_preds + (num_points - 2.5) * score_reweight
best_idx = torch.argmax(score, dim=1)
masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1)
iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1)
return masks, iou_preds
@torch.no_grad()
def forward(
self,
image_embeddings: torch.Tensor,
point_coords: torch.Tensor,
point_labels: torch.Tensor,
mask_input: torch.Tensor,
has_mask_input: torch.Tensor,
orig_im_size: torch.Tensor,
):
sparse_embedding = self._embed_points(point_coords, point_labels)
dense_embedding = self._embed_masks(mask_input, has_mask_input)
masks, scores = self.model.mask_decoder.predict_masks(
image_embeddings=image_embeddings,
image_pe=self.model.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embedding,
dense_prompt_embeddings=dense_embedding,
)
if self.use_stability_score:
scores = calculate_stability_score(
masks, self.model.mask_threshold, self.stability_score_offset
)
if self.return_single_mask:
masks, scores = self.select_masks(masks, scores, point_coords.shape[1])
upscaled_masks = self.mask_postprocessing(masks, orig_im_size)
if self.return_extra_metrics:
stability_scores = calculate_stability_score(
upscaled_masks, self.model.mask_threshold, self.stability_score_offset
)
areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1)
return upscaled_masks, scores, stability_scores, areas, masks
return upscaled_masks, scores, masks

View File

@ -4,14 +4,14 @@
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from copy import deepcopy
from typing import Tuple
import numpy as np import numpy as np
import torch import torch
from torch.nn import functional as F from torch.nn import functional as F
from torchvision.transforms.functional import resize, to_pil_image # type: ignore from torchvision.transforms.functional import resize, to_pil_image # type: ignore
from copy import deepcopy
from typing import Tuple
class ResizeLongestSide: class ResizeLongestSide:
""" """
@ -36,9 +36,7 @@ class ResizeLongestSide:
original image size in (H, W) format. original image size in (H, W) format.
""" """
old_h, old_w = original_size old_h, old_w = original_size
new_h, new_w = self.get_preprocess_shape( new_h, new_w = self.get_preprocess_shape(original_size[0], original_size[1], self.target_length)
original_size[0], original_size[1], self.target_length
)
coords = deepcopy(coords).astype(float) coords = deepcopy(coords).astype(float)
coords[..., 0] = coords[..., 0] * (new_w / old_w) coords[..., 0] = coords[..., 0] * (new_w / old_w)
coords[..., 1] = coords[..., 1] * (new_h / old_h) coords[..., 1] = coords[..., 1] * (new_h / old_h)
@ -60,29 +58,21 @@ class ResizeLongestSide:
""" """
# Expects an image in BCHW format. May not exactly match apply_image. # Expects an image in BCHW format. May not exactly match apply_image.
target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
return F.interpolate( return F.interpolate(image, target_size, mode="bilinear", align_corners=False, antialias=True)
image, target_size, mode="bilinear", align_corners=False, antialias=True
)
def apply_coords_torch( def apply_coords_torch(self, coords: torch.Tensor, original_size: Tuple[int, ...]) -> torch.Tensor:
self, coords: torch.Tensor, original_size: Tuple[int, ...]
) -> torch.Tensor:
""" """
Expects a torch tensor with length 2 in the last dimension. Requires the Expects a torch tensor with length 2 in the last dimension. Requires the
original image size in (H, W) format. original image size in (H, W) format.
""" """
old_h, old_w = original_size old_h, old_w = original_size
new_h, new_w = self.get_preprocess_shape( new_h, new_w = self.get_preprocess_shape(original_size[0], original_size[1], self.target_length)
original_size[0], original_size[1], self.target_length
)
coords = deepcopy(coords).to(torch.float) coords = deepcopy(coords).to(torch.float)
coords[..., 0] = coords[..., 0] * (new_w / old_w) coords[..., 0] = coords[..., 0] * (new_w / old_w)
coords[..., 1] = coords[..., 1] * (new_h / old_h) coords[..., 1] = coords[..., 1] * (new_h / old_h)
return coords return coords
def apply_boxes_torch( def apply_boxes_torch(self, boxes: torch.Tensor, original_size: Tuple[int, ...]) -> torch.Tensor:
self, boxes: torch.Tensor, original_size: Tuple[int, ...]
) -> torch.Tensor:
""" """
Expects a torch tensor with shape Bx4. Requires the original image Expects a torch tensor with shape Bx4. Requires the original image
size in (H, W) format. size in (H, W) format.

View File

@ -80,6 +80,7 @@ dependencies = [
"picklescan", "picklescan",
"pillow", "pillow",
"prompt-toolkit", "prompt-toolkit",
"pycocotools",
"pympler~=1.0.1", "pympler~=1.0.1",
"pypatchmatch", "pypatchmatch",
'pyperclip', 'pyperclip',
@ -90,6 +91,7 @@ dependencies = [
"scikit-image~=0.21.0", "scikit-image~=0.21.0",
"semver~=3.0.1", "semver~=3.0.1",
"send2trash", "send2trash",
"supervision",
"test-tube~=0.7.5", "test-tube~=0.7.5",
"windows-curses; sys_platform=='win32'", "windows-curses; sys_platform=='win32'",
] ]