2024-07-29 17:53:14 +00:00
from pathlib import Path
2024-07-30 19:34:33 +00:00
from typing import Literal
2024-07-29 17:53:14 +00:00
import numpy as np
import torch
from PIL import Image
2024-07-31 16:20:23 +00:00
from transformers import AutoModelForMaskGeneration , AutoProcessor
2024-07-29 17:53:14 +00:00
from transformers . models . sam import SamModel
from transformers . models . sam . processing_sam import SamProcessor
from invokeai . app . invocations . baseinvocation import BaseInvocation , invocation
2024-07-31 21:15:48 +00:00
from invokeai . app . invocations . fields import BoundingBoxField , ImageField , InputField , TensorField
from invokeai . app . invocations . primitives import MaskOutput
2024-07-29 17:53:14 +00:00
from invokeai . app . services . shared . invocation_context import InvocationContext
2024-07-31 16:28:47 +00:00
from invokeai . backend . image_util . segment_anything . mask_refinement import mask_to_polygon , polygon_to_mask
2024-08-01 13:57:47 +00:00
from invokeai . backend . image_util . segment_anything . segment_anything_pipeline import SegmentAnythingPipeline
2024-07-29 17:53:14 +00:00
SEGMENT_ANYTHING_MODEL_ID = " facebook/sam-vit-base "
@invocation (
2024-08-01 14:00:36 +00:00
" segment_anything " ,
title = " Segment Anything " ,
2024-07-29 17:53:14 +00:00
tags = [ " prompt " , " segmentation " ] ,
category = " segmentation " ,
version = " 1.0.0 " ,
)
2024-08-01 14:00:36 +00:00
class SegmentAnythingInvocation ( BaseInvocation ) :
2024-08-01 14:17:42 +00:00
""" Runs a Segment Anything Model. """
2024-07-29 17:53:14 +00:00
2024-08-01 14:17:42 +00:00
# Reference:
# - https://arxiv.org/pdf/2304.02643
# - https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam
# - https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb
2024-07-29 17:53:14 +00:00
image : ImageField = InputField ( description = " The image to segment. " )
2024-07-31 16:20:23 +00:00
bounding_boxes : list [ BoundingBoxField ] = InputField ( description = " The bounding boxes to prompt the SAM model with. " )
2024-07-29 17:53:14 +00:00
apply_polygon_refinement : bool = InputField (
2024-07-31 16:20:23 +00:00
description = " Whether to apply polygon refinement to the masks. This will smooth the edges of the masks slightly and ensure that each mask consists of a single closed polygon (before merging). " ,
2024-07-30 18:22:40 +00:00
default = True ,
)
mask_filter : Literal [ " all " , " largest " , " highest_box_score " ] = InputField (
description = " The filtering to apply to the detected masks before merging them into a final output. " ,
default = " all " ,
)
2024-07-29 17:53:14 +00:00
2024-07-30 19:34:33 +00:00
@torch.no_grad ( )
2024-07-31 21:15:48 +00:00
def invoke ( self , context : InvocationContext ) - > MaskOutput :
2024-07-30 19:55:57 +00:00
# The models expect a 3-channel RGB image.
image_pil = context . images . get_pil ( self . image . image_name , mode = " RGB " )
2024-07-29 17:53:14 +00:00
2024-07-31 16:20:23 +00:00
if len ( self . bounding_boxes ) == 0 :
2024-07-31 21:15:48 +00:00
combined_mask = torch . zeros ( image_pil . size [ : : - 1 ] , dtype = torch . bool )
2024-07-30 18:22:40 +00:00
else :
2024-07-31 16:20:23 +00:00
masks = self . _segment ( context = context , image = image_pil )
masks = self . _filter_masks ( masks = masks , bounding_boxes = self . bounding_boxes )
2024-07-29 17:53:14 +00:00
2024-07-31 21:15:48 +00:00
# masks contains bool values, so we merge them via max-reduce.
combined_mask , _ = torch . stack ( masks ) . max ( dim = 0 )
2024-07-29 17:53:14 +00:00
2024-07-31 21:15:48 +00:00
mask_tensor_name = context . tensors . save ( combined_mask )
height , width = combined_mask . shape
return MaskOutput ( mask = TensorField ( tensor_name = mask_tensor_name ) , width = width , height = height )
2024-07-29 17:53:14 +00:00
2024-07-31 13:28:52 +00:00
@staticmethod
def _load_sam_model ( model_path : Path ) :
sam_model = AutoModelForMaskGeneration . from_pretrained (
model_path ,
local_files_only = True ,
# TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
# model, and figure out how to make it work in the pipeline.
# torch_dtype=TorchDevice.choose_torch_dtype(),
)
assert isinstance ( sam_model , SamModel )
sam_processor = AutoProcessor . from_pretrained ( model_path , local_files_only = True )
assert isinstance ( sam_processor , SamProcessor )
2024-08-01 13:57:47 +00:00
return SegmentAnythingPipeline ( sam_model = sam_model , sam_processor = sam_processor )
2024-07-31 13:28:52 +00:00
2024-07-29 17:53:14 +00:00
def _segment (
self ,
context : InvocationContext ,
image : Image . Image ,
2024-07-31 21:15:48 +00:00
) - > list [ torch . Tensor ] :
2024-07-29 17:53:14 +00:00
""" Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes. """
2024-07-31 16:20:23 +00:00
# Convert the bounding boxes to the SAM input format.
sam_bounding_boxes = [ [ bb . x_min , bb . y_min , bb . x_max , bb . y_max ] for bb in self . bounding_boxes ]
2024-07-29 17:53:14 +00:00
with (
2024-07-31 13:28:52 +00:00
context . models . load_remote_model (
2024-08-01 14:00:36 +00:00
source = SEGMENT_ANYTHING_MODEL_ID , loader = SegmentAnythingInvocation . _load_sam_model
2024-07-31 13:28:52 +00:00
) as sam_pipeline ,
2024-07-29 17:53:14 +00:00
) :
2024-08-01 13:57:47 +00:00
assert isinstance ( sam_pipeline , SegmentAnythingPipeline )
2024-07-31 16:20:23 +00:00
masks = sam_pipeline . segment ( image = image , bounding_boxes = sam_bounding_boxes )
2024-07-29 17:53:14 +00:00
2024-07-31 21:15:48 +00:00
masks = self . _process_masks ( masks )
2024-07-31 12:50:56 +00:00
if self . apply_polygon_refinement :
masks = self . _apply_polygon_refinement ( masks )
2024-07-29 17:53:14 +00:00
2024-07-31 16:20:23 +00:00
return masks
2024-07-30 18:22:40 +00:00
2024-07-31 21:15:48 +00:00
def _process_masks ( self , masks : torch . Tensor ) - > list [ torch . Tensor ] :
""" Convert the tensor output from the Segment Anything model from a tensor of shape
[ num_masks , channels , height , width ] to a list of tensors of shape [ height , width ] .
"""
assert masks . dtype == torch . bool
2024-07-31 13:51:14 +00:00
# [num_masks, channels, height, width] -> [num_masks, height, width]
2024-07-31 21:15:48 +00:00
masks , _ = masks . max ( dim = 1 )
# Split the first dimension into a list of masks.
return list ( masks . cpu ( ) . unbind ( dim = 0 ) )
2024-07-30 18:22:40 +00:00
2024-07-31 21:15:48 +00:00
def _apply_polygon_refinement ( self , masks : list [ torch . Tensor ] ) - > list [ torch . Tensor ] :
2024-07-30 18:22:40 +00:00
""" Apply polygon refinement to the masks.
2024-07-29 17:53:14 +00:00
2024-07-30 18:22:40 +00:00
Convert each mask to a polygon , then back to a mask . This has the following effect :
- Smooth the edges of the mask slightly .
- Ensure that each mask consists of a single closed polygon
- Removes small mask pieces .
- Removes holes from the mask .
"""
2024-07-31 21:15:48 +00:00
# Convert tensor masks to np masks.
np_masks = [ mask . cpu ( ) . numpy ( ) . astype ( np . uint8 ) for mask in masks ]
# Apply polygon refinement.
for idx , mask in enumerate ( np_masks ) :
2024-07-31 12:50:56 +00:00
shape = mask . shape
assert len ( shape ) == 2 # Assert length to satisfy type checker.
polygon = mask_to_polygon ( mask )
mask = polygon_to_mask ( polygon , shape )
2024-07-31 21:15:48 +00:00
np_masks [ idx ] = mask
# Convert np masks back to tensor masks.
masks = [ torch . tensor ( mask , dtype = torch . bool ) for mask in np_masks ]
2024-07-29 17:53:14 +00:00
return masks
2024-07-30 18:22:40 +00:00
2024-07-31 21:15:48 +00:00
def _filter_masks ( self , masks : list [ torch . Tensor ] , bounding_boxes : list [ BoundingBoxField ] ) - > list [ torch . Tensor ] :
2024-07-30 18:22:40 +00:00
""" Filter the detected masks based on the specified mask filter. """
2024-07-31 16:20:23 +00:00
assert len ( masks ) == len ( bounding_boxes )
2024-07-30 18:22:40 +00:00
if self . mask_filter == " all " :
2024-07-31 16:20:23 +00:00
return masks
2024-07-30 18:22:40 +00:00
elif self . mask_filter == " largest " :
# Find the largest mask.
2024-07-31 21:15:48 +00:00
return [ max ( masks , key = lambda x : float ( x . sum ( ) ) ) ]
2024-07-30 18:22:40 +00:00
elif self . mask_filter == " highest_box_score " :
2024-07-31 16:20:23 +00:00
# Find the index of the bounding box with the highest score.
# Note that we fallback to -1.0 if the score is None. This is mainly to satisfy the type checker. In most
# cases the scores should all be non-None when using this filtering mode. That being said, -1.0 is a
# reasonable fallback since the expected score range is [0.0, 1.0].
max_score_idx = max ( range ( len ( bounding_boxes ) ) , key = lambda i : bounding_boxes [ i ] . score or - 1.0 )
return [ masks [ max_score_idx ] ]
2024-07-30 18:22:40 +00:00
else :
raise ValueError ( f " Invalid mask filter: { self . mask_filter } " )