mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
wip: segment anything node
This commit is contained in:
parent
e579be0118
commit
b20c70c588
76
invokeai/app/invocations/segment_anything.py
Normal file
76
invokeai/app/invocations/segment_anything.py
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
from typing import Dict, cast
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||||
|
from invokeai.app.invocations.fields import ImageField, InputField
|
||||||
|
from invokeai.app.invocations.primitives import ImageOutput
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
|
from invokeai.backend.image_util.grounding_segment_anything.gsa import GroundingSegmentAnythingDetector
|
||||||
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
|
GROUNDING_SEGMENT_ANYTHING_MODELS = {
|
||||||
|
"groundingdino_swint_ogc": "https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth",
|
||||||
|
"segment_anything_vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@invocation(
|
||||||
|
"segment_anything",
|
||||||
|
title="Segment Anything",
|
||||||
|
tags=["grounding_dino", "segment", "anything"],
|
||||||
|
category="image",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
|
class SegmentAnythingInvocation(BaseInvocation):
|
||||||
|
"""Automatically generate masks from an image using GroundingDINO & Segment Anything"""
|
||||||
|
|
||||||
|
image: ImageField = InputField(description="The image to process")
|
||||||
|
prompt: str = InputField(default="", description="Keywords to segment", title="Prompt")
|
||||||
|
box_threshold: float = InputField(
|
||||||
|
default=0.5, ge=0, le=1, description="Threshold of box detection", title="Box Threshold"
|
||||||
|
)
|
||||||
|
text_threshold: float = InputField(
|
||||||
|
default=0.5, ge=0, le=1, description="Threshold of text detection", title="Text Threshold"
|
||||||
|
)
|
||||||
|
nms_threshold: float = InputField(
|
||||||
|
default=0.8, ge=0, le=1, description="Threshold of nms detection", title="NMS Threshold"
|
||||||
|
)
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
|
input_image = context.images.get_pil(self.image.image_name)
|
||||||
|
|
||||||
|
grounding_dino_model = context.models.load_remote_model(
|
||||||
|
GROUNDING_SEGMENT_ANYTHING_MODELS["groundingdino_swint_ogc"]
|
||||||
|
)
|
||||||
|
segment_anything_model = context.models.load_remote_model(
|
||||||
|
GROUNDING_SEGMENT_ANYTHING_MODELS["segment_anything_vit_h"]
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
grounding_dino_model.model_on_device() as (_, grounding_dino_state_dict),
|
||||||
|
segment_anything_model.model_on_device() as (_, segment_anything_state_dict),
|
||||||
|
):
|
||||||
|
if not grounding_dino_state_dict or not segment_anything_state_dict:
|
||||||
|
raise RuntimeError("Unable to load segmentation models")
|
||||||
|
|
||||||
|
grounding_dino = GroundingSegmentAnythingDetector.build_grounding_dino(
|
||||||
|
cast(Dict[str, torch.Tensor], grounding_dino_state_dict)
|
||||||
|
)
|
||||||
|
segment_anything = GroundingSegmentAnythingDetector.build_segment_anything(
|
||||||
|
cast(Dict[str, torch.Tensor], segment_anything_state_dict), TorchDevice.choose_torch_device()
|
||||||
|
)
|
||||||
|
detector = GroundingSegmentAnythingDetector(grounding_dino, segment_anything)
|
||||||
|
|
||||||
|
mask = detector.predict(
|
||||||
|
input_image, self.prompt, self.box_threshold, self.text_threshold, self.nms_threshold
|
||||||
|
)
|
||||||
|
image_dto = context.images.save(mask)
|
||||||
|
|
||||||
|
"""Builds an ImageOutput and its ImageField"""
|
||||||
|
processed_image_field = ImageField(image_name=image_dto.image_name)
|
||||||
|
return ImageOutput(
|
||||||
|
image=processed_image_field,
|
||||||
|
width=input_image.width,
|
||||||
|
height=input_image.height,
|
||||||
|
)
|
@ -10,8 +10,8 @@ import torch
|
|||||||
import torchvision.transforms as T
|
import torchvision.transforms as T
|
||||||
import torchvision.transforms.functional as F
|
import torchvision.transforms.functional as F
|
||||||
|
|
||||||
from groundingdino.util.box_ops import box_xyxy_to_cxcywh
|
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.box_ops import box_xyxy_to_cxcywh
|
||||||
from groundingdino.util.misc import interpolate
|
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.misc import interpolate
|
||||||
|
|
||||||
|
|
||||||
def crop(image, target, region):
|
def crop(image, target, region):
|
||||||
@ -58,9 +58,7 @@ def crop(image, target, region):
|
|||||||
if os.environ.get("IPDB_SHILONG_DEBUG", None) == "INFO":
|
if os.environ.get("IPDB_SHILONG_DEBUG", None) == "INFO":
|
||||||
# for debug and visualization only.
|
# for debug and visualization only.
|
||||||
if "strings_positive" in target:
|
if "strings_positive" in target:
|
||||||
target["strings_positive"] = [
|
target["strings_positive"] = [_i for _i, _j in zip(target["strings_positive"], keep, strict=False) if _j]
|
||||||
_i for _i, _j in zip(target["strings_positive"], keep) if _j
|
|
||||||
]
|
|
||||||
|
|
||||||
return cropped_image, target
|
return cropped_image, target
|
||||||
|
|
||||||
@ -73,9 +71,7 @@ def hflip(image, target):
|
|||||||
target = target.copy()
|
target = target.copy()
|
||||||
if "boxes" in target:
|
if "boxes" in target:
|
||||||
boxes = target["boxes"]
|
boxes = target["boxes"]
|
||||||
boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor(
|
boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0])
|
||||||
[w, 0, w, 0]
|
|
||||||
)
|
|
||||||
target["boxes"] = boxes
|
target["boxes"] = boxes
|
||||||
|
|
||||||
if "masks" in target:
|
if "masks" in target:
|
||||||
@ -119,15 +115,13 @@ def resize(image, target, size, max_size=None):
|
|||||||
if target is None:
|
if target is None:
|
||||||
return rescaled_image, None
|
return rescaled_image, None
|
||||||
|
|
||||||
ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size))
|
ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size, strict=False))
|
||||||
ratio_width, ratio_height = ratios
|
ratio_width, ratio_height = ratios
|
||||||
|
|
||||||
target = target.copy()
|
target = target.copy()
|
||||||
if "boxes" in target:
|
if "boxes" in target:
|
||||||
boxes = target["boxes"]
|
boxes = target["boxes"]
|
||||||
scaled_boxes = boxes * torch.as_tensor(
|
scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height])
|
||||||
[ratio_width, ratio_height, ratio_width, ratio_height]
|
|
||||||
)
|
|
||||||
target["boxes"] = scaled_boxes
|
target["boxes"] = scaled_boxes
|
||||||
|
|
||||||
if "area" in target:
|
if "area" in target:
|
||||||
@ -139,9 +133,7 @@ def resize(image, target, size, max_size=None):
|
|||||||
target["size"] = torch.tensor([h, w])
|
target["size"] = torch.tensor([h, w])
|
||||||
|
|
||||||
if "masks" in target:
|
if "masks" in target:
|
||||||
target["masks"] = (
|
target["masks"] = interpolate(target["masks"][:, None].float(), size, mode="nearest")[:, 0] > 0.5
|
||||||
interpolate(target["masks"][:, None].float(), size, mode="nearest")[:, 0] > 0.5
|
|
||||||
)
|
|
||||||
|
|
||||||
return rescaled_image, target
|
return rescaled_image, target
|
||||||
|
|
||||||
@ -192,11 +184,7 @@ class RandomSizeCrop(object):
|
|||||||
h = random.randint(self.min_size, min(img.height, self.max_size))
|
h = random.randint(self.min_size, min(img.height, self.max_size))
|
||||||
region = T.RandomCrop.get_params(img, [h, w])
|
region = T.RandomCrop.get_params(img, [h, w])
|
||||||
result_img, result_target = crop(img, target, region)
|
result_img, result_target = crop(img, target, region)
|
||||||
if (
|
if not self.respect_boxes or len(result_target["boxes"]) == init_boxes or i == max_patience - 1:
|
||||||
not self.respect_boxes
|
|
||||||
or len(result_target["boxes"]) == init_boxes
|
|
||||||
or i == max_patience - 1
|
|
||||||
):
|
|
||||||
return result_img, result_target
|
return result_img, result_target
|
||||||
return result_img, result_target
|
return result_img, result_target
|
||||||
|
|
||||||
|
@ -12,4 +12,6 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
||||||
# ------------------------------------------------------------------------
|
# ------------------------------------------------------------------------
|
||||||
|
|
||||||
from .groundingdino import build_groundingdino
|
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.models.GroundingDINO.groundingdino import (
|
||||||
|
build_groundingdino,
|
||||||
|
)
|
||||||
|
@ -24,10 +24,13 @@ import torchvision
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from torchvision.models._utils import IntermediateLayerGetter
|
from torchvision.models._utils import IntermediateLayerGetter
|
||||||
|
|
||||||
from groundingdino.util.misc import NestedTensor, clean_state_dict, is_main_process
|
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.models.GroundingDINO.backbone.position_encoding import (
|
||||||
|
build_position_encoding,
|
||||||
from .position_encoding import build_position_encoding
|
)
|
||||||
from .swin_transformer import build_swin_transformer
|
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.models.GroundingDINO.backbone.swin_transformer import (
|
||||||
|
build_swin_transformer,
|
||||||
|
)
|
||||||
|
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.misc import NestedTensor, is_main_process
|
||||||
|
|
||||||
|
|
||||||
class FrozenBatchNorm2d(torch.nn.Module):
|
class FrozenBatchNorm2d(torch.nn.Module):
|
||||||
@ -80,19 +83,12 @@ class BackboneBase(nn.Module):
|
|||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
for name, parameter in backbone.named_parameters():
|
for name, parameter in backbone.named_parameters():
|
||||||
if (
|
if not train_backbone or "layer2" not in name and "layer3" not in name and "layer4" not in name:
|
||||||
not train_backbone
|
|
||||||
or "layer2" not in name
|
|
||||||
and "layer3" not in name
|
|
||||||
and "layer4" not in name
|
|
||||||
):
|
|
||||||
parameter.requires_grad_(False)
|
parameter.requires_grad_(False)
|
||||||
|
|
||||||
return_layers = {}
|
return_layers = {}
|
||||||
for idx, layer_index in enumerate(return_interm_indices):
|
for idx, layer_index in enumerate(return_interm_indices):
|
||||||
return_layers.update(
|
return_layers.update({"layer{}".format(5 - len(return_interm_indices) + idx): "{}".format(layer_index)})
|
||||||
{"layer{}".format(5 - len(return_interm_indices) + idx): "{}".format(layer_index)}
|
|
||||||
)
|
|
||||||
|
|
||||||
# if len:
|
# if len:
|
||||||
# if use_stage1_feature:
|
# if use_stage1_feature:
|
||||||
@ -214,8 +210,8 @@ def build_backbone(args):
|
|||||||
|
|
||||||
model = Joiner(backbone, position_embedding)
|
model = Joiner(backbone, position_embedding)
|
||||||
model.num_channels = bb_num_channels
|
model.num_channels = bb_num_channels
|
||||||
assert isinstance(
|
assert isinstance(bb_num_channels, List), "bb_num_channels is expected to be a List but {}".format(
|
||||||
bb_num_channels, List
|
type(bb_num_channels)
|
||||||
), "bb_num_channels is expected to be a List but {}".format(type(bb_num_channels))
|
)
|
||||||
# import ipdb; ipdb.set_trace()
|
# import ipdb; ipdb.set_trace()
|
||||||
return model
|
return model
|
||||||
|
@ -24,7 +24,7 @@ import math
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from groundingdino.util.misc import NestedTensor
|
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.misc import NestedTensor
|
||||||
|
|
||||||
|
|
||||||
class PositionEmbeddingSine(nn.Module):
|
class PositionEmbeddingSine(nn.Module):
|
||||||
@ -65,12 +65,8 @@ class PositionEmbeddingSine(nn.Module):
|
|||||||
|
|
||||||
pos_x = x_embed[:, :, :, None] / dim_t
|
pos_x = x_embed[:, :, :, None] / dim_t
|
||||||
pos_y = y_embed[:, :, :, None] / dim_t
|
pos_y = y_embed[:, :, :, None] / dim_t
|
||||||
pos_x = torch.stack(
|
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
||||||
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
|
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
||||||
).flatten(3)
|
|
||||||
pos_y = torch.stack(
|
|
||||||
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
|
|
||||||
).flatten(3)
|
|
||||||
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
||||||
return pos
|
return pos
|
||||||
|
|
||||||
@ -81,9 +77,7 @@ class PositionEmbeddingSineHW(nn.Module):
|
|||||||
used by the Attention is all you need paper, generalized to work on images.
|
used by the Attention is all you need paper, generalized to work on images.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, num_pos_feats=64, temperatureH=10000, temperatureW=10000, normalize=False, scale=None):
|
||||||
self, num_pos_feats=64, temperatureH=10000, temperatureW=10000, normalize=False, scale=None
|
|
||||||
):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_pos_feats = num_pos_feats
|
self.num_pos_feats = num_pos_feats
|
||||||
self.temperatureH = temperatureH
|
self.temperatureH = temperatureH
|
||||||
@ -111,19 +105,15 @@ class PositionEmbeddingSineHW(nn.Module):
|
|||||||
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
||||||
|
|
||||||
dim_tx = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
dim_tx = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
||||||
dim_tx = self.temperatureW ** (2 * (torch.div(dim_tx, 2, rounding_mode='floor')) / self.num_pos_feats)
|
dim_tx = self.temperatureW ** (2 * (torch.div(dim_tx, 2, rounding_mode="floor")) / self.num_pos_feats)
|
||||||
pos_x = x_embed[:, :, :, None] / dim_tx
|
pos_x = x_embed[:, :, :, None] / dim_tx
|
||||||
|
|
||||||
dim_ty = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
dim_ty = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
||||||
dim_ty = self.temperatureH ** (2 * (torch.div(dim_ty, 2, rounding_mode='floor')) / self.num_pos_feats)
|
dim_ty = self.temperatureH ** (2 * (torch.div(dim_ty, 2, rounding_mode="floor")) / self.num_pos_feats)
|
||||||
pos_y = y_embed[:, :, :, None] / dim_ty
|
pos_y = y_embed[:, :, :, None] / dim_ty
|
||||||
|
|
||||||
pos_x = torch.stack(
|
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
||||||
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
|
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
||||||
).flatten(3)
|
|
||||||
pos_y = torch.stack(
|
|
||||||
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
|
|
||||||
).flatten(3)
|
|
||||||
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
||||||
|
|
||||||
# import ipdb; ipdb.set_trace()
|
# import ipdb; ipdb.set_trace()
|
||||||
|
@ -18,15 +18,13 @@ import torch.nn.functional as F
|
|||||||
import torch.utils.checkpoint as checkpoint
|
import torch.utils.checkpoint as checkpoint
|
||||||
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
||||||
|
|
||||||
from groundingdino.util.misc import NestedTensor
|
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.misc import NestedTensor
|
||||||
|
|
||||||
|
|
||||||
class Mlp(nn.Module):
|
class Mlp(nn.Module):
|
||||||
"""Multilayer perceptron."""
|
"""Multilayer perceptron."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0):
|
||||||
self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0
|
|
||||||
):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
out_features = out_features or in_features
|
out_features = out_features or in_features
|
||||||
hidden_features = hidden_features or in_features
|
hidden_features = hidden_features or in_features
|
||||||
@ -138,24 +136,16 @@ class WindowAttention(nn.Module):
|
|||||||
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
||||||
"""
|
"""
|
||||||
B_, N, C = x.shape
|
B_, N, C = x.shape
|
||||||
qkv = (
|
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||||
self.qkv(x)
|
|
||||||
.reshape(B_, N, 3, self.num_heads, C // self.num_heads)
|
|
||||||
.permute(2, 0, 3, 1, 4)
|
|
||||||
)
|
|
||||||
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
||||||
|
|
||||||
q = q * self.scale
|
q = q * self.scale
|
||||||
attn = q @ k.transpose(-2, -1)
|
attn = q @ k.transpose(-2, -1)
|
||||||
|
|
||||||
relative_position_bias = self.relative_position_bias_table[
|
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
||||||
self.relative_position_index.view(-1)
|
|
||||||
].view(
|
|
||||||
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
|
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
|
||||||
) # Wh*Ww,Wh*Ww,nH
|
) # Wh*Ww,Wh*Ww,nH
|
||||||
relative_position_bias = relative_position_bias.permute(
|
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||||
2, 0, 1
|
|
||||||
).contiguous() # nH, Wh*Ww, Wh*Ww
|
|
||||||
attn = attn + relative_position_bias.unsqueeze(0)
|
attn = attn + relative_position_bias.unsqueeze(0)
|
||||||
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
@ -228,9 +218,7 @@ class SwinTransformerBlock(nn.Module):
|
|||||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||||
self.norm2 = norm_layer(dim)
|
self.norm2 = norm_layer(dim)
|
||||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||||
self.mlp = Mlp(
|
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
||||||
in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop
|
|
||||||
)
|
|
||||||
|
|
||||||
self.H = None
|
self.H = None
|
||||||
self.W = None
|
self.W = None
|
||||||
@ -266,12 +254,8 @@ class SwinTransformerBlock(nn.Module):
|
|||||||
attn_mask = None
|
attn_mask = None
|
||||||
|
|
||||||
# partition windows
|
# partition windows
|
||||||
x_windows = window_partition(
|
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
||||||
shifted_x, self.window_size
|
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
|
||||||
) # nW*B, window_size, window_size, C
|
|
||||||
x_windows = x_windows.view(
|
|
||||||
-1, self.window_size * self.window_size, C
|
|
||||||
) # nW*B, window_size*window_size, C
|
|
||||||
|
|
||||||
# W-MSA/SW-MSA
|
# W-MSA/SW-MSA
|
||||||
attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
|
attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
|
||||||
@ -433,14 +417,10 @@ class BasicLayer(nn.Module):
|
|||||||
img_mask[:, h, w, :] = cnt
|
img_mask[:, h, w, :] = cnt
|
||||||
cnt += 1
|
cnt += 1
|
||||||
|
|
||||||
mask_windows = window_partition(
|
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
|
||||||
img_mask, self.window_size
|
|
||||||
) # nW, window_size, window_size, 1
|
|
||||||
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
||||||
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
||||||
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
|
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
||||||
attn_mask == 0, float(0.0)
|
|
||||||
)
|
|
||||||
|
|
||||||
for blk in self.blocks:
|
for blk in self.blocks:
|
||||||
blk.H, blk.W = H, W
|
blk.H, blk.W = H, W
|
||||||
@ -589,9 +569,7 @@ class SwinTransformer(nn.Module):
|
|||||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||||
|
|
||||||
# stochastic depth
|
# stochastic depth
|
||||||
dpr = [
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
||||||
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
|
|
||||||
] # stochastic depth decay rule
|
|
||||||
|
|
||||||
# build layers
|
# build layers
|
||||||
self.layers = nn.ModuleList()
|
self.layers = nn.ModuleList()
|
||||||
@ -682,9 +660,7 @@ class SwinTransformer(nn.Module):
|
|||||||
Wh, Ww = x.size(2), x.size(3)
|
Wh, Ww = x.size(2), x.size(3)
|
||||||
if self.ape:
|
if self.ape:
|
||||||
# interpolate the position embedding to the corresponding size
|
# interpolate the position embedding to the corresponding size
|
||||||
absolute_pos_embed = F.interpolate(
|
absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic")
|
||||||
self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic"
|
|
||||||
)
|
|
||||||
x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
|
x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
|
||||||
else:
|
else:
|
||||||
x = x.flatten(2).transpose(1, 2)
|
x = x.flatten(2).transpose(1, 2)
|
||||||
@ -718,9 +694,7 @@ class SwinTransformer(nn.Module):
|
|||||||
Wh, Ww = x.size(2), x.size(3)
|
Wh, Ww = x.size(2), x.size(3)
|
||||||
if self.ape:
|
if self.ape:
|
||||||
# interpolate the position embedding to the corresponding size
|
# interpolate the position embedding to the corresponding size
|
||||||
absolute_pos_embed = F.interpolate(
|
absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic")
|
||||||
self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic"
|
|
||||||
)
|
|
||||||
x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
|
x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
|
||||||
else:
|
else:
|
||||||
x = x.flatten(2).transpose(1, 2)
|
x = x.flatten(2).transpose(1, 2)
|
||||||
@ -769,21 +743,11 @@ def build_swin_transformer(modelname, pretrain_img_size, **kw):
|
|||||||
]
|
]
|
||||||
|
|
||||||
model_para_dict = {
|
model_para_dict = {
|
||||||
"swin_T_224_1k": dict(
|
"swin_T_224_1k": dict(embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7),
|
||||||
embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7
|
"swin_B_224_22k": dict(embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=7),
|
||||||
),
|
"swin_B_384_22k": dict(embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=12),
|
||||||
"swin_B_224_22k": dict(
|
"swin_L_224_22k": dict(embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=7),
|
||||||
embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=7
|
"swin_L_384_22k": dict(embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=12),
|
||||||
),
|
|
||||||
"swin_B_384_22k": dict(
|
|
||||||
embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=12
|
|
||||||
),
|
|
||||||
"swin_L_224_22k": dict(
|
|
||||||
embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=7
|
|
||||||
),
|
|
||||||
"swin_L_384_22k": dict(
|
|
||||||
embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=12
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
kw_cgf = model_para_dict[modelname]
|
kw_cgf = model_para_dict[modelname]
|
||||||
kw_cgf.update(kw)
|
kw_cgf.update(kw)
|
||||||
|
@ -21,8 +21,12 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from groundingdino.util import get_tokenlizer
|
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util import get_tokenlizer
|
||||||
from groundingdino.util.misc import NestedTensor, inverse_sigmoid, nested_tensor_from_tensor_list
|
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.misc import (
|
||||||
|
NestedTensor,
|
||||||
|
inverse_sigmoid,
|
||||||
|
nested_tensor_from_tensor_list,
|
||||||
|
)
|
||||||
|
|
||||||
from ..registry import MODULE_BUILD_FUNCS
|
from ..registry import MODULE_BUILD_FUNCS
|
||||||
from .backbone import build_backbone
|
from .backbone import build_backbone
|
||||||
|
@ -22,14 +22,19 @@ import torch
|
|||||||
import torch.utils.checkpoint as checkpoint
|
import torch.utils.checkpoint as checkpoint
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
from groundingdino.util.misc import inverse_sigmoid
|
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.misc import inverse_sigmoid
|
||||||
|
|
||||||
from .fuse_modules import BiAttentionBlock
|
from .fuse_modules import BiAttentionBlock
|
||||||
from .ms_deform_attn import MultiScaleDeformableAttention as MSDeformAttn
|
from .ms_deform_attn import MultiScaleDeformableAttention as MSDeformAttn
|
||||||
from .transformer_vanilla import TransformerEncoderLayer
|
from .transformer_vanilla import TransformerEncoderLayer
|
||||||
from .utils import (MLP, _get_activation_fn, _get_clones,
|
from .utils import (
|
||||||
gen_encoder_output_proposals, gen_sineembed_for_position,
|
MLP,
|
||||||
get_sine_pos_embed)
|
_get_activation_fn,
|
||||||
|
_get_clones,
|
||||||
|
gen_encoder_output_proposals,
|
||||||
|
gen_sineembed_for_position,
|
||||||
|
get_sine_pos_embed,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Transformer(nn.Module):
|
class Transformer(nn.Module):
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -7,11 +7,11 @@ import torch
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torchvision.ops import box_convert
|
from torchvision.ops import box_convert
|
||||||
|
|
||||||
import groundingdino.datasets.transforms as T
|
import invokeai.backend.image_util.grounding_segment_anything.groundingdino.datasets.transforms as T
|
||||||
from groundingdino.models import build_model
|
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.models import build_model
|
||||||
from groundingdino.util.misc import clean_state_dict
|
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.misc import clean_state_dict
|
||||||
from groundingdino.util.slconfig import SLConfig
|
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.slconfig import SLConfig
|
||||||
from groundingdino.util.utils import get_phrases_from_posmap
|
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.utils import get_phrases_from_posmap
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------------------------
|
||||||
# OLD API
|
# OLD API
|
||||||
@ -25,12 +25,11 @@ def preprocess_caption(caption: str) -> str:
|
|||||||
return result + "."
|
return result + "."
|
||||||
|
|
||||||
|
|
||||||
def load_model(model_config_path: str, model_checkpoint_path: str, device: str = "cuda"):
|
def load_model(model_config_path: str, model_state_dict: Dict[str, torch.Tensor], device: str = "cuda"):
|
||||||
args = SLConfig.fromfile(model_config_path)
|
args = SLConfig.fromfile(model_config_path)
|
||||||
args.device = device
|
args.device = device
|
||||||
model = build_model(args)
|
model = build_model(args)
|
||||||
checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
|
model.load_state_dict(clean_state_dict(model_state_dict["model"]), strict=False)
|
||||||
model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
|
|
||||||
model.eval()
|
model.eval()
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@ -98,9 +97,9 @@ def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor
|
|||||||
|
|
||||||
class Model:
|
class Model:
|
||||||
|
|
||||||
def __init__(self, model_config_path: str, model_checkpoint_path: str, device: str = "cuda"):
|
def __init__(self, model_config_path: str, model_state_dict: Dict[str, torch.Tensor], device: str = "cuda"):
|
||||||
self.model = load_model(
|
self.model = load_model(
|
||||||
model_config_path=model_config_path, model_checkpoint_path=model_checkpoint_path, device=device
|
model_config_path=model_config_path, model_state_dict=model_state_dict, device=device
|
||||||
).to(device)
|
).to(device)
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
|
@ -1,91 +0,0 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
|
||||||
import functools
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
|
|
||||||
from termcolor import colored
|
|
||||||
|
|
||||||
|
|
||||||
class _ColorfulFormatter(logging.Formatter):
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
self._root_name = kwargs.pop("root_name") + "."
|
|
||||||
self._abbrev_name = kwargs.pop("abbrev_name", "")
|
|
||||||
if len(self._abbrev_name):
|
|
||||||
self._abbrev_name = self._abbrev_name + "."
|
|
||||||
super(_ColorfulFormatter, self).__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
def formatMessage(self, record):
|
|
||||||
record.name = record.name.replace(self._root_name, self._abbrev_name)
|
|
||||||
log = super(_ColorfulFormatter, self).formatMessage(record)
|
|
||||||
if record.levelno == logging.WARNING:
|
|
||||||
prefix = colored("WARNING", "red", attrs=["blink"])
|
|
||||||
elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
|
|
||||||
prefix = colored("ERROR", "red", attrs=["blink", "underline"])
|
|
||||||
else:
|
|
||||||
return log
|
|
||||||
return prefix + " " + log
|
|
||||||
|
|
||||||
|
|
||||||
# so that calling setup_logger multiple times won't add many handlers
|
|
||||||
@functools.lru_cache()
|
|
||||||
def setup_logger(output=None, distributed_rank=0, *, color=True, name="imagenet", abbrev_name=None):
|
|
||||||
"""
|
|
||||||
Initialize the detectron2 logger and set its verbosity level to "INFO".
|
|
||||||
|
|
||||||
Args:
|
|
||||||
output (str): a file name or a directory to save log. If None, will not save log file.
|
|
||||||
If ends with ".txt" or ".log", assumed to be a file name.
|
|
||||||
Otherwise, logs will be saved to `output/log.txt`.
|
|
||||||
name (str): the root module name of this logger
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
logging.Logger: a logger
|
|
||||||
"""
|
|
||||||
logger = logging.getLogger(name)
|
|
||||||
logger.setLevel(logging.DEBUG)
|
|
||||||
logger.propagate = False
|
|
||||||
|
|
||||||
if abbrev_name is None:
|
|
||||||
abbrev_name = name
|
|
||||||
|
|
||||||
plain_formatter = logging.Formatter("[%(asctime)s.%(msecs)03d]: %(message)s", datefmt="%m/%d %H:%M:%S")
|
|
||||||
# stdout logging: master only
|
|
||||||
if distributed_rank == 0:
|
|
||||||
ch = logging.StreamHandler(stream=sys.stdout)
|
|
||||||
ch.setLevel(logging.DEBUG)
|
|
||||||
if color:
|
|
||||||
formatter = _ColorfulFormatter(
|
|
||||||
colored("[%(asctime)s.%(msecs)03d]: ", "green") + "%(message)s",
|
|
||||||
datefmt="%m/%d %H:%M:%S",
|
|
||||||
root_name=name,
|
|
||||||
abbrev_name=str(abbrev_name),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
formatter = plain_formatter
|
|
||||||
ch.setFormatter(formatter)
|
|
||||||
logger.addHandler(ch)
|
|
||||||
|
|
||||||
# file logging: all workers
|
|
||||||
if output is not None:
|
|
||||||
if output.endswith(".txt") or output.endswith(".log"):
|
|
||||||
filename = output
|
|
||||||
else:
|
|
||||||
filename = os.path.join(output, "log.txt")
|
|
||||||
if distributed_rank > 0:
|
|
||||||
filename = filename + f".rank{distributed_rank}"
|
|
||||||
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
|
||||||
|
|
||||||
fh = logging.StreamHandler(_cached_log_stream(filename))
|
|
||||||
fh.setLevel(logging.DEBUG)
|
|
||||||
fh.setFormatter(plain_formatter)
|
|
||||||
logger.addHandler(fh)
|
|
||||||
|
|
||||||
return logger
|
|
||||||
|
|
||||||
|
|
||||||
# cache the opened file object, so that different calls to `setup_logger`
|
|
||||||
# with the same file name can safely write to the same file.
|
|
||||||
@functools.lru_cache(maxsize=None)
|
|
||||||
def _cached_log_stream(filename):
|
|
||||||
return open(filename, "a")
|
|
@ -161,7 +161,7 @@ def all_gather_cpu(data):
|
|||||||
dist.all_gather(tensor_list, tensor, group=cpu_group)
|
dist.all_gather(tensor_list, tensor, group=cpu_group)
|
||||||
|
|
||||||
data_list = []
|
data_list = []
|
||||||
for size, tensor in zip(size_list, tensor_list):
|
for size, tensor in zip(size_list, tensor_list, strict=False):
|
||||||
tensor = torch.split(tensor, [size, max_size - size], dim=0)[0]
|
tensor = torch.split(tensor, [size, max_size - size], dim=0)[0]
|
||||||
buffer = io.BytesIO(tensor.cpu().numpy())
|
buffer = io.BytesIO(tensor.cpu().numpy())
|
||||||
obj = torch.load(buffer)
|
obj = torch.load(buffer)
|
||||||
@ -210,7 +210,7 @@ def all_gather(data):
|
|||||||
dist.all_gather(tensor_list, tensor)
|
dist.all_gather(tensor_list, tensor)
|
||||||
|
|
||||||
data_list = []
|
data_list = []
|
||||||
for size, tensor in zip(size_list, tensor_list):
|
for size, tensor in zip(size_list, tensor_list, strict=False):
|
||||||
buffer = tensor.cpu().numpy().tobytes()[:size]
|
buffer = tensor.cpu().numpy().tobytes()[:size]
|
||||||
data_list.append(pickle.loads(buffer))
|
data_list.append(pickle.loads(buffer))
|
||||||
|
|
||||||
@ -240,7 +240,7 @@ def reduce_dict(input_dict, average=True):
|
|||||||
dist.all_reduce(values)
|
dist.all_reduce(values)
|
||||||
if average:
|
if average:
|
||||||
values /= world_size
|
values /= world_size
|
||||||
reduced_dict = {k: v for k, v in zip(names, values)}
|
reduced_dict = {k: v for k, v in zip(names, values, strict=False)}
|
||||||
return reduced_dict
|
return reduced_dict
|
||||||
|
|
||||||
|
|
||||||
@ -378,7 +378,7 @@ def get_sha():
|
|||||||
|
|
||||||
def collate_fn(batch):
|
def collate_fn(batch):
|
||||||
# import ipdb; ipdb.set_trace()
|
# import ipdb; ipdb.set_trace()
|
||||||
batch = list(zip(*batch))
|
batch = list(zip(*batch, strict=False))
|
||||||
batch[0] = nested_tensor_from_tensor_list(batch[0])
|
batch[0] = nested_tensor_from_tensor_list(batch[0])
|
||||||
return tuple(batch)
|
return tuple(batch)
|
||||||
|
|
||||||
@ -480,7 +480,7 @@ def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
|
|||||||
device = tensor_list[0].device
|
device = tensor_list[0].device
|
||||||
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
|
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
|
||||||
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
|
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
|
||||||
for img, pad_img, m in zip(tensor_list, tensor, mask):
|
for img, pad_img, m in zip(tensor_list, tensor, mask, strict=False):
|
||||||
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
||||||
m[: img.shape[1], : img.shape[2]] = False
|
m[: img.shape[1], : img.shape[2]] = False
|
||||||
else:
|
else:
|
||||||
@ -505,7 +505,7 @@ def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTen
|
|||||||
padded_imgs = []
|
padded_imgs = []
|
||||||
padded_masks = []
|
padded_masks = []
|
||||||
for img in tensor_list:
|
for img in tensor_list:
|
||||||
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
|
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape), strict=False)]
|
||||||
padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
|
padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
|
||||||
padded_imgs.append(padded_img)
|
padded_imgs.append(padded_img)
|
||||||
|
|
||||||
|
@ -9,7 +9,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from groundingdino.util.slconfig import SLConfig
|
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.slconfig import SLConfig
|
||||||
|
|
||||||
|
|
||||||
def slprint(x, name="x"):
|
def slprint(x, name="x"):
|
||||||
|
@ -1,44 +1,38 @@
|
|||||||
from typing import Dict, List, Literal, Optional
|
import pathlib
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import supervision as sv
|
import supervision as sv
|
||||||
import torch
|
import torch
|
||||||
import torchvision
|
import torchvision
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.inference import Model
|
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.inference import Model
|
||||||
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.build_sam import sam_model_registry
|
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.build_sam import sam_model_registry
|
||||||
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.predictor import SamPredictor
|
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.predictor import SamPredictor
|
||||||
|
|
||||||
GROUNDING_SEGMENT_ANYTHING_MODELS = {
|
|
||||||
"groundingdino_swint_ogc": "https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth",
|
|
||||||
"segment_anything_vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class GroundingSegmentAnythingDetector:
|
class GroundingSegmentAnythingDetector:
|
||||||
def __init__(self) -> None:
|
def __init__(self, grounding_dino_model: Model, segment_anything_model: SamPredictor) -> None:
|
||||||
self.grounding_dino_model: Optional[Model] = None
|
self.grounding_dino_model: Optional[Model] = grounding_dino_model
|
||||||
self.segment_anything_model: Optional[SamPredictor] = None
|
self.segment_anything_model: Optional[SamPredictor] = segment_anything_model
|
||||||
self.grounding_dino_config: str = "./groundingdino/config/GroundingDINO_SwinT_OGC.py"
|
|
||||||
self.sam_encoder: Literal["vit_h"] = "vit_h"
|
|
||||||
self.device: torch.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
|
||||||
|
|
||||||
def build_grounding_dino(self):
|
@staticmethod
|
||||||
|
def build_grounding_dino(grounding_dino_state_dict: Dict[str, torch.Tensor]):
|
||||||
|
grounding_dino_config = pathlib.Path(
|
||||||
|
"./invokeai/backend/image_util/grounding_segment_anything/groundingdino/config/GroundingDINO_SwinT_OGC.py"
|
||||||
|
)
|
||||||
return Model(
|
return Model(
|
||||||
model_config_path=self.grounding_dino_config,
|
model_state_dict=grounding_dino_state_dict,
|
||||||
model_checkpoint_path="./checkpoints/groundingdino_swint_ogc.pth",
|
model_config_path=grounding_dino_config.as_posix(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def build_segment_anything(self):
|
@staticmethod
|
||||||
sam = sam_model_registry[self.sam_encoder](checkpoint="./checkpoints/sam_vit_h_4b8939.pth")
|
def build_segment_anything(segment_anything_state_dict: Dict[str, torch.Tensor], device: torch.device):
|
||||||
sam.to(device=self.device)
|
sam = sam_model_registry["vit_h"](checkpoint=segment_anything_state_dict)
|
||||||
|
sam.to(device=device)
|
||||||
return SamPredictor(sam)
|
return SamPredictor(sam)
|
||||||
|
|
||||||
def build_grounding_sam(self):
|
|
||||||
self.grounding_dino_model = self.build_grounding_dino()
|
|
||||||
self.segment_anything_model = self.build_segment_anything()
|
|
||||||
|
|
||||||
def detect_objects(
|
def detect_objects(
|
||||||
self,
|
self,
|
||||||
image: np.ndarray,
|
image: np.ndarray,
|
||||||
@ -77,20 +71,18 @@ class GroundingSegmentAnythingDetector:
|
|||||||
|
|
||||||
def predict(
|
def predict(
|
||||||
self,
|
self,
|
||||||
image: str,
|
image: Image.Image,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
box_threshold: float = 0.5,
|
box_threshold: float = 0.5,
|
||||||
text_threshold: float = 0.5,
|
text_threshold: float = 0.5,
|
||||||
nms_threshold: float = 0.8,
|
nms_threshold: float = 0.8,
|
||||||
):
|
):
|
||||||
if not self.grounding_dino_model or not self.segment_anything_model:
|
open_cv_image = np.array(image)
|
||||||
self.build_grounding_sam()
|
open_cv_image = open_cv_image[:, :, ::-1].copy()
|
||||||
|
|
||||||
image = cv2.imread(image)
|
|
||||||
prompts = prompt.split(",")
|
prompts = prompt.split(",")
|
||||||
|
|
||||||
detections = self.detect_objects(image, prompts, box_threshold, text_threshold, nms_threshold)
|
detections = self.detect_objects(open_cv_image, prompts, box_threshold, text_threshold, nms_threshold)
|
||||||
segments = self.segment_detections(image, detections, prompts)
|
segments = self.segment_detections(open_cv_image, detections, prompts)
|
||||||
|
|
||||||
if len(segments) > 0:
|
if len(segments) > 0:
|
||||||
combined_mask = np.zeros_like(list(segments.values())[0])
|
combined_mask = np.zeros_like(list(segments.values())[0])
|
||||||
@ -98,15 +90,6 @@ class GroundingSegmentAnythingDetector:
|
|||||||
combined_mask = np.logical_or(combined_mask, mask)
|
combined_mask = np.logical_or(combined_mask, mask)
|
||||||
mask_preview = (combined_mask * 255).astype(np.uint8)
|
mask_preview = (combined_mask * 255).astype(np.uint8)
|
||||||
else:
|
else:
|
||||||
mask_preview = np.zeros(image.shape, np.uint8)
|
mask_preview = np.zeros(open_cv_image.shape, np.uint8)
|
||||||
|
|
||||||
cv2.imwrite("mask.png", mask_preview)
|
return Image.fromarray(mask_preview)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
gsa = GroundingSegmentAnythingDetector()
|
|
||||||
image = "./assets/image.webp"
|
|
||||||
|
|
||||||
while True:
|
|
||||||
prompt = input("Segment: ")
|
|
||||||
gsa.predict(image, prompt, 0.5, 0.5, 0.8)
|
|
||||||
|
@ -4,13 +4,22 @@
|
|||||||
# This source code is licensed under the license found in the
|
# This source code is licensed under the license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from .automatic_mask_generator import SamAutomaticMaskGenerator
|
|
||||||
from .build_sam import build_sam, build_sam_vit_b, build_sam_vit_h, build_sam_vit_l, sam_model_registry
|
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.automatic_mask_generator import (
|
||||||
from .build_sam_hq import (
|
SamAutomaticMaskGenerator,
|
||||||
|
) # noqa
|
||||||
|
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.build_sam import ( # noqa
|
||||||
|
build_sam,
|
||||||
|
build_sam_vit_b,
|
||||||
|
build_sam_vit_h,
|
||||||
|
build_sam_vit_l,
|
||||||
|
sam_model_registry,
|
||||||
|
)
|
||||||
|
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.build_sam_hq import ( # noqa
|
||||||
build_sam_hq,
|
build_sam_hq,
|
||||||
build_sam_hq_vit_b,
|
build_sam_hq_vit_b,
|
||||||
build_sam_hq_vit_h,
|
build_sam_hq_vit_h,
|
||||||
build_sam_hq_vit_l,
|
build_sam_hq_vit_l,
|
||||||
sam_hq_model_registry,
|
sam_hq_model_registry,
|
||||||
)
|
)
|
||||||
from .predictor import SamPredictor
|
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.predictor import SamPredictor # noqa
|
||||||
|
@ -10,9 +10,9 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from torchvision.ops.boxes import batched_nms, box_area # type: ignore
|
from torchvision.ops.boxes import batched_nms, box_area # type: ignore
|
||||||
|
|
||||||
from .modeling import Sam
|
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling import Sam
|
||||||
from .predictor import SamPredictor
|
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.predictor import SamPredictor
|
||||||
from .utils.amg import (
|
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.utils.amg import (
|
||||||
MaskData,
|
MaskData,
|
||||||
area_from_rle,
|
area_from_rle,
|
||||||
batch_iterator,
|
batch_iterator,
|
||||||
|
@ -8,7 +8,13 @@ from functools import partial
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer
|
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling import (
|
||||||
|
ImageEncoderViT,
|
||||||
|
MaskDecoder,
|
||||||
|
PromptEncoder,
|
||||||
|
Sam,
|
||||||
|
TwoWayTransformer,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_sam_vit_h(checkpoint=None):
|
def build_sam_vit_h(checkpoint=None):
|
||||||
@ -101,7 +107,5 @@ def _build_sam(
|
|||||||
)
|
)
|
||||||
sam.eval()
|
sam.eval()
|
||||||
if checkpoint is not None:
|
if checkpoint is not None:
|
||||||
with open(checkpoint, "rb") as f:
|
sam.load_state_dict(checkpoint)
|
||||||
state_dict = torch.load(f)
|
|
||||||
sam.load_state_dict(state_dict)
|
|
||||||
return sam
|
return sam
|
||||||
|
@ -8,7 +8,13 @@ from functools import partial
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .modeling import ImageEncoderViT, MaskDecoderHQ, PromptEncoder, Sam, TwoWayTransformer
|
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling import (
|
||||||
|
ImageEncoderViT,
|
||||||
|
MaskDecoderHQ,
|
||||||
|
PromptEncoder,
|
||||||
|
Sam,
|
||||||
|
TwoWayTransformer,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_sam_hq_vit_h(checkpoint=None):
|
def build_sam_hq_vit_h(checkpoint=None):
|
||||||
|
@ -4,9 +4,17 @@
|
|||||||
# This source code is licensed under the license found in the
|
# This source code is licensed under the license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from .image_encoder import ImageEncoderViT
|
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.image_encoder import (
|
||||||
from .mask_decoder import MaskDecoder
|
ImageEncoderViT,
|
||||||
from .mask_decoder_hq import MaskDecoderHQ
|
)
|
||||||
from .prompt_encoder import PromptEncoder
|
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.mask_decoder import MaskDecoder
|
||||||
from .sam import Sam
|
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.mask_decoder_hq import (
|
||||||
from .transformer import TwoWayTransformer
|
MaskDecoderHQ,
|
||||||
|
)
|
||||||
|
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.prompt_encoder import (
|
||||||
|
PromptEncoder,
|
||||||
|
)
|
||||||
|
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.sam import Sam
|
||||||
|
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.transformer import (
|
||||||
|
TwoWayTransformer,
|
||||||
|
)
|
||||||
|
@ -10,7 +10,10 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from .common import LayerNorm2d, MLPBlock
|
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.common import (
|
||||||
|
LayerNorm2d,
|
||||||
|
MLPBlock,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
|
# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
|
||||||
|
@ -10,7 +10,7 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from .common import LayerNorm2d
|
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.common import LayerNorm2d
|
||||||
|
|
||||||
|
|
||||||
class MaskDecoder(nn.Module):
|
class MaskDecoder(nn.Module):
|
||||||
|
@ -11,7 +11,7 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from .common import LayerNorm2d
|
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.common import LayerNorm2d
|
||||||
|
|
||||||
|
|
||||||
class MaskDecoderHQ(nn.Module):
|
class MaskDecoderHQ(nn.Module):
|
||||||
|
@ -10,7 +10,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from .common import LayerNorm2d
|
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.common import LayerNorm2d
|
||||||
|
|
||||||
|
|
||||||
class PromptEncoder(nn.Module):
|
class PromptEncoder(nn.Module):
|
||||||
|
@ -10,9 +10,13 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from .image_encoder import ImageEncoderViT
|
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.image_encoder import (
|
||||||
from .mask_decoder import MaskDecoder
|
ImageEncoderViT,
|
||||||
from .prompt_encoder import PromptEncoder
|
)
|
||||||
|
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.mask_decoder import MaskDecoder
|
||||||
|
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.prompt_encoder import (
|
||||||
|
PromptEncoder,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Sam(nn.Module):
|
class Sam(nn.Module):
|
||||||
@ -98,7 +102,7 @@ class Sam(nn.Module):
|
|||||||
image_embeddings = self.image_encoder(input_images)
|
image_embeddings = self.image_encoder(input_images)
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
for image_record, curr_embedding in zip(batched_input, image_embeddings):
|
for image_record, curr_embedding in zip(batched_input, image_embeddings, strict=False):
|
||||||
if "point_coords" in image_record:
|
if "point_coords" in image_record:
|
||||||
points = (image_record["point_coords"], image_record["point_labels"])
|
points = (image_record["point_coords"], image_record["point_labels"])
|
||||||
else:
|
else:
|
||||||
|
@ -10,7 +10,7 @@ from typing import Tuple, Type
|
|||||||
import torch
|
import torch
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
from .common import MLPBlock
|
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.common import MLPBlock
|
||||||
|
|
||||||
|
|
||||||
class TwoWayTransformer(nn.Module):
|
class TwoWayTransformer(nn.Module):
|
||||||
|
@ -9,8 +9,8 @@ from typing import Optional, Tuple
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .modeling import Sam
|
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling import Sam
|
||||||
from .utils.transforms import ResizeLongestSide
|
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.utils.transforms import ResizeLongestSide
|
||||||
|
|
||||||
|
|
||||||
class SamPredictor:
|
class SamPredictor:
|
||||||
|
@ -4,14 +4,14 @@
|
|||||||
# This source code is licensed under the license found in the
|
# This source code is licensed under the license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from itertools import product
|
from itertools import product
|
||||||
from typing import Any, Dict, Generator, ItemsView, List, Tuple
|
from typing import Any, Dict, Generator, ItemsView, List, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
class MaskData:
|
class MaskData:
|
||||||
"""
|
"""
|
||||||
@ -153,9 +153,7 @@ def area_from_rle(rle: Dict[str, Any]) -> int:
|
|||||||
return sum(rle["counts"][1::2])
|
return sum(rle["counts"][1::2])
|
||||||
|
|
||||||
|
|
||||||
def calculate_stability_score(
|
def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, threshold_offset: float) -> torch.Tensor:
|
||||||
masks: torch.Tensor, mask_threshold: float, threshold_offset: float
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
"""
|
||||||
Computes the stability score for a batch of masks. The stability
|
Computes the stability score for a batch of masks. The stability
|
||||||
score is the IoU between the binary masks obtained by thresholding
|
score is the IoU between the binary masks obtained by thresholding
|
||||||
@ -163,16 +161,8 @@ def calculate_stability_score(
|
|||||||
"""
|
"""
|
||||||
# One mask is always contained inside the other.
|
# One mask is always contained inside the other.
|
||||||
# Save memory by preventing unnecesary cast to torch.int64
|
# Save memory by preventing unnecesary cast to torch.int64
|
||||||
intersections = (
|
intersections = (masks > (mask_threshold + threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
|
||||||
(masks > (mask_threshold + threshold_offset))
|
unions = (masks > (mask_threshold - threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
|
||||||
.sum(-1, dtype=torch.int16)
|
|
||||||
.sum(-1, dtype=torch.int32)
|
|
||||||
)
|
|
||||||
unions = (
|
|
||||||
(masks > (mask_threshold - threshold_offset))
|
|
||||||
.sum(-1, dtype=torch.int16)
|
|
||||||
.sum(-1, dtype=torch.int32)
|
|
||||||
)
|
|
||||||
return intersections / unions
|
return intersections / unions
|
||||||
|
|
||||||
|
|
||||||
@ -186,9 +176,7 @@ def build_point_grid(n_per_side: int) -> np.ndarray:
|
|||||||
return points
|
return points
|
||||||
|
|
||||||
|
|
||||||
def build_all_layer_point_grids(
|
def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer: int) -> List[np.ndarray]:
|
||||||
n_per_side: int, n_layers: int, scale_per_layer: int
|
|
||||||
) -> List[np.ndarray]:
|
|
||||||
"""Generates point grids for all crop layers."""
|
"""Generates point grids for all crop layers."""
|
||||||
points_by_layer = []
|
points_by_layer = []
|
||||||
for i in range(n_layers + 1):
|
for i in range(n_layers + 1):
|
||||||
@ -252,9 +240,7 @@ def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
|
|||||||
return points + offset
|
return points + offset
|
||||||
|
|
||||||
|
|
||||||
def uncrop_masks(
|
def uncrop_masks(masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int) -> torch.Tensor:
|
||||||
masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int
|
|
||||||
) -> torch.Tensor:
|
|
||||||
x0, y0, x1, y1 = crop_box
|
x0, y0, x1, y1 = crop_box
|
||||||
if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
|
if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
|
||||||
return masks
|
return masks
|
||||||
@ -264,9 +250,7 @@ def uncrop_masks(
|
|||||||
return torch.nn.functional.pad(masks, pad, value=0)
|
return torch.nn.functional.pad(masks, pad, value=0)
|
||||||
|
|
||||||
|
|
||||||
def remove_small_regions(
|
def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tuple[np.ndarray, bool]:
|
||||||
mask: np.ndarray, area_thresh: float, mode: str
|
|
||||||
) -> Tuple[np.ndarray, bool]:
|
|
||||||
"""
|
"""
|
||||||
Removes small disconnected regions and holes in a mask. Returns the
|
Removes small disconnected regions and holes in a mask. Returns the
|
||||||
mask and an indicator of if the mask has been modified.
|
mask and an indicator of if the mask has been modified.
|
||||||
|
@ -1,144 +0,0 @@
|
|||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
|
|
||||||
# This source code is licensed under the license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch.nn import functional as F
|
|
||||||
|
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
from ..modeling import Sam
|
|
||||||
from .amg import calculate_stability_score
|
|
||||||
|
|
||||||
|
|
||||||
class SamOnnxModel(nn.Module):
|
|
||||||
"""
|
|
||||||
This model should not be called directly, but is used in ONNX export.
|
|
||||||
It combines the prompt encoder, mask decoder, and mask postprocessing of Sam,
|
|
||||||
with some functions modified to enable model tracing. Also supports extra
|
|
||||||
options controlling what information. See the ONNX export script for details.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model: Sam,
|
|
||||||
return_single_mask: bool,
|
|
||||||
use_stability_score: bool = False,
|
|
||||||
return_extra_metrics: bool = False,
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.mask_decoder = model.mask_decoder
|
|
||||||
self.model = model
|
|
||||||
self.img_size = model.image_encoder.img_size
|
|
||||||
self.return_single_mask = return_single_mask
|
|
||||||
self.use_stability_score = use_stability_score
|
|
||||||
self.stability_score_offset = 1.0
|
|
||||||
self.return_extra_metrics = return_extra_metrics
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def resize_longest_image_size(
|
|
||||||
input_image_size: torch.Tensor, longest_side: int
|
|
||||||
) -> torch.Tensor:
|
|
||||||
input_image_size = input_image_size.to(torch.float32)
|
|
||||||
scale = longest_side / torch.max(input_image_size)
|
|
||||||
transformed_size = scale * input_image_size
|
|
||||||
transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64)
|
|
||||||
return transformed_size
|
|
||||||
|
|
||||||
def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor:
|
|
||||||
point_coords = point_coords + 0.5
|
|
||||||
point_coords = point_coords / self.img_size
|
|
||||||
point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords)
|
|
||||||
point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)
|
|
||||||
|
|
||||||
point_embedding = point_embedding * (point_labels != -1)
|
|
||||||
point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * (
|
|
||||||
point_labels == -1
|
|
||||||
)
|
|
||||||
|
|
||||||
for i in range(self.model.prompt_encoder.num_point_embeddings):
|
|
||||||
point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[
|
|
||||||
i
|
|
||||||
].weight * (point_labels == i)
|
|
||||||
|
|
||||||
return point_embedding
|
|
||||||
|
|
||||||
def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor:
|
|
||||||
mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask)
|
|
||||||
mask_embedding = mask_embedding + (
|
|
||||||
1 - has_mask_input
|
|
||||||
) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)
|
|
||||||
return mask_embedding
|
|
||||||
|
|
||||||
def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor:
|
|
||||||
masks = F.interpolate(
|
|
||||||
masks,
|
|
||||||
size=(self.img_size, self.img_size),
|
|
||||||
mode="bilinear",
|
|
||||||
align_corners=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size)
|
|
||||||
masks = masks[..., : int(prepadded_size[0]), : int(prepadded_size[1])]
|
|
||||||
|
|
||||||
orig_im_size = orig_im_size.to(torch.int64)
|
|
||||||
h, w = orig_im_size[0], orig_im_size[1]
|
|
||||||
masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False)
|
|
||||||
return masks
|
|
||||||
|
|
||||||
def select_masks(
|
|
||||||
self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
# Determine if we should return the multiclick mask or not from the number of points.
|
|
||||||
# The reweighting is used to avoid control flow.
|
|
||||||
score_reweight = torch.tensor(
|
|
||||||
[[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)]
|
|
||||||
).to(iou_preds.device)
|
|
||||||
score = iou_preds + (num_points - 2.5) * score_reweight
|
|
||||||
best_idx = torch.argmax(score, dim=1)
|
|
||||||
masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1)
|
|
||||||
iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1)
|
|
||||||
|
|
||||||
return masks, iou_preds
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
image_embeddings: torch.Tensor,
|
|
||||||
point_coords: torch.Tensor,
|
|
||||||
point_labels: torch.Tensor,
|
|
||||||
mask_input: torch.Tensor,
|
|
||||||
has_mask_input: torch.Tensor,
|
|
||||||
orig_im_size: torch.Tensor,
|
|
||||||
):
|
|
||||||
sparse_embedding = self._embed_points(point_coords, point_labels)
|
|
||||||
dense_embedding = self._embed_masks(mask_input, has_mask_input)
|
|
||||||
|
|
||||||
masks, scores = self.model.mask_decoder.predict_masks(
|
|
||||||
image_embeddings=image_embeddings,
|
|
||||||
image_pe=self.model.prompt_encoder.get_dense_pe(),
|
|
||||||
sparse_prompt_embeddings=sparse_embedding,
|
|
||||||
dense_prompt_embeddings=dense_embedding,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.use_stability_score:
|
|
||||||
scores = calculate_stability_score(
|
|
||||||
masks, self.model.mask_threshold, self.stability_score_offset
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.return_single_mask:
|
|
||||||
masks, scores = self.select_masks(masks, scores, point_coords.shape[1])
|
|
||||||
|
|
||||||
upscaled_masks = self.mask_postprocessing(masks, orig_im_size)
|
|
||||||
|
|
||||||
if self.return_extra_metrics:
|
|
||||||
stability_scores = calculate_stability_score(
|
|
||||||
upscaled_masks, self.model.mask_threshold, self.stability_score_offset
|
|
||||||
)
|
|
||||||
areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1)
|
|
||||||
return upscaled_masks, scores, stability_scores, areas, masks
|
|
||||||
|
|
||||||
return upscaled_masks, scores, masks
|
|
@ -4,14 +4,14 @@
|
|||||||
# This source code is licensed under the license found in the
|
# This source code is licensed under the license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from copy import deepcopy
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from torchvision.transforms.functional import resize, to_pil_image # type: ignore
|
from torchvision.transforms.functional import resize, to_pil_image # type: ignore
|
||||||
|
|
||||||
from copy import deepcopy
|
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
|
|
||||||
class ResizeLongestSide:
|
class ResizeLongestSide:
|
||||||
"""
|
"""
|
||||||
@ -36,9 +36,7 @@ class ResizeLongestSide:
|
|||||||
original image size in (H, W) format.
|
original image size in (H, W) format.
|
||||||
"""
|
"""
|
||||||
old_h, old_w = original_size
|
old_h, old_w = original_size
|
||||||
new_h, new_w = self.get_preprocess_shape(
|
new_h, new_w = self.get_preprocess_shape(original_size[0], original_size[1], self.target_length)
|
||||||
original_size[0], original_size[1], self.target_length
|
|
||||||
)
|
|
||||||
coords = deepcopy(coords).astype(float)
|
coords = deepcopy(coords).astype(float)
|
||||||
coords[..., 0] = coords[..., 0] * (new_w / old_w)
|
coords[..., 0] = coords[..., 0] * (new_w / old_w)
|
||||||
coords[..., 1] = coords[..., 1] * (new_h / old_h)
|
coords[..., 1] = coords[..., 1] * (new_h / old_h)
|
||||||
@ -60,29 +58,21 @@ class ResizeLongestSide:
|
|||||||
"""
|
"""
|
||||||
# Expects an image in BCHW format. May not exactly match apply_image.
|
# Expects an image in BCHW format. May not exactly match apply_image.
|
||||||
target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
|
target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
|
||||||
return F.interpolate(
|
return F.interpolate(image, target_size, mode="bilinear", align_corners=False, antialias=True)
|
||||||
image, target_size, mode="bilinear", align_corners=False, antialias=True
|
|
||||||
)
|
|
||||||
|
|
||||||
def apply_coords_torch(
|
def apply_coords_torch(self, coords: torch.Tensor, original_size: Tuple[int, ...]) -> torch.Tensor:
|
||||||
self, coords: torch.Tensor, original_size: Tuple[int, ...]
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
"""
|
||||||
Expects a torch tensor with length 2 in the last dimension. Requires the
|
Expects a torch tensor with length 2 in the last dimension. Requires the
|
||||||
original image size in (H, W) format.
|
original image size in (H, W) format.
|
||||||
"""
|
"""
|
||||||
old_h, old_w = original_size
|
old_h, old_w = original_size
|
||||||
new_h, new_w = self.get_preprocess_shape(
|
new_h, new_w = self.get_preprocess_shape(original_size[0], original_size[1], self.target_length)
|
||||||
original_size[0], original_size[1], self.target_length
|
|
||||||
)
|
|
||||||
coords = deepcopy(coords).to(torch.float)
|
coords = deepcopy(coords).to(torch.float)
|
||||||
coords[..., 0] = coords[..., 0] * (new_w / old_w)
|
coords[..., 0] = coords[..., 0] * (new_w / old_w)
|
||||||
coords[..., 1] = coords[..., 1] * (new_h / old_h)
|
coords[..., 1] = coords[..., 1] * (new_h / old_h)
|
||||||
return coords
|
return coords
|
||||||
|
|
||||||
def apply_boxes_torch(
|
def apply_boxes_torch(self, boxes: torch.Tensor, original_size: Tuple[int, ...]) -> torch.Tensor:
|
||||||
self, boxes: torch.Tensor, original_size: Tuple[int, ...]
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
"""
|
||||||
Expects a torch tensor with shape Bx4. Requires the original image
|
Expects a torch tensor with shape Bx4. Requires the original image
|
||||||
size in (H, W) format.
|
size in (H, W) format.
|
||||||
|
@ -74,12 +74,13 @@ dependencies = [
|
|||||||
"easing-functions",
|
"easing-functions",
|
||||||
"einops",
|
"einops",
|
||||||
"facexlib",
|
"facexlib",
|
||||||
"matplotlib", # needed for plotting of Penner easing functions
|
"matplotlib", # needed for plotting of Penner easing functions
|
||||||
"npyscreen",
|
"npyscreen",
|
||||||
"omegaconf",
|
"omegaconf",
|
||||||
"picklescan",
|
"picklescan",
|
||||||
"pillow",
|
"pillow",
|
||||||
"prompt-toolkit",
|
"prompt-toolkit",
|
||||||
|
"pycocotools",
|
||||||
"pympler~=1.0.1",
|
"pympler~=1.0.1",
|
||||||
"pypatchmatch",
|
"pypatchmatch",
|
||||||
'pyperclip',
|
'pyperclip',
|
||||||
@ -90,6 +91,7 @@ dependencies = [
|
|||||||
"scikit-image~=0.21.0",
|
"scikit-image~=0.21.0",
|
||||||
"semver~=3.0.1",
|
"semver~=3.0.1",
|
||||||
"send2trash",
|
"send2trash",
|
||||||
|
"supervision",
|
||||||
"test-tube~=0.7.5",
|
"test-tube~=0.7.5",
|
||||||
"windows-curses; sys_platform=='win32'",
|
"windows-curses; sys_platform=='win32'",
|
||||||
]
|
]
|
||||||
|
Loading…
Reference in New Issue
Block a user