Add mask_filter and detection_threshold options to the GroundedSAMInvocation.

This commit is contained in:
Ryan Dick 2024-07-30 14:22:40 -04:00
parent ff6398f7d8
commit aca2a2fa13
3 changed files with 83 additions and 25 deletions

View File

@ -1,6 +1,6 @@
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Optional
from typing import Any, Literal, Optional
import numpy as np
import numpy.typing as npt
@ -70,8 +70,8 @@ class DetectionResult:
class GroundedSAMInvocation(BaseInvocation):
"""Runs Grounded-SAM, as proposed in https://arxiv.org/pdf/2401.14159.
More specifically, a Grounding DINO model is run to obtain bounding boxes for a text prompt, then the bounding box
is passed as a prompt to a Segment Anything model to obtain a segmentation mask.
More specifically, a Grounding DINO model is run to obtain bounding boxes for a text prompt, then the bounding boxes
are passed as a prompt to a Segment Anything model to obtain a segmentation mask.
Reference:
- https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam
@ -81,22 +81,38 @@ class GroundedSAMInvocation(BaseInvocation):
prompt: str = InputField(description="The prompt describing the object to segment.")
image: ImageField = InputField(description="The image to segment.")
apply_polygon_refinement: bool = InputField(
description="Whether to apply polygon refinement to the mask. This will smooth the edges of the mask slightly "
"and ensure that the mask consists of a single closed polygon.",
default=False,
description="Whether to apply polygon refinement to the masks. This will smooth the edges of the mask slightly and ensure that each mask consists of a single closed polygon (before merging).",
default=True,
)
mask_filter: Literal["all", "largest", "highest_box_score"] = InputField(
description="The filtering to apply to the detected masks before merging them into a final output.",
default="all",
)
detection_threshold: float = InputField(
description="The detection threshold for the Grounding DINO model. All detected bounding boxes with scores above this threshold will be used.",
ge=0.0,
le=1.0,
default=0.1,
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image_pil = context.images.get_pil(self.image.image_name)
detections = self._detect(context=context, image=image_pil, labels=[self.prompt])
detections = self._segment(context=context, image=image_pil, detection_results=detections)
detections = self._detect(
context=context, image=image_pil, labels=[self.prompt], threshold=self.detection_threshold
)
if len(detections) == 0:
combined_mask = np.zeros(image_pil.size[::-1], dtype=np.uint8)
else:
detections = self._segment(context=context, image=image_pil, detection_results=detections)
detections = self._filter_detections(detections)
masks = [detection.mask for detection in detections]
combined_mask = self._merge_masks(masks)
# Extract ouput mask.
mask_np = detections[0].mask
assert mask_np is not None
# Map [0, 1] to [0, 255].
mask_np = mask_np * 255
mask_np = combined_mask * 255
mask_pil = Image.fromarray(mask_np)
image_dto = context.images.save(image=mask_pil)
@ -119,6 +135,9 @@ class GroundedSAMInvocation(BaseInvocation):
threshold: float = 0.3,
) -> list[DetectionResult]:
"""Use Grounding DINO to detect bounding boxes for a set of labels in an image."""
# TODO(ryand): I copied this "."-handling logic from the transformers example code. Test it and see if it
# actually makes a difference.
labels = [label if label.endswith(".") else label + "." for label in labels]
def load_grounding_dino(model_path: Path):
grounding_dino_pipeline = pipeline(
@ -135,11 +154,7 @@ class GroundedSAMInvocation(BaseInvocation):
with context.models.load_remote_model(source=GROUNDING_DINO_MODEL_ID, loader=load_grounding_dino) as detector:
assert isinstance(detector, GroundingDinoPipeline)
# TODO(ryand): I copied this "."-handling logic from the transformers example code. Test it and see if it
# actually makes a difference.
labels = [label if label.endswith(".") else label + "." for label in labels]
results = detector(image, candidate_labels=labels, threshold=threshold)
results = detector.detect(image=image, candidate_labels=labels, threshold=threshold)
results = [DetectionResult.from_dict(result) for result in results]
return results
@ -172,21 +187,34 @@ class GroundedSAMInvocation(BaseInvocation):
boxes = self._to_box_array(detection_results)
masks = sam_pipeline.segment(image=image, boxes=boxes)
masks = self._refine_masks(masks)
for detection_result, mask in zip(detection_results, masks, strict=False):
detection_result.mask = mask
masks = self._to_numpy_masks(masks)
masks = self._apply_polygon_refinement(masks)
return detection_results
for detection_result, mask in zip(detection_results, masks, strict=False):
detection_result.mask = mask
def _refine_masks(self, masks: torch.Tensor) -> list[npt.NDArray[np.uint8]]:
return detection_results
def _to_numpy_masks(self, masks: torch.Tensor) -> list[npt.NDArray[np.uint8]]:
"""Convert the tensor output from the Segment Anything model to a list of numpy masks."""
masks = masks.cpu().float()
masks = masks.permute(0, 2, 3, 1)
masks = masks.mean(axis=-1)
masks = (masks > 0).int()
masks = masks.numpy().astype(np.uint8)
masks = list(masks)
return masks
def _apply_polygon_refinement(self, masks: list[npt.NDArray[np.uint8]]) -> list[npt.NDArray[np.uint8]]:
"""Apply polygon refinement to the masks.
Convert each mask to a polygon, then back to a mask. This has the following effect:
- Smooth the edges of the mask slightly.
- Ensure that each mask consists of a single closed polygon
- Removes small mask pieces.
- Removes holes from the mask.
"""
if self.apply_polygon_refinement:
for idx, mask in enumerate(masks):
shape = mask.shape
@ -195,3 +223,29 @@ class GroundedSAMInvocation(BaseInvocation):
masks[idx] = mask
return masks
def _filter_detections(self, detections: list[DetectionResult]) -> list[DetectionResult]:
"""Filter the detected masks based on the specified mask filter."""
if self.mask_filter == "all":
return detections
elif self.mask_filter == "largest":
# Find the largest mask.
mask_areas = [detection.mask.sum() for detection in detections]
largest_mask_idx = mask_areas.index(max(mask_areas))
return [detections[largest_mask_idx]]
elif self.mask_filter == "highest_box_score":
# Find the detection with the highest box score.
max_score_detection = detections[0]
for detection in detections:
if detection.score > max_score_detection.score:
max_score_detection = detection
return [max_score_detection]
else:
raise ValueError(f"Invalid mask filter: {self.mask_filter}")
def _merge_masks(self, masks: list[npt.NDArray[np.uint8]]) -> npt.NDArray[np.uint8]:
"""Merge multiple masks into a single mask."""
# Merge all masks together.
stacked_mask = np.stack(masks, axis=0)
combined_mask = np.max(stacked_mask, axis=0)
return combined_mask

View File

@ -1,6 +1,7 @@
from typing import Optional
import torch
from PIL import Image
from transformers.pipelines import ZeroShotObjectDetectionPipeline
@ -12,8 +13,8 @@ class GroundingDinoPipeline:
def __init__(self, pipeline: ZeroShotObjectDetectionPipeline):
self._pipeline = pipeline
def __call__(self, *args, **kwargs):
return self._pipeline(*args, **kwargs)
def detect(self, image: Image.Image, candidate_labels: list[str], threshold: float = 0.1):
return self._pipeline(image=image, candidate_labels=candidate_labels, threshold=threshold)
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> "GroundingDinoPipeline":
self._pipeline.model.to(device=device, dtype=dtype)

View File

@ -30,6 +30,9 @@ class SegmentAnythingModel:
masks=outputs.pred_masks,
original_sizes=inputs.original_sizes,
reshaped_input_sizes=inputs.reshaped_input_sizes,
)[0]
)
# There should be only one batch.
assert len(masks) == 1
masks = masks[0]
return masks