From ff6398f7d8efb62f8a74ed344727b00d5f5d422a Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Mon, 29 Jul 2024 13:53:14 -0400 Subject: [PATCH] Add a GroundedSamInvocation for image segmentation from a text prompt (Grounding DINO + Segment Anything Model). --- invokeai/app/invocations/grounded_sam.py | 197 ++++++++++++++++++ invokeai/backend/grounded_sam/__init__.py | 0 .../grounded_sam/grounding_dino_pipeline.py | 27 +++ .../backend/grounded_sam/mask_refinement.py | 50 +++++ .../grounded_sam/segment_anything_model.py | 35 ++++ .../backend/model_manager/load/model_util.py | 14 +- 6 files changed, 322 insertions(+), 1 deletion(-) create mode 100644 invokeai/app/invocations/grounded_sam.py create mode 100644 invokeai/backend/grounded_sam/__init__.py create mode 100644 invokeai/backend/grounded_sam/grounding_dino_pipeline.py create mode 100644 invokeai/backend/grounded_sam/mask_refinement.py create mode 100644 invokeai/backend/grounded_sam/segment_anything_model.py diff --git a/invokeai/app/invocations/grounded_sam.py b/invokeai/app/invocations/grounded_sam.py new file mode 100644 index 0000000000..5bfa5079e9 --- /dev/null +++ b/invokeai/app/invocations/grounded_sam.py @@ -0,0 +1,197 @@ +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Optional + +import numpy as np +import numpy.typing as npt +import torch +from PIL import Image +from transformers import AutoModelForMaskGeneration, AutoProcessor, pipeline +from transformers.models.sam import SamModel +from transformers.models.sam.processing_sam import SamProcessor +from transformers.pipelines import ZeroShotObjectDetectionPipeline + +from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation +from invokeai.app.invocations.fields import ImageField, InputField +from invokeai.app.invocations.primitives import ImageOutput +from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.grounded_sam.grounding_dino_pipeline import GroundingDinoPipeline +from invokeai.backend.grounded_sam.mask_refinement import mask_to_polygon, polygon_to_mask +from invokeai.backend.grounded_sam.segment_anything_model import SegmentAnythingModel + +GROUNDING_DINO_MODEL_ID = "IDEA-Research/grounding-dino-tiny" +SEGMENT_ANYTHING_MODEL_ID = "facebook/sam-vit-base" + + +@dataclass +class BoundingBox: + """Bounding box helper class used locally for the Grounding DINO outputs.""" + + xmin: int + ymin: int + xmax: int + ymax: int + + def to_box(self) -> list[int]: + """Convert to the array notation expected by SAM.""" + return [self.xmin, self.ymin, self.xmax, self.ymax] + + +@dataclass +class DetectionResult: + """Detection result from Grounding DINO or Grounded SAM.""" + + score: float + label: str + box: BoundingBox + mask: Optional[npt.NDArray[Any]] = None + + @classmethod + def from_dict(cls, detection_dict: dict[str, Any]): + return cls( + score=detection_dict["score"], + label=detection_dict["label"], + box=BoundingBox( + xmin=detection_dict["box"]["xmin"], + ymin=detection_dict["box"]["ymin"], + xmax=detection_dict["box"]["xmax"], + ymax=detection_dict["box"]["ymax"], + ), + ) + + +@invocation( + "grounded_segment_anything", + title="Segment Anything (Text Prompt)", + tags=["prompt", "segmentation"], + category="segmentation", + version="1.0.0", +) +class GroundedSAMInvocation(BaseInvocation): + """Runs Grounded-SAM, as proposed in https://arxiv.org/pdf/2401.14159. + + More specifically, a Grounding DINO model is run to obtain bounding boxes for a text prompt, then the bounding box + is passed as a prompt to a Segment Anything model to obtain a segmentation mask. + + Reference: + - https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam + - https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb + """ + + prompt: str = InputField(description="The prompt describing the object to segment.") + image: ImageField = InputField(description="The image to segment.") + apply_polygon_refinement: bool = InputField( + description="Whether to apply polygon refinement to the mask. This will smooth the edges of the mask slightly " + "and ensure that the mask consists of a single closed polygon.", + default=False, + ) + + def invoke(self, context: InvocationContext) -> ImageOutput: + image_pil = context.images.get_pil(self.image.image_name) + + detections = self._detect(context=context, image=image_pil, labels=[self.prompt]) + detections = self._segment(context=context, image=image_pil, detection_results=detections) + + # Extract ouput mask. + mask_np = detections[0].mask + assert mask_np is not None + # Map [0, 1] to [0, 255]. + mask_np = mask_np * 255 + mask_pil = Image.fromarray(mask_np) + + image_dto = context.images.save(image=mask_pil) + return ImageOutput.build(image_dto) + + def _to_box_array(self, detection_results: list[DetectionResult]) -> list[list[list[int]]]: + """Convert a list of DetectionResults to the format expected by the Segment Anything model. + + Args: + detection_results (list[DetectionResult]): The Grounding DINO detection results. + """ + boxes = [result.box.to_box() for result in detection_results] + return [boxes] + + def _detect( + self, + context: InvocationContext, + image: Image.Image, + labels: list[str], + threshold: float = 0.3, + ) -> list[DetectionResult]: + """Use Grounding DINO to detect bounding boxes for a set of labels in an image.""" + + def load_grounding_dino(model_path: Path): + grounding_dino_pipeline = pipeline( + model=str(model_path), + task="zero-shot-object-detection", + local_files_only=True, + # TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the + # model, and figure out how to make it work in the pipeline. + # torch_dtype=TorchDevice.choose_torch_dtype(), + ) + assert isinstance(grounding_dino_pipeline, ZeroShotObjectDetectionPipeline) + return GroundingDinoPipeline(grounding_dino_pipeline) + + with context.models.load_remote_model(source=GROUNDING_DINO_MODEL_ID, loader=load_grounding_dino) as detector: + assert isinstance(detector, GroundingDinoPipeline) + + # TODO(ryand): I copied this "."-handling logic from the transformers example code. Test it and see if it + # actually makes a difference. + labels = [label if label.endswith(".") else label + "." for label in labels] + + results = detector(image, candidate_labels=labels, threshold=threshold) + results = [DetectionResult.from_dict(result) for result in results] + return results + + def _segment( + self, + context: InvocationContext, + image: Image.Image, + detection_results: list[DetectionResult], + ) -> list[DetectionResult]: + """Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes.""" + + def load_sam_model(model_path: Path): + sam_model = AutoModelForMaskGeneration.from_pretrained( + model_path, + local_files_only=True, + # TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the + # model, and figure out how to make it work in the pipeline. + # torch_dtype=TorchDevice.choose_torch_dtype(), + ) + assert isinstance(sam_model, SamModel) + + sam_processor = AutoProcessor.from_pretrained(model_path, local_files_only=True) + assert isinstance(sam_processor, SamProcessor) + return SegmentAnythingModel(sam_model=sam_model, sam_processor=sam_processor) + + with ( + context.models.load_remote_model(source=SEGMENT_ANYTHING_MODEL_ID, loader=load_sam_model) as sam_pipeline, + ): + assert isinstance(sam_pipeline, SegmentAnythingModel) + + boxes = self._to_box_array(detection_results) + masks = sam_pipeline.segment(image=image, boxes=boxes) + masks = self._refine_masks(masks) + + for detection_result, mask in zip(detection_results, masks, strict=False): + detection_result.mask = mask + + return detection_results + + def _refine_masks(self, masks: torch.Tensor) -> list[npt.NDArray[np.uint8]]: + masks = masks.cpu().float() + masks = masks.permute(0, 2, 3, 1) + masks = masks.mean(axis=-1) + masks = (masks > 0).int() + masks = masks.numpy().astype(np.uint8) + masks = list(masks) + + if self.apply_polygon_refinement: + for idx, mask in enumerate(masks): + shape = mask.shape + polygon = mask_to_polygon(mask) + mask = polygon_to_mask(polygon, shape) + masks[idx] = mask + + return masks diff --git a/invokeai/backend/grounded_sam/__init__.py b/invokeai/backend/grounded_sam/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/invokeai/backend/grounded_sam/grounding_dino_pipeline.py b/invokeai/backend/grounded_sam/grounding_dino_pipeline.py new file mode 100644 index 0000000000..d6c1f6a377 --- /dev/null +++ b/invokeai/backend/grounded_sam/grounding_dino_pipeline.py @@ -0,0 +1,27 @@ +from typing import Optional + +import torch +from transformers.pipelines import ZeroShotObjectDetectionPipeline + + +class GroundingDinoPipeline: + """A wrapper class for a ZeroShotObjectDetectionPipeline that makes it compatible with the model manager's memory + management system. + """ + + def __init__(self, pipeline: ZeroShotObjectDetectionPipeline): + self._pipeline = pipeline + + def __call__(self, *args, **kwargs): + return self._pipeline(*args, **kwargs) + + def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> "GroundingDinoPipeline": + self._pipeline.model.to(device=device, dtype=dtype) + self._pipeline.device = self._pipeline.model.device + return self + + def calc_size(self) -> int: + # HACK(ryand): Fix the circular import issue. + from invokeai.backend.model_manager.load.model_util import calc_module_size + + return calc_module_size(self._pipeline.model) diff --git a/invokeai/backend/grounded_sam/mask_refinement.py b/invokeai/backend/grounded_sam/mask_refinement.py new file mode 100644 index 0000000000..2c8cf077d1 --- /dev/null +++ b/invokeai/backend/grounded_sam/mask_refinement.py @@ -0,0 +1,50 @@ +# This file contains utilities for Grounded-SAM mask refinement based on: +# https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb + + +import cv2 +import numpy as np +import numpy.typing as npt + + +def mask_to_polygon(mask: npt.NDArray[np.uint8]) -> list[tuple[int, int]]: + """Convert a binary mask to a polygon. + + Returns: + list[list[int]]: List of (x, y) coordinates representing the vertices of the polygon. + """ + # Find contours in the binary mask. + contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + + # Find the contour with the largest area. + largest_contour = max(contours, key=cv2.contourArea) + + # Extract the vertices of the contour. + polygon = largest_contour.reshape(-1, 2).tolist() + + return polygon + + +def polygon_to_mask( + polygon: list[tuple[int, int]], image_shape: tuple[int, int], fill_value: int = 1 +) -> npt.NDArray[np.uint8]: + """Convert a polygon to a segmentation mask. + + Args: + polygon (list): List of (x, y) coordinates representing the vertices of the polygon. + image_shape (tuple): Shape of the image (height, width) for the mask. + fill_value (int): Value to fill the polygon with. + + Returns: + np.ndarray: Segmentation mask with the polygon filled (with value 255). + """ + # Create an empty mask. + mask = np.zeros(image_shape, dtype=np.uint8) + + # Convert polygon to an array of points. + pts = np.array(polygon, dtype=np.int32) + + # Fill the polygon with white color (255). + cv2.fillPoly(mask, [pts], color=(fill_value,)) + + return mask diff --git a/invokeai/backend/grounded_sam/segment_anything_model.py b/invokeai/backend/grounded_sam/segment_anything_model.py new file mode 100644 index 0000000000..5072d1bc35 --- /dev/null +++ b/invokeai/backend/grounded_sam/segment_anything_model.py @@ -0,0 +1,35 @@ +from typing import Optional + +import torch +from PIL import Image +from transformers.models.sam import SamModel +from transformers.models.sam.processing_sam import SamProcessor + + +class SegmentAnythingModel: + """A wrapper class for the transformers SAM model and processor that makes it compatible with the model manager.""" + + def __init__(self, sam_model: SamModel, sam_processor: SamProcessor): + self._sam_model = sam_model + self._sam_processor = sam_processor + + def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> "SegmentAnythingModel": + self._sam_model.to(device=device, dtype=dtype) + return self + + def calc_size(self) -> int: + # HACK(ryand): Fix the circular import issue. + from invokeai.backend.model_manager.load.model_util import calc_module_size + + return calc_module_size(self._sam_model) + + def segment(self, image: Image.Image, boxes: list[list[list[int]]]) -> torch.Tensor: + inputs = self._sam_processor(images=image, input_boxes=boxes, return_tensors="pt").to(self._sam_model.device) + outputs = self._sam_model(**inputs) + masks = self._sam_processor.post_process_masks( + masks=outputs.pred_masks, + original_sizes=inputs.original_sizes, + reshaped_input_sizes=inputs.reshaped_input_sizes, + )[0] + + return masks diff --git a/invokeai/backend/model_manager/load/model_util.py b/invokeai/backend/model_manager/load/model_util.py index f070a42965..22d493f7a0 100644 --- a/invokeai/backend/model_manager/load/model_util.py +++ b/invokeai/backend/model_manager/load/model_util.py @@ -11,6 +11,8 @@ from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.schedulers.scheduling_utils import SchedulerMixin from transformers import CLIPTokenizer +from invokeai.backend.grounded_sam.grounding_dino_pipeline import GroundingDinoPipeline +from invokeai.backend.grounded_sam.segment_anything_model import SegmentAnythingModel from invokeai.backend.ip_adapter.ip_adapter import IPAdapter from invokeai.backend.lora import LoRAModelRaw from invokeai.backend.model_manager.config import AnyModel @@ -34,7 +36,17 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int: elif isinstance(model, CLIPTokenizer): # TODO(ryand): Accurately calculate the tokenizer's size. It's small enough that it shouldn't matter for now. return 0 - elif isinstance(model, (TextualInversionModelRaw, IPAdapter, LoRAModelRaw, SpandrelImageToImageModel)): + elif isinstance( + model, + ( + TextualInversionModelRaw, + IPAdapter, + LoRAModelRaw, + SpandrelImageToImageModel, + GroundingDinoPipeline, + SegmentAnythingModel, + ), + ): return model.calc_size() else: # TODO(ryand): Promote this from a log to an exception once we are confident that we are handling all of the