Split GroundedSamInvocation into GroundingDinoInvocation and SegmentAnythingModelInvocation.

This commit is contained in:
Ryan Dick 2024-07-31 12:20:23 -04:00
parent 73386826d6
commit 0193267a53
6 changed files with 180 additions and 93 deletions

View File

@ -242,6 +242,23 @@ class ConditioningField(BaseModel):
) )
class BoundingBoxField(BaseModel):
"""A bounding box primitive value."""
x_min: int = Field(ge=0, description="The minimum x-coordinate of the bounding box (inclusive).")
x_max: int = Field(ge=0, description="The maximum x-coordinate of the bounding box (exclusive).")
y_min: int = Field(ge=0, description="The minimum y-coordinate of the bounding box (inclusive).")
y_max: int = Field(ge=0, description="The maximum y-coordinate of the bounding box (exclusive).")
score: Optional[float] = Field(
default=None,
ge=0.0,
le=1.0,
description="The score associated with the bounding box. In the range [0, 1]. This value is typically set "
"when the bounding box was produced by a detector and has an associated confidence score.",
)
class MetadataField(RootModel[dict[str, Any]]): class MetadataField(RootModel[dict[str, Any]]):
""" """
Pydantic model for metadata with custom root of type dict[str, Any]. Pydantic model for metadata with custom root of type dict[str, Any].

View File

@ -0,0 +1,95 @@
from pathlib import Path
import torch
from PIL import Image
from transformers import pipeline
from transformers.pipelines import ZeroShotObjectDetectionPipeline
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import BoundingBoxField, ImageField, InputField
from invokeai.app.invocations.primitives import BoundingBoxCollectionOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.grounded_sam.detection_result import DetectionResult
from invokeai.backend.image_util.grounded_sam.grounding_dino_pipeline import GroundingDinoPipeline
GROUNDING_DINO_MODEL_ID = "IDEA-Research/grounding-dino-tiny"
@invocation(
"grounding_dino",
title="Grounding DINO (Text Prompt Object Detection)",
tags=["prompt", "object detection"],
category="image",
version="1.0.0",
)
class GroundingDinoInvocation(BaseInvocation):
"""Runs a Grounding DINO model (https://arxiv.org/pdf/2303.05499). Performs zero-shot bounding-box object detection
from a text prompt.
Reference:
- https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam
- https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb
"""
prompt: str = InputField(description="The prompt describing the object to segment.")
image: ImageField = InputField(description="The image to segment.")
detection_threshold: float = InputField(
description="The detection threshold for the Grounding DINO model. All detected bounding boxes with scores above this threshold will be returned.",
ge=0.0,
le=1.0,
default=0.3,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> BoundingBoxCollectionOutput:
# The model expects a 3-channel RGB image.
image_pil = context.images.get_pil(self.image.image_name, mode="RGB")
detections = self._detect(
context=context, image=image_pil, labels=[self.prompt], threshold=self.detection_threshold
)
# Convert detections to BoundingBoxCollectionOutput.
bounding_boxes: list[BoundingBoxField] = []
for detection in detections:
bounding_boxes.append(
BoundingBoxField(
x_min=detection.box.xmin,
x_max=detection.box.xmax,
y_min=detection.box.ymin,
y_max=detection.box.ymax,
score=detection.score,
)
)
return BoundingBoxCollectionOutput(collection=bounding_boxes)
@staticmethod
def _load_grounding_dino(model_path: Path):
grounding_dino_pipeline = pipeline(
model=str(model_path),
task="zero-shot-object-detection",
local_files_only=True,
# TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
# model, and figure out how to make it work in the pipeline.
# torch_dtype=TorchDevice.choose_torch_dtype(),
)
assert isinstance(grounding_dino_pipeline, ZeroShotObjectDetectionPipeline)
return GroundingDinoPipeline(grounding_dino_pipeline)
def _detect(
self,
context: InvocationContext,
image: Image.Image,
labels: list[str],
threshold: float = 0.3,
) -> list[DetectionResult]:
"""Use Grounding DINO to detect bounding boxes for a set of labels in an image."""
# TODO(ryand): I copied this "."-handling logic from the transformers example code. Test it and see if it
# actually makes a difference.
labels = [label if label.endswith(".") else label + "." for label in labels]
with context.models.load_remote_model(
source=GROUNDING_DINO_MODEL_ID, loader=GroundingDinoInvocation._load_grounding_dino
) as detector:
assert isinstance(detector, GroundingDinoPipeline)
return detector.detect(image=image, candidate_labels=labels, threshold=threshold)

View File

@ -7,6 +7,7 @@ import torch
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import ( from invokeai.app.invocations.fields import (
BoundingBoxField,
ColorField, ColorField,
ConditioningField, ConditioningField,
DenoiseMaskField, DenoiseMaskField,
@ -469,3 +470,24 @@ class ConditioningCollectionInvocation(BaseInvocation):
# endregion # endregion
# region BoundingBox
@invocation_output("bounding_box_output")
class BoundingBoxOutput(BaseInvocationOutput):
"""Base class for nodes that output a single bounding box"""
bounding_box: BoundingBoxField = OutputField(description="The output bounding box.")
@invocation_output("bounding_box_collection_output")
class BoundingBoxCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of bounding boxes"""
collection: list[BoundingBoxField] = OutputField(
description="The output bounding boxes.",
)
# endregion

View File

@ -5,75 +5,56 @@ import numpy as np
import numpy.typing as npt import numpy.typing as npt
import torch import torch
from PIL import Image from PIL import Image
from transformers import AutoModelForMaskGeneration, AutoProcessor, pipeline from transformers import AutoModelForMaskGeneration, AutoProcessor
from transformers.models.sam import SamModel from transformers.models.sam import SamModel
from transformers.models.sam.processing_sam import SamProcessor from transformers.models.sam.processing_sam import SamProcessor
from transformers.pipelines import ZeroShotObjectDetectionPipeline
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import ImageField, InputField from invokeai.app.invocations.fields import BoundingBoxField, ImageField, InputField
from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.grounded_sam.detection_result import DetectionResult
from invokeai.backend.image_util.grounded_sam.grounding_dino_pipeline import GroundingDinoPipeline
from invokeai.backend.image_util.grounded_sam.mask_refinement import mask_to_polygon, polygon_to_mask from invokeai.backend.image_util.grounded_sam.mask_refinement import mask_to_polygon, polygon_to_mask
from invokeai.backend.image_util.grounded_sam.segment_anything_model import SegmentAnythingModel from invokeai.backend.image_util.grounded_sam.segment_anything_model import SegmentAnythingModel
GROUNDING_DINO_MODEL_ID = "IDEA-Research/grounding-dino-tiny"
SEGMENT_ANYTHING_MODEL_ID = "facebook/sam-vit-base" SEGMENT_ANYTHING_MODEL_ID = "facebook/sam-vit-base"
@invocation( @invocation(
"grounded_segment_anything", "segment_anything_model",
title="Segment Anything (Text Prompt)", title="Segment Anything Model",
tags=["prompt", "segmentation"], tags=["prompt", "segmentation"],
category="segmentation", category="segmentation",
version="1.0.0", version="1.0.0",
) )
class GroundedSAMInvocation(BaseInvocation): class SegmentAnythingModelInvocation(BaseInvocation):
"""Runs Grounded-SAM, as proposed in https://arxiv.org/pdf/2401.14159. """Runs a Segment Anything Model (https://arxiv.org/pdf/2304.02643).
More specifically, a Grounding DINO model is run to obtain bounding boxes for a text prompt, then the bounding boxes
are passed as a prompt to a Segment Anything model to obtain a segmentation mask.
Reference: Reference:
- https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam - https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam
- https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb - https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb
""" """
prompt: str = InputField(description="The prompt describing the object to segment.")
image: ImageField = InputField(description="The image to segment.") image: ImageField = InputField(description="The image to segment.")
bounding_boxes: list[BoundingBoxField] = InputField(description="The bounding boxes to prompt the SAM model with.")
apply_polygon_refinement: bool = InputField( apply_polygon_refinement: bool = InputField(
description="Whether to apply polygon refinement to the masks. This will smooth the edges of the mask slightly and ensure that each mask consists of a single closed polygon (before merging).", description="Whether to apply polygon refinement to the masks. This will smooth the edges of the masks slightly and ensure that each mask consists of a single closed polygon (before merging).",
default=True, default=True,
) )
mask_filter: Literal["all", "largest", "highest_box_score"] = InputField( mask_filter: Literal["all", "largest", "highest_box_score"] = InputField(
description="The filtering to apply to the detected masks before merging them into a final output.", description="The filtering to apply to the detected masks before merging them into a final output.",
default="all", default="all",
) )
detection_threshold: float = InputField(
description="The detection threshold for the Grounding DINO model. All detected bounding boxes with scores above this threshold will be used.",
ge=0.0,
le=1.0,
default=0.3,
)
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
# The models expect a 3-channel RGB image. # The models expect a 3-channel RGB image.
image_pil = context.images.get_pil(self.image.image_name, mode="RGB") image_pil = context.images.get_pil(self.image.image_name, mode="RGB")
detections = self._detect( if len(self.bounding_boxes) == 0:
context=context, image=image_pil, labels=[self.prompt], threshold=self.detection_threshold
)
if len(detections) == 0:
combined_mask = np.zeros(image_pil.size[::-1], dtype=np.uint8) combined_mask = np.zeros(image_pil.size[::-1], dtype=np.uint8)
else: else:
detections = self._segment(context=context, image=image_pil, detection_results=detections) masks = self._segment(context=context, image=image_pil)
masks = self._filter_masks(masks=masks, bounding_boxes=self.bounding_boxes)
detections = self._filter_detections(detections)
masks = [detection.mask for detection in detections]
# masks contains binary values of 0 or 1, so we merge them via max-reduce. # masks contains binary values of 0 or 1, so we merge them via max-reduce.
combined_mask = np.maximum.reduce(masks) combined_mask = np.maximum.reduce(masks)
@ -84,19 +65,6 @@ class GroundedSAMInvocation(BaseInvocation):
image_dto = context.images.save(image=mask_pil) image_dto = context.images.save(image=mask_pil)
return ImageOutput.build(image_dto) return ImageOutput.build(image_dto)
@staticmethod
def _load_grounding_dino(model_path: Path):
grounding_dino_pipeline = pipeline(
model=str(model_path),
task="zero-shot-object-detection",
local_files_only=True,
# TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
# model, and figure out how to make it work in the pipeline.
# torch_dtype=TorchDevice.choose_torch_dtype(),
)
assert isinstance(grounding_dino_pipeline, ZeroShotObjectDetectionPipeline)
return GroundingDinoPipeline(grounding_dino_pipeline)
@staticmethod @staticmethod
def _load_sam_model(model_path: Path): def _load_sam_model(model_path: Path):
sam_model = AutoModelForMaskGeneration.from_pretrained( sam_model = AutoModelForMaskGeneration.from_pretrained(
@ -112,47 +80,28 @@ class GroundedSAMInvocation(BaseInvocation):
assert isinstance(sam_processor, SamProcessor) assert isinstance(sam_processor, SamProcessor)
return SegmentAnythingModel(sam_model=sam_model, sam_processor=sam_processor) return SegmentAnythingModel(sam_model=sam_model, sam_processor=sam_processor)
def _detect(
self,
context: InvocationContext,
image: Image.Image,
labels: list[str],
threshold: float = 0.3,
) -> list[DetectionResult]:
"""Use Grounding DINO to detect bounding boxes for a set of labels in an image."""
# TODO(ryand): I copied this "."-handling logic from the transformers example code. Test it and see if it
# actually makes a difference.
labels = [label if label.endswith(".") else label + "." for label in labels]
with context.models.load_remote_model(
source=GROUNDING_DINO_MODEL_ID, loader=GroundedSAMInvocation._load_grounding_dino
) as detector:
assert isinstance(detector, GroundingDinoPipeline)
return detector.detect(image=image, candidate_labels=labels, threshold=threshold)
def _segment( def _segment(
self, self,
context: InvocationContext, context: InvocationContext,
image: Image.Image, image: Image.Image,
detection_results: list[DetectionResult], ) -> list[npt.NDArray[np.uint8]]:
) -> list[DetectionResult]:
"""Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes.""" """Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes."""
# Convert the bounding boxes to the SAM input format.
sam_bounding_boxes = [[bb.x_min, bb.y_min, bb.x_max, bb.y_max] for bb in self.bounding_boxes]
with ( with (
context.models.load_remote_model( context.models.load_remote_model(
source=SEGMENT_ANYTHING_MODEL_ID, loader=GroundedSAMInvocation._load_sam_model source=SEGMENT_ANYTHING_MODEL_ID, loader=SegmentAnythingModelInvocation._load_sam_model
) as sam_pipeline, ) as sam_pipeline,
): ):
assert isinstance(sam_pipeline, SegmentAnythingModel) assert isinstance(sam_pipeline, SegmentAnythingModel)
masks = sam_pipeline.segment(image=image, detection_results=detection_results) masks = sam_pipeline.segment(image=image, bounding_boxes=sam_bounding_boxes)
masks = self._to_numpy_masks(masks) masks = self._to_numpy_masks(masks)
if self.apply_polygon_refinement: if self.apply_polygon_refinement:
masks = self._apply_polygon_refinement(masks) masks = self._apply_polygon_refinement(masks)
for detection_result, mask in zip(detection_results, masks, strict=True): return masks
detection_result.mask = mask
return detection_results
def _to_numpy_masks(self, masks: torch.Tensor) -> list[npt.NDArray[np.uint8]]: def _to_numpy_masks(self, masks: torch.Tensor) -> list[npt.NDArray[np.uint8]]:
"""Convert the tensor output from the Segment Anything model to a list of numpy masks.""" """Convert the tensor output from the Segment Anything model to a list of numpy masks."""
@ -181,15 +130,23 @@ class GroundedSAMInvocation(BaseInvocation):
return masks return masks
def _filter_detections(self, detections: list[DetectionResult]) -> list[DetectionResult]: def _filter_masks(
self, masks: list[npt.NDArray[np.uint8]], bounding_boxes: list[BoundingBoxField]
) -> list[npt.NDArray[np.uint8]]:
"""Filter the detected masks based on the specified mask filter.""" """Filter the detected masks based on the specified mask filter."""
assert len(masks) == len(bounding_boxes)
if self.mask_filter == "all": if self.mask_filter == "all":
return detections return masks
elif self.mask_filter == "largest": elif self.mask_filter == "largest":
# Find the largest mask. # Find the largest mask.
return [max(detections, key=lambda x: x.mask.sum())] return [max(masks, key=lambda x: x.sum())]
elif self.mask_filter == "highest_box_score": elif self.mask_filter == "highest_box_score":
# Find the detection with the highest box score. # Find the index of the bounding box with the highest score.
return [max(detections, key=lambda x: x.score)] # Note that we fallback to -1.0 if the score is None. This is mainly to satisfy the type checker. In most
# cases the scores should all be non-None when using this filtering mode. That being said, -1.0 is a
# reasonable fallback since the expected score range is [0.0, 1.0].
max_score_idx = max(range(len(bounding_boxes)), key=lambda i: bounding_boxes[i].score or -1.0)
return [masks[max_score_idx]]
else: else:
raise ValueError(f"Invalid mask filter: {self.mask_filter}") raise ValueError(f"Invalid mask filter: {self.mask_filter}")

View File

@ -1,6 +1,3 @@
from typing import Any, Optional
import numpy.typing as npt
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
@ -12,18 +9,13 @@ class BoundingBox(BaseModel):
xmax: int xmax: int
ymax: int ymax: int
def to_box(self) -> list[int]:
"""Convert to the array notation expected by SAM."""
return [self.xmin, self.ymin, self.xmax, self.ymax]
class DetectionResult(BaseModel): class DetectionResult(BaseModel):
"""Detection result from Grounding DINO or Grounded SAM.""" """Detection result from Grounding DINO."""
score: float score: float
label: str label: str
box: BoundingBox box: BoundingBox
mask: Optional[npt.NDArray[Any]] = None
model_config = ConfigDict( model_config = ConfigDict(
# Allow arbitrary types for mask, since it will be a numpy array. # Allow arbitrary types for mask, since it will be a numpy array.
arbitrary_types_allowed=True arbitrary_types_allowed=True

View File

@ -5,7 +5,6 @@ from PIL import Image
from transformers.models.sam import SamModel from transformers.models.sam import SamModel
from transformers.models.sam.processing_sam import SamProcessor from transformers.models.sam.processing_sam import SamProcessor
from invokeai.backend.image_util.grounded_sam.detection_result import DetectionResult
from invokeai.backend.raw_model import RawModel from invokeai.backend.raw_model import RawModel
@ -28,8 +27,19 @@ class SegmentAnythingModel(RawModel):
return calc_module_size(self._sam_model) return calc_module_size(self._sam_model)
def segment(self, image: Image.Image, detection_results: list[DetectionResult]) -> torch.Tensor: def segment(self, image: Image.Image, bounding_boxes: list[list[int]]) -> torch.Tensor:
boxes = self._to_box_array(detection_results) """Run the SAM model.
Args:
image (Image.Image): The image to segment.
bounding_boxes (list[list[int]]): The bounding box prompts. Each bounding box is in the format
[xmin, ymin, xmax, ymax].
Returns:
torch.Tensor: The segmentation masks. dtype: torch.bool. shape: [num_masks, channels, height, width].
"""
# Add batch dimension of 1 to the bounding boxes.
boxes = [bounding_boxes]
inputs = self._sam_processor(images=image, input_boxes=boxes, return_tensors="pt").to(self._sam_model.device) inputs = self._sam_processor(images=image, input_boxes=boxes, return_tensors="pt").to(self._sam_model.device)
outputs = self._sam_model(**inputs) outputs = self._sam_model(**inputs)
masks = self._sam_processor.post_process_masks( masks = self._sam_processor.post_process_masks(
@ -40,10 +50,4 @@ class SegmentAnythingModel(RawModel):
# There should be only one batch. # There should be only one batch.
assert len(masks) == 1 assert len(masks) == 1
masks = masks[0] return masks[0]
return masks
def _to_box_array(self, detection_results: list[DetectionResult]) -> list[list[list[int]]]:
"""Convert a list of DetectionResults to the bbox format expected by the Segment Anything model."""
boxes = [result.box.to_box() for result in detection_results]
return [boxes]