mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
wip: segment anything node
This commit is contained in:
parent
e579be0118
commit
b20c70c588
76
invokeai/app/invocations/segment_anything.py
Normal file
76
invokeai/app/invocations/segment_anything.py
Normal file
@ -0,0 +1,76 @@
|
||||
from typing import Dict, cast
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import ImageField, InputField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.image_util.grounding_segment_anything.gsa import GroundingSegmentAnythingDetector
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
GROUNDING_SEGMENT_ANYTHING_MODELS = {
|
||||
"groundingdino_swint_ogc": "https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth",
|
||||
"segment_anything_vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
|
||||
}
|
||||
|
||||
|
||||
@invocation(
|
||||
"segment_anything",
|
||||
title="Segment Anything",
|
||||
tags=["grounding_dino", "segment", "anything"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class SegmentAnythingInvocation(BaseInvocation):
|
||||
"""Automatically generate masks from an image using GroundingDINO & Segment Anything"""
|
||||
|
||||
image: ImageField = InputField(description="The image to process")
|
||||
prompt: str = InputField(default="", description="Keywords to segment", title="Prompt")
|
||||
box_threshold: float = InputField(
|
||||
default=0.5, ge=0, le=1, description="Threshold of box detection", title="Box Threshold"
|
||||
)
|
||||
text_threshold: float = InputField(
|
||||
default=0.5, ge=0, le=1, description="Threshold of text detection", title="Text Threshold"
|
||||
)
|
||||
nms_threshold: float = InputField(
|
||||
default=0.8, ge=0, le=1, description="Threshold of nms detection", title="NMS Threshold"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
input_image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
grounding_dino_model = context.models.load_remote_model(
|
||||
GROUNDING_SEGMENT_ANYTHING_MODELS["groundingdino_swint_ogc"]
|
||||
)
|
||||
segment_anything_model = context.models.load_remote_model(
|
||||
GROUNDING_SEGMENT_ANYTHING_MODELS["segment_anything_vit_h"]
|
||||
)
|
||||
|
||||
with (
|
||||
grounding_dino_model.model_on_device() as (_, grounding_dino_state_dict),
|
||||
segment_anything_model.model_on_device() as (_, segment_anything_state_dict),
|
||||
):
|
||||
if not grounding_dino_state_dict or not segment_anything_state_dict:
|
||||
raise RuntimeError("Unable to load segmentation models")
|
||||
|
||||
grounding_dino = GroundingSegmentAnythingDetector.build_grounding_dino(
|
||||
cast(Dict[str, torch.Tensor], grounding_dino_state_dict)
|
||||
)
|
||||
segment_anything = GroundingSegmentAnythingDetector.build_segment_anything(
|
||||
cast(Dict[str, torch.Tensor], segment_anything_state_dict), TorchDevice.choose_torch_device()
|
||||
)
|
||||
detector = GroundingSegmentAnythingDetector(grounding_dino, segment_anything)
|
||||
|
||||
mask = detector.predict(
|
||||
input_image, self.prompt, self.box_threshold, self.text_threshold, self.nms_threshold
|
||||
)
|
||||
image_dto = context.images.save(mask)
|
||||
|
||||
"""Builds an ImageOutput and its ImageField"""
|
||||
processed_image_field = ImageField(image_name=image_dto.image_name)
|
||||
return ImageOutput(
|
||||
image=processed_image_field,
|
||||
width=input_image.width,
|
||||
height=input_image.height,
|
||||
)
|
@ -10,8 +10,8 @@ import torch
|
||||
import torchvision.transforms as T
|
||||
import torchvision.transforms.functional as F
|
||||
|
||||
from groundingdino.util.box_ops import box_xyxy_to_cxcywh
|
||||
from groundingdino.util.misc import interpolate
|
||||
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.box_ops import box_xyxy_to_cxcywh
|
||||
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.misc import interpolate
|
||||
|
||||
|
||||
def crop(image, target, region):
|
||||
@ -58,9 +58,7 @@ def crop(image, target, region):
|
||||
if os.environ.get("IPDB_SHILONG_DEBUG", None) == "INFO":
|
||||
# for debug and visualization only.
|
||||
if "strings_positive" in target:
|
||||
target["strings_positive"] = [
|
||||
_i for _i, _j in zip(target["strings_positive"], keep) if _j
|
||||
]
|
||||
target["strings_positive"] = [_i for _i, _j in zip(target["strings_positive"], keep, strict=False) if _j]
|
||||
|
||||
return cropped_image, target
|
||||
|
||||
@ -73,9 +71,7 @@ def hflip(image, target):
|
||||
target = target.copy()
|
||||
if "boxes" in target:
|
||||
boxes = target["boxes"]
|
||||
boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor(
|
||||
[w, 0, w, 0]
|
||||
)
|
||||
boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0])
|
||||
target["boxes"] = boxes
|
||||
|
||||
if "masks" in target:
|
||||
@ -119,15 +115,13 @@ def resize(image, target, size, max_size=None):
|
||||
if target is None:
|
||||
return rescaled_image, None
|
||||
|
||||
ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size))
|
||||
ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size, strict=False))
|
||||
ratio_width, ratio_height = ratios
|
||||
|
||||
target = target.copy()
|
||||
if "boxes" in target:
|
||||
boxes = target["boxes"]
|
||||
scaled_boxes = boxes * torch.as_tensor(
|
||||
[ratio_width, ratio_height, ratio_width, ratio_height]
|
||||
)
|
||||
scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height])
|
||||
target["boxes"] = scaled_boxes
|
||||
|
||||
if "area" in target:
|
||||
@ -139,9 +133,7 @@ def resize(image, target, size, max_size=None):
|
||||
target["size"] = torch.tensor([h, w])
|
||||
|
||||
if "masks" in target:
|
||||
target["masks"] = (
|
||||
interpolate(target["masks"][:, None].float(), size, mode="nearest")[:, 0] > 0.5
|
||||
)
|
||||
target["masks"] = interpolate(target["masks"][:, None].float(), size, mode="nearest")[:, 0] > 0.5
|
||||
|
||||
return rescaled_image, target
|
||||
|
||||
@ -192,11 +184,7 @@ class RandomSizeCrop(object):
|
||||
h = random.randint(self.min_size, min(img.height, self.max_size))
|
||||
region = T.RandomCrop.get_params(img, [h, w])
|
||||
result_img, result_target = crop(img, target, region)
|
||||
if (
|
||||
not self.respect_boxes
|
||||
or len(result_target["boxes"]) == init_boxes
|
||||
or i == max_patience - 1
|
||||
):
|
||||
if not self.respect_boxes or len(result_target["boxes"]) == init_boxes or i == max_patience - 1:
|
||||
return result_img, result_target
|
||||
return result_img, result_target
|
||||
|
||||
|
@ -12,4 +12,6 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
||||
# ------------------------------------------------------------------------
|
||||
|
||||
from .groundingdino import build_groundingdino
|
||||
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.models.GroundingDINO.groundingdino import (
|
||||
build_groundingdino,
|
||||
)
|
||||
|
@ -24,10 +24,13 @@ import torchvision
|
||||
from torch import nn
|
||||
from torchvision.models._utils import IntermediateLayerGetter
|
||||
|
||||
from groundingdino.util.misc import NestedTensor, clean_state_dict, is_main_process
|
||||
|
||||
from .position_encoding import build_position_encoding
|
||||
from .swin_transformer import build_swin_transformer
|
||||
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.models.GroundingDINO.backbone.position_encoding import (
|
||||
build_position_encoding,
|
||||
)
|
||||
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.models.GroundingDINO.backbone.swin_transformer import (
|
||||
build_swin_transformer,
|
||||
)
|
||||
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.misc import NestedTensor, is_main_process
|
||||
|
||||
|
||||
class FrozenBatchNorm2d(torch.nn.Module):
|
||||
@ -80,19 +83,12 @@ class BackboneBase(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
for name, parameter in backbone.named_parameters():
|
||||
if (
|
||||
not train_backbone
|
||||
or "layer2" not in name
|
||||
and "layer3" not in name
|
||||
and "layer4" not in name
|
||||
):
|
||||
if not train_backbone or "layer2" not in name and "layer3" not in name and "layer4" not in name:
|
||||
parameter.requires_grad_(False)
|
||||
|
||||
return_layers = {}
|
||||
for idx, layer_index in enumerate(return_interm_indices):
|
||||
return_layers.update(
|
||||
{"layer{}".format(5 - len(return_interm_indices) + idx): "{}".format(layer_index)}
|
||||
)
|
||||
return_layers.update({"layer{}".format(5 - len(return_interm_indices) + idx): "{}".format(layer_index)})
|
||||
|
||||
# if len:
|
||||
# if use_stage1_feature:
|
||||
@ -214,8 +210,8 @@ def build_backbone(args):
|
||||
|
||||
model = Joiner(backbone, position_embedding)
|
||||
model.num_channels = bb_num_channels
|
||||
assert isinstance(
|
||||
bb_num_channels, List
|
||||
), "bb_num_channels is expected to be a List but {}".format(type(bb_num_channels))
|
||||
assert isinstance(bb_num_channels, List), "bb_num_channels is expected to be a List but {}".format(
|
||||
type(bb_num_channels)
|
||||
)
|
||||
# import ipdb; ipdb.set_trace()
|
||||
return model
|
||||
|
@ -24,7 +24,7 @@ import math
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from groundingdino.util.misc import NestedTensor
|
||||
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.misc import NestedTensor
|
||||
|
||||
|
||||
class PositionEmbeddingSine(nn.Module):
|
||||
@ -65,12 +65,8 @@ class PositionEmbeddingSine(nn.Module):
|
||||
|
||||
pos_x = x_embed[:, :, :, None] / dim_t
|
||||
pos_y = y_embed[:, :, :, None] / dim_t
|
||||
pos_x = torch.stack(
|
||||
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
|
||||
).flatten(3)
|
||||
pos_y = torch.stack(
|
||||
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
|
||||
).flatten(3)
|
||||
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
||||
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
||||
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
||||
return pos
|
||||
|
||||
@ -81,9 +77,7 @@ class PositionEmbeddingSineHW(nn.Module):
|
||||
used by the Attention is all you need paper, generalized to work on images.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, num_pos_feats=64, temperatureH=10000, temperatureW=10000, normalize=False, scale=None
|
||||
):
|
||||
def __init__(self, num_pos_feats=64, temperatureH=10000, temperatureW=10000, normalize=False, scale=None):
|
||||
super().__init__()
|
||||
self.num_pos_feats = num_pos_feats
|
||||
self.temperatureH = temperatureH
|
||||
@ -111,19 +105,15 @@ class PositionEmbeddingSineHW(nn.Module):
|
||||
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
||||
|
||||
dim_tx = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
||||
dim_tx = self.temperatureW ** (2 * (torch.div(dim_tx, 2, rounding_mode='floor')) / self.num_pos_feats)
|
||||
dim_tx = self.temperatureW ** (2 * (torch.div(dim_tx, 2, rounding_mode="floor")) / self.num_pos_feats)
|
||||
pos_x = x_embed[:, :, :, None] / dim_tx
|
||||
|
||||
dim_ty = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
||||
dim_ty = self.temperatureH ** (2 * (torch.div(dim_ty, 2, rounding_mode='floor')) / self.num_pos_feats)
|
||||
dim_ty = self.temperatureH ** (2 * (torch.div(dim_ty, 2, rounding_mode="floor")) / self.num_pos_feats)
|
||||
pos_y = y_embed[:, :, :, None] / dim_ty
|
||||
|
||||
pos_x = torch.stack(
|
||||
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
|
||||
).flatten(3)
|
||||
pos_y = torch.stack(
|
||||
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
|
||||
).flatten(3)
|
||||
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
||||
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
||||
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
||||
|
||||
# import ipdb; ipdb.set_trace()
|
||||
|
@ -18,15 +18,13 @@ import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
||||
|
||||
from groundingdino.util.misc import NestedTensor
|
||||
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.misc import NestedTensor
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
"""Multilayer perceptron."""
|
||||
|
||||
def __init__(
|
||||
self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0
|
||||
):
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
@ -138,24 +136,16 @@ class WindowAttention(nn.Module):
|
||||
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
||||
"""
|
||||
B_, N, C = x.shape
|
||||
qkv = (
|
||||
self.qkv(x)
|
||||
.reshape(B_, N, 3, self.num_heads, C // self.num_heads)
|
||||
.permute(2, 0, 3, 1, 4)
|
||||
)
|
||||
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
||||
|
||||
q = q * self.scale
|
||||
attn = q @ k.transpose(-2, -1)
|
||||
|
||||
relative_position_bias = self.relative_position_bias_table[
|
||||
self.relative_position_index.view(-1)
|
||||
].view(
|
||||
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
||||
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
|
||||
) # Wh*Ww,Wh*Ww,nH
|
||||
relative_position_bias = relative_position_bias.permute(
|
||||
2, 0, 1
|
||||
).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||
attn = attn + relative_position_bias.unsqueeze(0)
|
||||
|
||||
if mask is not None:
|
||||
@ -228,9 +218,7 @@ class SwinTransformerBlock(nn.Module):
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp(
|
||||
in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop
|
||||
)
|
||||
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
||||
|
||||
self.H = None
|
||||
self.W = None
|
||||
@ -266,12 +254,8 @@ class SwinTransformerBlock(nn.Module):
|
||||
attn_mask = None
|
||||
|
||||
# partition windows
|
||||
x_windows = window_partition(
|
||||
shifted_x, self.window_size
|
||||
) # nW*B, window_size, window_size, C
|
||||
x_windows = x_windows.view(
|
||||
-1, self.window_size * self.window_size, C
|
||||
) # nW*B, window_size*window_size, C
|
||||
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
||||
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
|
||||
|
||||
# W-MSA/SW-MSA
|
||||
attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
|
||||
@ -433,14 +417,10 @@ class BasicLayer(nn.Module):
|
||||
img_mask[:, h, w, :] = cnt
|
||||
cnt += 1
|
||||
|
||||
mask_windows = window_partition(
|
||||
img_mask, self.window_size
|
||||
) # nW, window_size, window_size, 1
|
||||
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
|
||||
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
||||
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
||||
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
|
||||
attn_mask == 0, float(0.0)
|
||||
)
|
||||
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
||||
|
||||
for blk in self.blocks:
|
||||
blk.H, blk.W = H, W
|
||||
@ -589,9 +569,7 @@ class SwinTransformer(nn.Module):
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
|
||||
# stochastic depth
|
||||
dpr = [
|
||||
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
|
||||
] # stochastic depth decay rule
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
||||
|
||||
# build layers
|
||||
self.layers = nn.ModuleList()
|
||||
@ -682,9 +660,7 @@ class SwinTransformer(nn.Module):
|
||||
Wh, Ww = x.size(2), x.size(3)
|
||||
if self.ape:
|
||||
# interpolate the position embedding to the corresponding size
|
||||
absolute_pos_embed = F.interpolate(
|
||||
self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic"
|
||||
)
|
||||
absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic")
|
||||
x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
|
||||
else:
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
@ -718,9 +694,7 @@ class SwinTransformer(nn.Module):
|
||||
Wh, Ww = x.size(2), x.size(3)
|
||||
if self.ape:
|
||||
# interpolate the position embedding to the corresponding size
|
||||
absolute_pos_embed = F.interpolate(
|
||||
self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic"
|
||||
)
|
||||
absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic")
|
||||
x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
|
||||
else:
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
@ -769,21 +743,11 @@ def build_swin_transformer(modelname, pretrain_img_size, **kw):
|
||||
]
|
||||
|
||||
model_para_dict = {
|
||||
"swin_T_224_1k": dict(
|
||||
embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7
|
||||
),
|
||||
"swin_B_224_22k": dict(
|
||||
embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=7
|
||||
),
|
||||
"swin_B_384_22k": dict(
|
||||
embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=12
|
||||
),
|
||||
"swin_L_224_22k": dict(
|
||||
embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=7
|
||||
),
|
||||
"swin_L_384_22k": dict(
|
||||
embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=12
|
||||
),
|
||||
"swin_T_224_1k": dict(embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7),
|
||||
"swin_B_224_22k": dict(embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=7),
|
||||
"swin_B_384_22k": dict(embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=12),
|
||||
"swin_L_224_22k": dict(embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=7),
|
||||
"swin_L_384_22k": dict(embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=12),
|
||||
}
|
||||
kw_cgf = model_para_dict[modelname]
|
||||
kw_cgf.update(kw)
|
||||
|
@ -21,8 +21,12 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from groundingdino.util import get_tokenlizer
|
||||
from groundingdino.util.misc import NestedTensor, inverse_sigmoid, nested_tensor_from_tensor_list
|
||||
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util import get_tokenlizer
|
||||
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.misc import (
|
||||
NestedTensor,
|
||||
inverse_sigmoid,
|
||||
nested_tensor_from_tensor_list,
|
||||
)
|
||||
|
||||
from ..registry import MODULE_BUILD_FUNCS
|
||||
from .backbone import build_backbone
|
||||
|
@ -22,14 +22,19 @@ import torch
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
from torch import Tensor, nn
|
||||
|
||||
from groundingdino.util.misc import inverse_sigmoid
|
||||
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.misc import inverse_sigmoid
|
||||
|
||||
from .fuse_modules import BiAttentionBlock
|
||||
from .ms_deform_attn import MultiScaleDeformableAttention as MSDeformAttn
|
||||
from .transformer_vanilla import TransformerEncoderLayer
|
||||
from .utils import (MLP, _get_activation_fn, _get_clones,
|
||||
gen_encoder_output_proposals, gen_sineembed_for_position,
|
||||
get_sine_pos_embed)
|
||||
from .utils import (
|
||||
MLP,
|
||||
_get_activation_fn,
|
||||
_get_clones,
|
||||
gen_encoder_output_proposals,
|
||||
gen_sineembed_for_position,
|
||||
get_sine_pos_embed,
|
||||
)
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import List, Tuple
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
@ -7,11 +7,11 @@ import torch
|
||||
from PIL import Image
|
||||
from torchvision.ops import box_convert
|
||||
|
||||
import groundingdino.datasets.transforms as T
|
||||
from groundingdino.models import build_model
|
||||
from groundingdino.util.misc import clean_state_dict
|
||||
from groundingdino.util.slconfig import SLConfig
|
||||
from groundingdino.util.utils import get_phrases_from_posmap
|
||||
import invokeai.backend.image_util.grounding_segment_anything.groundingdino.datasets.transforms as T
|
||||
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.models import build_model
|
||||
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.misc import clean_state_dict
|
||||
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.slconfig import SLConfig
|
||||
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.utils import get_phrases_from_posmap
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------------------------
|
||||
# OLD API
|
||||
@ -25,12 +25,11 @@ def preprocess_caption(caption: str) -> str:
|
||||
return result + "."
|
||||
|
||||
|
||||
def load_model(model_config_path: str, model_checkpoint_path: str, device: str = "cuda"):
|
||||
def load_model(model_config_path: str, model_state_dict: Dict[str, torch.Tensor], device: str = "cuda"):
|
||||
args = SLConfig.fromfile(model_config_path)
|
||||
args.device = device
|
||||
model = build_model(args)
|
||||
checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
|
||||
model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
|
||||
model.load_state_dict(clean_state_dict(model_state_dict["model"]), strict=False)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
@ -98,9 +97,9 @@ def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor
|
||||
|
||||
class Model:
|
||||
|
||||
def __init__(self, model_config_path: str, model_checkpoint_path: str, device: str = "cuda"):
|
||||
def __init__(self, model_config_path: str, model_state_dict: Dict[str, torch.Tensor], device: str = "cuda"):
|
||||
self.model = load_model(
|
||||
model_config_path=model_config_path, model_checkpoint_path=model_checkpoint_path, device=device
|
||||
model_config_path=model_config_path, model_state_dict=model_state_dict, device=device
|
||||
).to(device)
|
||||
self.device = device
|
||||
|
||||
|
@ -1,91 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
from termcolor import colored
|
||||
|
||||
|
||||
class _ColorfulFormatter(logging.Formatter):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._root_name = kwargs.pop("root_name") + "."
|
||||
self._abbrev_name = kwargs.pop("abbrev_name", "")
|
||||
if len(self._abbrev_name):
|
||||
self._abbrev_name = self._abbrev_name + "."
|
||||
super(_ColorfulFormatter, self).__init__(*args, **kwargs)
|
||||
|
||||
def formatMessage(self, record):
|
||||
record.name = record.name.replace(self._root_name, self._abbrev_name)
|
||||
log = super(_ColorfulFormatter, self).formatMessage(record)
|
||||
if record.levelno == logging.WARNING:
|
||||
prefix = colored("WARNING", "red", attrs=["blink"])
|
||||
elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
|
||||
prefix = colored("ERROR", "red", attrs=["blink", "underline"])
|
||||
else:
|
||||
return log
|
||||
return prefix + " " + log
|
||||
|
||||
|
||||
# so that calling setup_logger multiple times won't add many handlers
|
||||
@functools.lru_cache()
|
||||
def setup_logger(output=None, distributed_rank=0, *, color=True, name="imagenet", abbrev_name=None):
|
||||
"""
|
||||
Initialize the detectron2 logger and set its verbosity level to "INFO".
|
||||
|
||||
Args:
|
||||
output (str): a file name or a directory to save log. If None, will not save log file.
|
||||
If ends with ".txt" or ".log", assumed to be a file name.
|
||||
Otherwise, logs will be saved to `output/log.txt`.
|
||||
name (str): the root module name of this logger
|
||||
|
||||
Returns:
|
||||
logging.Logger: a logger
|
||||
"""
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
logger.propagate = False
|
||||
|
||||
if abbrev_name is None:
|
||||
abbrev_name = name
|
||||
|
||||
plain_formatter = logging.Formatter("[%(asctime)s.%(msecs)03d]: %(message)s", datefmt="%m/%d %H:%M:%S")
|
||||
# stdout logging: master only
|
||||
if distributed_rank == 0:
|
||||
ch = logging.StreamHandler(stream=sys.stdout)
|
||||
ch.setLevel(logging.DEBUG)
|
||||
if color:
|
||||
formatter = _ColorfulFormatter(
|
||||
colored("[%(asctime)s.%(msecs)03d]: ", "green") + "%(message)s",
|
||||
datefmt="%m/%d %H:%M:%S",
|
||||
root_name=name,
|
||||
abbrev_name=str(abbrev_name),
|
||||
)
|
||||
else:
|
||||
formatter = plain_formatter
|
||||
ch.setFormatter(formatter)
|
||||
logger.addHandler(ch)
|
||||
|
||||
# file logging: all workers
|
||||
if output is not None:
|
||||
if output.endswith(".txt") or output.endswith(".log"):
|
||||
filename = output
|
||||
else:
|
||||
filename = os.path.join(output, "log.txt")
|
||||
if distributed_rank > 0:
|
||||
filename = filename + f".rank{distributed_rank}"
|
||||
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
||||
|
||||
fh = logging.StreamHandler(_cached_log_stream(filename))
|
||||
fh.setLevel(logging.DEBUG)
|
||||
fh.setFormatter(plain_formatter)
|
||||
logger.addHandler(fh)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
# cache the opened file object, so that different calls to `setup_logger`
|
||||
# with the same file name can safely write to the same file.
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _cached_log_stream(filename):
|
||||
return open(filename, "a")
|
@ -161,7 +161,7 @@ def all_gather_cpu(data):
|
||||
dist.all_gather(tensor_list, tensor, group=cpu_group)
|
||||
|
||||
data_list = []
|
||||
for size, tensor in zip(size_list, tensor_list):
|
||||
for size, tensor in zip(size_list, tensor_list, strict=False):
|
||||
tensor = torch.split(tensor, [size, max_size - size], dim=0)[0]
|
||||
buffer = io.BytesIO(tensor.cpu().numpy())
|
||||
obj = torch.load(buffer)
|
||||
@ -210,7 +210,7 @@ def all_gather(data):
|
||||
dist.all_gather(tensor_list, tensor)
|
||||
|
||||
data_list = []
|
||||
for size, tensor in zip(size_list, tensor_list):
|
||||
for size, tensor in zip(size_list, tensor_list, strict=False):
|
||||
buffer = tensor.cpu().numpy().tobytes()[:size]
|
||||
data_list.append(pickle.loads(buffer))
|
||||
|
||||
@ -240,7 +240,7 @@ def reduce_dict(input_dict, average=True):
|
||||
dist.all_reduce(values)
|
||||
if average:
|
||||
values /= world_size
|
||||
reduced_dict = {k: v for k, v in zip(names, values)}
|
||||
reduced_dict = {k: v for k, v in zip(names, values, strict=False)}
|
||||
return reduced_dict
|
||||
|
||||
|
||||
@ -378,7 +378,7 @@ def get_sha():
|
||||
|
||||
def collate_fn(batch):
|
||||
# import ipdb; ipdb.set_trace()
|
||||
batch = list(zip(*batch))
|
||||
batch = list(zip(*batch, strict=False))
|
||||
batch[0] = nested_tensor_from_tensor_list(batch[0])
|
||||
return tuple(batch)
|
||||
|
||||
@ -480,7 +480,7 @@ def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
|
||||
device = tensor_list[0].device
|
||||
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
|
||||
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
|
||||
for img, pad_img, m in zip(tensor_list, tensor, mask):
|
||||
for img, pad_img, m in zip(tensor_list, tensor, mask, strict=False):
|
||||
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
||||
m[: img.shape[1], : img.shape[2]] = False
|
||||
else:
|
||||
@ -505,7 +505,7 @@ def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTen
|
||||
padded_imgs = []
|
||||
padded_masks = []
|
||||
for img in tensor_list:
|
||||
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
|
||||
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape), strict=False)]
|
||||
padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
|
||||
padded_imgs.append(padded_img)
|
||||
|
||||
|
@ -9,7 +9,7 @@ import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from groundingdino.util.slconfig import SLConfig
|
||||
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.slconfig import SLConfig
|
||||
|
||||
|
||||
def slprint(x, name="x"):
|
||||
|
@ -1,44 +1,38 @@
|
||||
from typing import Dict, List, Literal, Optional
|
||||
import pathlib
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import supervision as sv
|
||||
import torch
|
||||
import torchvision
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.inference import Model
|
||||
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.build_sam import sam_model_registry
|
||||
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.predictor import SamPredictor
|
||||
|
||||
GROUNDING_SEGMENT_ANYTHING_MODELS = {
|
||||
"groundingdino_swint_ogc": "https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth",
|
||||
"segment_anything_vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
|
||||
}
|
||||
|
||||
|
||||
class GroundingSegmentAnythingDetector:
|
||||
def __init__(self) -> None:
|
||||
self.grounding_dino_model: Optional[Model] = None
|
||||
self.segment_anything_model: Optional[SamPredictor] = None
|
||||
self.grounding_dino_config: str = "./groundingdino/config/GroundingDINO_SwinT_OGC.py"
|
||||
self.sam_encoder: Literal["vit_h"] = "vit_h"
|
||||
self.device: torch.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
def __init__(self, grounding_dino_model: Model, segment_anything_model: SamPredictor) -> None:
|
||||
self.grounding_dino_model: Optional[Model] = grounding_dino_model
|
||||
self.segment_anything_model: Optional[SamPredictor] = segment_anything_model
|
||||
|
||||
def build_grounding_dino(self):
|
||||
@staticmethod
|
||||
def build_grounding_dino(grounding_dino_state_dict: Dict[str, torch.Tensor]):
|
||||
grounding_dino_config = pathlib.Path(
|
||||
"./invokeai/backend/image_util/grounding_segment_anything/groundingdino/config/GroundingDINO_SwinT_OGC.py"
|
||||
)
|
||||
return Model(
|
||||
model_config_path=self.grounding_dino_config,
|
||||
model_checkpoint_path="./checkpoints/groundingdino_swint_ogc.pth",
|
||||
model_state_dict=grounding_dino_state_dict,
|
||||
model_config_path=grounding_dino_config.as_posix(),
|
||||
)
|
||||
|
||||
def build_segment_anything(self):
|
||||
sam = sam_model_registry[self.sam_encoder](checkpoint="./checkpoints/sam_vit_h_4b8939.pth")
|
||||
sam.to(device=self.device)
|
||||
@staticmethod
|
||||
def build_segment_anything(segment_anything_state_dict: Dict[str, torch.Tensor], device: torch.device):
|
||||
sam = sam_model_registry["vit_h"](checkpoint=segment_anything_state_dict)
|
||||
sam.to(device=device)
|
||||
return SamPredictor(sam)
|
||||
|
||||
def build_grounding_sam(self):
|
||||
self.grounding_dino_model = self.build_grounding_dino()
|
||||
self.segment_anything_model = self.build_segment_anything()
|
||||
|
||||
def detect_objects(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
@ -77,20 +71,18 @@ class GroundingSegmentAnythingDetector:
|
||||
|
||||
def predict(
|
||||
self,
|
||||
image: str,
|
||||
image: Image.Image,
|
||||
prompt: str,
|
||||
box_threshold: float = 0.5,
|
||||
text_threshold: float = 0.5,
|
||||
nms_threshold: float = 0.8,
|
||||
):
|
||||
if not self.grounding_dino_model or not self.segment_anything_model:
|
||||
self.build_grounding_sam()
|
||||
|
||||
image = cv2.imread(image)
|
||||
open_cv_image = np.array(image)
|
||||
open_cv_image = open_cv_image[:, :, ::-1].copy()
|
||||
prompts = prompt.split(",")
|
||||
|
||||
detections = self.detect_objects(image, prompts, box_threshold, text_threshold, nms_threshold)
|
||||
segments = self.segment_detections(image, detections, prompts)
|
||||
detections = self.detect_objects(open_cv_image, prompts, box_threshold, text_threshold, nms_threshold)
|
||||
segments = self.segment_detections(open_cv_image, detections, prompts)
|
||||
|
||||
if len(segments) > 0:
|
||||
combined_mask = np.zeros_like(list(segments.values())[0])
|
||||
@ -98,15 +90,6 @@ class GroundingSegmentAnythingDetector:
|
||||
combined_mask = np.logical_or(combined_mask, mask)
|
||||
mask_preview = (combined_mask * 255).astype(np.uint8)
|
||||
else:
|
||||
mask_preview = np.zeros(image.shape, np.uint8)
|
||||
mask_preview = np.zeros(open_cv_image.shape, np.uint8)
|
||||
|
||||
cv2.imwrite("mask.png", mask_preview)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
gsa = GroundingSegmentAnythingDetector()
|
||||
image = "./assets/image.webp"
|
||||
|
||||
while True:
|
||||
prompt = input("Segment: ")
|
||||
gsa.predict(image, prompt, 0.5, 0.5, 0.8)
|
||||
return Image.fromarray(mask_preview)
|
||||
|
@ -4,13 +4,22 @@
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .automatic_mask_generator import SamAutomaticMaskGenerator
|
||||
from .build_sam import build_sam, build_sam_vit_b, build_sam_vit_h, build_sam_vit_l, sam_model_registry
|
||||
from .build_sam_hq import (
|
||||
|
||||
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.automatic_mask_generator import (
|
||||
SamAutomaticMaskGenerator,
|
||||
) # noqa
|
||||
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.build_sam import ( # noqa
|
||||
build_sam,
|
||||
build_sam_vit_b,
|
||||
build_sam_vit_h,
|
||||
build_sam_vit_l,
|
||||
sam_model_registry,
|
||||
)
|
||||
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.build_sam_hq import ( # noqa
|
||||
build_sam_hq,
|
||||
build_sam_hq_vit_b,
|
||||
build_sam_hq_vit_h,
|
||||
build_sam_hq_vit_l,
|
||||
sam_hq_model_registry,
|
||||
)
|
||||
from .predictor import SamPredictor
|
||||
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.predictor import SamPredictor # noqa
|
||||
|
@ -10,9 +10,9 @@ import numpy as np
|
||||
import torch
|
||||
from torchvision.ops.boxes import batched_nms, box_area # type: ignore
|
||||
|
||||
from .modeling import Sam
|
||||
from .predictor import SamPredictor
|
||||
from .utils.amg import (
|
||||
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling import Sam
|
||||
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.predictor import SamPredictor
|
||||
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.utils.amg import (
|
||||
MaskData,
|
||||
area_from_rle,
|
||||
batch_iterator,
|
||||
|
@ -8,7 +8,13 @@ from functools import partial
|
||||
|
||||
import torch
|
||||
|
||||
from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer
|
||||
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling import (
|
||||
ImageEncoderViT,
|
||||
MaskDecoder,
|
||||
PromptEncoder,
|
||||
Sam,
|
||||
TwoWayTransformer,
|
||||
)
|
||||
|
||||
|
||||
def build_sam_vit_h(checkpoint=None):
|
||||
@ -101,7 +107,5 @@ def _build_sam(
|
||||
)
|
||||
sam.eval()
|
||||
if checkpoint is not None:
|
||||
with open(checkpoint, "rb") as f:
|
||||
state_dict = torch.load(f)
|
||||
sam.load_state_dict(state_dict)
|
||||
sam.load_state_dict(checkpoint)
|
||||
return sam
|
||||
|
@ -8,7 +8,13 @@ from functools import partial
|
||||
|
||||
import torch
|
||||
|
||||
from .modeling import ImageEncoderViT, MaskDecoderHQ, PromptEncoder, Sam, TwoWayTransformer
|
||||
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling import (
|
||||
ImageEncoderViT,
|
||||
MaskDecoderHQ,
|
||||
PromptEncoder,
|
||||
Sam,
|
||||
TwoWayTransformer,
|
||||
)
|
||||
|
||||
|
||||
def build_sam_hq_vit_h(checkpoint=None):
|
||||
|
@ -4,9 +4,17 @@
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .image_encoder import ImageEncoderViT
|
||||
from .mask_decoder import MaskDecoder
|
||||
from .mask_decoder_hq import MaskDecoderHQ
|
||||
from .prompt_encoder import PromptEncoder
|
||||
from .sam import Sam
|
||||
from .transformer import TwoWayTransformer
|
||||
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.image_encoder import (
|
||||
ImageEncoderViT,
|
||||
)
|
||||
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.mask_decoder import MaskDecoder
|
||||
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.mask_decoder_hq import (
|
||||
MaskDecoderHQ,
|
||||
)
|
||||
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.prompt_encoder import (
|
||||
PromptEncoder,
|
||||
)
|
||||
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.sam import Sam
|
||||
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.transformer import (
|
||||
TwoWayTransformer,
|
||||
)
|
||||
|
@ -10,7 +10,10 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .common import LayerNorm2d, MLPBlock
|
||||
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.common import (
|
||||
LayerNorm2d,
|
||||
MLPBlock,
|
||||
)
|
||||
|
||||
|
||||
# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
|
||||
|
@ -10,7 +10,7 @@ import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .common import LayerNorm2d
|
||||
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.common import LayerNorm2d
|
||||
|
||||
|
||||
class MaskDecoder(nn.Module):
|
||||
|
@ -11,7 +11,7 @@ import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .common import LayerNorm2d
|
||||
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.common import LayerNorm2d
|
||||
|
||||
|
||||
class MaskDecoderHQ(nn.Module):
|
||||
|
@ -10,7 +10,7 @@ import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .common import LayerNorm2d
|
||||
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.common import LayerNorm2d
|
||||
|
||||
|
||||
class PromptEncoder(nn.Module):
|
||||
|
@ -10,9 +10,13 @@ import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .image_encoder import ImageEncoderViT
|
||||
from .mask_decoder import MaskDecoder
|
||||
from .prompt_encoder import PromptEncoder
|
||||
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.image_encoder import (
|
||||
ImageEncoderViT,
|
||||
)
|
||||
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.mask_decoder import MaskDecoder
|
||||
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.prompt_encoder import (
|
||||
PromptEncoder,
|
||||
)
|
||||
|
||||
|
||||
class Sam(nn.Module):
|
||||
@ -98,7 +102,7 @@ class Sam(nn.Module):
|
||||
image_embeddings = self.image_encoder(input_images)
|
||||
|
||||
outputs = []
|
||||
for image_record, curr_embedding in zip(batched_input, image_embeddings):
|
||||
for image_record, curr_embedding in zip(batched_input, image_embeddings, strict=False):
|
||||
if "point_coords" in image_record:
|
||||
points = (image_record["point_coords"], image_record["point_labels"])
|
||||
else:
|
||||
|
@ -10,7 +10,7 @@ from typing import Tuple, Type
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from .common import MLPBlock
|
||||
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.common import MLPBlock
|
||||
|
||||
|
||||
class TwoWayTransformer(nn.Module):
|
||||
|
@ -9,8 +9,8 @@ from typing import Optional, Tuple
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from .modeling import Sam
|
||||
from .utils.transforms import ResizeLongestSide
|
||||
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling import Sam
|
||||
from invokeai.backend.image_util.grounding_segment_anything.segment_anything.utils.transforms import ResizeLongestSide
|
||||
|
||||
|
||||
class SamPredictor:
|
||||
|
@ -4,14 +4,14 @@
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import math
|
||||
from copy import deepcopy
|
||||
from itertools import product
|
||||
from typing import Any, Dict, Generator, ItemsView, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
class MaskData:
|
||||
"""
|
||||
@ -153,9 +153,7 @@ def area_from_rle(rle: Dict[str, Any]) -> int:
|
||||
return sum(rle["counts"][1::2])
|
||||
|
||||
|
||||
def calculate_stability_score(
|
||||
masks: torch.Tensor, mask_threshold: float, threshold_offset: float
|
||||
) -> torch.Tensor:
|
||||
def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, threshold_offset: float) -> torch.Tensor:
|
||||
"""
|
||||
Computes the stability score for a batch of masks. The stability
|
||||
score is the IoU between the binary masks obtained by thresholding
|
||||
@ -163,16 +161,8 @@ def calculate_stability_score(
|
||||
"""
|
||||
# One mask is always contained inside the other.
|
||||
# Save memory by preventing unnecesary cast to torch.int64
|
||||
intersections = (
|
||||
(masks > (mask_threshold + threshold_offset))
|
||||
.sum(-1, dtype=torch.int16)
|
||||
.sum(-1, dtype=torch.int32)
|
||||
)
|
||||
unions = (
|
||||
(masks > (mask_threshold - threshold_offset))
|
||||
.sum(-1, dtype=torch.int16)
|
||||
.sum(-1, dtype=torch.int32)
|
||||
)
|
||||
intersections = (masks > (mask_threshold + threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
|
||||
unions = (masks > (mask_threshold - threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
|
||||
return intersections / unions
|
||||
|
||||
|
||||
@ -186,9 +176,7 @@ def build_point_grid(n_per_side: int) -> np.ndarray:
|
||||
return points
|
||||
|
||||
|
||||
def build_all_layer_point_grids(
|
||||
n_per_side: int, n_layers: int, scale_per_layer: int
|
||||
) -> List[np.ndarray]:
|
||||
def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer: int) -> List[np.ndarray]:
|
||||
"""Generates point grids for all crop layers."""
|
||||
points_by_layer = []
|
||||
for i in range(n_layers + 1):
|
||||
@ -252,9 +240,7 @@ def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
|
||||
return points + offset
|
||||
|
||||
|
||||
def uncrop_masks(
|
||||
masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int
|
||||
) -> torch.Tensor:
|
||||
def uncrop_masks(masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int) -> torch.Tensor:
|
||||
x0, y0, x1, y1 = crop_box
|
||||
if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
|
||||
return masks
|
||||
@ -264,9 +250,7 @@ def uncrop_masks(
|
||||
return torch.nn.functional.pad(masks, pad, value=0)
|
||||
|
||||
|
||||
def remove_small_regions(
|
||||
mask: np.ndarray, area_thresh: float, mode: str
|
||||
) -> Tuple[np.ndarray, bool]:
|
||||
def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tuple[np.ndarray, bool]:
|
||||
"""
|
||||
Removes small disconnected regions and holes in a mask. Returns the
|
||||
mask and an indicator of if the mask has been modified.
|
||||
|
@ -1,144 +0,0 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
from ..modeling import Sam
|
||||
from .amg import calculate_stability_score
|
||||
|
||||
|
||||
class SamOnnxModel(nn.Module):
|
||||
"""
|
||||
This model should not be called directly, but is used in ONNX export.
|
||||
It combines the prompt encoder, mask decoder, and mask postprocessing of Sam,
|
||||
with some functions modified to enable model tracing. Also supports extra
|
||||
options controlling what information. See the ONNX export script for details.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Sam,
|
||||
return_single_mask: bool,
|
||||
use_stability_score: bool = False,
|
||||
return_extra_metrics: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.mask_decoder = model.mask_decoder
|
||||
self.model = model
|
||||
self.img_size = model.image_encoder.img_size
|
||||
self.return_single_mask = return_single_mask
|
||||
self.use_stability_score = use_stability_score
|
||||
self.stability_score_offset = 1.0
|
||||
self.return_extra_metrics = return_extra_metrics
|
||||
|
||||
@staticmethod
|
||||
def resize_longest_image_size(
|
||||
input_image_size: torch.Tensor, longest_side: int
|
||||
) -> torch.Tensor:
|
||||
input_image_size = input_image_size.to(torch.float32)
|
||||
scale = longest_side / torch.max(input_image_size)
|
||||
transformed_size = scale * input_image_size
|
||||
transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64)
|
||||
return transformed_size
|
||||
|
||||
def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor:
|
||||
point_coords = point_coords + 0.5
|
||||
point_coords = point_coords / self.img_size
|
||||
point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords)
|
||||
point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)
|
||||
|
||||
point_embedding = point_embedding * (point_labels != -1)
|
||||
point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * (
|
||||
point_labels == -1
|
||||
)
|
||||
|
||||
for i in range(self.model.prompt_encoder.num_point_embeddings):
|
||||
point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[
|
||||
i
|
||||
].weight * (point_labels == i)
|
||||
|
||||
return point_embedding
|
||||
|
||||
def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor:
|
||||
mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask)
|
||||
mask_embedding = mask_embedding + (
|
||||
1 - has_mask_input
|
||||
) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)
|
||||
return mask_embedding
|
||||
|
||||
def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor:
|
||||
masks = F.interpolate(
|
||||
masks,
|
||||
size=(self.img_size, self.img_size),
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
)
|
||||
|
||||
prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size)
|
||||
masks = masks[..., : int(prepadded_size[0]), : int(prepadded_size[1])]
|
||||
|
||||
orig_im_size = orig_im_size.to(torch.int64)
|
||||
h, w = orig_im_size[0], orig_im_size[1]
|
||||
masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False)
|
||||
return masks
|
||||
|
||||
def select_masks(
|
||||
self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Determine if we should return the multiclick mask or not from the number of points.
|
||||
# The reweighting is used to avoid control flow.
|
||||
score_reweight = torch.tensor(
|
||||
[[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)]
|
||||
).to(iou_preds.device)
|
||||
score = iou_preds + (num_points - 2.5) * score_reweight
|
||||
best_idx = torch.argmax(score, dim=1)
|
||||
masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1)
|
||||
iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1)
|
||||
|
||||
return masks, iou_preds
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
image_embeddings: torch.Tensor,
|
||||
point_coords: torch.Tensor,
|
||||
point_labels: torch.Tensor,
|
||||
mask_input: torch.Tensor,
|
||||
has_mask_input: torch.Tensor,
|
||||
orig_im_size: torch.Tensor,
|
||||
):
|
||||
sparse_embedding = self._embed_points(point_coords, point_labels)
|
||||
dense_embedding = self._embed_masks(mask_input, has_mask_input)
|
||||
|
||||
masks, scores = self.model.mask_decoder.predict_masks(
|
||||
image_embeddings=image_embeddings,
|
||||
image_pe=self.model.prompt_encoder.get_dense_pe(),
|
||||
sparse_prompt_embeddings=sparse_embedding,
|
||||
dense_prompt_embeddings=dense_embedding,
|
||||
)
|
||||
|
||||
if self.use_stability_score:
|
||||
scores = calculate_stability_score(
|
||||
masks, self.model.mask_threshold, self.stability_score_offset
|
||||
)
|
||||
|
||||
if self.return_single_mask:
|
||||
masks, scores = self.select_masks(masks, scores, point_coords.shape[1])
|
||||
|
||||
upscaled_masks = self.mask_postprocessing(masks, orig_im_size)
|
||||
|
||||
if self.return_extra_metrics:
|
||||
stability_scores = calculate_stability_score(
|
||||
upscaled_masks, self.model.mask_threshold, self.stability_score_offset
|
||||
)
|
||||
areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1)
|
||||
return upscaled_masks, scores, stability_scores, areas, masks
|
||||
|
||||
return upscaled_masks, scores, masks
|
@ -4,14 +4,14 @@
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
from torchvision.transforms.functional import resize, to_pil_image # type: ignore
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
class ResizeLongestSide:
|
||||
"""
|
||||
@ -36,9 +36,7 @@ class ResizeLongestSide:
|
||||
original image size in (H, W) format.
|
||||
"""
|
||||
old_h, old_w = original_size
|
||||
new_h, new_w = self.get_preprocess_shape(
|
||||
original_size[0], original_size[1], self.target_length
|
||||
)
|
||||
new_h, new_w = self.get_preprocess_shape(original_size[0], original_size[1], self.target_length)
|
||||
coords = deepcopy(coords).astype(float)
|
||||
coords[..., 0] = coords[..., 0] * (new_w / old_w)
|
||||
coords[..., 1] = coords[..., 1] * (new_h / old_h)
|
||||
@ -60,29 +58,21 @@ class ResizeLongestSide:
|
||||
"""
|
||||
# Expects an image in BCHW format. May not exactly match apply_image.
|
||||
target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
|
||||
return F.interpolate(
|
||||
image, target_size, mode="bilinear", align_corners=False, antialias=True
|
||||
)
|
||||
return F.interpolate(image, target_size, mode="bilinear", align_corners=False, antialias=True)
|
||||
|
||||
def apply_coords_torch(
|
||||
self, coords: torch.Tensor, original_size: Tuple[int, ...]
|
||||
) -> torch.Tensor:
|
||||
def apply_coords_torch(self, coords: torch.Tensor, original_size: Tuple[int, ...]) -> torch.Tensor:
|
||||
"""
|
||||
Expects a torch tensor with length 2 in the last dimension. Requires the
|
||||
original image size in (H, W) format.
|
||||
"""
|
||||
old_h, old_w = original_size
|
||||
new_h, new_w = self.get_preprocess_shape(
|
||||
original_size[0], original_size[1], self.target_length
|
||||
)
|
||||
new_h, new_w = self.get_preprocess_shape(original_size[0], original_size[1], self.target_length)
|
||||
coords = deepcopy(coords).to(torch.float)
|
||||
coords[..., 0] = coords[..., 0] * (new_w / old_w)
|
||||
coords[..., 1] = coords[..., 1] * (new_h / old_h)
|
||||
return coords
|
||||
|
||||
def apply_boxes_torch(
|
||||
self, boxes: torch.Tensor, original_size: Tuple[int, ...]
|
||||
) -> torch.Tensor:
|
||||
def apply_boxes_torch(self, boxes: torch.Tensor, original_size: Tuple[int, ...]) -> torch.Tensor:
|
||||
"""
|
||||
Expects a torch tensor with shape Bx4. Requires the original image
|
||||
size in (H, W) format.
|
||||
|
@ -80,6 +80,7 @@ dependencies = [
|
||||
"picklescan",
|
||||
"pillow",
|
||||
"prompt-toolkit",
|
||||
"pycocotools",
|
||||
"pympler~=1.0.1",
|
||||
"pypatchmatch",
|
||||
'pyperclip',
|
||||
@ -90,6 +91,7 @@ dependencies = [
|
||||
"scikit-image~=0.21.0",
|
||||
"semver~=3.0.1",
|
||||
"send2trash",
|
||||
"supervision",
|
||||
"test-tube~=0.7.5",
|
||||
"windows-curses; sys_platform=='win32'",
|
||||
]
|
||||
|
Loading…
Reference in New Issue
Block a user