mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add mask_filter and detection_threshold options to the GroundedSAMInvocation.
This commit is contained in:
parent
ff6398f7d8
commit
aca2a2fa13
@ -1,6 +1,6 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Optional
|
from typing import Any, Literal, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
@ -70,8 +70,8 @@ class DetectionResult:
|
|||||||
class GroundedSAMInvocation(BaseInvocation):
|
class GroundedSAMInvocation(BaseInvocation):
|
||||||
"""Runs Grounded-SAM, as proposed in https://arxiv.org/pdf/2401.14159.
|
"""Runs Grounded-SAM, as proposed in https://arxiv.org/pdf/2401.14159.
|
||||||
|
|
||||||
More specifically, a Grounding DINO model is run to obtain bounding boxes for a text prompt, then the bounding box
|
More specifically, a Grounding DINO model is run to obtain bounding boxes for a text prompt, then the bounding boxes
|
||||||
is passed as a prompt to a Segment Anything model to obtain a segmentation mask.
|
are passed as a prompt to a Segment Anything model to obtain a segmentation mask.
|
||||||
|
|
||||||
Reference:
|
Reference:
|
||||||
- https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam
|
- https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam
|
||||||
@ -81,22 +81,38 @@ class GroundedSAMInvocation(BaseInvocation):
|
|||||||
prompt: str = InputField(description="The prompt describing the object to segment.")
|
prompt: str = InputField(description="The prompt describing the object to segment.")
|
||||||
image: ImageField = InputField(description="The image to segment.")
|
image: ImageField = InputField(description="The image to segment.")
|
||||||
apply_polygon_refinement: bool = InputField(
|
apply_polygon_refinement: bool = InputField(
|
||||||
description="Whether to apply polygon refinement to the mask. This will smooth the edges of the mask slightly "
|
description="Whether to apply polygon refinement to the masks. This will smooth the edges of the mask slightly and ensure that each mask consists of a single closed polygon (before merging).",
|
||||||
"and ensure that the mask consists of a single closed polygon.",
|
default=True,
|
||||||
default=False,
|
)
|
||||||
|
mask_filter: Literal["all", "largest", "highest_box_score"] = InputField(
|
||||||
|
description="The filtering to apply to the detected masks before merging them into a final output.",
|
||||||
|
default="all",
|
||||||
|
)
|
||||||
|
detection_threshold: float = InputField(
|
||||||
|
description="The detection threshold for the Grounding DINO model. All detected bounding boxes with scores above this threshold will be used.",
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
default=0.1,
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image_pil = context.images.get_pil(self.image.image_name)
|
image_pil = context.images.get_pil(self.image.image_name)
|
||||||
|
|
||||||
detections = self._detect(context=context, image=image_pil, labels=[self.prompt])
|
detections = self._detect(
|
||||||
|
context=context, image=image_pil, labels=[self.prompt], threshold=self.detection_threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(detections) == 0:
|
||||||
|
combined_mask = np.zeros(image_pil.size[::-1], dtype=np.uint8)
|
||||||
|
else:
|
||||||
detections = self._segment(context=context, image=image_pil, detection_results=detections)
|
detections = self._segment(context=context, image=image_pil, detection_results=detections)
|
||||||
|
|
||||||
# Extract ouput mask.
|
detections = self._filter_detections(detections)
|
||||||
mask_np = detections[0].mask
|
masks = [detection.mask for detection in detections]
|
||||||
assert mask_np is not None
|
combined_mask = self._merge_masks(masks)
|
||||||
|
|
||||||
# Map [0, 1] to [0, 255].
|
# Map [0, 1] to [0, 255].
|
||||||
mask_np = mask_np * 255
|
mask_np = combined_mask * 255
|
||||||
mask_pil = Image.fromarray(mask_np)
|
mask_pil = Image.fromarray(mask_np)
|
||||||
|
|
||||||
image_dto = context.images.save(image=mask_pil)
|
image_dto = context.images.save(image=mask_pil)
|
||||||
@ -119,6 +135,9 @@ class GroundedSAMInvocation(BaseInvocation):
|
|||||||
threshold: float = 0.3,
|
threshold: float = 0.3,
|
||||||
) -> list[DetectionResult]:
|
) -> list[DetectionResult]:
|
||||||
"""Use Grounding DINO to detect bounding boxes for a set of labels in an image."""
|
"""Use Grounding DINO to detect bounding boxes for a set of labels in an image."""
|
||||||
|
# TODO(ryand): I copied this "."-handling logic from the transformers example code. Test it and see if it
|
||||||
|
# actually makes a difference.
|
||||||
|
labels = [label if label.endswith(".") else label + "." for label in labels]
|
||||||
|
|
||||||
def load_grounding_dino(model_path: Path):
|
def load_grounding_dino(model_path: Path):
|
||||||
grounding_dino_pipeline = pipeline(
|
grounding_dino_pipeline = pipeline(
|
||||||
@ -135,11 +154,7 @@ class GroundedSAMInvocation(BaseInvocation):
|
|||||||
with context.models.load_remote_model(source=GROUNDING_DINO_MODEL_ID, loader=load_grounding_dino) as detector:
|
with context.models.load_remote_model(source=GROUNDING_DINO_MODEL_ID, loader=load_grounding_dino) as detector:
|
||||||
assert isinstance(detector, GroundingDinoPipeline)
|
assert isinstance(detector, GroundingDinoPipeline)
|
||||||
|
|
||||||
# TODO(ryand): I copied this "."-handling logic from the transformers example code. Test it and see if it
|
results = detector.detect(image=image, candidate_labels=labels, threshold=threshold)
|
||||||
# actually makes a difference.
|
|
||||||
labels = [label if label.endswith(".") else label + "." for label in labels]
|
|
||||||
|
|
||||||
results = detector(image, candidate_labels=labels, threshold=threshold)
|
|
||||||
results = [DetectionResult.from_dict(result) for result in results]
|
results = [DetectionResult.from_dict(result) for result in results]
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@ -172,21 +187,34 @@ class GroundedSAMInvocation(BaseInvocation):
|
|||||||
|
|
||||||
boxes = self._to_box_array(detection_results)
|
boxes = self._to_box_array(detection_results)
|
||||||
masks = sam_pipeline.segment(image=image, boxes=boxes)
|
masks = sam_pipeline.segment(image=image, boxes=boxes)
|
||||||
masks = self._refine_masks(masks)
|
|
||||||
|
masks = self._to_numpy_masks(masks)
|
||||||
|
masks = self._apply_polygon_refinement(masks)
|
||||||
|
|
||||||
for detection_result, mask in zip(detection_results, masks, strict=False):
|
for detection_result, mask in zip(detection_results, masks, strict=False):
|
||||||
detection_result.mask = mask
|
detection_result.mask = mask
|
||||||
|
|
||||||
return detection_results
|
return detection_results
|
||||||
|
|
||||||
def _refine_masks(self, masks: torch.Tensor) -> list[npt.NDArray[np.uint8]]:
|
def _to_numpy_masks(self, masks: torch.Tensor) -> list[npt.NDArray[np.uint8]]:
|
||||||
|
"""Convert the tensor output from the Segment Anything model to a list of numpy masks."""
|
||||||
masks = masks.cpu().float()
|
masks = masks.cpu().float()
|
||||||
masks = masks.permute(0, 2, 3, 1)
|
masks = masks.permute(0, 2, 3, 1)
|
||||||
masks = masks.mean(axis=-1)
|
masks = masks.mean(axis=-1)
|
||||||
masks = (masks > 0).int()
|
masks = (masks > 0).int()
|
||||||
masks = masks.numpy().astype(np.uint8)
|
masks = masks.numpy().astype(np.uint8)
|
||||||
masks = list(masks)
|
masks = list(masks)
|
||||||
|
return masks
|
||||||
|
|
||||||
|
def _apply_polygon_refinement(self, masks: list[npt.NDArray[np.uint8]]) -> list[npt.NDArray[np.uint8]]:
|
||||||
|
"""Apply polygon refinement to the masks.
|
||||||
|
|
||||||
|
Convert each mask to a polygon, then back to a mask. This has the following effect:
|
||||||
|
- Smooth the edges of the mask slightly.
|
||||||
|
- Ensure that each mask consists of a single closed polygon
|
||||||
|
- Removes small mask pieces.
|
||||||
|
- Removes holes from the mask.
|
||||||
|
"""
|
||||||
if self.apply_polygon_refinement:
|
if self.apply_polygon_refinement:
|
||||||
for idx, mask in enumerate(masks):
|
for idx, mask in enumerate(masks):
|
||||||
shape = mask.shape
|
shape = mask.shape
|
||||||
@ -195,3 +223,29 @@ class GroundedSAMInvocation(BaseInvocation):
|
|||||||
masks[idx] = mask
|
masks[idx] = mask
|
||||||
|
|
||||||
return masks
|
return masks
|
||||||
|
|
||||||
|
def _filter_detections(self, detections: list[DetectionResult]) -> list[DetectionResult]:
|
||||||
|
"""Filter the detected masks based on the specified mask filter."""
|
||||||
|
if self.mask_filter == "all":
|
||||||
|
return detections
|
||||||
|
elif self.mask_filter == "largest":
|
||||||
|
# Find the largest mask.
|
||||||
|
mask_areas = [detection.mask.sum() for detection in detections]
|
||||||
|
largest_mask_idx = mask_areas.index(max(mask_areas))
|
||||||
|
return [detections[largest_mask_idx]]
|
||||||
|
elif self.mask_filter == "highest_box_score":
|
||||||
|
# Find the detection with the highest box score.
|
||||||
|
max_score_detection = detections[0]
|
||||||
|
for detection in detections:
|
||||||
|
if detection.score > max_score_detection.score:
|
||||||
|
max_score_detection = detection
|
||||||
|
return [max_score_detection]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid mask filter: {self.mask_filter}")
|
||||||
|
|
||||||
|
def _merge_masks(self, masks: list[npt.NDArray[np.uint8]]) -> npt.NDArray[np.uint8]:
|
||||||
|
"""Merge multiple masks into a single mask."""
|
||||||
|
# Merge all masks together.
|
||||||
|
stacked_mask = np.stack(masks, axis=0)
|
||||||
|
combined_mask = np.max(stacked_mask, axis=0)
|
||||||
|
return combined_mask
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from PIL import Image
|
||||||
from transformers.pipelines import ZeroShotObjectDetectionPipeline
|
from transformers.pipelines import ZeroShotObjectDetectionPipeline
|
||||||
|
|
||||||
|
|
||||||
@ -12,8 +13,8 @@ class GroundingDinoPipeline:
|
|||||||
def __init__(self, pipeline: ZeroShotObjectDetectionPipeline):
|
def __init__(self, pipeline: ZeroShotObjectDetectionPipeline):
|
||||||
self._pipeline = pipeline
|
self._pipeline = pipeline
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def detect(self, image: Image.Image, candidate_labels: list[str], threshold: float = 0.1):
|
||||||
return self._pipeline(*args, **kwargs)
|
return self._pipeline(image=image, candidate_labels=candidate_labels, threshold=threshold)
|
||||||
|
|
||||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> "GroundingDinoPipeline":
|
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> "GroundingDinoPipeline":
|
||||||
self._pipeline.model.to(device=device, dtype=dtype)
|
self._pipeline.model.to(device=device, dtype=dtype)
|
||||||
|
@ -30,6 +30,9 @@ class SegmentAnythingModel:
|
|||||||
masks=outputs.pred_masks,
|
masks=outputs.pred_masks,
|
||||||
original_sizes=inputs.original_sizes,
|
original_sizes=inputs.original_sizes,
|
||||||
reshaped_input_sizes=inputs.reshaped_input_sizes,
|
reshaped_input_sizes=inputs.reshaped_input_sizes,
|
||||||
)[0]
|
)
|
||||||
|
|
||||||
|
# There should be only one batch.
|
||||||
|
assert len(masks) == 1
|
||||||
|
masks = masks[0]
|
||||||
return masks
|
return masks
|
||||||
|
Loading…
x
Reference in New Issue
Block a user