diff --git a/invokeai/app/invocations/grounded_sam.py b/invokeai/app/invocations/grounded_sam.py index 5bfa5079e9..211016fac4 100644 --- a/invokeai/app/invocations/grounded_sam.py +++ b/invokeai/app/invocations/grounded_sam.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from pathlib import Path -from typing import Any, Optional +from typing import Any, Literal, Optional import numpy as np import numpy.typing as npt @@ -70,8 +70,8 @@ class DetectionResult: class GroundedSAMInvocation(BaseInvocation): """Runs Grounded-SAM, as proposed in https://arxiv.org/pdf/2401.14159. - More specifically, a Grounding DINO model is run to obtain bounding boxes for a text prompt, then the bounding box - is passed as a prompt to a Segment Anything model to obtain a segmentation mask. + More specifically, a Grounding DINO model is run to obtain bounding boxes for a text prompt, then the bounding boxes + are passed as a prompt to a Segment Anything model to obtain a segmentation mask. Reference: - https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam @@ -81,22 +81,38 @@ class GroundedSAMInvocation(BaseInvocation): prompt: str = InputField(description="The prompt describing the object to segment.") image: ImageField = InputField(description="The image to segment.") apply_polygon_refinement: bool = InputField( - description="Whether to apply polygon refinement to the mask. This will smooth the edges of the mask slightly " - "and ensure that the mask consists of a single closed polygon.", - default=False, + description="Whether to apply polygon refinement to the masks. This will smooth the edges of the mask slightly and ensure that each mask consists of a single closed polygon (before merging).", + default=True, + ) + mask_filter: Literal["all", "largest", "highest_box_score"] = InputField( + description="The filtering to apply to the detected masks before merging them into a final output.", + default="all", + ) + detection_threshold: float = InputField( + description="The detection threshold for the Grounding DINO model. All detected bounding boxes with scores above this threshold will be used.", + ge=0.0, + le=1.0, + default=0.1, ) def invoke(self, context: InvocationContext) -> ImageOutput: image_pil = context.images.get_pil(self.image.image_name) - detections = self._detect(context=context, image=image_pil, labels=[self.prompt]) - detections = self._segment(context=context, image=image_pil, detection_results=detections) + detections = self._detect( + context=context, image=image_pil, labels=[self.prompt], threshold=self.detection_threshold + ) + + if len(detections) == 0: + combined_mask = np.zeros(image_pil.size[::-1], dtype=np.uint8) + else: + detections = self._segment(context=context, image=image_pil, detection_results=detections) + + detections = self._filter_detections(detections) + masks = [detection.mask for detection in detections] + combined_mask = self._merge_masks(masks) - # Extract ouput mask. - mask_np = detections[0].mask - assert mask_np is not None # Map [0, 1] to [0, 255]. - mask_np = mask_np * 255 + mask_np = combined_mask * 255 mask_pil = Image.fromarray(mask_np) image_dto = context.images.save(image=mask_pil) @@ -119,6 +135,9 @@ class GroundedSAMInvocation(BaseInvocation): threshold: float = 0.3, ) -> list[DetectionResult]: """Use Grounding DINO to detect bounding boxes for a set of labels in an image.""" + # TODO(ryand): I copied this "."-handling logic from the transformers example code. Test it and see if it + # actually makes a difference. + labels = [label if label.endswith(".") else label + "." for label in labels] def load_grounding_dino(model_path: Path): grounding_dino_pipeline = pipeline( @@ -135,11 +154,7 @@ class GroundedSAMInvocation(BaseInvocation): with context.models.load_remote_model(source=GROUNDING_DINO_MODEL_ID, loader=load_grounding_dino) as detector: assert isinstance(detector, GroundingDinoPipeline) - # TODO(ryand): I copied this "."-handling logic from the transformers example code. Test it and see if it - # actually makes a difference. - labels = [label if label.endswith(".") else label + "." for label in labels] - - results = detector(image, candidate_labels=labels, threshold=threshold) + results = detector.detect(image=image, candidate_labels=labels, threshold=threshold) results = [DetectionResult.from_dict(result) for result in results] return results @@ -172,21 +187,34 @@ class GroundedSAMInvocation(BaseInvocation): boxes = self._to_box_array(detection_results) masks = sam_pipeline.segment(image=image, boxes=boxes) - masks = self._refine_masks(masks) - for detection_result, mask in zip(detection_results, masks, strict=False): - detection_result.mask = mask + masks = self._to_numpy_masks(masks) + masks = self._apply_polygon_refinement(masks) - return detection_results + for detection_result, mask in zip(detection_results, masks, strict=False): + detection_result.mask = mask - def _refine_masks(self, masks: torch.Tensor) -> list[npt.NDArray[np.uint8]]: + return detection_results + + def _to_numpy_masks(self, masks: torch.Tensor) -> list[npt.NDArray[np.uint8]]: + """Convert the tensor output from the Segment Anything model to a list of numpy masks.""" masks = masks.cpu().float() masks = masks.permute(0, 2, 3, 1) masks = masks.mean(axis=-1) masks = (masks > 0).int() masks = masks.numpy().astype(np.uint8) masks = list(masks) + return masks + def _apply_polygon_refinement(self, masks: list[npt.NDArray[np.uint8]]) -> list[npt.NDArray[np.uint8]]: + """Apply polygon refinement to the masks. + + Convert each mask to a polygon, then back to a mask. This has the following effect: + - Smooth the edges of the mask slightly. + - Ensure that each mask consists of a single closed polygon + - Removes small mask pieces. + - Removes holes from the mask. + """ if self.apply_polygon_refinement: for idx, mask in enumerate(masks): shape = mask.shape @@ -195,3 +223,29 @@ class GroundedSAMInvocation(BaseInvocation): masks[idx] = mask return masks + + def _filter_detections(self, detections: list[DetectionResult]) -> list[DetectionResult]: + """Filter the detected masks based on the specified mask filter.""" + if self.mask_filter == "all": + return detections + elif self.mask_filter == "largest": + # Find the largest mask. + mask_areas = [detection.mask.sum() for detection in detections] + largest_mask_idx = mask_areas.index(max(mask_areas)) + return [detections[largest_mask_idx]] + elif self.mask_filter == "highest_box_score": + # Find the detection with the highest box score. + max_score_detection = detections[0] + for detection in detections: + if detection.score > max_score_detection.score: + max_score_detection = detection + return [max_score_detection] + else: + raise ValueError(f"Invalid mask filter: {self.mask_filter}") + + def _merge_masks(self, masks: list[npt.NDArray[np.uint8]]) -> npt.NDArray[np.uint8]: + """Merge multiple masks into a single mask.""" + # Merge all masks together. + stacked_mask = np.stack(masks, axis=0) + combined_mask = np.max(stacked_mask, axis=0) + return combined_mask diff --git a/invokeai/backend/grounded_sam/grounding_dino_pipeline.py b/invokeai/backend/grounded_sam/grounding_dino_pipeline.py index d6c1f6a377..0143e300bc 100644 --- a/invokeai/backend/grounded_sam/grounding_dino_pipeline.py +++ b/invokeai/backend/grounded_sam/grounding_dino_pipeline.py @@ -1,6 +1,7 @@ from typing import Optional import torch +from PIL import Image from transformers.pipelines import ZeroShotObjectDetectionPipeline @@ -12,8 +13,8 @@ class GroundingDinoPipeline: def __init__(self, pipeline: ZeroShotObjectDetectionPipeline): self._pipeline = pipeline - def __call__(self, *args, **kwargs): - return self._pipeline(*args, **kwargs) + def detect(self, image: Image.Image, candidate_labels: list[str], threshold: float = 0.1): + return self._pipeline(image=image, candidate_labels=candidate_labels, threshold=threshold) def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> "GroundingDinoPipeline": self._pipeline.model.to(device=device, dtype=dtype) diff --git a/invokeai/backend/grounded_sam/segment_anything_model.py b/invokeai/backend/grounded_sam/segment_anything_model.py index 5072d1bc35..1c5b1cf456 100644 --- a/invokeai/backend/grounded_sam/segment_anything_model.py +++ b/invokeai/backend/grounded_sam/segment_anything_model.py @@ -30,6 +30,9 @@ class SegmentAnythingModel: masks=outputs.pred_masks, original_sizes=inputs.original_sizes, reshaped_input_sizes=inputs.reshaped_input_sizes, - )[0] + ) + # There should be only one batch. + assert len(masks) == 1 + masks = masks[0] return masks