2024-07-29 17:53:14 +00:00
from pathlib import Path
2024-07-30 19:34:33 +00:00
from typing import Literal
2024-07-29 17:53:14 +00:00
import numpy as np
import numpy . typing as npt
import torch
from PIL import Image
2024-07-31 16:20:23 +00:00
from transformers import AutoModelForMaskGeneration , AutoProcessor
2024-07-29 17:53:14 +00:00
from transformers . models . sam import SamModel
from transformers . models . sam . processing_sam import SamProcessor
from invokeai . app . invocations . baseinvocation import BaseInvocation , invocation
2024-07-31 16:20:23 +00:00
from invokeai . app . invocations . fields import BoundingBoxField , ImageField , InputField
2024-07-29 17:53:14 +00:00
from invokeai . app . invocations . primitives import ImageOutput
from invokeai . app . services . shared . invocation_context import InvocationContext
2024-07-31 14:00:30 +00:00
from invokeai . backend . image_util . grounded_sam . mask_refinement import mask_to_polygon , polygon_to_mask
from invokeai . backend . image_util . grounded_sam . segment_anything_model import SegmentAnythingModel
2024-07-29 17:53:14 +00:00
SEGMENT_ANYTHING_MODEL_ID = " facebook/sam-vit-base "
@invocation (
2024-07-31 16:20:23 +00:00
" segment_anything_model " ,
title = " Segment Anything Model " ,
2024-07-29 17:53:14 +00:00
tags = [ " prompt " , " segmentation " ] ,
category = " segmentation " ,
version = " 1.0.0 " ,
)
2024-07-31 16:20:23 +00:00
class SegmentAnythingModelInvocation ( BaseInvocation ) :
""" Runs a Segment Anything Model (https://arxiv.org/pdf/2304.02643).
2024-07-29 17:53:14 +00:00
Reference :
- https : / / huggingface . co / docs / transformers / v4 .43 .3 / en / model_doc / grounding - dino #grounded-sam
- https : / / github . com / NielsRogge / Transformers - Tutorials / blob / a39f33ac1557b02ebfb191ea7753e332b5ca933f / Grounding % 20 DINO / GroundingDINO_with_Segment_Anything . ipynb
"""
image : ImageField = InputField ( description = " The image to segment. " )
2024-07-31 16:20:23 +00:00
bounding_boxes : list [ BoundingBoxField ] = InputField ( description = " The bounding boxes to prompt the SAM model with. " )
2024-07-29 17:53:14 +00:00
apply_polygon_refinement : bool = InputField (
2024-07-31 16:20:23 +00:00
description = " Whether to apply polygon refinement to the masks. This will smooth the edges of the masks slightly and ensure that each mask consists of a single closed polygon (before merging). " ,
2024-07-30 18:22:40 +00:00
default = True ,
)
mask_filter : Literal [ " all " , " largest " , " highest_box_score " ] = InputField (
description = " The filtering to apply to the detected masks before merging them into a final output. " ,
default = " all " ,
)
2024-07-29 17:53:14 +00:00
2024-07-30 19:34:33 +00:00
@torch.no_grad ( )
2024-07-29 17:53:14 +00:00
def invoke ( self , context : InvocationContext ) - > ImageOutput :
2024-07-30 19:55:57 +00:00
# The models expect a 3-channel RGB image.
image_pil = context . images . get_pil ( self . image . image_name , mode = " RGB " )
2024-07-29 17:53:14 +00:00
2024-07-31 16:20:23 +00:00
if len ( self . bounding_boxes ) == 0 :
2024-07-30 18:22:40 +00:00
combined_mask = np . zeros ( image_pil . size [ : : - 1 ] , dtype = np . uint8 )
else :
2024-07-31 16:20:23 +00:00
masks = self . _segment ( context = context , image = image_pil )
masks = self . _filter_masks ( masks = masks , bounding_boxes = self . bounding_boxes )
2024-07-31 12:58:51 +00:00
# masks contains binary values of 0 or 1, so we merge them via max-reduce.
combined_mask = np . maximum . reduce ( masks )
2024-07-29 17:53:14 +00:00
# Map [0, 1] to [0, 255].
2024-07-30 18:22:40 +00:00
mask_np = combined_mask * 255
2024-07-29 17:53:14 +00:00
mask_pil = Image . fromarray ( mask_np )
image_dto = context . images . save ( image = mask_pil )
return ImageOutput . build ( image_dto )
2024-07-31 13:28:52 +00:00
@staticmethod
def _load_sam_model ( model_path : Path ) :
sam_model = AutoModelForMaskGeneration . from_pretrained (
model_path ,
local_files_only = True ,
# TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
# model, and figure out how to make it work in the pipeline.
# torch_dtype=TorchDevice.choose_torch_dtype(),
)
assert isinstance ( sam_model , SamModel )
sam_processor = AutoProcessor . from_pretrained ( model_path , local_files_only = True )
assert isinstance ( sam_processor , SamProcessor )
return SegmentAnythingModel ( sam_model = sam_model , sam_processor = sam_processor )
2024-07-29 17:53:14 +00:00
def _segment (
self ,
context : InvocationContext ,
image : Image . Image ,
2024-07-31 16:20:23 +00:00
) - > list [ npt . NDArray [ np . uint8 ] ] :
2024-07-29 17:53:14 +00:00
""" Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes. """
2024-07-31 16:20:23 +00:00
# Convert the bounding boxes to the SAM input format.
sam_bounding_boxes = [ [ bb . x_min , bb . y_min , bb . x_max , bb . y_max ] for bb in self . bounding_boxes ]
2024-07-29 17:53:14 +00:00
with (
2024-07-31 13:28:52 +00:00
context . models . load_remote_model (
2024-07-31 16:20:23 +00:00
source = SEGMENT_ANYTHING_MODEL_ID , loader = SegmentAnythingModelInvocation . _load_sam_model
2024-07-31 13:28:52 +00:00
) as sam_pipeline ,
2024-07-29 17:53:14 +00:00
) :
assert isinstance ( sam_pipeline , SegmentAnythingModel )
2024-07-31 16:20:23 +00:00
masks = sam_pipeline . segment ( image = image , bounding_boxes = sam_bounding_boxes )
2024-07-29 17:53:14 +00:00
2024-07-30 18:22:40 +00:00
masks = self . _to_numpy_masks ( masks )
2024-07-31 12:50:56 +00:00
if self . apply_polygon_refinement :
masks = self . _apply_polygon_refinement ( masks )
2024-07-29 17:53:14 +00:00
2024-07-31 16:20:23 +00:00
return masks
2024-07-30 18:22:40 +00:00
def _to_numpy_masks ( self , masks : torch . Tensor ) - > list [ npt . NDArray [ np . uint8 ] ] :
""" Convert the tensor output from the Segment Anything model to a list of numpy masks. """
2024-07-31 13:51:14 +00:00
eps = 0.0001
# [num_masks, channels, height, width] -> [num_masks, height, width]
masks = masks . permute ( 0 , 2 , 3 , 1 ) . float ( ) . mean ( dim = - 1 )
masks = masks > eps
np_masks = masks . cpu ( ) . numpy ( ) . astype ( np . uint8 )
2024-07-31 12:27:01 +00:00
return list ( np_masks )
2024-07-30 18:22:40 +00:00
def _apply_polygon_refinement ( self , masks : list [ npt . NDArray [ np . uint8 ] ] ) - > list [ npt . NDArray [ np . uint8 ] ] :
""" Apply polygon refinement to the masks.
2024-07-29 17:53:14 +00:00
2024-07-30 18:22:40 +00:00
Convert each mask to a polygon , then back to a mask . This has the following effect :
- Smooth the edges of the mask slightly .
- Ensure that each mask consists of a single closed polygon
- Removes small mask pieces .
- Removes holes from the mask .
"""
2024-07-31 12:50:56 +00:00
for idx , mask in enumerate ( masks ) :
shape = mask . shape
assert len ( shape ) == 2 # Assert length to satisfy type checker.
polygon = mask_to_polygon ( mask )
mask = polygon_to_mask ( polygon , shape )
masks [ idx ] = mask
2024-07-29 17:53:14 +00:00
return masks
2024-07-30 18:22:40 +00:00
2024-07-31 16:20:23 +00:00
def _filter_masks (
self , masks : list [ npt . NDArray [ np . uint8 ] ] , bounding_boxes : list [ BoundingBoxField ]
) - > list [ npt . NDArray [ np . uint8 ] ] :
2024-07-30 18:22:40 +00:00
""" Filter the detected masks based on the specified mask filter. """
2024-07-31 16:20:23 +00:00
assert len ( masks ) == len ( bounding_boxes )
2024-07-30 18:22:40 +00:00
if self . mask_filter == " all " :
2024-07-31 16:20:23 +00:00
return masks
2024-07-30 18:22:40 +00:00
elif self . mask_filter == " largest " :
# Find the largest mask.
2024-07-31 16:20:23 +00:00
return [ max ( masks , key = lambda x : x . sum ( ) ) ]
2024-07-30 18:22:40 +00:00
elif self . mask_filter == " highest_box_score " :
2024-07-31 16:20:23 +00:00
# Find the index of the bounding box with the highest score.
# Note that we fallback to -1.0 if the score is None. This is mainly to satisfy the type checker. In most
# cases the scores should all be non-None when using this filtering mode. That being said, -1.0 is a
# reasonable fallback since the expected score range is [0.0, 1.0].
max_score_idx = max ( range ( len ( bounding_boxes ) ) , key = lambda i : bounding_boxes [ i ] . score or - 1.0 )
return [ masks [ max_score_idx ] ]
2024-07-30 18:22:40 +00:00
else :
raise ValueError ( f " Invalid mask filter: { self . mask_filter } " )