diff --git a/invokeai/app/invocations/segment_anything.py b/invokeai/app/invocations/segment_anything.py new file mode 100644 index 0000000000..3516f22687 --- /dev/null +++ b/invokeai/app/invocations/segment_anything.py @@ -0,0 +1,76 @@ +from typing import Dict, cast + +import torch + +from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation +from invokeai.app.invocations.fields import ImageField, InputField +from invokeai.app.invocations.primitives import ImageOutput +from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.image_util.grounding_segment_anything.gsa import GroundingSegmentAnythingDetector +from invokeai.backend.util.devices import TorchDevice + +GROUNDING_SEGMENT_ANYTHING_MODELS = { + "groundingdino_swint_ogc": "https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth", + "segment_anything_vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", +} + + +@invocation( + "segment_anything", + title="Segment Anything", + tags=["grounding_dino", "segment", "anything"], + category="image", + version="1.0.0", +) +class SegmentAnythingInvocation(BaseInvocation): + """Automatically generate masks from an image using GroundingDINO & Segment Anything""" + + image: ImageField = InputField(description="The image to process") + prompt: str = InputField(default="", description="Keywords to segment", title="Prompt") + box_threshold: float = InputField( + default=0.5, ge=0, le=1, description="Threshold of box detection", title="Box Threshold" + ) + text_threshold: float = InputField( + default=0.5, ge=0, le=1, description="Threshold of text detection", title="Text Threshold" + ) + nms_threshold: float = InputField( + default=0.8, ge=0, le=1, description="Threshold of nms detection", title="NMS Threshold" + ) + + def invoke(self, context: InvocationContext) -> ImageOutput: + input_image = context.images.get_pil(self.image.image_name) + + grounding_dino_model = context.models.load_remote_model( + GROUNDING_SEGMENT_ANYTHING_MODELS["groundingdino_swint_ogc"] + ) + segment_anything_model = context.models.load_remote_model( + GROUNDING_SEGMENT_ANYTHING_MODELS["segment_anything_vit_h"] + ) + + with ( + grounding_dino_model.model_on_device() as (_, grounding_dino_state_dict), + segment_anything_model.model_on_device() as (_, segment_anything_state_dict), + ): + if not grounding_dino_state_dict or not segment_anything_state_dict: + raise RuntimeError("Unable to load segmentation models") + + grounding_dino = GroundingSegmentAnythingDetector.build_grounding_dino( + cast(Dict[str, torch.Tensor], grounding_dino_state_dict) + ) + segment_anything = GroundingSegmentAnythingDetector.build_segment_anything( + cast(Dict[str, torch.Tensor], segment_anything_state_dict), TorchDevice.choose_torch_device() + ) + detector = GroundingSegmentAnythingDetector(grounding_dino, segment_anything) + + mask = detector.predict( + input_image, self.prompt, self.box_threshold, self.text_threshold, self.nms_threshold + ) + image_dto = context.images.save(mask) + + """Builds an ImageOutput and its ImageField""" + processed_image_field = ImageField(image_name=image_dto.image_name) + return ImageOutput( + image=processed_image_field, + width=input_image.width, + height=input_image.height, + ) diff --git a/invokeai/backend/image_util/grounding_segment_anything/groundingdino/datasets/transforms.py b/invokeai/backend/image_util/grounding_segment_anything/groundingdino/datasets/transforms.py index 91cf9269e4..1e7e8a51f2 100644 --- a/invokeai/backend/image_util/grounding_segment_anything/groundingdino/datasets/transforms.py +++ b/invokeai/backend/image_util/grounding_segment_anything/groundingdino/datasets/transforms.py @@ -10,8 +10,8 @@ import torch import torchvision.transforms as T import torchvision.transforms.functional as F -from groundingdino.util.box_ops import box_xyxy_to_cxcywh -from groundingdino.util.misc import interpolate +from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.box_ops import box_xyxy_to_cxcywh +from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.misc import interpolate def crop(image, target, region): @@ -58,9 +58,7 @@ def crop(image, target, region): if os.environ.get("IPDB_SHILONG_DEBUG", None) == "INFO": # for debug and visualization only. if "strings_positive" in target: - target["strings_positive"] = [ - _i for _i, _j in zip(target["strings_positive"], keep) if _j - ] + target["strings_positive"] = [_i for _i, _j in zip(target["strings_positive"], keep, strict=False) if _j] return cropped_image, target @@ -73,9 +71,7 @@ def hflip(image, target): target = target.copy() if "boxes" in target: boxes = target["boxes"] - boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor( - [w, 0, w, 0] - ) + boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0]) target["boxes"] = boxes if "masks" in target: @@ -119,15 +115,13 @@ def resize(image, target, size, max_size=None): if target is None: return rescaled_image, None - ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size)) + ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size, strict=False)) ratio_width, ratio_height = ratios target = target.copy() if "boxes" in target: boxes = target["boxes"] - scaled_boxes = boxes * torch.as_tensor( - [ratio_width, ratio_height, ratio_width, ratio_height] - ) + scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height]) target["boxes"] = scaled_boxes if "area" in target: @@ -139,9 +133,7 @@ def resize(image, target, size, max_size=None): target["size"] = torch.tensor([h, w]) if "masks" in target: - target["masks"] = ( - interpolate(target["masks"][:, None].float(), size, mode="nearest")[:, 0] > 0.5 - ) + target["masks"] = interpolate(target["masks"][:, None].float(), size, mode="nearest")[:, 0] > 0.5 return rescaled_image, target @@ -192,11 +184,7 @@ class RandomSizeCrop(object): h = random.randint(self.min_size, min(img.height, self.max_size)) region = T.RandomCrop.get_params(img, [h, w]) result_img, result_target = crop(img, target, region) - if ( - not self.respect_boxes - or len(result_target["boxes"]) == init_boxes - or i == max_patience - 1 - ): + if not self.respect_boxes or len(result_target["boxes"]) == init_boxes or i == max_patience - 1: return result_img, result_target return result_img, result_target diff --git a/invokeai/backend/image_util/grounding_segment_anything/groundingdino/models/GroundingDINO/__init__.py b/invokeai/backend/image_util/grounding_segment_anything/groundingdino/models/GroundingDINO/__init__.py index 2af819d61d..cbcddaac07 100644 --- a/invokeai/backend/image_util/grounding_segment_anything/groundingdino/models/GroundingDINO/__init__.py +++ b/invokeai/backend/image_util/grounding_segment_anything/groundingdino/models/GroundingDINO/__init__.py @@ -12,4 +12,6 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. # ------------------------------------------------------------------------ -from .groundingdino import build_groundingdino +from invokeai.backend.image_util.grounding_segment_anything.groundingdino.models.GroundingDINO.groundingdino import ( + build_groundingdino, +) diff --git a/invokeai/backend/image_util/grounding_segment_anything/groundingdino/models/GroundingDINO/backbone/backbone.py b/invokeai/backend/image_util/grounding_segment_anything/groundingdino/models/GroundingDINO/backbone/backbone.py index c8340c723f..ee25f20e35 100644 --- a/invokeai/backend/image_util/grounding_segment_anything/groundingdino/models/GroundingDINO/backbone/backbone.py +++ b/invokeai/backend/image_util/grounding_segment_anything/groundingdino/models/GroundingDINO/backbone/backbone.py @@ -24,10 +24,13 @@ import torchvision from torch import nn from torchvision.models._utils import IntermediateLayerGetter -from groundingdino.util.misc import NestedTensor, clean_state_dict, is_main_process - -from .position_encoding import build_position_encoding -from .swin_transformer import build_swin_transformer +from invokeai.backend.image_util.grounding_segment_anything.groundingdino.models.GroundingDINO.backbone.position_encoding import ( + build_position_encoding, +) +from invokeai.backend.image_util.grounding_segment_anything.groundingdino.models.GroundingDINO.backbone.swin_transformer import ( + build_swin_transformer, +) +from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.misc import NestedTensor, is_main_process class FrozenBatchNorm2d(torch.nn.Module): @@ -80,19 +83,12 @@ class BackboneBase(nn.Module): ): super().__init__() for name, parameter in backbone.named_parameters(): - if ( - not train_backbone - or "layer2" not in name - and "layer3" not in name - and "layer4" not in name - ): + if not train_backbone or "layer2" not in name and "layer3" not in name and "layer4" not in name: parameter.requires_grad_(False) return_layers = {} for idx, layer_index in enumerate(return_interm_indices): - return_layers.update( - {"layer{}".format(5 - len(return_interm_indices) + idx): "{}".format(layer_index)} - ) + return_layers.update({"layer{}".format(5 - len(return_interm_indices) + idx): "{}".format(layer_index)}) # if len: # if use_stage1_feature: @@ -214,8 +210,8 @@ def build_backbone(args): model = Joiner(backbone, position_embedding) model.num_channels = bb_num_channels - assert isinstance( - bb_num_channels, List - ), "bb_num_channels is expected to be a List but {}".format(type(bb_num_channels)) + assert isinstance(bb_num_channels, List), "bb_num_channels is expected to be a List but {}".format( + type(bb_num_channels) + ) # import ipdb; ipdb.set_trace() return model diff --git a/invokeai/backend/image_util/grounding_segment_anything/groundingdino/models/GroundingDINO/backbone/position_encoding.py b/invokeai/backend/image_util/grounding_segment_anything/groundingdino/models/GroundingDINO/backbone/position_encoding.py index eac7e896bb..6828d7c3cb 100644 --- a/invokeai/backend/image_util/grounding_segment_anything/groundingdino/models/GroundingDINO/backbone/position_encoding.py +++ b/invokeai/backend/image_util/grounding_segment_anything/groundingdino/models/GroundingDINO/backbone/position_encoding.py @@ -24,7 +24,7 @@ import math import torch from torch import nn -from groundingdino.util.misc import NestedTensor +from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.misc import NestedTensor class PositionEmbeddingSine(nn.Module): @@ -65,12 +65,8 @@ class PositionEmbeddingSine(nn.Module): pos_x = x_embed[:, :, :, None] / dim_t pos_y = y_embed[:, :, :, None] / dim_t - pos_x = torch.stack( - (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 - ).flatten(3) - pos_y = torch.stack( - (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 - ).flatten(3) + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) return pos @@ -81,9 +77,7 @@ class PositionEmbeddingSineHW(nn.Module): used by the Attention is all you need paper, generalized to work on images. """ - def __init__( - self, num_pos_feats=64, temperatureH=10000, temperatureW=10000, normalize=False, scale=None - ): + def __init__(self, num_pos_feats=64, temperatureH=10000, temperatureW=10000, normalize=False, scale=None): super().__init__() self.num_pos_feats = num_pos_feats self.temperatureH = temperatureH @@ -111,19 +105,15 @@ class PositionEmbeddingSineHW(nn.Module): x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale dim_tx = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) - dim_tx = self.temperatureW ** (2 * (torch.div(dim_tx, 2, rounding_mode='floor')) / self.num_pos_feats) + dim_tx = self.temperatureW ** (2 * (torch.div(dim_tx, 2, rounding_mode="floor")) / self.num_pos_feats) pos_x = x_embed[:, :, :, None] / dim_tx dim_ty = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) - dim_ty = self.temperatureH ** (2 * (torch.div(dim_ty, 2, rounding_mode='floor')) / self.num_pos_feats) + dim_ty = self.temperatureH ** (2 * (torch.div(dim_ty, 2, rounding_mode="floor")) / self.num_pos_feats) pos_y = y_embed[:, :, :, None] / dim_ty - pos_x = torch.stack( - (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 - ).flatten(3) - pos_y = torch.stack( - (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 - ).flatten(3) + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) # import ipdb; ipdb.set_trace() diff --git a/invokeai/backend/image_util/grounding_segment_anything/groundingdino/models/GroundingDINO/backbone/swin_transformer.py b/invokeai/backend/image_util/grounding_segment_anything/groundingdino/models/GroundingDINO/backbone/swin_transformer.py index fa8837e400..141c6568b9 100644 --- a/invokeai/backend/image_util/grounding_segment_anything/groundingdino/models/GroundingDINO/backbone/swin_transformer.py +++ b/invokeai/backend/image_util/grounding_segment_anything/groundingdino/models/GroundingDINO/backbone/swin_transformer.py @@ -18,15 +18,13 @@ import torch.nn.functional as F import torch.utils.checkpoint as checkpoint from timm.models.layers import DropPath, to_2tuple, trunc_normal_ -from groundingdino.util.misc import NestedTensor +from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.misc import NestedTensor class Mlp(nn.Module): """Multilayer perceptron.""" - def __init__( - self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0 - ): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features @@ -138,24 +136,16 @@ class WindowAttention(nn.Module): mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None """ B_, N, C = x.shape - qkv = ( - self.qkv(x) - .reshape(B_, N, 3, self.num_heads, C // self.num_heads) - .permute(2, 0, 3, 1, 4) - ) + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) q = q * self.scale attn = q @ k.transpose(-2, -1) - relative_position_bias = self.relative_position_bias_table[ - self.relative_position_index.view(-1) - ].view( + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1 ) # Wh*Ww,Wh*Ww,nH - relative_position_bias = relative_position_bias.permute( - 2, 0, 1 - ).contiguous() # nH, Wh*Ww, Wh*Ww + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww attn = attn + relative_position_bias.unsqueeze(0) if mask is not None: @@ -228,9 +218,7 @@ class SwinTransformerBlock(nn.Module): self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp( - in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop - ) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) self.H = None self.W = None @@ -266,12 +254,8 @@ class SwinTransformerBlock(nn.Module): attn_mask = None # partition windows - x_windows = window_partition( - shifted_x, self.window_size - ) # nW*B, window_size, window_size, C - x_windows = x_windows.view( - -1, self.window_size * self.window_size, C - ) # nW*B, window_size*window_size, C + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C # W-MSA/SW-MSA attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C @@ -433,14 +417,10 @@ class BasicLayer(nn.Module): img_mask[:, h, w, :] = cnt cnt += 1 - mask_windows = window_partition( - img_mask, self.window_size - ) # nW, window_size, window_size, 1 + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 mask_windows = mask_windows.view(-1, self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill( - attn_mask == 0, float(0.0) - ) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) for blk in self.blocks: blk.H, blk.W = H, W @@ -589,9 +569,7 @@ class SwinTransformer(nn.Module): self.pos_drop = nn.Dropout(p=drop_rate) # stochastic depth - dpr = [ - x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) - ] # stochastic depth decay rule + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule # build layers self.layers = nn.ModuleList() @@ -682,9 +660,7 @@ class SwinTransformer(nn.Module): Wh, Ww = x.size(2), x.size(3) if self.ape: # interpolate the position embedding to the corresponding size - absolute_pos_embed = F.interpolate( - self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic" - ) + absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic") x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C else: x = x.flatten(2).transpose(1, 2) @@ -718,9 +694,7 @@ class SwinTransformer(nn.Module): Wh, Ww = x.size(2), x.size(3) if self.ape: # interpolate the position embedding to the corresponding size - absolute_pos_embed = F.interpolate( - self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic" - ) + absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic") x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C else: x = x.flatten(2).transpose(1, 2) @@ -769,21 +743,11 @@ def build_swin_transformer(modelname, pretrain_img_size, **kw): ] model_para_dict = { - "swin_T_224_1k": dict( - embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7 - ), - "swin_B_224_22k": dict( - embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=7 - ), - "swin_B_384_22k": dict( - embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=12 - ), - "swin_L_224_22k": dict( - embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=7 - ), - "swin_L_384_22k": dict( - embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=12 - ), + "swin_T_224_1k": dict(embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7), + "swin_B_224_22k": dict(embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=7), + "swin_B_384_22k": dict(embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=12), + "swin_L_224_22k": dict(embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=7), + "swin_L_384_22k": dict(embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=12), } kw_cgf = model_para_dict[modelname] kw_cgf.update(kw) diff --git a/invokeai/backend/image_util/grounding_segment_anything/groundingdino/models/GroundingDINO/groundingdino.py b/invokeai/backend/image_util/grounding_segment_anything/groundingdino/models/GroundingDINO/groundingdino.py index b76df92e13..00ac21bea2 100644 --- a/invokeai/backend/image_util/grounding_segment_anything/groundingdino/models/GroundingDINO/groundingdino.py +++ b/invokeai/backend/image_util/grounding_segment_anything/groundingdino/models/GroundingDINO/groundingdino.py @@ -21,8 +21,12 @@ import torch import torch.nn.functional as F from torch import nn -from groundingdino.util import get_tokenlizer -from groundingdino.util.misc import NestedTensor, inverse_sigmoid, nested_tensor_from_tensor_list +from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util import get_tokenlizer +from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.misc import ( + NestedTensor, + inverse_sigmoid, + nested_tensor_from_tensor_list, +) from ..registry import MODULE_BUILD_FUNCS from .backbone import build_backbone diff --git a/invokeai/backend/image_util/grounding_segment_anything/groundingdino/models/GroundingDINO/transformer.py b/invokeai/backend/image_util/grounding_segment_anything/groundingdino/models/GroundingDINO/transformer.py index 1a25469604..8bd0a3cf67 100644 --- a/invokeai/backend/image_util/grounding_segment_anything/groundingdino/models/GroundingDINO/transformer.py +++ b/invokeai/backend/image_util/grounding_segment_anything/groundingdino/models/GroundingDINO/transformer.py @@ -22,14 +22,19 @@ import torch import torch.utils.checkpoint as checkpoint from torch import Tensor, nn -from groundingdino.util.misc import inverse_sigmoid +from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.misc import inverse_sigmoid from .fuse_modules import BiAttentionBlock from .ms_deform_attn import MultiScaleDeformableAttention as MSDeformAttn from .transformer_vanilla import TransformerEncoderLayer -from .utils import (MLP, _get_activation_fn, _get_clones, - gen_encoder_output_proposals, gen_sineembed_for_position, - get_sine_pos_embed) +from .utils import ( + MLP, + _get_activation_fn, + _get_clones, + gen_encoder_output_proposals, + gen_sineembed_for_position, + get_sine_pos_embed, +) class Transformer(nn.Module): diff --git a/invokeai/backend/image_util/grounding_segment_anything/groundingdino/util/inference.py b/invokeai/backend/image_util/grounding_segment_anything/groundingdino/util/inference.py index 9bfb2cbca9..1ea1270cda 100644 --- a/invokeai/backend/image_util/grounding_segment_anything/groundingdino/util/inference.py +++ b/invokeai/backend/image_util/grounding_segment_anything/groundingdino/util/inference.py @@ -1,4 +1,4 @@ -from typing import List, Tuple +from typing import Dict, List, Tuple import cv2 import numpy as np @@ -7,11 +7,11 @@ import torch from PIL import Image from torchvision.ops import box_convert -import groundingdino.datasets.transforms as T -from groundingdino.models import build_model -from groundingdino.util.misc import clean_state_dict -from groundingdino.util.slconfig import SLConfig -from groundingdino.util.utils import get_phrases_from_posmap +import invokeai.backend.image_util.grounding_segment_anything.groundingdino.datasets.transforms as T +from invokeai.backend.image_util.grounding_segment_anything.groundingdino.models import build_model +from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.misc import clean_state_dict +from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.slconfig import SLConfig +from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.utils import get_phrases_from_posmap # ---------------------------------------------------------------------------------------------------------------------- # OLD API @@ -25,12 +25,11 @@ def preprocess_caption(caption: str) -> str: return result + "." -def load_model(model_config_path: str, model_checkpoint_path: str, device: str = "cuda"): +def load_model(model_config_path: str, model_state_dict: Dict[str, torch.Tensor], device: str = "cuda"): args = SLConfig.fromfile(model_config_path) args.device = device model = build_model(args) - checkpoint = torch.load(model_checkpoint_path, map_location="cpu") - model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False) + model.load_state_dict(clean_state_dict(model_state_dict["model"]), strict=False) model.eval() return model @@ -98,9 +97,9 @@ def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor class Model: - def __init__(self, model_config_path: str, model_checkpoint_path: str, device: str = "cuda"): + def __init__(self, model_config_path: str, model_state_dict: Dict[str, torch.Tensor], device: str = "cuda"): self.model = load_model( - model_config_path=model_config_path, model_checkpoint_path=model_checkpoint_path, device=device + model_config_path=model_config_path, model_state_dict=model_state_dict, device=device ).to(device) self.device = device diff --git a/invokeai/backend/image_util/grounding_segment_anything/groundingdino/util/logger.py b/invokeai/backend/image_util/grounding_segment_anything/groundingdino/util/logger.py deleted file mode 100644 index 679e0f5926..0000000000 --- a/invokeai/backend/image_util/grounding_segment_anything/groundingdino/util/logger.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved -import functools -import logging -import os -import sys - -from termcolor import colored - - -class _ColorfulFormatter(logging.Formatter): - def __init__(self, *args, **kwargs): - self._root_name = kwargs.pop("root_name") + "." - self._abbrev_name = kwargs.pop("abbrev_name", "") - if len(self._abbrev_name): - self._abbrev_name = self._abbrev_name + "." - super(_ColorfulFormatter, self).__init__(*args, **kwargs) - - def formatMessage(self, record): - record.name = record.name.replace(self._root_name, self._abbrev_name) - log = super(_ColorfulFormatter, self).formatMessage(record) - if record.levelno == logging.WARNING: - prefix = colored("WARNING", "red", attrs=["blink"]) - elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: - prefix = colored("ERROR", "red", attrs=["blink", "underline"]) - else: - return log - return prefix + " " + log - - -# so that calling setup_logger multiple times won't add many handlers -@functools.lru_cache() -def setup_logger(output=None, distributed_rank=0, *, color=True, name="imagenet", abbrev_name=None): - """ - Initialize the detectron2 logger and set its verbosity level to "INFO". - - Args: - output (str): a file name or a directory to save log. If None, will not save log file. - If ends with ".txt" or ".log", assumed to be a file name. - Otherwise, logs will be saved to `output/log.txt`. - name (str): the root module name of this logger - - Returns: - logging.Logger: a logger - """ - logger = logging.getLogger(name) - logger.setLevel(logging.DEBUG) - logger.propagate = False - - if abbrev_name is None: - abbrev_name = name - - plain_formatter = logging.Formatter("[%(asctime)s.%(msecs)03d]: %(message)s", datefmt="%m/%d %H:%M:%S") - # stdout logging: master only - if distributed_rank == 0: - ch = logging.StreamHandler(stream=sys.stdout) - ch.setLevel(logging.DEBUG) - if color: - formatter = _ColorfulFormatter( - colored("[%(asctime)s.%(msecs)03d]: ", "green") + "%(message)s", - datefmt="%m/%d %H:%M:%S", - root_name=name, - abbrev_name=str(abbrev_name), - ) - else: - formatter = plain_formatter - ch.setFormatter(formatter) - logger.addHandler(ch) - - # file logging: all workers - if output is not None: - if output.endswith(".txt") or output.endswith(".log"): - filename = output - else: - filename = os.path.join(output, "log.txt") - if distributed_rank > 0: - filename = filename + f".rank{distributed_rank}" - os.makedirs(os.path.dirname(filename), exist_ok=True) - - fh = logging.StreamHandler(_cached_log_stream(filename)) - fh.setLevel(logging.DEBUG) - fh.setFormatter(plain_formatter) - logger.addHandler(fh) - - return logger - - -# cache the opened file object, so that different calls to `setup_logger` -# with the same file name can safely write to the same file. -@functools.lru_cache(maxsize=None) -def _cached_log_stream(filename): - return open(filename, "a") diff --git a/invokeai/backend/image_util/grounding_segment_anything/groundingdino/util/misc.py b/invokeai/backend/image_util/grounding_segment_anything/groundingdino/util/misc.py index 89e58d144b..f921d00274 100644 --- a/invokeai/backend/image_util/grounding_segment_anything/groundingdino/util/misc.py +++ b/invokeai/backend/image_util/grounding_segment_anything/groundingdino/util/misc.py @@ -161,7 +161,7 @@ def all_gather_cpu(data): dist.all_gather(tensor_list, tensor, group=cpu_group) data_list = [] - for size, tensor in zip(size_list, tensor_list): + for size, tensor in zip(size_list, tensor_list, strict=False): tensor = torch.split(tensor, [size, max_size - size], dim=0)[0] buffer = io.BytesIO(tensor.cpu().numpy()) obj = torch.load(buffer) @@ -210,7 +210,7 @@ def all_gather(data): dist.all_gather(tensor_list, tensor) data_list = [] - for size, tensor in zip(size_list, tensor_list): + for size, tensor in zip(size_list, tensor_list, strict=False): buffer = tensor.cpu().numpy().tobytes()[:size] data_list.append(pickle.loads(buffer)) @@ -240,7 +240,7 @@ def reduce_dict(input_dict, average=True): dist.all_reduce(values) if average: values /= world_size - reduced_dict = {k: v for k, v in zip(names, values)} + reduced_dict = {k: v for k, v in zip(names, values, strict=False)} return reduced_dict @@ -378,7 +378,7 @@ def get_sha(): def collate_fn(batch): # import ipdb; ipdb.set_trace() - batch = list(zip(*batch)) + batch = list(zip(*batch, strict=False)) batch[0] = nested_tensor_from_tensor_list(batch[0]) return tuple(batch) @@ -480,7 +480,7 @@ def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): device = tensor_list[0].device tensor = torch.zeros(batch_shape, dtype=dtype, device=device) mask = torch.ones((b, h, w), dtype=torch.bool, device=device) - for img, pad_img, m in zip(tensor_list, tensor, mask): + for img, pad_img, m in zip(tensor_list, tensor, mask, strict=False): pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) m[: img.shape[1], : img.shape[2]] = False else: @@ -505,7 +505,7 @@ def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTen padded_imgs = [] padded_masks = [] for img in tensor_list: - padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] + padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape), strict=False)] padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) padded_imgs.append(padded_img) diff --git a/invokeai/backend/image_util/grounding_segment_anything/groundingdino/util/utils.py b/invokeai/backend/image_util/grounding_segment_anything/groundingdino/util/utils.py index c69ed7d4f1..807c12124c 100644 --- a/invokeai/backend/image_util/grounding_segment_anything/groundingdino/util/utils.py +++ b/invokeai/backend/image_util/grounding_segment_anything/groundingdino/util/utils.py @@ -9,7 +9,7 @@ import numpy as np import torch from transformers import AutoTokenizer -from groundingdino.util.slconfig import SLConfig +from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.slconfig import SLConfig def slprint(x, name="x"): diff --git a/invokeai/backend/image_util/grounding_segment_anything/gsa.py b/invokeai/backend/image_util/grounding_segment_anything/gsa.py index c076e9e86f..3102091bef 100644 --- a/invokeai/backend/image_util/grounding_segment_anything/gsa.py +++ b/invokeai/backend/image_util/grounding_segment_anything/gsa.py @@ -1,44 +1,38 @@ -from typing import Dict, List, Literal, Optional +import pathlib +from typing import Dict, List, Optional -import cv2 import numpy as np import supervision as sv import torch import torchvision +from PIL import Image from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.inference import Model from invokeai.backend.image_util.grounding_segment_anything.segment_anything.build_sam import sam_model_registry from invokeai.backend.image_util.grounding_segment_anything.segment_anything.predictor import SamPredictor -GROUNDING_SEGMENT_ANYTHING_MODELS = { - "groundingdino_swint_ogc": "https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth", - "segment_anything_vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", -} - class GroundingSegmentAnythingDetector: - def __init__(self) -> None: - self.grounding_dino_model: Optional[Model] = None - self.segment_anything_model: Optional[SamPredictor] = None - self.grounding_dino_config: str = "./groundingdino/config/GroundingDINO_SwinT_OGC.py" - self.sam_encoder: Literal["vit_h"] = "vit_h" - self.device: torch.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + def __init__(self, grounding_dino_model: Model, segment_anything_model: SamPredictor) -> None: + self.grounding_dino_model: Optional[Model] = grounding_dino_model + self.segment_anything_model: Optional[SamPredictor] = segment_anything_model - def build_grounding_dino(self): + @staticmethod + def build_grounding_dino(grounding_dino_state_dict: Dict[str, torch.Tensor]): + grounding_dino_config = pathlib.Path( + "./invokeai/backend/image_util/grounding_segment_anything/groundingdino/config/GroundingDINO_SwinT_OGC.py" + ) return Model( - model_config_path=self.grounding_dino_config, - model_checkpoint_path="./checkpoints/groundingdino_swint_ogc.pth", + model_state_dict=grounding_dino_state_dict, + model_config_path=grounding_dino_config.as_posix(), ) - def build_segment_anything(self): - sam = sam_model_registry[self.sam_encoder](checkpoint="./checkpoints/sam_vit_h_4b8939.pth") - sam.to(device=self.device) + @staticmethod + def build_segment_anything(segment_anything_state_dict: Dict[str, torch.Tensor], device: torch.device): + sam = sam_model_registry["vit_h"](checkpoint=segment_anything_state_dict) + sam.to(device=device) return SamPredictor(sam) - def build_grounding_sam(self): - self.grounding_dino_model = self.build_grounding_dino() - self.segment_anything_model = self.build_segment_anything() - def detect_objects( self, image: np.ndarray, @@ -77,20 +71,18 @@ class GroundingSegmentAnythingDetector: def predict( self, - image: str, + image: Image.Image, prompt: str, box_threshold: float = 0.5, text_threshold: float = 0.5, nms_threshold: float = 0.8, ): - if not self.grounding_dino_model or not self.segment_anything_model: - self.build_grounding_sam() - - image = cv2.imread(image) + open_cv_image = np.array(image) + open_cv_image = open_cv_image[:, :, ::-1].copy() prompts = prompt.split(",") - detections = self.detect_objects(image, prompts, box_threshold, text_threshold, nms_threshold) - segments = self.segment_detections(image, detections, prompts) + detections = self.detect_objects(open_cv_image, prompts, box_threshold, text_threshold, nms_threshold) + segments = self.segment_detections(open_cv_image, detections, prompts) if len(segments) > 0: combined_mask = np.zeros_like(list(segments.values())[0]) @@ -98,15 +90,6 @@ class GroundingSegmentAnythingDetector: combined_mask = np.logical_or(combined_mask, mask) mask_preview = (combined_mask * 255).astype(np.uint8) else: - mask_preview = np.zeros(image.shape, np.uint8) + mask_preview = np.zeros(open_cv_image.shape, np.uint8) - cv2.imwrite("mask.png", mask_preview) - - -if __name__ == "__main__": - gsa = GroundingSegmentAnythingDetector() - image = "./assets/image.webp" - - while True: - prompt = input("Segment: ") - gsa.predict(image, prompt, 0.5, 0.5, 0.8) + return Image.fromarray(mask_preview) diff --git a/invokeai/backend/image_util/grounding_segment_anything/segment_anything/__init__.py b/invokeai/backend/image_util/grounding_segment_anything/segment_anything/__init__.py index 9915b49c41..bf29bfd709 100644 --- a/invokeai/backend/image_util/grounding_segment_anything/segment_anything/__init__.py +++ b/invokeai/backend/image_util/grounding_segment_anything/segment_anything/__init__.py @@ -4,13 +4,22 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -from .automatic_mask_generator import SamAutomaticMaskGenerator -from .build_sam import build_sam, build_sam_vit_b, build_sam_vit_h, build_sam_vit_l, sam_model_registry -from .build_sam_hq import ( + +from invokeai.backend.image_util.grounding_segment_anything.segment_anything.automatic_mask_generator import ( + SamAutomaticMaskGenerator, +) # noqa +from invokeai.backend.image_util.grounding_segment_anything.segment_anything.build_sam import ( # noqa + build_sam, + build_sam_vit_b, + build_sam_vit_h, + build_sam_vit_l, + sam_model_registry, +) +from invokeai.backend.image_util.grounding_segment_anything.segment_anything.build_sam_hq import ( # noqa build_sam_hq, build_sam_hq_vit_b, build_sam_hq_vit_h, build_sam_hq_vit_l, sam_hq_model_registry, ) -from .predictor import SamPredictor +from invokeai.backend.image_util.grounding_segment_anything.segment_anything.predictor import SamPredictor # noqa diff --git a/invokeai/backend/image_util/grounding_segment_anything/segment_anything/automatic_mask_generator.py b/invokeai/backend/image_util/grounding_segment_anything/segment_anything/automatic_mask_generator.py index 2cd252d770..440bfd5f02 100644 --- a/invokeai/backend/image_util/grounding_segment_anything/segment_anything/automatic_mask_generator.py +++ b/invokeai/backend/image_util/grounding_segment_anything/segment_anything/automatic_mask_generator.py @@ -10,9 +10,9 @@ import numpy as np import torch from torchvision.ops.boxes import batched_nms, box_area # type: ignore -from .modeling import Sam -from .predictor import SamPredictor -from .utils.amg import ( +from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling import Sam +from invokeai.backend.image_util.grounding_segment_anything.segment_anything.predictor import SamPredictor +from invokeai.backend.image_util.grounding_segment_anything.segment_anything.utils.amg import ( MaskData, area_from_rle, batch_iterator, diff --git a/invokeai/backend/image_util/grounding_segment_anything/segment_anything/build_sam.py b/invokeai/backend/image_util/grounding_segment_anything/segment_anything/build_sam.py index 012cfdf9af..ee600dcaef 100644 --- a/invokeai/backend/image_util/grounding_segment_anything/segment_anything/build_sam.py +++ b/invokeai/backend/image_util/grounding_segment_anything/segment_anything/build_sam.py @@ -8,7 +8,13 @@ from functools import partial import torch -from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer +from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling import ( + ImageEncoderViT, + MaskDecoder, + PromptEncoder, + Sam, + TwoWayTransformer, +) def build_sam_vit_h(checkpoint=None): @@ -101,7 +107,5 @@ def _build_sam( ) sam.eval() if checkpoint is not None: - with open(checkpoint, "rb") as f: - state_dict = torch.load(f) - sam.load_state_dict(state_dict) + sam.load_state_dict(checkpoint) return sam diff --git a/invokeai/backend/image_util/grounding_segment_anything/segment_anything/build_sam_hq.py b/invokeai/backend/image_util/grounding_segment_anything/segment_anything/build_sam_hq.py index 416c6aabb9..670b3ee41b 100644 --- a/invokeai/backend/image_util/grounding_segment_anything/segment_anything/build_sam_hq.py +++ b/invokeai/backend/image_util/grounding_segment_anything/segment_anything/build_sam_hq.py @@ -8,7 +8,13 @@ from functools import partial import torch -from .modeling import ImageEncoderViT, MaskDecoderHQ, PromptEncoder, Sam, TwoWayTransformer +from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling import ( + ImageEncoderViT, + MaskDecoderHQ, + PromptEncoder, + Sam, + TwoWayTransformer, +) def build_sam_hq_vit_h(checkpoint=None): diff --git a/invokeai/backend/image_util/grounding_segment_anything/segment_anything/modeling/__init__.py b/invokeai/backend/image_util/grounding_segment_anything/segment_anything/modeling/__init__.py index 2168e3e83b..027b418940 100644 --- a/invokeai/backend/image_util/grounding_segment_anything/segment_anything/modeling/__init__.py +++ b/invokeai/backend/image_util/grounding_segment_anything/segment_anything/modeling/__init__.py @@ -4,9 +4,17 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -from .image_encoder import ImageEncoderViT -from .mask_decoder import MaskDecoder -from .mask_decoder_hq import MaskDecoderHQ -from .prompt_encoder import PromptEncoder -from .sam import Sam -from .transformer import TwoWayTransformer +from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.image_encoder import ( + ImageEncoderViT, +) +from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.mask_decoder import MaskDecoder +from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.mask_decoder_hq import ( + MaskDecoderHQ, +) +from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.prompt_encoder import ( + PromptEncoder, +) +from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.sam import Sam +from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.transformer import ( + TwoWayTransformer, +) diff --git a/invokeai/backend/image_util/grounding_segment_anything/segment_anything/modeling/image_encoder.py b/invokeai/backend/image_util/grounding_segment_anything/segment_anything/modeling/image_encoder.py index 6a5c18f43c..c6a4e22755 100644 --- a/invokeai/backend/image_util/grounding_segment_anything/segment_anything/modeling/image_encoder.py +++ b/invokeai/backend/image_util/grounding_segment_anything/segment_anything/modeling/image_encoder.py @@ -10,7 +10,10 @@ import torch import torch.nn as nn import torch.nn.functional as F -from .common import LayerNorm2d, MLPBlock +from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.common import ( + LayerNorm2d, + MLPBlock, +) # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa diff --git a/invokeai/backend/image_util/grounding_segment_anything/segment_anything/modeling/mask_decoder.py b/invokeai/backend/image_util/grounding_segment_anything/segment_anything/modeling/mask_decoder.py index 90915101d9..e9e42993e4 100644 --- a/invokeai/backend/image_util/grounding_segment_anything/segment_anything/modeling/mask_decoder.py +++ b/invokeai/backend/image_util/grounding_segment_anything/segment_anything/modeling/mask_decoder.py @@ -10,7 +10,7 @@ import torch from torch import nn from torch.nn import functional as F -from .common import LayerNorm2d +from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.common import LayerNorm2d class MaskDecoder(nn.Module): diff --git a/invokeai/backend/image_util/grounding_segment_anything/segment_anything/modeling/mask_decoder_hq.py b/invokeai/backend/image_util/grounding_segment_anything/segment_anything/modeling/mask_decoder_hq.py index 5079423465..5b830ef9ba 100644 --- a/invokeai/backend/image_util/grounding_segment_anything/segment_anything/modeling/mask_decoder_hq.py +++ b/invokeai/backend/image_util/grounding_segment_anything/segment_anything/modeling/mask_decoder_hq.py @@ -11,7 +11,7 @@ import torch from torch import nn from torch.nn import functional as F -from .common import LayerNorm2d +from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.common import LayerNorm2d class MaskDecoderHQ(nn.Module): diff --git a/invokeai/backend/image_util/grounding_segment_anything/segment_anything/modeling/prompt_encoder.py b/invokeai/backend/image_util/grounding_segment_anything/segment_anything/modeling/prompt_encoder.py index 1887c8d83f..1d00295ff6 100644 --- a/invokeai/backend/image_util/grounding_segment_anything/segment_anything/modeling/prompt_encoder.py +++ b/invokeai/backend/image_util/grounding_segment_anything/segment_anything/modeling/prompt_encoder.py @@ -10,7 +10,7 @@ import numpy as np import torch from torch import nn -from .common import LayerNorm2d +from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.common import LayerNorm2d class PromptEncoder(nn.Module): diff --git a/invokeai/backend/image_util/grounding_segment_anything/segment_anything/modeling/sam.py b/invokeai/backend/image_util/grounding_segment_anything/segment_anything/modeling/sam.py index c1ce4ecae2..5dfb03e124 100644 --- a/invokeai/backend/image_util/grounding_segment_anything/segment_anything/modeling/sam.py +++ b/invokeai/backend/image_util/grounding_segment_anything/segment_anything/modeling/sam.py @@ -10,9 +10,13 @@ import torch from torch import nn from torch.nn import functional as F -from .image_encoder import ImageEncoderViT -from .mask_decoder import MaskDecoder -from .prompt_encoder import PromptEncoder +from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.image_encoder import ( + ImageEncoderViT, +) +from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.mask_decoder import MaskDecoder +from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.prompt_encoder import ( + PromptEncoder, +) class Sam(nn.Module): @@ -98,7 +102,7 @@ class Sam(nn.Module): image_embeddings = self.image_encoder(input_images) outputs = [] - for image_record, curr_embedding in zip(batched_input, image_embeddings): + for image_record, curr_embedding in zip(batched_input, image_embeddings, strict=False): if "point_coords" in image_record: points = (image_record["point_coords"], image_record["point_labels"]) else: diff --git a/invokeai/backend/image_util/grounding_segment_anything/segment_anything/modeling/transformer.py b/invokeai/backend/image_util/grounding_segment_anything/segment_anything/modeling/transformer.py index 5bdd75377e..5d0054cee7 100644 --- a/invokeai/backend/image_util/grounding_segment_anything/segment_anything/modeling/transformer.py +++ b/invokeai/backend/image_util/grounding_segment_anything/segment_anything/modeling/transformer.py @@ -10,7 +10,7 @@ from typing import Tuple, Type import torch from torch import Tensor, nn -from .common import MLPBlock +from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling.common import MLPBlock class TwoWayTransformer(nn.Module): diff --git a/invokeai/backend/image_util/grounding_segment_anything/segment_anything/predictor.py b/invokeai/backend/image_util/grounding_segment_anything/segment_anything/predictor.py index cd165f8958..2f4f418524 100644 --- a/invokeai/backend/image_util/grounding_segment_anything/segment_anything/predictor.py +++ b/invokeai/backend/image_util/grounding_segment_anything/segment_anything/predictor.py @@ -9,8 +9,8 @@ from typing import Optional, Tuple import numpy as np import torch -from .modeling import Sam -from .utils.transforms import ResizeLongestSide +from invokeai.backend.image_util.grounding_segment_anything.segment_anything.modeling import Sam +from invokeai.backend.image_util.grounding_segment_anything.segment_anything.utils.transforms import ResizeLongestSide class SamPredictor: diff --git a/invokeai/backend/image_util/grounding_segment_anything/segment_anything/utils/amg.py b/invokeai/backend/image_util/grounding_segment_anything/segment_anything/utils/amg.py index 3a137778e4..1c9c491fa1 100644 --- a/invokeai/backend/image_util/grounding_segment_anything/segment_anything/utils/amg.py +++ b/invokeai/backend/image_util/grounding_segment_anything/segment_anything/utils/amg.py @@ -4,14 +4,14 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -import numpy as np -import torch - import math from copy import deepcopy from itertools import product from typing import Any, Dict, Generator, ItemsView, List, Tuple +import numpy as np +import torch + class MaskData: """ @@ -153,9 +153,7 @@ def area_from_rle(rle: Dict[str, Any]) -> int: return sum(rle["counts"][1::2]) -def calculate_stability_score( - masks: torch.Tensor, mask_threshold: float, threshold_offset: float -) -> torch.Tensor: +def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, threshold_offset: float) -> torch.Tensor: """ Computes the stability score for a batch of masks. The stability score is the IoU between the binary masks obtained by thresholding @@ -163,16 +161,8 @@ def calculate_stability_score( """ # One mask is always contained inside the other. # Save memory by preventing unnecesary cast to torch.int64 - intersections = ( - (masks > (mask_threshold + threshold_offset)) - .sum(-1, dtype=torch.int16) - .sum(-1, dtype=torch.int32) - ) - unions = ( - (masks > (mask_threshold - threshold_offset)) - .sum(-1, dtype=torch.int16) - .sum(-1, dtype=torch.int32) - ) + intersections = (masks > (mask_threshold + threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32) + unions = (masks > (mask_threshold - threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32) return intersections / unions @@ -186,9 +176,7 @@ def build_point_grid(n_per_side: int) -> np.ndarray: return points -def build_all_layer_point_grids( - n_per_side: int, n_layers: int, scale_per_layer: int -) -> List[np.ndarray]: +def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer: int) -> List[np.ndarray]: """Generates point grids for all crop layers.""" points_by_layer = [] for i in range(n_layers + 1): @@ -252,9 +240,7 @@ def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor: return points + offset -def uncrop_masks( - masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int -) -> torch.Tensor: +def uncrop_masks(masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int) -> torch.Tensor: x0, y0, x1, y1 = crop_box if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h: return masks @@ -264,9 +250,7 @@ def uncrop_masks( return torch.nn.functional.pad(masks, pad, value=0) -def remove_small_regions( - mask: np.ndarray, area_thresh: float, mode: str -) -> Tuple[np.ndarray, bool]: +def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tuple[np.ndarray, bool]: """ Removes small disconnected regions and holes in a mask. Returns the mask and an indicator of if the mask has been modified. diff --git a/invokeai/backend/image_util/grounding_segment_anything/segment_anything/utils/onnx.py b/invokeai/backend/image_util/grounding_segment_anything/segment_anything/utils/onnx.py deleted file mode 100644 index 4297b31291..0000000000 --- a/invokeai/backend/image_util/grounding_segment_anything/segment_anything/utils/onnx.py +++ /dev/null @@ -1,144 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import torch -import torch.nn as nn -from torch.nn import functional as F - -from typing import Tuple - -from ..modeling import Sam -from .amg import calculate_stability_score - - -class SamOnnxModel(nn.Module): - """ - This model should not be called directly, but is used in ONNX export. - It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, - with some functions modified to enable model tracing. Also supports extra - options controlling what information. See the ONNX export script for details. - """ - - def __init__( - self, - model: Sam, - return_single_mask: bool, - use_stability_score: bool = False, - return_extra_metrics: bool = False, - ) -> None: - super().__init__() - self.mask_decoder = model.mask_decoder - self.model = model - self.img_size = model.image_encoder.img_size - self.return_single_mask = return_single_mask - self.use_stability_score = use_stability_score - self.stability_score_offset = 1.0 - self.return_extra_metrics = return_extra_metrics - - @staticmethod - def resize_longest_image_size( - input_image_size: torch.Tensor, longest_side: int - ) -> torch.Tensor: - input_image_size = input_image_size.to(torch.float32) - scale = longest_side / torch.max(input_image_size) - transformed_size = scale * input_image_size - transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) - return transformed_size - - def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor: - point_coords = point_coords + 0.5 - point_coords = point_coords / self.img_size - point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) - point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) - - point_embedding = point_embedding * (point_labels != -1) - point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * ( - point_labels == -1 - ) - - for i in range(self.model.prompt_encoder.num_point_embeddings): - point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[ - i - ].weight * (point_labels == i) - - return point_embedding - - def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor: - mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask) - mask_embedding = mask_embedding + ( - 1 - has_mask_input - ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) - return mask_embedding - - def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor: - masks = F.interpolate( - masks, - size=(self.img_size, self.img_size), - mode="bilinear", - align_corners=False, - ) - - prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size) - masks = masks[..., : int(prepadded_size[0]), : int(prepadded_size[1])] - - orig_im_size = orig_im_size.to(torch.int64) - h, w = orig_im_size[0], orig_im_size[1] - masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) - return masks - - def select_masks( - self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int - ) -> Tuple[torch.Tensor, torch.Tensor]: - # Determine if we should return the multiclick mask or not from the number of points. - # The reweighting is used to avoid control flow. - score_reweight = torch.tensor( - [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)] - ).to(iou_preds.device) - score = iou_preds + (num_points - 2.5) * score_reweight - best_idx = torch.argmax(score, dim=1) - masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1) - iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1) - - return masks, iou_preds - - @torch.no_grad() - def forward( - self, - image_embeddings: torch.Tensor, - point_coords: torch.Tensor, - point_labels: torch.Tensor, - mask_input: torch.Tensor, - has_mask_input: torch.Tensor, - orig_im_size: torch.Tensor, - ): - sparse_embedding = self._embed_points(point_coords, point_labels) - dense_embedding = self._embed_masks(mask_input, has_mask_input) - - masks, scores = self.model.mask_decoder.predict_masks( - image_embeddings=image_embeddings, - image_pe=self.model.prompt_encoder.get_dense_pe(), - sparse_prompt_embeddings=sparse_embedding, - dense_prompt_embeddings=dense_embedding, - ) - - if self.use_stability_score: - scores = calculate_stability_score( - masks, self.model.mask_threshold, self.stability_score_offset - ) - - if self.return_single_mask: - masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) - - upscaled_masks = self.mask_postprocessing(masks, orig_im_size) - - if self.return_extra_metrics: - stability_scores = calculate_stability_score( - upscaled_masks, self.model.mask_threshold, self.stability_score_offset - ) - areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) - return upscaled_masks, scores, stability_scores, areas, masks - - return upscaled_masks, scores, masks diff --git a/invokeai/backend/image_util/grounding_segment_anything/segment_anything/utils/transforms.py b/invokeai/backend/image_util/grounding_segment_anything/segment_anything/utils/transforms.py index 3ad346661f..96a4ed6bc2 100644 --- a/invokeai/backend/image_util/grounding_segment_anything/segment_anything/utils/transforms.py +++ b/invokeai/backend/image_util/grounding_segment_anything/segment_anything/utils/transforms.py @@ -4,14 +4,14 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +from copy import deepcopy +from typing import Tuple + import numpy as np import torch from torch.nn import functional as F from torchvision.transforms.functional import resize, to_pil_image # type: ignore -from copy import deepcopy -from typing import Tuple - class ResizeLongestSide: """ @@ -36,9 +36,7 @@ class ResizeLongestSide: original image size in (H, W) format. """ old_h, old_w = original_size - new_h, new_w = self.get_preprocess_shape( - original_size[0], original_size[1], self.target_length - ) + new_h, new_w = self.get_preprocess_shape(original_size[0], original_size[1], self.target_length) coords = deepcopy(coords).astype(float) coords[..., 0] = coords[..., 0] * (new_w / old_w) coords[..., 1] = coords[..., 1] * (new_h / old_h) @@ -60,29 +58,21 @@ class ResizeLongestSide: """ # Expects an image in BCHW format. May not exactly match apply_image. target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) - return F.interpolate( - image, target_size, mode="bilinear", align_corners=False, antialias=True - ) + return F.interpolate(image, target_size, mode="bilinear", align_corners=False, antialias=True) - def apply_coords_torch( - self, coords: torch.Tensor, original_size: Tuple[int, ...] - ) -> torch.Tensor: + def apply_coords_torch(self, coords: torch.Tensor, original_size: Tuple[int, ...]) -> torch.Tensor: """ Expects a torch tensor with length 2 in the last dimension. Requires the original image size in (H, W) format. """ old_h, old_w = original_size - new_h, new_w = self.get_preprocess_shape( - original_size[0], original_size[1], self.target_length - ) + new_h, new_w = self.get_preprocess_shape(original_size[0], original_size[1], self.target_length) coords = deepcopy(coords).to(torch.float) coords[..., 0] = coords[..., 0] * (new_w / old_w) coords[..., 1] = coords[..., 1] * (new_h / old_h) return coords - def apply_boxes_torch( - self, boxes: torch.Tensor, original_size: Tuple[int, ...] - ) -> torch.Tensor: + def apply_boxes_torch(self, boxes: torch.Tensor, original_size: Tuple[int, ...]) -> torch.Tensor: """ Expects a torch tensor with shape Bx4. Requires the original image size in (H, W) format. diff --git a/pyproject.toml b/pyproject.toml index 9acaa17e44..74cbfbc172 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,12 +74,13 @@ dependencies = [ "easing-functions", "einops", "facexlib", - "matplotlib", # needed for plotting of Penner easing functions + "matplotlib", # needed for plotting of Penner easing functions "npyscreen", "omegaconf", "picklescan", "pillow", "prompt-toolkit", + "pycocotools", "pympler~=1.0.1", "pypatchmatch", 'pyperclip', @@ -90,6 +91,7 @@ dependencies = [ "scikit-image~=0.21.0", "semver~=3.0.1", "send2trash", + "supervision", "test-tube~=0.7.5", "windows-curses; sys_platform=='win32'", ]