Add mask_filter and detection_threshold options to the GroundedSAMInvocation.

This commit is contained in:
Ryan Dick 2024-07-30 14:22:40 -04:00
parent ff6398f7d8
commit aca2a2fa13
3 changed files with 83 additions and 25 deletions

View File

@ -1,6 +1,6 @@
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Any, Optional from typing import Any, Literal, Optional
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
@ -70,8 +70,8 @@ class DetectionResult:
class GroundedSAMInvocation(BaseInvocation): class GroundedSAMInvocation(BaseInvocation):
"""Runs Grounded-SAM, as proposed in https://arxiv.org/pdf/2401.14159. """Runs Grounded-SAM, as proposed in https://arxiv.org/pdf/2401.14159.
More specifically, a Grounding DINO model is run to obtain bounding boxes for a text prompt, then the bounding box More specifically, a Grounding DINO model is run to obtain bounding boxes for a text prompt, then the bounding boxes
is passed as a prompt to a Segment Anything model to obtain a segmentation mask. are passed as a prompt to a Segment Anything model to obtain a segmentation mask.
Reference: Reference:
- https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam - https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam
@ -81,22 +81,38 @@ class GroundedSAMInvocation(BaseInvocation):
prompt: str = InputField(description="The prompt describing the object to segment.") prompt: str = InputField(description="The prompt describing the object to segment.")
image: ImageField = InputField(description="The image to segment.") image: ImageField = InputField(description="The image to segment.")
apply_polygon_refinement: bool = InputField( apply_polygon_refinement: bool = InputField(
description="Whether to apply polygon refinement to the mask. This will smooth the edges of the mask slightly " description="Whether to apply polygon refinement to the masks. This will smooth the edges of the mask slightly and ensure that each mask consists of a single closed polygon (before merging).",
"and ensure that the mask consists of a single closed polygon.", default=True,
default=False, )
mask_filter: Literal["all", "largest", "highest_box_score"] = InputField(
description="The filtering to apply to the detected masks before merging them into a final output.",
default="all",
)
detection_threshold: float = InputField(
description="The detection threshold for the Grounding DINO model. All detected bounding boxes with scores above this threshold will be used.",
ge=0.0,
le=1.0,
default=0.1,
) )
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image_pil = context.images.get_pil(self.image.image_name) image_pil = context.images.get_pil(self.image.image_name)
detections = self._detect(context=context, image=image_pil, labels=[self.prompt]) detections = self._detect(
detections = self._segment(context=context, image=image_pil, detection_results=detections) context=context, image=image_pil, labels=[self.prompt], threshold=self.detection_threshold
)
if len(detections) == 0:
combined_mask = np.zeros(image_pil.size[::-1], dtype=np.uint8)
else:
detections = self._segment(context=context, image=image_pil, detection_results=detections)
detections = self._filter_detections(detections)
masks = [detection.mask for detection in detections]
combined_mask = self._merge_masks(masks)
# Extract ouput mask.
mask_np = detections[0].mask
assert mask_np is not None
# Map [0, 1] to [0, 255]. # Map [0, 1] to [0, 255].
mask_np = mask_np * 255 mask_np = combined_mask * 255
mask_pil = Image.fromarray(mask_np) mask_pil = Image.fromarray(mask_np)
image_dto = context.images.save(image=mask_pil) image_dto = context.images.save(image=mask_pil)
@ -119,6 +135,9 @@ class GroundedSAMInvocation(BaseInvocation):
threshold: float = 0.3, threshold: float = 0.3,
) -> list[DetectionResult]: ) -> list[DetectionResult]:
"""Use Grounding DINO to detect bounding boxes for a set of labels in an image.""" """Use Grounding DINO to detect bounding boxes for a set of labels in an image."""
# TODO(ryand): I copied this "."-handling logic from the transformers example code. Test it and see if it
# actually makes a difference.
labels = [label if label.endswith(".") else label + "." for label in labels]
def load_grounding_dino(model_path: Path): def load_grounding_dino(model_path: Path):
grounding_dino_pipeline = pipeline( grounding_dino_pipeline = pipeline(
@ -135,11 +154,7 @@ class GroundedSAMInvocation(BaseInvocation):
with context.models.load_remote_model(source=GROUNDING_DINO_MODEL_ID, loader=load_grounding_dino) as detector: with context.models.load_remote_model(source=GROUNDING_DINO_MODEL_ID, loader=load_grounding_dino) as detector:
assert isinstance(detector, GroundingDinoPipeline) assert isinstance(detector, GroundingDinoPipeline)
# TODO(ryand): I copied this "."-handling logic from the transformers example code. Test it and see if it results = detector.detect(image=image, candidate_labels=labels, threshold=threshold)
# actually makes a difference.
labels = [label if label.endswith(".") else label + "." for label in labels]
results = detector(image, candidate_labels=labels, threshold=threshold)
results = [DetectionResult.from_dict(result) for result in results] results = [DetectionResult.from_dict(result) for result in results]
return results return results
@ -172,21 +187,34 @@ class GroundedSAMInvocation(BaseInvocation):
boxes = self._to_box_array(detection_results) boxes = self._to_box_array(detection_results)
masks = sam_pipeline.segment(image=image, boxes=boxes) masks = sam_pipeline.segment(image=image, boxes=boxes)
masks = self._refine_masks(masks)
for detection_result, mask in zip(detection_results, masks, strict=False): masks = self._to_numpy_masks(masks)
detection_result.mask = mask masks = self._apply_polygon_refinement(masks)
return detection_results for detection_result, mask in zip(detection_results, masks, strict=False):
detection_result.mask = mask
def _refine_masks(self, masks: torch.Tensor) -> list[npt.NDArray[np.uint8]]: return detection_results
def _to_numpy_masks(self, masks: torch.Tensor) -> list[npt.NDArray[np.uint8]]:
"""Convert the tensor output from the Segment Anything model to a list of numpy masks."""
masks = masks.cpu().float() masks = masks.cpu().float()
masks = masks.permute(0, 2, 3, 1) masks = masks.permute(0, 2, 3, 1)
masks = masks.mean(axis=-1) masks = masks.mean(axis=-1)
masks = (masks > 0).int() masks = (masks > 0).int()
masks = masks.numpy().astype(np.uint8) masks = masks.numpy().astype(np.uint8)
masks = list(masks) masks = list(masks)
return masks
def _apply_polygon_refinement(self, masks: list[npt.NDArray[np.uint8]]) -> list[npt.NDArray[np.uint8]]:
"""Apply polygon refinement to the masks.
Convert each mask to a polygon, then back to a mask. This has the following effect:
- Smooth the edges of the mask slightly.
- Ensure that each mask consists of a single closed polygon
- Removes small mask pieces.
- Removes holes from the mask.
"""
if self.apply_polygon_refinement: if self.apply_polygon_refinement:
for idx, mask in enumerate(masks): for idx, mask in enumerate(masks):
shape = mask.shape shape = mask.shape
@ -195,3 +223,29 @@ class GroundedSAMInvocation(BaseInvocation):
masks[idx] = mask masks[idx] = mask
return masks return masks
def _filter_detections(self, detections: list[DetectionResult]) -> list[DetectionResult]:
"""Filter the detected masks based on the specified mask filter."""
if self.mask_filter == "all":
return detections
elif self.mask_filter == "largest":
# Find the largest mask.
mask_areas = [detection.mask.sum() for detection in detections]
largest_mask_idx = mask_areas.index(max(mask_areas))
return [detections[largest_mask_idx]]
elif self.mask_filter == "highest_box_score":
# Find the detection with the highest box score.
max_score_detection = detections[0]
for detection in detections:
if detection.score > max_score_detection.score:
max_score_detection = detection
return [max_score_detection]
else:
raise ValueError(f"Invalid mask filter: {self.mask_filter}")
def _merge_masks(self, masks: list[npt.NDArray[np.uint8]]) -> npt.NDArray[np.uint8]:
"""Merge multiple masks into a single mask."""
# Merge all masks together.
stacked_mask = np.stack(masks, axis=0)
combined_mask = np.max(stacked_mask, axis=0)
return combined_mask

View File

@ -1,6 +1,7 @@
from typing import Optional from typing import Optional
import torch import torch
from PIL import Image
from transformers.pipelines import ZeroShotObjectDetectionPipeline from transformers.pipelines import ZeroShotObjectDetectionPipeline
@ -12,8 +13,8 @@ class GroundingDinoPipeline:
def __init__(self, pipeline: ZeroShotObjectDetectionPipeline): def __init__(self, pipeline: ZeroShotObjectDetectionPipeline):
self._pipeline = pipeline self._pipeline = pipeline
def __call__(self, *args, **kwargs): def detect(self, image: Image.Image, candidate_labels: list[str], threshold: float = 0.1):
return self._pipeline(*args, **kwargs) return self._pipeline(image=image, candidate_labels=candidate_labels, threshold=threshold)
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> "GroundingDinoPipeline": def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> "GroundingDinoPipeline":
self._pipeline.model.to(device=device, dtype=dtype) self._pipeline.model.to(device=device, dtype=dtype)

View File

@ -30,6 +30,9 @@ class SegmentAnythingModel:
masks=outputs.pred_masks, masks=outputs.pred_masks,
original_sizes=inputs.original_sizes, original_sizes=inputs.original_sizes,
reshaped_input_sizes=inputs.reshaped_input_sizes, reshaped_input_sizes=inputs.reshaped_input_sizes,
)[0] )
# There should be only one batch.
assert len(masks) == 1
masks = masks[0]
return masks return masks