2024-07-29 17:53:14 +00:00
from dataclasses import dataclass
from pathlib import Path
2024-07-30 18:22:40 +00:00
from typing import Any , Literal , Optional
2024-07-29 17:53:14 +00:00
import numpy as np
import numpy . typing as npt
import torch
from PIL import Image
from transformers import AutoModelForMaskGeneration , AutoProcessor , pipeline
from transformers . models . sam import SamModel
from transformers . models . sam . processing_sam import SamProcessor
from transformers . pipelines import ZeroShotObjectDetectionPipeline
from invokeai . app . invocations . baseinvocation import BaseInvocation , invocation
from invokeai . app . invocations . fields import ImageField , InputField
from invokeai . app . invocations . primitives import ImageOutput
from invokeai . app . services . shared . invocation_context import InvocationContext
from invokeai . backend . grounded_sam . grounding_dino_pipeline import GroundingDinoPipeline
from invokeai . backend . grounded_sam . mask_refinement import mask_to_polygon , polygon_to_mask
from invokeai . backend . grounded_sam . segment_anything_model import SegmentAnythingModel
GROUNDING_DINO_MODEL_ID = " IDEA-Research/grounding-dino-tiny "
SEGMENT_ANYTHING_MODEL_ID = " facebook/sam-vit-base "
@dataclass
class BoundingBox :
""" Bounding box helper class used locally for the Grounding DINO outputs. """
xmin : int
ymin : int
xmax : int
ymax : int
def to_box ( self ) - > list [ int ] :
""" Convert to the array notation expected by SAM. """
return [ self . xmin , self . ymin , self . xmax , self . ymax ]
@dataclass
class DetectionResult :
""" Detection result from Grounding DINO or Grounded SAM. """
score : float
label : str
box : BoundingBox
mask : Optional [ npt . NDArray [ Any ] ] = None
@classmethod
def from_dict ( cls , detection_dict : dict [ str , Any ] ) :
return cls (
score = detection_dict [ " score " ] ,
label = detection_dict [ " label " ] ,
box = BoundingBox (
xmin = detection_dict [ " box " ] [ " xmin " ] ,
ymin = detection_dict [ " box " ] [ " ymin " ] ,
xmax = detection_dict [ " box " ] [ " xmax " ] ,
ymax = detection_dict [ " box " ] [ " ymax " ] ,
) ,
)
@invocation (
" grounded_segment_anything " ,
title = " Segment Anything (Text Prompt) " ,
tags = [ " prompt " , " segmentation " ] ,
category = " segmentation " ,
version = " 1.0.0 " ,
)
class GroundedSAMInvocation ( BaseInvocation ) :
""" Runs Grounded-SAM, as proposed in https://arxiv.org/pdf/2401.14159.
2024-07-30 18:22:40 +00:00
More specifically , a Grounding DINO model is run to obtain bounding boxes for a text prompt , then the bounding boxes
are passed as a prompt to a Segment Anything model to obtain a segmentation mask .
2024-07-29 17:53:14 +00:00
Reference :
- https : / / huggingface . co / docs / transformers / v4 .43 .3 / en / model_doc / grounding - dino #grounded-sam
- https : / / github . com / NielsRogge / Transformers - Tutorials / blob / a39f33ac1557b02ebfb191ea7753e332b5ca933f / Grounding % 20 DINO / GroundingDINO_with_Segment_Anything . ipynb
"""
prompt : str = InputField ( description = " The prompt describing the object to segment. " )
image : ImageField = InputField ( description = " The image to segment. " )
apply_polygon_refinement : bool = InputField (
2024-07-30 18:22:40 +00:00
description = " Whether to apply polygon refinement to the masks. This will smooth the edges of the mask slightly and ensure that each mask consists of a single closed polygon (before merging). " ,
default = True ,
)
mask_filter : Literal [ " all " , " largest " , " highest_box_score " ] = InputField (
description = " The filtering to apply to the detected masks before merging them into a final output. " ,
default = " all " ,
)
detection_threshold : float = InputField (
description = " The detection threshold for the Grounding DINO model. All detected bounding boxes with scores above this threshold will be used. " ,
ge = 0.0 ,
le = 1.0 ,
default = 0.1 ,
2024-07-29 17:53:14 +00:00
)
def invoke ( self , context : InvocationContext ) - > ImageOutput :
image_pil = context . images . get_pil ( self . image . image_name )
2024-07-30 18:22:40 +00:00
detections = self . _detect (
context = context , image = image_pil , labels = [ self . prompt ] , threshold = self . detection_threshold
)
if len ( detections ) == 0 :
combined_mask = np . zeros ( image_pil . size [ : : - 1 ] , dtype = np . uint8 )
else :
detections = self . _segment ( context = context , image = image_pil , detection_results = detections )
detections = self . _filter_detections ( detections )
masks = [ detection . mask for detection in detections ]
combined_mask = self . _merge_masks ( masks )
2024-07-29 17:53:14 +00:00
# Map [0, 1] to [0, 255].
2024-07-30 18:22:40 +00:00
mask_np = combined_mask * 255
2024-07-29 17:53:14 +00:00
mask_pil = Image . fromarray ( mask_np )
image_dto = context . images . save ( image = mask_pil )
return ImageOutput . build ( image_dto )
def _to_box_array ( self , detection_results : list [ DetectionResult ] ) - > list [ list [ list [ int ] ] ] :
""" Convert a list of DetectionResults to the format expected by the Segment Anything model.
Args :
detection_results ( list [ DetectionResult ] ) : The Grounding DINO detection results .
"""
boxes = [ result . box . to_box ( ) for result in detection_results ]
return [ boxes ]
def _detect (
self ,
context : InvocationContext ,
image : Image . Image ,
labels : list [ str ] ,
threshold : float = 0.3 ,
) - > list [ DetectionResult ] :
""" Use Grounding DINO to detect bounding boxes for a set of labels in an image. """
2024-07-30 18:22:40 +00:00
# TODO(ryand): I copied this "."-handling logic from the transformers example code. Test it and see if it
# actually makes a difference.
labels = [ label if label . endswith ( " . " ) else label + " . " for label in labels ]
2024-07-29 17:53:14 +00:00
def load_grounding_dino ( model_path : Path ) :
grounding_dino_pipeline = pipeline (
model = str ( model_path ) ,
task = " zero-shot-object-detection " ,
local_files_only = True ,
# TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
# model, and figure out how to make it work in the pipeline.
# torch_dtype=TorchDevice.choose_torch_dtype(),
)
assert isinstance ( grounding_dino_pipeline , ZeroShotObjectDetectionPipeline )
return GroundingDinoPipeline ( grounding_dino_pipeline )
with context . models . load_remote_model ( source = GROUNDING_DINO_MODEL_ID , loader = load_grounding_dino ) as detector :
assert isinstance ( detector , GroundingDinoPipeline )
2024-07-30 18:22:40 +00:00
results = detector . detect ( image = image , candidate_labels = labels , threshold = threshold )
2024-07-29 17:53:14 +00:00
results = [ DetectionResult . from_dict ( result ) for result in results ]
return results
def _segment (
self ,
context : InvocationContext ,
image : Image . Image ,
detection_results : list [ DetectionResult ] ,
) - > list [ DetectionResult ] :
""" Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes. """
def load_sam_model ( model_path : Path ) :
sam_model = AutoModelForMaskGeneration . from_pretrained (
model_path ,
local_files_only = True ,
# TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
# model, and figure out how to make it work in the pipeline.
# torch_dtype=TorchDevice.choose_torch_dtype(),
)
assert isinstance ( sam_model , SamModel )
sam_processor = AutoProcessor . from_pretrained ( model_path , local_files_only = True )
assert isinstance ( sam_processor , SamProcessor )
return SegmentAnythingModel ( sam_model = sam_model , sam_processor = sam_processor )
with (
context . models . load_remote_model ( source = SEGMENT_ANYTHING_MODEL_ID , loader = load_sam_model ) as sam_pipeline ,
) :
assert isinstance ( sam_pipeline , SegmentAnythingModel )
boxes = self . _to_box_array ( detection_results )
masks = sam_pipeline . segment ( image = image , boxes = boxes )
2024-07-30 18:22:40 +00:00
masks = self . _to_numpy_masks ( masks )
masks = self . _apply_polygon_refinement ( masks )
2024-07-29 17:53:14 +00:00
2024-07-30 18:22:40 +00:00
for detection_result , mask in zip ( detection_results , masks , strict = False ) :
detection_result . mask = mask
2024-07-29 17:53:14 +00:00
2024-07-30 18:22:40 +00:00
return detection_results
def _to_numpy_masks ( self , masks : torch . Tensor ) - > list [ npt . NDArray [ np . uint8 ] ] :
""" Convert the tensor output from the Segment Anything model to a list of numpy masks. """
2024-07-29 17:53:14 +00:00
masks = masks . cpu ( ) . float ( )
masks = masks . permute ( 0 , 2 , 3 , 1 )
masks = masks . mean ( axis = - 1 )
masks = ( masks > 0 ) . int ( )
masks = masks . numpy ( ) . astype ( np . uint8 )
masks = list ( masks )
2024-07-30 18:22:40 +00:00
return masks
def _apply_polygon_refinement ( self , masks : list [ npt . NDArray [ np . uint8 ] ] ) - > list [ npt . NDArray [ np . uint8 ] ] :
""" Apply polygon refinement to the masks.
2024-07-29 17:53:14 +00:00
2024-07-30 18:22:40 +00:00
Convert each mask to a polygon , then back to a mask . This has the following effect :
- Smooth the edges of the mask slightly .
- Ensure that each mask consists of a single closed polygon
- Removes small mask pieces .
- Removes holes from the mask .
"""
2024-07-29 17:53:14 +00:00
if self . apply_polygon_refinement :
for idx , mask in enumerate ( masks ) :
shape = mask . shape
polygon = mask_to_polygon ( mask )
mask = polygon_to_mask ( polygon , shape )
masks [ idx ] = mask
return masks
2024-07-30 18:22:40 +00:00
def _filter_detections ( self , detections : list [ DetectionResult ] ) - > list [ DetectionResult ] :
""" Filter the detected masks based on the specified mask filter. """
if self . mask_filter == " all " :
return detections
elif self . mask_filter == " largest " :
# Find the largest mask.
mask_areas = [ detection . mask . sum ( ) for detection in detections ]
largest_mask_idx = mask_areas . index ( max ( mask_areas ) )
return [ detections [ largest_mask_idx ] ]
elif self . mask_filter == " highest_box_score " :
# Find the detection with the highest box score.
max_score_detection = detections [ 0 ]
for detection in detections :
if detection . score > max_score_detection . score :
max_score_detection = detection
return [ max_score_detection ]
else :
raise ValueError ( f " Invalid mask filter: { self . mask_filter } " )
def _merge_masks ( self , masks : list [ npt . NDArray [ np . uint8 ] ] ) - > npt . NDArray [ np . uint8 ] :
""" Merge multiple masks into a single mask. """
# Merge all masks together.
stacked_mask = np . stack ( masks , axis = 0 )
combined_mask = np . max ( stacked_mask , axis = 0 )
return combined_mask