Add a GroundedSamInvocation for image segmentation from a text prompt (Grounding DINO + Segment Anything Model).

This commit is contained in:
Ryan Dick 2024-07-29 13:53:14 -04:00
parent 2ad13ac7eb
commit ff6398f7d8
6 changed files with 322 additions and 1 deletions

View File

@ -0,0 +1,197 @@
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Optional
import numpy as np
import numpy.typing as npt
import torch
from PIL import Image
from transformers import AutoModelForMaskGeneration, AutoProcessor, pipeline
from transformers.models.sam import SamModel
from transformers.models.sam.processing_sam import SamProcessor
from transformers.pipelines import ZeroShotObjectDetectionPipeline
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import ImageField, InputField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.grounded_sam.grounding_dino_pipeline import GroundingDinoPipeline
from invokeai.backend.grounded_sam.mask_refinement import mask_to_polygon, polygon_to_mask
from invokeai.backend.grounded_sam.segment_anything_model import SegmentAnythingModel
GROUNDING_DINO_MODEL_ID = "IDEA-Research/grounding-dino-tiny"
SEGMENT_ANYTHING_MODEL_ID = "facebook/sam-vit-base"
@dataclass
class BoundingBox:
"""Bounding box helper class used locally for the Grounding DINO outputs."""
xmin: int
ymin: int
xmax: int
ymax: int
def to_box(self) -> list[int]:
"""Convert to the array notation expected by SAM."""
return [self.xmin, self.ymin, self.xmax, self.ymax]
@dataclass
class DetectionResult:
"""Detection result from Grounding DINO or Grounded SAM."""
score: float
label: str
box: BoundingBox
mask: Optional[npt.NDArray[Any]] = None
@classmethod
def from_dict(cls, detection_dict: dict[str, Any]):
return cls(
score=detection_dict["score"],
label=detection_dict["label"],
box=BoundingBox(
xmin=detection_dict["box"]["xmin"],
ymin=detection_dict["box"]["ymin"],
xmax=detection_dict["box"]["xmax"],
ymax=detection_dict["box"]["ymax"],
),
)
@invocation(
"grounded_segment_anything",
title="Segment Anything (Text Prompt)",
tags=["prompt", "segmentation"],
category="segmentation",
version="1.0.0",
)
class GroundedSAMInvocation(BaseInvocation):
"""Runs Grounded-SAM, as proposed in https://arxiv.org/pdf/2401.14159.
More specifically, a Grounding DINO model is run to obtain bounding boxes for a text prompt, then the bounding box
is passed as a prompt to a Segment Anything model to obtain a segmentation mask.
Reference:
- https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam
- https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb
"""
prompt: str = InputField(description="The prompt describing the object to segment.")
image: ImageField = InputField(description="The image to segment.")
apply_polygon_refinement: bool = InputField(
description="Whether to apply polygon refinement to the mask. This will smooth the edges of the mask slightly "
"and ensure that the mask consists of a single closed polygon.",
default=False,
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image_pil = context.images.get_pil(self.image.image_name)
detections = self._detect(context=context, image=image_pil, labels=[self.prompt])
detections = self._segment(context=context, image=image_pil, detection_results=detections)
# Extract ouput mask.
mask_np = detections[0].mask
assert mask_np is not None
# Map [0, 1] to [0, 255].
mask_np = mask_np * 255
mask_pil = Image.fromarray(mask_np)
image_dto = context.images.save(image=mask_pil)
return ImageOutput.build(image_dto)
def _to_box_array(self, detection_results: list[DetectionResult]) -> list[list[list[int]]]:
"""Convert a list of DetectionResults to the format expected by the Segment Anything model.
Args:
detection_results (list[DetectionResult]): The Grounding DINO detection results.
"""
boxes = [result.box.to_box() for result in detection_results]
return [boxes]
def _detect(
self,
context: InvocationContext,
image: Image.Image,
labels: list[str],
threshold: float = 0.3,
) -> list[DetectionResult]:
"""Use Grounding DINO to detect bounding boxes for a set of labels in an image."""
def load_grounding_dino(model_path: Path):
grounding_dino_pipeline = pipeline(
model=str(model_path),
task="zero-shot-object-detection",
local_files_only=True,
# TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
# model, and figure out how to make it work in the pipeline.
# torch_dtype=TorchDevice.choose_torch_dtype(),
)
assert isinstance(grounding_dino_pipeline, ZeroShotObjectDetectionPipeline)
return GroundingDinoPipeline(grounding_dino_pipeline)
with context.models.load_remote_model(source=GROUNDING_DINO_MODEL_ID, loader=load_grounding_dino) as detector:
assert isinstance(detector, GroundingDinoPipeline)
# TODO(ryand): I copied this "."-handling logic from the transformers example code. Test it and see if it
# actually makes a difference.
labels = [label if label.endswith(".") else label + "." for label in labels]
results = detector(image, candidate_labels=labels, threshold=threshold)
results = [DetectionResult.from_dict(result) for result in results]
return results
def _segment(
self,
context: InvocationContext,
image: Image.Image,
detection_results: list[DetectionResult],
) -> list[DetectionResult]:
"""Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes."""
def load_sam_model(model_path: Path):
sam_model = AutoModelForMaskGeneration.from_pretrained(
model_path,
local_files_only=True,
# TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
# model, and figure out how to make it work in the pipeline.
# torch_dtype=TorchDevice.choose_torch_dtype(),
)
assert isinstance(sam_model, SamModel)
sam_processor = AutoProcessor.from_pretrained(model_path, local_files_only=True)
assert isinstance(sam_processor, SamProcessor)
return SegmentAnythingModel(sam_model=sam_model, sam_processor=sam_processor)
with (
context.models.load_remote_model(source=SEGMENT_ANYTHING_MODEL_ID, loader=load_sam_model) as sam_pipeline,
):
assert isinstance(sam_pipeline, SegmentAnythingModel)
boxes = self._to_box_array(detection_results)
masks = sam_pipeline.segment(image=image, boxes=boxes)
masks = self._refine_masks(masks)
for detection_result, mask in zip(detection_results, masks, strict=False):
detection_result.mask = mask
return detection_results
def _refine_masks(self, masks: torch.Tensor) -> list[npt.NDArray[np.uint8]]:
masks = masks.cpu().float()
masks = masks.permute(0, 2, 3, 1)
masks = masks.mean(axis=-1)
masks = (masks > 0).int()
masks = masks.numpy().astype(np.uint8)
masks = list(masks)
if self.apply_polygon_refinement:
for idx, mask in enumerate(masks):
shape = mask.shape
polygon = mask_to_polygon(mask)
mask = polygon_to_mask(polygon, shape)
masks[idx] = mask
return masks

View File

@ -0,0 +1,27 @@
from typing import Optional
import torch
from transformers.pipelines import ZeroShotObjectDetectionPipeline
class GroundingDinoPipeline:
"""A wrapper class for a ZeroShotObjectDetectionPipeline that makes it compatible with the model manager's memory
management system.
"""
def __init__(self, pipeline: ZeroShotObjectDetectionPipeline):
self._pipeline = pipeline
def __call__(self, *args, **kwargs):
return self._pipeline(*args, **kwargs)
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> "GroundingDinoPipeline":
self._pipeline.model.to(device=device, dtype=dtype)
self._pipeline.device = self._pipeline.model.device
return self
def calc_size(self) -> int:
# HACK(ryand): Fix the circular import issue.
from invokeai.backend.model_manager.load.model_util import calc_module_size
return calc_module_size(self._pipeline.model)

View File

@ -0,0 +1,50 @@
# This file contains utilities for Grounded-SAM mask refinement based on:
# https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb
import cv2
import numpy as np
import numpy.typing as npt
def mask_to_polygon(mask: npt.NDArray[np.uint8]) -> list[tuple[int, int]]:
"""Convert a binary mask to a polygon.
Returns:
list[list[int]]: List of (x, y) coordinates representing the vertices of the polygon.
"""
# Find contours in the binary mask.
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# Find the contour with the largest area.
largest_contour = max(contours, key=cv2.contourArea)
# Extract the vertices of the contour.
polygon = largest_contour.reshape(-1, 2).tolist()
return polygon
def polygon_to_mask(
polygon: list[tuple[int, int]], image_shape: tuple[int, int], fill_value: int = 1
) -> npt.NDArray[np.uint8]:
"""Convert a polygon to a segmentation mask.
Args:
polygon (list): List of (x, y) coordinates representing the vertices of the polygon.
image_shape (tuple): Shape of the image (height, width) for the mask.
fill_value (int): Value to fill the polygon with.
Returns:
np.ndarray: Segmentation mask with the polygon filled (with value 255).
"""
# Create an empty mask.
mask = np.zeros(image_shape, dtype=np.uint8)
# Convert polygon to an array of points.
pts = np.array(polygon, dtype=np.int32)
# Fill the polygon with white color (255).
cv2.fillPoly(mask, [pts], color=(fill_value,))
return mask

View File

@ -0,0 +1,35 @@
from typing import Optional
import torch
from PIL import Image
from transformers.models.sam import SamModel
from transformers.models.sam.processing_sam import SamProcessor
class SegmentAnythingModel:
"""A wrapper class for the transformers SAM model and processor that makes it compatible with the model manager."""
def __init__(self, sam_model: SamModel, sam_processor: SamProcessor):
self._sam_model = sam_model
self._sam_processor = sam_processor
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> "SegmentAnythingModel":
self._sam_model.to(device=device, dtype=dtype)
return self
def calc_size(self) -> int:
# HACK(ryand): Fix the circular import issue.
from invokeai.backend.model_manager.load.model_util import calc_module_size
return calc_module_size(self._sam_model)
def segment(self, image: Image.Image, boxes: list[list[list[int]]]) -> torch.Tensor:
inputs = self._sam_processor(images=image, input_boxes=boxes, return_tensors="pt").to(self._sam_model.device)
outputs = self._sam_model(**inputs)
masks = self._sam_processor.post_process_masks(
masks=outputs.pred_masks,
original_sizes=inputs.original_sizes,
reshaped_input_sizes=inputs.reshaped_input_sizes,
)[0]
return masks

View File

@ -11,6 +11,8 @@ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.schedulers.scheduling_utils import SchedulerMixin from diffusers.schedulers.scheduling_utils import SchedulerMixin
from transformers import CLIPTokenizer from transformers import CLIPTokenizer
from invokeai.backend.grounded_sam.grounding_dino_pipeline import GroundingDinoPipeline
from invokeai.backend.grounded_sam.segment_anything_model import SegmentAnythingModel
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.lora import LoRAModelRaw from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_manager.config import AnyModel from invokeai.backend.model_manager.config import AnyModel
@ -34,7 +36,17 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
elif isinstance(model, CLIPTokenizer): elif isinstance(model, CLIPTokenizer):
# TODO(ryand): Accurately calculate the tokenizer's size. It's small enough that it shouldn't matter for now. # TODO(ryand): Accurately calculate the tokenizer's size. It's small enough that it shouldn't matter for now.
return 0 return 0
elif isinstance(model, (TextualInversionModelRaw, IPAdapter, LoRAModelRaw, SpandrelImageToImageModel)): elif isinstance(
model,
(
TextualInversionModelRaw,
IPAdapter,
LoRAModelRaw,
SpandrelImageToImageModel,
GroundingDinoPipeline,
SegmentAnythingModel,
),
):
return model.calc_size() return model.calc_size()
else: else:
# TODO(ryand): Promote this from a log to an exception once we are confident that we are handling all of the # TODO(ryand): Promote this from a log to an exception once we are confident that we are handling all of the