From ff6398f7d8efb62f8a74ed344727b00d5f5d422a Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Mon, 29 Jul 2024 13:53:14 -0400 Subject: [PATCH 01/28] Add a GroundedSamInvocation for image segmentation from a text prompt (Grounding DINO + Segment Anything Model). --- invokeai/app/invocations/grounded_sam.py | 197 ++++++++++++++++++ invokeai/backend/grounded_sam/__init__.py | 0 .../grounded_sam/grounding_dino_pipeline.py | 27 +++ .../backend/grounded_sam/mask_refinement.py | 50 +++++ .../grounded_sam/segment_anything_model.py | 35 ++++ .../backend/model_manager/load/model_util.py | 14 +- 6 files changed, 322 insertions(+), 1 deletion(-) create mode 100644 invokeai/app/invocations/grounded_sam.py create mode 100644 invokeai/backend/grounded_sam/__init__.py create mode 100644 invokeai/backend/grounded_sam/grounding_dino_pipeline.py create mode 100644 invokeai/backend/grounded_sam/mask_refinement.py create mode 100644 invokeai/backend/grounded_sam/segment_anything_model.py diff --git a/invokeai/app/invocations/grounded_sam.py b/invokeai/app/invocations/grounded_sam.py new file mode 100644 index 0000000000..5bfa5079e9 --- /dev/null +++ b/invokeai/app/invocations/grounded_sam.py @@ -0,0 +1,197 @@ +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Optional + +import numpy as np +import numpy.typing as npt +import torch +from PIL import Image +from transformers import AutoModelForMaskGeneration, AutoProcessor, pipeline +from transformers.models.sam import SamModel +from transformers.models.sam.processing_sam import SamProcessor +from transformers.pipelines import ZeroShotObjectDetectionPipeline + +from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation +from invokeai.app.invocations.fields import ImageField, InputField +from invokeai.app.invocations.primitives import ImageOutput +from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.grounded_sam.grounding_dino_pipeline import GroundingDinoPipeline +from invokeai.backend.grounded_sam.mask_refinement import mask_to_polygon, polygon_to_mask +from invokeai.backend.grounded_sam.segment_anything_model import SegmentAnythingModel + +GROUNDING_DINO_MODEL_ID = "IDEA-Research/grounding-dino-tiny" +SEGMENT_ANYTHING_MODEL_ID = "facebook/sam-vit-base" + + +@dataclass +class BoundingBox: + """Bounding box helper class used locally for the Grounding DINO outputs.""" + + xmin: int + ymin: int + xmax: int + ymax: int + + def to_box(self) -> list[int]: + """Convert to the array notation expected by SAM.""" + return [self.xmin, self.ymin, self.xmax, self.ymax] + + +@dataclass +class DetectionResult: + """Detection result from Grounding DINO or Grounded SAM.""" + + score: float + label: str + box: BoundingBox + mask: Optional[npt.NDArray[Any]] = None + + @classmethod + def from_dict(cls, detection_dict: dict[str, Any]): + return cls( + score=detection_dict["score"], + label=detection_dict["label"], + box=BoundingBox( + xmin=detection_dict["box"]["xmin"], + ymin=detection_dict["box"]["ymin"], + xmax=detection_dict["box"]["xmax"], + ymax=detection_dict["box"]["ymax"], + ), + ) + + +@invocation( + "grounded_segment_anything", + title="Segment Anything (Text Prompt)", + tags=["prompt", "segmentation"], + category="segmentation", + version="1.0.0", +) +class GroundedSAMInvocation(BaseInvocation): + """Runs Grounded-SAM, as proposed in https://arxiv.org/pdf/2401.14159. + + More specifically, a Grounding DINO model is run to obtain bounding boxes for a text prompt, then the bounding box + is passed as a prompt to a Segment Anything model to obtain a segmentation mask. + + Reference: + - https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam + - https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb + """ + + prompt: str = InputField(description="The prompt describing the object to segment.") + image: ImageField = InputField(description="The image to segment.") + apply_polygon_refinement: bool = InputField( + description="Whether to apply polygon refinement to the mask. This will smooth the edges of the mask slightly " + "and ensure that the mask consists of a single closed polygon.", + default=False, + ) + + def invoke(self, context: InvocationContext) -> ImageOutput: + image_pil = context.images.get_pil(self.image.image_name) + + detections = self._detect(context=context, image=image_pil, labels=[self.prompt]) + detections = self._segment(context=context, image=image_pil, detection_results=detections) + + # Extract ouput mask. + mask_np = detections[0].mask + assert mask_np is not None + # Map [0, 1] to [0, 255]. + mask_np = mask_np * 255 + mask_pil = Image.fromarray(mask_np) + + image_dto = context.images.save(image=mask_pil) + return ImageOutput.build(image_dto) + + def _to_box_array(self, detection_results: list[DetectionResult]) -> list[list[list[int]]]: + """Convert a list of DetectionResults to the format expected by the Segment Anything model. + + Args: + detection_results (list[DetectionResult]): The Grounding DINO detection results. + """ + boxes = [result.box.to_box() for result in detection_results] + return [boxes] + + def _detect( + self, + context: InvocationContext, + image: Image.Image, + labels: list[str], + threshold: float = 0.3, + ) -> list[DetectionResult]: + """Use Grounding DINO to detect bounding boxes for a set of labels in an image.""" + + def load_grounding_dino(model_path: Path): + grounding_dino_pipeline = pipeline( + model=str(model_path), + task="zero-shot-object-detection", + local_files_only=True, + # TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the + # model, and figure out how to make it work in the pipeline. + # torch_dtype=TorchDevice.choose_torch_dtype(), + ) + assert isinstance(grounding_dino_pipeline, ZeroShotObjectDetectionPipeline) + return GroundingDinoPipeline(grounding_dino_pipeline) + + with context.models.load_remote_model(source=GROUNDING_DINO_MODEL_ID, loader=load_grounding_dino) as detector: + assert isinstance(detector, GroundingDinoPipeline) + + # TODO(ryand): I copied this "."-handling logic from the transformers example code. Test it and see if it + # actually makes a difference. + labels = [label if label.endswith(".") else label + "." for label in labels] + + results = detector(image, candidate_labels=labels, threshold=threshold) + results = [DetectionResult.from_dict(result) for result in results] + return results + + def _segment( + self, + context: InvocationContext, + image: Image.Image, + detection_results: list[DetectionResult], + ) -> list[DetectionResult]: + """Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes.""" + + def load_sam_model(model_path: Path): + sam_model = AutoModelForMaskGeneration.from_pretrained( + model_path, + local_files_only=True, + # TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the + # model, and figure out how to make it work in the pipeline. + # torch_dtype=TorchDevice.choose_torch_dtype(), + ) + assert isinstance(sam_model, SamModel) + + sam_processor = AutoProcessor.from_pretrained(model_path, local_files_only=True) + assert isinstance(sam_processor, SamProcessor) + return SegmentAnythingModel(sam_model=sam_model, sam_processor=sam_processor) + + with ( + context.models.load_remote_model(source=SEGMENT_ANYTHING_MODEL_ID, loader=load_sam_model) as sam_pipeline, + ): + assert isinstance(sam_pipeline, SegmentAnythingModel) + + boxes = self._to_box_array(detection_results) + masks = sam_pipeline.segment(image=image, boxes=boxes) + masks = self._refine_masks(masks) + + for detection_result, mask in zip(detection_results, masks, strict=False): + detection_result.mask = mask + + return detection_results + + def _refine_masks(self, masks: torch.Tensor) -> list[npt.NDArray[np.uint8]]: + masks = masks.cpu().float() + masks = masks.permute(0, 2, 3, 1) + masks = masks.mean(axis=-1) + masks = (masks > 0).int() + masks = masks.numpy().astype(np.uint8) + masks = list(masks) + + if self.apply_polygon_refinement: + for idx, mask in enumerate(masks): + shape = mask.shape + polygon = mask_to_polygon(mask) + mask = polygon_to_mask(polygon, shape) + masks[idx] = mask + + return masks diff --git a/invokeai/backend/grounded_sam/__init__.py b/invokeai/backend/grounded_sam/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/invokeai/backend/grounded_sam/grounding_dino_pipeline.py b/invokeai/backend/grounded_sam/grounding_dino_pipeline.py new file mode 100644 index 0000000000..d6c1f6a377 --- /dev/null +++ b/invokeai/backend/grounded_sam/grounding_dino_pipeline.py @@ -0,0 +1,27 @@ +from typing import Optional + +import torch +from transformers.pipelines import ZeroShotObjectDetectionPipeline + + +class GroundingDinoPipeline: + """A wrapper class for a ZeroShotObjectDetectionPipeline that makes it compatible with the model manager's memory + management system. + """ + + def __init__(self, pipeline: ZeroShotObjectDetectionPipeline): + self._pipeline = pipeline + + def __call__(self, *args, **kwargs): + return self._pipeline(*args, **kwargs) + + def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> "GroundingDinoPipeline": + self._pipeline.model.to(device=device, dtype=dtype) + self._pipeline.device = self._pipeline.model.device + return self + + def calc_size(self) -> int: + # HACK(ryand): Fix the circular import issue. + from invokeai.backend.model_manager.load.model_util import calc_module_size + + return calc_module_size(self._pipeline.model) diff --git a/invokeai/backend/grounded_sam/mask_refinement.py b/invokeai/backend/grounded_sam/mask_refinement.py new file mode 100644 index 0000000000..2c8cf077d1 --- /dev/null +++ b/invokeai/backend/grounded_sam/mask_refinement.py @@ -0,0 +1,50 @@ +# This file contains utilities for Grounded-SAM mask refinement based on: +# https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb + + +import cv2 +import numpy as np +import numpy.typing as npt + + +def mask_to_polygon(mask: npt.NDArray[np.uint8]) -> list[tuple[int, int]]: + """Convert a binary mask to a polygon. + + Returns: + list[list[int]]: List of (x, y) coordinates representing the vertices of the polygon. + """ + # Find contours in the binary mask. + contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + + # Find the contour with the largest area. + largest_contour = max(contours, key=cv2.contourArea) + + # Extract the vertices of the contour. + polygon = largest_contour.reshape(-1, 2).tolist() + + return polygon + + +def polygon_to_mask( + polygon: list[tuple[int, int]], image_shape: tuple[int, int], fill_value: int = 1 +) -> npt.NDArray[np.uint8]: + """Convert a polygon to a segmentation mask. + + Args: + polygon (list): List of (x, y) coordinates representing the vertices of the polygon. + image_shape (tuple): Shape of the image (height, width) for the mask. + fill_value (int): Value to fill the polygon with. + + Returns: + np.ndarray: Segmentation mask with the polygon filled (with value 255). + """ + # Create an empty mask. + mask = np.zeros(image_shape, dtype=np.uint8) + + # Convert polygon to an array of points. + pts = np.array(polygon, dtype=np.int32) + + # Fill the polygon with white color (255). + cv2.fillPoly(mask, [pts], color=(fill_value,)) + + return mask diff --git a/invokeai/backend/grounded_sam/segment_anything_model.py b/invokeai/backend/grounded_sam/segment_anything_model.py new file mode 100644 index 0000000000..5072d1bc35 --- /dev/null +++ b/invokeai/backend/grounded_sam/segment_anything_model.py @@ -0,0 +1,35 @@ +from typing import Optional + +import torch +from PIL import Image +from transformers.models.sam import SamModel +from transformers.models.sam.processing_sam import SamProcessor + + +class SegmentAnythingModel: + """A wrapper class for the transformers SAM model and processor that makes it compatible with the model manager.""" + + def __init__(self, sam_model: SamModel, sam_processor: SamProcessor): + self._sam_model = sam_model + self._sam_processor = sam_processor + + def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> "SegmentAnythingModel": + self._sam_model.to(device=device, dtype=dtype) + return self + + def calc_size(self) -> int: + # HACK(ryand): Fix the circular import issue. + from invokeai.backend.model_manager.load.model_util import calc_module_size + + return calc_module_size(self._sam_model) + + def segment(self, image: Image.Image, boxes: list[list[list[int]]]) -> torch.Tensor: + inputs = self._sam_processor(images=image, input_boxes=boxes, return_tensors="pt").to(self._sam_model.device) + outputs = self._sam_model(**inputs) + masks = self._sam_processor.post_process_masks( + masks=outputs.pred_masks, + original_sizes=inputs.original_sizes, + reshaped_input_sizes=inputs.reshaped_input_sizes, + )[0] + + return masks diff --git a/invokeai/backend/model_manager/load/model_util.py b/invokeai/backend/model_manager/load/model_util.py index f070a42965..22d493f7a0 100644 --- a/invokeai/backend/model_manager/load/model_util.py +++ b/invokeai/backend/model_manager/load/model_util.py @@ -11,6 +11,8 @@ from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.schedulers.scheduling_utils import SchedulerMixin from transformers import CLIPTokenizer +from invokeai.backend.grounded_sam.grounding_dino_pipeline import GroundingDinoPipeline +from invokeai.backend.grounded_sam.segment_anything_model import SegmentAnythingModel from invokeai.backend.ip_adapter.ip_adapter import IPAdapter from invokeai.backend.lora import LoRAModelRaw from invokeai.backend.model_manager.config import AnyModel @@ -34,7 +36,17 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int: elif isinstance(model, CLIPTokenizer): # TODO(ryand): Accurately calculate the tokenizer's size. It's small enough that it shouldn't matter for now. return 0 - elif isinstance(model, (TextualInversionModelRaw, IPAdapter, LoRAModelRaw, SpandrelImageToImageModel)): + elif isinstance( + model, + ( + TextualInversionModelRaw, + IPAdapter, + LoRAModelRaw, + SpandrelImageToImageModel, + GroundingDinoPipeline, + SegmentAnythingModel, + ), + ): return model.calc_size() else: # TODO(ryand): Promote this from a log to an exception once we are confident that we are handling all of the From aca2a2fa131afafc15713a9a00c7a7dcf52f470d Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Tue, 30 Jul 2024 14:22:40 -0400 Subject: [PATCH 02/28] Add mask_filter and detection_threshold options to the GroundedSAMInvocation. --- invokeai/app/invocations/grounded_sam.py | 98 ++++++++++++++----- .../grounded_sam/grounding_dino_pipeline.py | 5 +- .../grounded_sam/segment_anything_model.py | 5 +- 3 files changed, 83 insertions(+), 25 deletions(-) diff --git a/invokeai/app/invocations/grounded_sam.py b/invokeai/app/invocations/grounded_sam.py index 5bfa5079e9..211016fac4 100644 --- a/invokeai/app/invocations/grounded_sam.py +++ b/invokeai/app/invocations/grounded_sam.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from pathlib import Path -from typing import Any, Optional +from typing import Any, Literal, Optional import numpy as np import numpy.typing as npt @@ -70,8 +70,8 @@ class DetectionResult: class GroundedSAMInvocation(BaseInvocation): """Runs Grounded-SAM, as proposed in https://arxiv.org/pdf/2401.14159. - More specifically, a Grounding DINO model is run to obtain bounding boxes for a text prompt, then the bounding box - is passed as a prompt to a Segment Anything model to obtain a segmentation mask. + More specifically, a Grounding DINO model is run to obtain bounding boxes for a text prompt, then the bounding boxes + are passed as a prompt to a Segment Anything model to obtain a segmentation mask. Reference: - https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam @@ -81,22 +81,38 @@ class GroundedSAMInvocation(BaseInvocation): prompt: str = InputField(description="The prompt describing the object to segment.") image: ImageField = InputField(description="The image to segment.") apply_polygon_refinement: bool = InputField( - description="Whether to apply polygon refinement to the mask. This will smooth the edges of the mask slightly " - "and ensure that the mask consists of a single closed polygon.", - default=False, + description="Whether to apply polygon refinement to the masks. This will smooth the edges of the mask slightly and ensure that each mask consists of a single closed polygon (before merging).", + default=True, + ) + mask_filter: Literal["all", "largest", "highest_box_score"] = InputField( + description="The filtering to apply to the detected masks before merging them into a final output.", + default="all", + ) + detection_threshold: float = InputField( + description="The detection threshold for the Grounding DINO model. All detected bounding boxes with scores above this threshold will be used.", + ge=0.0, + le=1.0, + default=0.1, ) def invoke(self, context: InvocationContext) -> ImageOutput: image_pil = context.images.get_pil(self.image.image_name) - detections = self._detect(context=context, image=image_pil, labels=[self.prompt]) - detections = self._segment(context=context, image=image_pil, detection_results=detections) + detections = self._detect( + context=context, image=image_pil, labels=[self.prompt], threshold=self.detection_threshold + ) + + if len(detections) == 0: + combined_mask = np.zeros(image_pil.size[::-1], dtype=np.uint8) + else: + detections = self._segment(context=context, image=image_pil, detection_results=detections) + + detections = self._filter_detections(detections) + masks = [detection.mask for detection in detections] + combined_mask = self._merge_masks(masks) - # Extract ouput mask. - mask_np = detections[0].mask - assert mask_np is not None # Map [0, 1] to [0, 255]. - mask_np = mask_np * 255 + mask_np = combined_mask * 255 mask_pil = Image.fromarray(mask_np) image_dto = context.images.save(image=mask_pil) @@ -119,6 +135,9 @@ class GroundedSAMInvocation(BaseInvocation): threshold: float = 0.3, ) -> list[DetectionResult]: """Use Grounding DINO to detect bounding boxes for a set of labels in an image.""" + # TODO(ryand): I copied this "."-handling logic from the transformers example code. Test it and see if it + # actually makes a difference. + labels = [label if label.endswith(".") else label + "." for label in labels] def load_grounding_dino(model_path: Path): grounding_dino_pipeline = pipeline( @@ -135,11 +154,7 @@ class GroundedSAMInvocation(BaseInvocation): with context.models.load_remote_model(source=GROUNDING_DINO_MODEL_ID, loader=load_grounding_dino) as detector: assert isinstance(detector, GroundingDinoPipeline) - # TODO(ryand): I copied this "."-handling logic from the transformers example code. Test it and see if it - # actually makes a difference. - labels = [label if label.endswith(".") else label + "." for label in labels] - - results = detector(image, candidate_labels=labels, threshold=threshold) + results = detector.detect(image=image, candidate_labels=labels, threshold=threshold) results = [DetectionResult.from_dict(result) for result in results] return results @@ -172,21 +187,34 @@ class GroundedSAMInvocation(BaseInvocation): boxes = self._to_box_array(detection_results) masks = sam_pipeline.segment(image=image, boxes=boxes) - masks = self._refine_masks(masks) - for detection_result, mask in zip(detection_results, masks, strict=False): - detection_result.mask = mask + masks = self._to_numpy_masks(masks) + masks = self._apply_polygon_refinement(masks) - return detection_results + for detection_result, mask in zip(detection_results, masks, strict=False): + detection_result.mask = mask - def _refine_masks(self, masks: torch.Tensor) -> list[npt.NDArray[np.uint8]]: + return detection_results + + def _to_numpy_masks(self, masks: torch.Tensor) -> list[npt.NDArray[np.uint8]]: + """Convert the tensor output from the Segment Anything model to a list of numpy masks.""" masks = masks.cpu().float() masks = masks.permute(0, 2, 3, 1) masks = masks.mean(axis=-1) masks = (masks > 0).int() masks = masks.numpy().astype(np.uint8) masks = list(masks) + return masks + def _apply_polygon_refinement(self, masks: list[npt.NDArray[np.uint8]]) -> list[npt.NDArray[np.uint8]]: + """Apply polygon refinement to the masks. + + Convert each mask to a polygon, then back to a mask. This has the following effect: + - Smooth the edges of the mask slightly. + - Ensure that each mask consists of a single closed polygon + - Removes small mask pieces. + - Removes holes from the mask. + """ if self.apply_polygon_refinement: for idx, mask in enumerate(masks): shape = mask.shape @@ -195,3 +223,29 @@ class GroundedSAMInvocation(BaseInvocation): masks[idx] = mask return masks + + def _filter_detections(self, detections: list[DetectionResult]) -> list[DetectionResult]: + """Filter the detected masks based on the specified mask filter.""" + if self.mask_filter == "all": + return detections + elif self.mask_filter == "largest": + # Find the largest mask. + mask_areas = [detection.mask.sum() for detection in detections] + largest_mask_idx = mask_areas.index(max(mask_areas)) + return [detections[largest_mask_idx]] + elif self.mask_filter == "highest_box_score": + # Find the detection with the highest box score. + max_score_detection = detections[0] + for detection in detections: + if detection.score > max_score_detection.score: + max_score_detection = detection + return [max_score_detection] + else: + raise ValueError(f"Invalid mask filter: {self.mask_filter}") + + def _merge_masks(self, masks: list[npt.NDArray[np.uint8]]) -> npt.NDArray[np.uint8]: + """Merge multiple masks into a single mask.""" + # Merge all masks together. + stacked_mask = np.stack(masks, axis=0) + combined_mask = np.max(stacked_mask, axis=0) + return combined_mask diff --git a/invokeai/backend/grounded_sam/grounding_dino_pipeline.py b/invokeai/backend/grounded_sam/grounding_dino_pipeline.py index d6c1f6a377..0143e300bc 100644 --- a/invokeai/backend/grounded_sam/grounding_dino_pipeline.py +++ b/invokeai/backend/grounded_sam/grounding_dino_pipeline.py @@ -1,6 +1,7 @@ from typing import Optional import torch +from PIL import Image from transformers.pipelines import ZeroShotObjectDetectionPipeline @@ -12,8 +13,8 @@ class GroundingDinoPipeline: def __init__(self, pipeline: ZeroShotObjectDetectionPipeline): self._pipeline = pipeline - def __call__(self, *args, **kwargs): - return self._pipeline(*args, **kwargs) + def detect(self, image: Image.Image, candidate_labels: list[str], threshold: float = 0.1): + return self._pipeline(image=image, candidate_labels=candidate_labels, threshold=threshold) def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> "GroundingDinoPipeline": self._pipeline.model.to(device=device, dtype=dtype) diff --git a/invokeai/backend/grounded_sam/segment_anything_model.py b/invokeai/backend/grounded_sam/segment_anything_model.py index 5072d1bc35..1c5b1cf456 100644 --- a/invokeai/backend/grounded_sam/segment_anything_model.py +++ b/invokeai/backend/grounded_sam/segment_anything_model.py @@ -30,6 +30,9 @@ class SegmentAnythingModel: masks=outputs.pred_masks, original_sizes=inputs.original_sizes, reshaped_input_sizes=inputs.reshaped_input_sizes, - )[0] + ) + # There should be only one batch. + assert len(masks) == 1 + masks = masks[0] return masks From 918f77bce015e9734bf7de1a3c2d8c2a05952945 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Tue, 30 Jul 2024 15:34:33 -0400 Subject: [PATCH 03/28] Move some logic from GroundedSAMInvocation to the backend classes. --- invokeai/app/invocations/grounded_sam.py | 63 +++---------------- .../grounded_sam/grounding_dino_pipeline.py | 8 ++- .../grounded_sam/segment_anything_model.py | 10 ++- 3 files changed, 22 insertions(+), 59 deletions(-) diff --git a/invokeai/app/invocations/grounded_sam.py b/invokeai/app/invocations/grounded_sam.py index 211016fac4..8eb8770e47 100644 --- a/invokeai/app/invocations/grounded_sam.py +++ b/invokeai/app/invocations/grounded_sam.py @@ -1,6 +1,5 @@ -from dataclasses import dataclass from pathlib import Path -from typing import Any, Literal, Optional +from typing import Literal import numpy as np import numpy.typing as npt @@ -15,6 +14,7 @@ from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation from invokeai.app.invocations.fields import ImageField, InputField from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.grounded_sam.detection_result import DetectionResult from invokeai.backend.grounded_sam.grounding_dino_pipeline import GroundingDinoPipeline from invokeai.backend.grounded_sam.mask_refinement import mask_to_polygon, polygon_to_mask from invokeai.backend.grounded_sam.segment_anything_model import SegmentAnythingModel @@ -23,43 +23,6 @@ GROUNDING_DINO_MODEL_ID = "IDEA-Research/grounding-dino-tiny" SEGMENT_ANYTHING_MODEL_ID = "facebook/sam-vit-base" -@dataclass -class BoundingBox: - """Bounding box helper class used locally for the Grounding DINO outputs.""" - - xmin: int - ymin: int - xmax: int - ymax: int - - def to_box(self) -> list[int]: - """Convert to the array notation expected by SAM.""" - return [self.xmin, self.ymin, self.xmax, self.ymax] - - -@dataclass -class DetectionResult: - """Detection result from Grounding DINO or Grounded SAM.""" - - score: float - label: str - box: BoundingBox - mask: Optional[npt.NDArray[Any]] = None - - @classmethod - def from_dict(cls, detection_dict: dict[str, Any]): - return cls( - score=detection_dict["score"], - label=detection_dict["label"], - box=BoundingBox( - xmin=detection_dict["box"]["xmin"], - ymin=detection_dict["box"]["ymin"], - xmax=detection_dict["box"]["xmax"], - ymax=detection_dict["box"]["ymax"], - ), - ) - - @invocation( "grounded_segment_anything", title="Segment Anything (Text Prompt)", @@ -92,9 +55,10 @@ class GroundedSAMInvocation(BaseInvocation): description="The detection threshold for the Grounding DINO model. All detected bounding boxes with scores above this threshold will be used.", ge=0.0, le=1.0, - default=0.1, + default=0.3, ) + @torch.no_grad() def invoke(self, context: InvocationContext) -> ImageOutput: image_pil = context.images.get_pil(self.image.image_name) @@ -118,15 +82,6 @@ class GroundedSAMInvocation(BaseInvocation): image_dto = context.images.save(image=mask_pil) return ImageOutput.build(image_dto) - def _to_box_array(self, detection_results: list[DetectionResult]) -> list[list[list[int]]]: - """Convert a list of DetectionResults to the format expected by the Segment Anything model. - - Args: - detection_results (list[DetectionResult]): The Grounding DINO detection results. - """ - boxes = [result.box.to_box() for result in detection_results] - return [boxes] - def _detect( self, context: InvocationContext, @@ -153,10 +108,7 @@ class GroundedSAMInvocation(BaseInvocation): with context.models.load_remote_model(source=GROUNDING_DINO_MODEL_ID, loader=load_grounding_dino) as detector: assert isinstance(detector, GroundingDinoPipeline) - - results = detector.detect(image=image, candidate_labels=labels, threshold=threshold) - results = [DetectionResult.from_dict(result) for result in results] - return results + return detector.detect(image=image, candidate_labels=labels, threshold=threshold) def _segment( self, @@ -185,8 +137,7 @@ class GroundedSAMInvocation(BaseInvocation): ): assert isinstance(sam_pipeline, SegmentAnythingModel) - boxes = self._to_box_array(detection_results) - masks = sam_pipeline.segment(image=image, boxes=boxes) + masks = sam_pipeline.segment(image=image, detection_results=detection_results) masks = self._to_numpy_masks(masks) masks = self._apply_polygon_refinement(masks) @@ -200,7 +151,7 @@ class GroundedSAMInvocation(BaseInvocation): """Convert the tensor output from the Segment Anything model to a list of numpy masks.""" masks = masks.cpu().float() masks = masks.permute(0, 2, 3, 1) - masks = masks.mean(axis=-1) + masks = masks.mean(dim=-1) masks = (masks > 0).int() masks = masks.numpy().astype(np.uint8) masks = list(masks) diff --git a/invokeai/backend/grounded_sam/grounding_dino_pipeline.py b/invokeai/backend/grounded_sam/grounding_dino_pipeline.py index 0143e300bc..028aeb283d 100644 --- a/invokeai/backend/grounded_sam/grounding_dino_pipeline.py +++ b/invokeai/backend/grounded_sam/grounding_dino_pipeline.py @@ -4,6 +4,8 @@ import torch from PIL import Image from transformers.pipelines import ZeroShotObjectDetectionPipeline +from invokeai.backend.grounded_sam.detection_result import DetectionResult + class GroundingDinoPipeline: """A wrapper class for a ZeroShotObjectDetectionPipeline that makes it compatible with the model manager's memory @@ -13,8 +15,10 @@ class GroundingDinoPipeline: def __init__(self, pipeline: ZeroShotObjectDetectionPipeline): self._pipeline = pipeline - def detect(self, image: Image.Image, candidate_labels: list[str], threshold: float = 0.1): - return self._pipeline(image=image, candidate_labels=candidate_labels, threshold=threshold) + def detect(self, image: Image.Image, candidate_labels: list[str], threshold: float = 0.1) -> list[DetectionResult]: + results = self._pipeline(image=image, candidate_labels=candidate_labels, threshold=threshold) + results = [DetectionResult.from_dict(result) for result in results] + return results def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> "GroundingDinoPipeline": self._pipeline.model.to(device=device, dtype=dtype) diff --git a/invokeai/backend/grounded_sam/segment_anything_model.py b/invokeai/backend/grounded_sam/segment_anything_model.py index 1c5b1cf456..86106b869c 100644 --- a/invokeai/backend/grounded_sam/segment_anything_model.py +++ b/invokeai/backend/grounded_sam/segment_anything_model.py @@ -5,6 +5,8 @@ from PIL import Image from transformers.models.sam import SamModel from transformers.models.sam.processing_sam import SamProcessor +from invokeai.backend.grounded_sam.detection_result import DetectionResult + class SegmentAnythingModel: """A wrapper class for the transformers SAM model and processor that makes it compatible with the model manager.""" @@ -23,7 +25,8 @@ class SegmentAnythingModel: return calc_module_size(self._sam_model) - def segment(self, image: Image.Image, boxes: list[list[list[int]]]) -> torch.Tensor: + def segment(self, image: Image.Image, detection_results: list[DetectionResult]) -> torch.Tensor: + boxes = self._to_box_array(detection_results) inputs = self._sam_processor(images=image, input_boxes=boxes, return_tensors="pt").to(self._sam_model.device) outputs = self._sam_model(**inputs) masks = self._sam_processor.post_process_masks( @@ -36,3 +39,8 @@ class SegmentAnythingModel: assert len(masks) == 1 masks = masks[0] return masks + + def _to_box_array(self, detection_results: list[DetectionResult]) -> list[list[list[int]]]: + """Convert a list of DetectionResults to the bbox format expected by the Segment Anything model.""" + boxes = [result.box.to_box() for result in detection_results] + return [boxes] From 6b10b59abe9bfc1f4c125dadd5e70bbaa0b5326d Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Tue, 30 Jul 2024 15:55:57 -0400 Subject: [PATCH 04/28] Make GroundedSAMInvocation work with any input image mode (RGB, RGBA, grayscale). --- invokeai/app/invocations/grounded_sam.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/invokeai/app/invocations/grounded_sam.py b/invokeai/app/invocations/grounded_sam.py index 8eb8770e47..411ec4a91f 100644 --- a/invokeai/app/invocations/grounded_sam.py +++ b/invokeai/app/invocations/grounded_sam.py @@ -60,7 +60,8 @@ class GroundedSAMInvocation(BaseInvocation): @torch.no_grad() def invoke(self, context: InvocationContext) -> ImageOutput: - image_pil = context.images.get_pil(self.image.image_name) + # The models expect a 3-channel RGB image. + image_pil = context.images.get_pil(self.image.image_name, mode="RGB") detections = self._detect( context=context, image=image_pil, labels=[self.prompt], threshold=self.detection_threshold From 2da9f913f3b4d05f6ce02819751fc74feef50796 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Tue, 30 Jul 2024 16:04:29 -0400 Subject: [PATCH 05/28] Add detection_result.py - was forgotten in a prior commit --- .../backend/grounded_sam/detection_result.py | 41 +++++++++++++++++++ 1 file changed, 41 insertions(+) create mode 100644 invokeai/backend/grounded_sam/detection_result.py diff --git a/invokeai/backend/grounded_sam/detection_result.py b/invokeai/backend/grounded_sam/detection_result.py new file mode 100644 index 0000000000..40e4254385 --- /dev/null +++ b/invokeai/backend/grounded_sam/detection_result.py @@ -0,0 +1,41 @@ +from dataclasses import dataclass +from typing import Any, Optional + +import numpy.typing as npt + + +@dataclass +class BoundingBox: + """Bounding box helper class.""" + + xmin: int + ymin: int + xmax: int + ymax: int + + def to_box(self) -> list[int]: + """Convert to the array notation expected by SAM.""" + return [self.xmin, self.ymin, self.xmax, self.ymax] + + +@dataclass +class DetectionResult: + """Detection result from Grounding DINO or Grounded SAM.""" + + score: float + label: str + box: BoundingBox + mask: Optional[npt.NDArray[Any]] = None + + @classmethod + def from_dict(cls, detection_dict: dict[str, Any]): + return cls( + score=detection_dict["score"], + label=detection_dict["label"], + box=BoundingBox( + xmin=detection_dict["box"]["xmin"], + ymin=detection_dict["box"]["ymin"], + xmax=detection_dict["box"]["xmax"], + ymax=detection_dict["box"]["ymax"], + ), + ) From 5701c79fab40927ae94b30fcdf79337bb0ad32e8 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Tue, 30 Jul 2024 23:04:15 +0200 Subject: [PATCH 06/28] Prevent Grounding DINO and Segment Anything from being moved to MPS - they don't work on MPS devices. --- invokeai/backend/grounded_sam/grounding_dino_pipeline.py | 4 ++++ invokeai/backend/grounded_sam/segment_anything_model.py | 3 +++ 2 files changed, 7 insertions(+) diff --git a/invokeai/backend/grounded_sam/grounding_dino_pipeline.py b/invokeai/backend/grounded_sam/grounding_dino_pipeline.py index 028aeb283d..1fc92b5e12 100644 --- a/invokeai/backend/grounded_sam/grounding_dino_pipeline.py +++ b/invokeai/backend/grounded_sam/grounding_dino_pipeline.py @@ -21,6 +21,10 @@ class GroundingDinoPipeline: return results def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> "GroundingDinoPipeline": + # HACK(ryand): The GroundingDinoPipeline does not work on MPS devices. We only allow it to be moved to CPU or + # CUDA. + if device is not None and device.type not in {"cpu", "cuda"}: + device = None self._pipeline.model.to(device=device, dtype=dtype) self._pipeline.device = self._pipeline.model.device return self diff --git a/invokeai/backend/grounded_sam/segment_anything_model.py b/invokeai/backend/grounded_sam/segment_anything_model.py index 86106b869c..1cc105c5fd 100644 --- a/invokeai/backend/grounded_sam/segment_anything_model.py +++ b/invokeai/backend/grounded_sam/segment_anything_model.py @@ -16,6 +16,9 @@ class SegmentAnythingModel: self._sam_processor = sam_processor def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> "SegmentAnythingModel": + # HACK(ryand): The SAM pipeline does not work on MPS devices. We only allow it to be moved to CPU or CUDA. + if device is not None and device.type not in {"cpu", "cuda"}: + device = None self._sam_model.to(device=device, dtype=dtype) return self From 67c32f3d6c77b3d8b8d8dc8b471c28c4495062ec Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 31 Jul 2024 08:15:28 -0400 Subject: [PATCH 07/28] Fix typo: zip(..., strict=True) --- invokeai/app/invocations/grounded_sam.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/app/invocations/grounded_sam.py b/invokeai/app/invocations/grounded_sam.py index 411ec4a91f..74fc5c73e0 100644 --- a/invokeai/app/invocations/grounded_sam.py +++ b/invokeai/app/invocations/grounded_sam.py @@ -143,7 +143,7 @@ class GroundedSAMInvocation(BaseInvocation): masks = self._to_numpy_masks(masks) masks = self._apply_polygon_refinement(masks) - for detection_result, mask in zip(detection_results, masks, strict=False): + for detection_result, mask in zip(detection_results, masks, strict=True): detection_result.mask = mask return detection_results From bdae81e429c546343c4d4e08c5fc168b29e355f6 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 31 Jul 2024 08:25:19 -0400 Subject: [PATCH 08/28] (minor) Simplify GroundedSAMInvocation._filter_detections() --- invokeai/app/invocations/grounded_sam.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/invokeai/app/invocations/grounded_sam.py b/invokeai/app/invocations/grounded_sam.py index 74fc5c73e0..1c97cea17e 100644 --- a/invokeai/app/invocations/grounded_sam.py +++ b/invokeai/app/invocations/grounded_sam.py @@ -182,16 +182,10 @@ class GroundedSAMInvocation(BaseInvocation): return detections elif self.mask_filter == "largest": # Find the largest mask. - mask_areas = [detection.mask.sum() for detection in detections] - largest_mask_idx = mask_areas.index(max(mask_areas)) - return [detections[largest_mask_idx]] + return [max(detections, key=lambda x: x.mask.sum())] elif self.mask_filter == "highest_box_score": # Find the detection with the highest box score. - max_score_detection = detections[0] - for detection in detections: - if detection.score > max_score_detection.score: - max_score_detection = detection - return [max_score_detection] + return [max(detections, key=lambda x: x.score)] else: raise ValueError(f"Invalid mask filter: {self.mask_filter}") From cec739936696013a74179f8e7f53963e1d89eedd Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 31 Jul 2024 08:27:01 -0400 Subject: [PATCH 09/28] (minor) Use a new variable name to satisfy type checks. --- invokeai/app/invocations/grounded_sam.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/invokeai/app/invocations/grounded_sam.py b/invokeai/app/invocations/grounded_sam.py index 1c97cea17e..dbf6d740bc 100644 --- a/invokeai/app/invocations/grounded_sam.py +++ b/invokeai/app/invocations/grounded_sam.py @@ -154,9 +154,8 @@ class GroundedSAMInvocation(BaseInvocation): masks = masks.permute(0, 2, 3, 1) masks = masks.mean(dim=-1) masks = (masks > 0).int() - masks = masks.numpy().astype(np.uint8) - masks = list(masks) - return masks + np_masks = masks.numpy().astype(np.uint8) + return list(np_masks) def _apply_polygon_refinement(self, masks: list[npt.NDArray[np.uint8]]) -> list[npt.NDArray[np.uint8]]: """Apply polygon refinement to the masks. From 33e8604b5753cadd1b07ef8a6e995552986516f6 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 31 Jul 2024 08:47:00 -0400 Subject: [PATCH 10/28] Make Grounding DINO DetectionResult a Pydantic model. --- .../backend/grounded_sam/detection_result.py | 25 ++++++------------- .../grounded_sam/grounding_dino_pipeline.py | 2 +- 2 files changed, 8 insertions(+), 19 deletions(-) diff --git a/invokeai/backend/grounded_sam/detection_result.py b/invokeai/backend/grounded_sam/detection_result.py index 40e4254385..a9f2bdd65f 100644 --- a/invokeai/backend/grounded_sam/detection_result.py +++ b/invokeai/backend/grounded_sam/detection_result.py @@ -1,11 +1,10 @@ -from dataclasses import dataclass from typing import Any, Optional import numpy.typing as npt +from pydantic import BaseModel, ConfigDict -@dataclass -class BoundingBox: +class BoundingBox(BaseModel): """Bounding box helper class.""" xmin: int @@ -18,24 +17,14 @@ class BoundingBox: return [self.xmin, self.ymin, self.xmax, self.ymax] -@dataclass -class DetectionResult: +class DetectionResult(BaseModel): """Detection result from Grounding DINO or Grounded SAM.""" score: float label: str box: BoundingBox mask: Optional[npt.NDArray[Any]] = None - - @classmethod - def from_dict(cls, detection_dict: dict[str, Any]): - return cls( - score=detection_dict["score"], - label=detection_dict["label"], - box=BoundingBox( - xmin=detection_dict["box"]["xmin"], - ymin=detection_dict["box"]["ymin"], - xmax=detection_dict["box"]["xmax"], - ymax=detection_dict["box"]["ymax"], - ), - ) + model_config = ConfigDict( + # Allow arbitrary types for mask, since it will be a numpy array. + arbitrary_types_allowed=True + ) diff --git a/invokeai/backend/grounded_sam/grounding_dino_pipeline.py b/invokeai/backend/grounded_sam/grounding_dino_pipeline.py index 1fc92b5e12..97c92f9249 100644 --- a/invokeai/backend/grounded_sam/grounding_dino_pipeline.py +++ b/invokeai/backend/grounded_sam/grounding_dino_pipeline.py @@ -17,7 +17,7 @@ class GroundingDinoPipeline: def detect(self, image: Image.Image, candidate_labels: list[str], threshold: float = 0.1) -> list[DetectionResult]: results = self._pipeline(image=image, candidate_labels=candidate_labels, threshold=threshold) - results = [DetectionResult.from_dict(result) for result in results] + results = [DetectionResult.model_validate(result) for result in results] return results def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> "GroundingDinoPipeline": From e8ecf5e15554e1367e3ca972ad04700107e564e1 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 31 Jul 2024 08:50:56 -0400 Subject: [PATCH 11/28] (minor) Move apply_polygon_refinement condition up a layer. --- invokeai/app/invocations/grounded_sam.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/invokeai/app/invocations/grounded_sam.py b/invokeai/app/invocations/grounded_sam.py index dbf6d740bc..4b1181f59b 100644 --- a/invokeai/app/invocations/grounded_sam.py +++ b/invokeai/app/invocations/grounded_sam.py @@ -141,7 +141,8 @@ class GroundedSAMInvocation(BaseInvocation): masks = sam_pipeline.segment(image=image, detection_results=detection_results) masks = self._to_numpy_masks(masks) - masks = self._apply_polygon_refinement(masks) + if self.apply_polygon_refinement: + masks = self._apply_polygon_refinement(masks) for detection_result, mask in zip(detection_results, masks, strict=True): detection_result.mask = mask @@ -166,12 +167,12 @@ class GroundedSAMInvocation(BaseInvocation): - Removes small mask pieces. - Removes holes from the mask. """ - if self.apply_polygon_refinement: - for idx, mask in enumerate(masks): - shape = mask.shape - polygon = mask_to_polygon(mask) - mask = polygon_to_mask(polygon, shape) - masks[idx] = mask + for idx, mask in enumerate(masks): + shape = mask.shape + assert len(shape) == 2 # Assert length to satisfy type checker. + polygon = mask_to_polygon(mask) + mask = polygon_to_mask(polygon, shape) + masks[idx] = mask return masks From 0a7048f6503c1f1e1b4d641911824aec827466fb Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 31 Jul 2024 08:58:51 -0400 Subject: [PATCH 12/28] (minor) Simplify GroundedSAMInvocation._merge_masks(...). --- invokeai/app/invocations/grounded_sam.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/invokeai/app/invocations/grounded_sam.py b/invokeai/app/invocations/grounded_sam.py index 4b1181f59b..b23811afa6 100644 --- a/invokeai/app/invocations/grounded_sam.py +++ b/invokeai/app/invocations/grounded_sam.py @@ -74,7 +74,8 @@ class GroundedSAMInvocation(BaseInvocation): detections = self._filter_detections(detections) masks = [detection.mask for detection in detections] - combined_mask = self._merge_masks(masks) + # masks contains binary values of 0 or 1, so we merge them via max-reduce. + combined_mask = np.maximum.reduce(masks) # Map [0, 1] to [0, 255]. mask_np = combined_mask * 255 @@ -188,10 +189,3 @@ class GroundedSAMInvocation(BaseInvocation): return [max(detections, key=lambda x: x.score)] else: raise ValueError(f"Invalid mask filter: {self.mask_filter}") - - def _merge_masks(self, masks: list[npt.NDArray[np.uint8]]) -> npt.NDArray[np.uint8]: - """Merge multiple masks into a single mask.""" - # Merge all masks together. - stacked_mask = np.stack(masks, axis=0) - combined_mask = np.max(stacked_mask, axis=0) - return combined_mask From e206890e25f52f165aaf76c3120dbbc61ea47cc1 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 31 Jul 2024 09:28:52 -0400 Subject: [PATCH 13/28] Use staticmethods rather than inner functions for the Grounding DINO and SAM model loaders. --- invokeai/app/invocations/grounded_sam.py | 64 +++++++++++++----------- 1 file changed, 34 insertions(+), 30 deletions(-) diff --git a/invokeai/app/invocations/grounded_sam.py b/invokeai/app/invocations/grounded_sam.py index b23811afa6..cac338c160 100644 --- a/invokeai/app/invocations/grounded_sam.py +++ b/invokeai/app/invocations/grounded_sam.py @@ -84,6 +84,34 @@ class GroundedSAMInvocation(BaseInvocation): image_dto = context.images.save(image=mask_pil) return ImageOutput.build(image_dto) + @staticmethod + def _load_grounding_dino(model_path: Path): + grounding_dino_pipeline = pipeline( + model=str(model_path), + task="zero-shot-object-detection", + local_files_only=True, + # TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the + # model, and figure out how to make it work in the pipeline. + # torch_dtype=TorchDevice.choose_torch_dtype(), + ) + assert isinstance(grounding_dino_pipeline, ZeroShotObjectDetectionPipeline) + return GroundingDinoPipeline(grounding_dino_pipeline) + + @staticmethod + def _load_sam_model(model_path: Path): + sam_model = AutoModelForMaskGeneration.from_pretrained( + model_path, + local_files_only=True, + # TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the + # model, and figure out how to make it work in the pipeline. + # torch_dtype=TorchDevice.choose_torch_dtype(), + ) + assert isinstance(sam_model, SamModel) + + sam_processor = AutoProcessor.from_pretrained(model_path, local_files_only=True) + assert isinstance(sam_processor, SamProcessor) + return SegmentAnythingModel(sam_model=sam_model, sam_processor=sam_processor) + def _detect( self, context: InvocationContext, @@ -96,19 +124,9 @@ class GroundedSAMInvocation(BaseInvocation): # actually makes a difference. labels = [label if label.endswith(".") else label + "." for label in labels] - def load_grounding_dino(model_path: Path): - grounding_dino_pipeline = pipeline( - model=str(model_path), - task="zero-shot-object-detection", - local_files_only=True, - # TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the - # model, and figure out how to make it work in the pipeline. - # torch_dtype=TorchDevice.choose_torch_dtype(), - ) - assert isinstance(grounding_dino_pipeline, ZeroShotObjectDetectionPipeline) - return GroundingDinoPipeline(grounding_dino_pipeline) - - with context.models.load_remote_model(source=GROUNDING_DINO_MODEL_ID, loader=load_grounding_dino) as detector: + with context.models.load_remote_model( + source=GROUNDING_DINO_MODEL_ID, loader=GroundedSAMInvocation._load_grounding_dino + ) as detector: assert isinstance(detector, GroundingDinoPipeline) return detector.detect(image=image, candidate_labels=labels, threshold=threshold) @@ -119,26 +137,12 @@ class GroundedSAMInvocation(BaseInvocation): detection_results: list[DetectionResult], ) -> list[DetectionResult]: """Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes.""" - - def load_sam_model(model_path: Path): - sam_model = AutoModelForMaskGeneration.from_pretrained( - model_path, - local_files_only=True, - # TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the - # model, and figure out how to make it work in the pipeline. - # torch_dtype=TorchDevice.choose_torch_dtype(), - ) - assert isinstance(sam_model, SamModel) - - sam_processor = AutoProcessor.from_pretrained(model_path, local_files_only=True) - assert isinstance(sam_processor, SamProcessor) - return SegmentAnythingModel(sam_model=sam_model, sam_processor=sam_processor) - with ( - context.models.load_remote_model(source=SEGMENT_ANYTHING_MODEL_ID, loader=load_sam_model) as sam_pipeline, + context.models.load_remote_model( + source=SEGMENT_ANYTHING_MODEL_ID, loader=GroundedSAMInvocation._load_sam_model + ) as sam_pipeline, ): assert isinstance(sam_pipeline, SegmentAnythingModel) - masks = sam_pipeline.segment(image=image, detection_results=detection_results) masks = self._to_numpy_masks(masks) From bcd1483a14e2af4b5cede1c65318628388d46cb8 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 31 Jul 2024 09:51:14 -0400 Subject: [PATCH 14/28] Re-order GroundedSAMInvocation._to_numpy_masks(...) to do slightly more work on the GPU. --- invokeai/app/invocations/grounded_sam.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/invokeai/app/invocations/grounded_sam.py b/invokeai/app/invocations/grounded_sam.py index cac338c160..3a1cb9be56 100644 --- a/invokeai/app/invocations/grounded_sam.py +++ b/invokeai/app/invocations/grounded_sam.py @@ -156,11 +156,11 @@ class GroundedSAMInvocation(BaseInvocation): def _to_numpy_masks(self, masks: torch.Tensor) -> list[npt.NDArray[np.uint8]]: """Convert the tensor output from the Segment Anything model to a list of numpy masks.""" - masks = masks.cpu().float() - masks = masks.permute(0, 2, 3, 1) - masks = masks.mean(dim=-1) - masks = (masks > 0).int() - np_masks = masks.numpy().astype(np.uint8) + eps = 0.0001 + # [num_masks, channels, height, width] -> [num_masks, height, width] + masks = masks.permute(0, 2, 3, 1).float().mean(dim=-1) + masks = masks > eps + np_masks = masks.cpu().numpy().astype(np.uint8) return list(np_masks) def _apply_polygon_refinement(self, masks: list[npt.NDArray[np.uint8]]) -> list[npt.NDArray[np.uint8]]: From 9f448fecb7cc8bb07a7741ee898294661fc92cc9 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 31 Jul 2024 10:00:30 -0400 Subject: [PATCH 15/28] Move invokeai/backend/grounded_sam -> invokeai/backend/image_util/grounded_sam --- invokeai/app/invocations/grounded_sam.py | 8 ++++---- .../backend/{ => image_util}/grounded_sam/__init__.py | 0 .../{ => image_util}/grounded_sam/detection_result.py | 0 .../grounded_sam/grounding_dino_pipeline.py | 2 +- .../{ => image_util}/grounded_sam/mask_refinement.py | 0 .../grounded_sam/segment_anything_model.py | 2 +- invokeai/backend/model_manager/load/model_util.py | 4 ++-- 7 files changed, 8 insertions(+), 8 deletions(-) rename invokeai/backend/{ => image_util}/grounded_sam/__init__.py (100%) rename invokeai/backend/{ => image_util}/grounded_sam/detection_result.py (100%) rename invokeai/backend/{ => image_util}/grounded_sam/grounding_dino_pipeline.py (94%) rename invokeai/backend/{ => image_util}/grounded_sam/mask_refinement.py (100%) rename invokeai/backend/{ => image_util}/grounded_sam/segment_anything_model.py (95%) diff --git a/invokeai/app/invocations/grounded_sam.py b/invokeai/app/invocations/grounded_sam.py index 3a1cb9be56..d741c3d6ad 100644 --- a/invokeai/app/invocations/grounded_sam.py +++ b/invokeai/app/invocations/grounded_sam.py @@ -14,10 +14,10 @@ from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation from invokeai.app.invocations.fields import ImageField, InputField from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.services.shared.invocation_context import InvocationContext -from invokeai.backend.grounded_sam.detection_result import DetectionResult -from invokeai.backend.grounded_sam.grounding_dino_pipeline import GroundingDinoPipeline -from invokeai.backend.grounded_sam.mask_refinement import mask_to_polygon, polygon_to_mask -from invokeai.backend.grounded_sam.segment_anything_model import SegmentAnythingModel +from invokeai.backend.image_util.grounded_sam.detection_result import DetectionResult +from invokeai.backend.image_util.grounded_sam.grounding_dino_pipeline import GroundingDinoPipeline +from invokeai.backend.image_util.grounded_sam.mask_refinement import mask_to_polygon, polygon_to_mask +from invokeai.backend.image_util.grounded_sam.segment_anything_model import SegmentAnythingModel GROUNDING_DINO_MODEL_ID = "IDEA-Research/grounding-dino-tiny" SEGMENT_ANYTHING_MODEL_ID = "facebook/sam-vit-base" diff --git a/invokeai/backend/grounded_sam/__init__.py b/invokeai/backend/image_util/grounded_sam/__init__.py similarity index 100% rename from invokeai/backend/grounded_sam/__init__.py rename to invokeai/backend/image_util/grounded_sam/__init__.py diff --git a/invokeai/backend/grounded_sam/detection_result.py b/invokeai/backend/image_util/grounded_sam/detection_result.py similarity index 100% rename from invokeai/backend/grounded_sam/detection_result.py rename to invokeai/backend/image_util/grounded_sam/detection_result.py diff --git a/invokeai/backend/grounded_sam/grounding_dino_pipeline.py b/invokeai/backend/image_util/grounded_sam/grounding_dino_pipeline.py similarity index 94% rename from invokeai/backend/grounded_sam/grounding_dino_pipeline.py rename to invokeai/backend/image_util/grounded_sam/grounding_dino_pipeline.py index 97c92f9249..c03a3d8cab 100644 --- a/invokeai/backend/grounded_sam/grounding_dino_pipeline.py +++ b/invokeai/backend/image_util/grounded_sam/grounding_dino_pipeline.py @@ -4,7 +4,7 @@ import torch from PIL import Image from transformers.pipelines import ZeroShotObjectDetectionPipeline -from invokeai.backend.grounded_sam.detection_result import DetectionResult +from invokeai.backend.image_util.grounded_sam.detection_result import DetectionResult class GroundingDinoPipeline: diff --git a/invokeai/backend/grounded_sam/mask_refinement.py b/invokeai/backend/image_util/grounded_sam/mask_refinement.py similarity index 100% rename from invokeai/backend/grounded_sam/mask_refinement.py rename to invokeai/backend/image_util/grounded_sam/mask_refinement.py diff --git a/invokeai/backend/grounded_sam/segment_anything_model.py b/invokeai/backend/image_util/grounded_sam/segment_anything_model.py similarity index 95% rename from invokeai/backend/grounded_sam/segment_anything_model.py rename to invokeai/backend/image_util/grounded_sam/segment_anything_model.py index 1cc105c5fd..c9959424f2 100644 --- a/invokeai/backend/grounded_sam/segment_anything_model.py +++ b/invokeai/backend/image_util/grounded_sam/segment_anything_model.py @@ -5,7 +5,7 @@ from PIL import Image from transformers.models.sam import SamModel from transformers.models.sam.processing_sam import SamProcessor -from invokeai.backend.grounded_sam.detection_result import DetectionResult +from invokeai.backend.image_util.grounded_sam.detection_result import DetectionResult class SegmentAnythingModel: diff --git a/invokeai/backend/model_manager/load/model_util.py b/invokeai/backend/model_manager/load/model_util.py index 22d493f7a0..351331176c 100644 --- a/invokeai/backend/model_manager/load/model_util.py +++ b/invokeai/backend/model_manager/load/model_util.py @@ -11,8 +11,8 @@ from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.schedulers.scheduling_utils import SchedulerMixin from transformers import CLIPTokenizer -from invokeai.backend.grounded_sam.grounding_dino_pipeline import GroundingDinoPipeline -from invokeai.backend.grounded_sam.segment_anything_model import SegmentAnythingModel +from invokeai.backend.image_util.grounded_sam.grounding_dino_pipeline import GroundingDinoPipeline +from invokeai.backend.image_util.grounded_sam.segment_anything_model import SegmentAnythingModel from invokeai.backend.ip_adapter.ip_adapter import IPAdapter from invokeai.backend.lora import LoRAModelRaw from invokeai.backend.model_manager.config import AnyModel From 73386826d6bd434742d994158653b4b7886dfa11 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 31 Jul 2024 10:25:34 -0400 Subject: [PATCH 16/28] Make GroundingDinoPipeline and SegmentAnythingModel subclasses of RawModel for type checking purposes. --- .../image_util/grounded_sam/grounding_dino_pipeline.py | 6 +++--- .../image_util/grounded_sam/segment_anything_model.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/invokeai/backend/image_util/grounded_sam/grounding_dino_pipeline.py b/invokeai/backend/image_util/grounded_sam/grounding_dino_pipeline.py index c03a3d8cab..d998eb1d57 100644 --- a/invokeai/backend/image_util/grounded_sam/grounding_dino_pipeline.py +++ b/invokeai/backend/image_util/grounded_sam/grounding_dino_pipeline.py @@ -5,9 +5,10 @@ from PIL import Image from transformers.pipelines import ZeroShotObjectDetectionPipeline from invokeai.backend.image_util.grounded_sam.detection_result import DetectionResult +from invokeai.backend.raw_model import RawModel -class GroundingDinoPipeline: +class GroundingDinoPipeline(RawModel): """A wrapper class for a ZeroShotObjectDetectionPipeline that makes it compatible with the model manager's memory management system. """ @@ -20,14 +21,13 @@ class GroundingDinoPipeline: results = [DetectionResult.model_validate(result) for result in results] return results - def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> "GroundingDinoPipeline": + def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None): # HACK(ryand): The GroundingDinoPipeline does not work on MPS devices. We only allow it to be moved to CPU or # CUDA. if device is not None and device.type not in {"cpu", "cuda"}: device = None self._pipeline.model.to(device=device, dtype=dtype) self._pipeline.device = self._pipeline.model.device - return self def calc_size(self) -> int: # HACK(ryand): Fix the circular import issue. diff --git a/invokeai/backend/image_util/grounded_sam/segment_anything_model.py b/invokeai/backend/image_util/grounded_sam/segment_anything_model.py index c9959424f2..5147141303 100644 --- a/invokeai/backend/image_util/grounded_sam/segment_anything_model.py +++ b/invokeai/backend/image_util/grounded_sam/segment_anything_model.py @@ -6,21 +6,21 @@ from transformers.models.sam import SamModel from transformers.models.sam.processing_sam import SamProcessor from invokeai.backend.image_util.grounded_sam.detection_result import DetectionResult +from invokeai.backend.raw_model import RawModel -class SegmentAnythingModel: +class SegmentAnythingModel(RawModel): """A wrapper class for the transformers SAM model and processor that makes it compatible with the model manager.""" def __init__(self, sam_model: SamModel, sam_processor: SamProcessor): self._sam_model = sam_model self._sam_processor = sam_processor - def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> "SegmentAnythingModel": + def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None): # HACK(ryand): The SAM pipeline does not work on MPS devices. We only allow it to be moved to CPU or CUDA. if device is not None and device.type not in {"cpu", "cuda"}: device = None self._sam_model.to(device=device, dtype=dtype) - return self def calc_size(self) -> int: # HACK(ryand): Fix the circular import issue. From 0193267a53b8020626fb6908a63c1e12e1cb2d17 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 31 Jul 2024 12:20:23 -0400 Subject: [PATCH 17/28] Split GroundedSamInvocation into GroundingDinoInvocation and SegmentAnythingModelInvocation. --- invokeai/app/invocations/fields.py | 17 +++ invokeai/app/invocations/grounding_dino.py | 95 ++++++++++++++++ invokeai/app/invocations/primitives.py | 22 ++++ ...unded_sam.py => segment_anything_model.py} | 105 ++++++------------ .../grounded_sam/detection_result.py | 10 +- .../grounded_sam/segment_anything_model.py | 24 ++-- 6 files changed, 180 insertions(+), 93 deletions(-) create mode 100644 invokeai/app/invocations/grounding_dino.py rename invokeai/app/invocations/{grounded_sam.py => segment_anything_model.py} (56%) diff --git a/invokeai/app/invocations/fields.py b/invokeai/app/invocations/fields.py index f9a483f84c..fe97123fef 100644 --- a/invokeai/app/invocations/fields.py +++ b/invokeai/app/invocations/fields.py @@ -242,6 +242,23 @@ class ConditioningField(BaseModel): ) +class BoundingBoxField(BaseModel): + """A bounding box primitive value.""" + + x_min: int = Field(ge=0, description="The minimum x-coordinate of the bounding box (inclusive).") + x_max: int = Field(ge=0, description="The maximum x-coordinate of the bounding box (exclusive).") + y_min: int = Field(ge=0, description="The minimum y-coordinate of the bounding box (inclusive).") + y_max: int = Field(ge=0, description="The maximum y-coordinate of the bounding box (exclusive).") + + score: Optional[float] = Field( + default=None, + ge=0.0, + le=1.0, + description="The score associated with the bounding box. In the range [0, 1]. This value is typically set " + "when the bounding box was produced by a detector and has an associated confidence score.", + ) + + class MetadataField(RootModel[dict[str, Any]]): """ Pydantic model for metadata with custom root of type dict[str, Any]. diff --git a/invokeai/app/invocations/grounding_dino.py b/invokeai/app/invocations/grounding_dino.py new file mode 100644 index 0000000000..ade5b1456e --- /dev/null +++ b/invokeai/app/invocations/grounding_dino.py @@ -0,0 +1,95 @@ +from pathlib import Path + +import torch +from PIL import Image +from transformers import pipeline +from transformers.pipelines import ZeroShotObjectDetectionPipeline + +from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation +from invokeai.app.invocations.fields import BoundingBoxField, ImageField, InputField +from invokeai.app.invocations.primitives import BoundingBoxCollectionOutput +from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.image_util.grounded_sam.detection_result import DetectionResult +from invokeai.backend.image_util.grounded_sam.grounding_dino_pipeline import GroundingDinoPipeline + +GROUNDING_DINO_MODEL_ID = "IDEA-Research/grounding-dino-tiny" + + +@invocation( + "grounding_dino", + title="Grounding DINO (Text Prompt Object Detection)", + tags=["prompt", "object detection"], + category="image", + version="1.0.0", +) +class GroundingDinoInvocation(BaseInvocation): + """Runs a Grounding DINO model (https://arxiv.org/pdf/2303.05499). Performs zero-shot bounding-box object detection + from a text prompt. + + Reference: + - https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam + - https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb + """ + + prompt: str = InputField(description="The prompt describing the object to segment.") + image: ImageField = InputField(description="The image to segment.") + detection_threshold: float = InputField( + description="The detection threshold for the Grounding DINO model. All detected bounding boxes with scores above this threshold will be returned.", + ge=0.0, + le=1.0, + default=0.3, + ) + + @torch.no_grad() + def invoke(self, context: InvocationContext) -> BoundingBoxCollectionOutput: + # The model expects a 3-channel RGB image. + image_pil = context.images.get_pil(self.image.image_name, mode="RGB") + + detections = self._detect( + context=context, image=image_pil, labels=[self.prompt], threshold=self.detection_threshold + ) + + # Convert detections to BoundingBoxCollectionOutput. + bounding_boxes: list[BoundingBoxField] = [] + for detection in detections: + bounding_boxes.append( + BoundingBoxField( + x_min=detection.box.xmin, + x_max=detection.box.xmax, + y_min=detection.box.ymin, + y_max=detection.box.ymax, + score=detection.score, + ) + ) + return BoundingBoxCollectionOutput(collection=bounding_boxes) + + @staticmethod + def _load_grounding_dino(model_path: Path): + grounding_dino_pipeline = pipeline( + model=str(model_path), + task="zero-shot-object-detection", + local_files_only=True, + # TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the + # model, and figure out how to make it work in the pipeline. + # torch_dtype=TorchDevice.choose_torch_dtype(), + ) + assert isinstance(grounding_dino_pipeline, ZeroShotObjectDetectionPipeline) + return GroundingDinoPipeline(grounding_dino_pipeline) + + def _detect( + self, + context: InvocationContext, + image: Image.Image, + labels: list[str], + threshold: float = 0.3, + ) -> list[DetectionResult]: + """Use Grounding DINO to detect bounding boxes for a set of labels in an image.""" + # TODO(ryand): I copied this "."-handling logic from the transformers example code. Test it and see if it + # actually makes a difference. + labels = [label if label.endswith(".") else label + "." for label in labels] + + with context.models.load_remote_model( + source=GROUNDING_DINO_MODEL_ID, loader=GroundingDinoInvocation._load_grounding_dino + ) as detector: + assert isinstance(detector, GroundingDinoPipeline) + return detector.detect(image=image, candidate_labels=labels, threshold=threshold) diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index e5056e3775..4e1e00bfd8 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -7,6 +7,7 @@ import torch from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR from invokeai.app.invocations.fields import ( + BoundingBoxField, ColorField, ConditioningField, DenoiseMaskField, @@ -469,3 +470,24 @@ class ConditioningCollectionInvocation(BaseInvocation): # endregion + +# region BoundingBox + + +@invocation_output("bounding_box_output") +class BoundingBoxOutput(BaseInvocationOutput): + """Base class for nodes that output a single bounding box""" + + bounding_box: BoundingBoxField = OutputField(description="The output bounding box.") + + +@invocation_output("bounding_box_collection_output") +class BoundingBoxCollectionOutput(BaseInvocationOutput): + """Base class for nodes that output a collection of bounding boxes""" + + collection: list[BoundingBoxField] = OutputField( + description="The output bounding boxes.", + ) + + +# endregion diff --git a/invokeai/app/invocations/grounded_sam.py b/invokeai/app/invocations/segment_anything_model.py similarity index 56% rename from invokeai/app/invocations/grounded_sam.py rename to invokeai/app/invocations/segment_anything_model.py index d741c3d6ad..39e2ee6d7f 100644 --- a/invokeai/app/invocations/grounded_sam.py +++ b/invokeai/app/invocations/segment_anything_model.py @@ -5,75 +5,56 @@ import numpy as np import numpy.typing as npt import torch from PIL import Image -from transformers import AutoModelForMaskGeneration, AutoProcessor, pipeline +from transformers import AutoModelForMaskGeneration, AutoProcessor from transformers.models.sam import SamModel from transformers.models.sam.processing_sam import SamProcessor -from transformers.pipelines import ZeroShotObjectDetectionPipeline from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation -from invokeai.app.invocations.fields import ImageField, InputField +from invokeai.app.invocations.fields import BoundingBoxField, ImageField, InputField from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.services.shared.invocation_context import InvocationContext -from invokeai.backend.image_util.grounded_sam.detection_result import DetectionResult -from invokeai.backend.image_util.grounded_sam.grounding_dino_pipeline import GroundingDinoPipeline from invokeai.backend.image_util.grounded_sam.mask_refinement import mask_to_polygon, polygon_to_mask from invokeai.backend.image_util.grounded_sam.segment_anything_model import SegmentAnythingModel -GROUNDING_DINO_MODEL_ID = "IDEA-Research/grounding-dino-tiny" SEGMENT_ANYTHING_MODEL_ID = "facebook/sam-vit-base" @invocation( - "grounded_segment_anything", - title="Segment Anything (Text Prompt)", + "segment_anything_model", + title="Segment Anything Model", tags=["prompt", "segmentation"], category="segmentation", version="1.0.0", ) -class GroundedSAMInvocation(BaseInvocation): - """Runs Grounded-SAM, as proposed in https://arxiv.org/pdf/2401.14159. - - More specifically, a Grounding DINO model is run to obtain bounding boxes for a text prompt, then the bounding boxes - are passed as a prompt to a Segment Anything model to obtain a segmentation mask. +class SegmentAnythingModelInvocation(BaseInvocation): + """Runs a Segment Anything Model (https://arxiv.org/pdf/2304.02643). Reference: - https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam - https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb """ - prompt: str = InputField(description="The prompt describing the object to segment.") image: ImageField = InputField(description="The image to segment.") + bounding_boxes: list[BoundingBoxField] = InputField(description="The bounding boxes to prompt the SAM model with.") apply_polygon_refinement: bool = InputField( - description="Whether to apply polygon refinement to the masks. This will smooth the edges of the mask slightly and ensure that each mask consists of a single closed polygon (before merging).", + description="Whether to apply polygon refinement to the masks. This will smooth the edges of the masks slightly and ensure that each mask consists of a single closed polygon (before merging).", default=True, ) mask_filter: Literal["all", "largest", "highest_box_score"] = InputField( description="The filtering to apply to the detected masks before merging them into a final output.", default="all", ) - detection_threshold: float = InputField( - description="The detection threshold for the Grounding DINO model. All detected bounding boxes with scores above this threshold will be used.", - ge=0.0, - le=1.0, - default=0.3, - ) @torch.no_grad() def invoke(self, context: InvocationContext) -> ImageOutput: # The models expect a 3-channel RGB image. image_pil = context.images.get_pil(self.image.image_name, mode="RGB") - detections = self._detect( - context=context, image=image_pil, labels=[self.prompt], threshold=self.detection_threshold - ) - - if len(detections) == 0: + if len(self.bounding_boxes) == 0: combined_mask = np.zeros(image_pil.size[::-1], dtype=np.uint8) else: - detections = self._segment(context=context, image=image_pil, detection_results=detections) - - detections = self._filter_detections(detections) - masks = [detection.mask for detection in detections] + masks = self._segment(context=context, image=image_pil) + masks = self._filter_masks(masks=masks, bounding_boxes=self.bounding_boxes) # masks contains binary values of 0 or 1, so we merge them via max-reduce. combined_mask = np.maximum.reduce(masks) @@ -84,19 +65,6 @@ class GroundedSAMInvocation(BaseInvocation): image_dto = context.images.save(image=mask_pil) return ImageOutput.build(image_dto) - @staticmethod - def _load_grounding_dino(model_path: Path): - grounding_dino_pipeline = pipeline( - model=str(model_path), - task="zero-shot-object-detection", - local_files_only=True, - # TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the - # model, and figure out how to make it work in the pipeline. - # torch_dtype=TorchDevice.choose_torch_dtype(), - ) - assert isinstance(grounding_dino_pipeline, ZeroShotObjectDetectionPipeline) - return GroundingDinoPipeline(grounding_dino_pipeline) - @staticmethod def _load_sam_model(model_path: Path): sam_model = AutoModelForMaskGeneration.from_pretrained( @@ -112,47 +80,28 @@ class GroundedSAMInvocation(BaseInvocation): assert isinstance(sam_processor, SamProcessor) return SegmentAnythingModel(sam_model=sam_model, sam_processor=sam_processor) - def _detect( - self, - context: InvocationContext, - image: Image.Image, - labels: list[str], - threshold: float = 0.3, - ) -> list[DetectionResult]: - """Use Grounding DINO to detect bounding boxes for a set of labels in an image.""" - # TODO(ryand): I copied this "."-handling logic from the transformers example code. Test it and see if it - # actually makes a difference. - labels = [label if label.endswith(".") else label + "." for label in labels] - - with context.models.load_remote_model( - source=GROUNDING_DINO_MODEL_ID, loader=GroundedSAMInvocation._load_grounding_dino - ) as detector: - assert isinstance(detector, GroundingDinoPipeline) - return detector.detect(image=image, candidate_labels=labels, threshold=threshold) - def _segment( self, context: InvocationContext, image: Image.Image, - detection_results: list[DetectionResult], - ) -> list[DetectionResult]: + ) -> list[npt.NDArray[np.uint8]]: """Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes.""" + # Convert the bounding boxes to the SAM input format. + sam_bounding_boxes = [[bb.x_min, bb.y_min, bb.x_max, bb.y_max] for bb in self.bounding_boxes] + with ( context.models.load_remote_model( - source=SEGMENT_ANYTHING_MODEL_ID, loader=GroundedSAMInvocation._load_sam_model + source=SEGMENT_ANYTHING_MODEL_ID, loader=SegmentAnythingModelInvocation._load_sam_model ) as sam_pipeline, ): assert isinstance(sam_pipeline, SegmentAnythingModel) - masks = sam_pipeline.segment(image=image, detection_results=detection_results) + masks = sam_pipeline.segment(image=image, bounding_boxes=sam_bounding_boxes) masks = self._to_numpy_masks(masks) if self.apply_polygon_refinement: masks = self._apply_polygon_refinement(masks) - for detection_result, mask in zip(detection_results, masks, strict=True): - detection_result.mask = mask - - return detection_results + return masks def _to_numpy_masks(self, masks: torch.Tensor) -> list[npt.NDArray[np.uint8]]: """Convert the tensor output from the Segment Anything model to a list of numpy masks.""" @@ -181,15 +130,23 @@ class GroundedSAMInvocation(BaseInvocation): return masks - def _filter_detections(self, detections: list[DetectionResult]) -> list[DetectionResult]: + def _filter_masks( + self, masks: list[npt.NDArray[np.uint8]], bounding_boxes: list[BoundingBoxField] + ) -> list[npt.NDArray[np.uint8]]: """Filter the detected masks based on the specified mask filter.""" + assert len(masks) == len(bounding_boxes) + if self.mask_filter == "all": - return detections + return masks elif self.mask_filter == "largest": # Find the largest mask. - return [max(detections, key=lambda x: x.mask.sum())] + return [max(masks, key=lambda x: x.sum())] elif self.mask_filter == "highest_box_score": - # Find the detection with the highest box score. - return [max(detections, key=lambda x: x.score)] + # Find the index of the bounding box with the highest score. + # Note that we fallback to -1.0 if the score is None. This is mainly to satisfy the type checker. In most + # cases the scores should all be non-None when using this filtering mode. That being said, -1.0 is a + # reasonable fallback since the expected score range is [0.0, 1.0]. + max_score_idx = max(range(len(bounding_boxes)), key=lambda i: bounding_boxes[i].score or -1.0) + return [masks[max_score_idx]] else: raise ValueError(f"Invalid mask filter: {self.mask_filter}") diff --git a/invokeai/backend/image_util/grounded_sam/detection_result.py b/invokeai/backend/image_util/grounded_sam/detection_result.py index a9f2bdd65f..2d0c78e681 100644 --- a/invokeai/backend/image_util/grounded_sam/detection_result.py +++ b/invokeai/backend/image_util/grounded_sam/detection_result.py @@ -1,6 +1,3 @@ -from typing import Any, Optional - -import numpy.typing as npt from pydantic import BaseModel, ConfigDict @@ -12,18 +9,13 @@ class BoundingBox(BaseModel): xmax: int ymax: int - def to_box(self) -> list[int]: - """Convert to the array notation expected by SAM.""" - return [self.xmin, self.ymin, self.xmax, self.ymax] - class DetectionResult(BaseModel): - """Detection result from Grounding DINO or Grounded SAM.""" + """Detection result from Grounding DINO.""" score: float label: str box: BoundingBox - mask: Optional[npt.NDArray[Any]] = None model_config = ConfigDict( # Allow arbitrary types for mask, since it will be a numpy array. arbitrary_types_allowed=True diff --git a/invokeai/backend/image_util/grounded_sam/segment_anything_model.py b/invokeai/backend/image_util/grounded_sam/segment_anything_model.py index 5147141303..d358db5c00 100644 --- a/invokeai/backend/image_util/grounded_sam/segment_anything_model.py +++ b/invokeai/backend/image_util/grounded_sam/segment_anything_model.py @@ -5,7 +5,6 @@ from PIL import Image from transformers.models.sam import SamModel from transformers.models.sam.processing_sam import SamProcessor -from invokeai.backend.image_util.grounded_sam.detection_result import DetectionResult from invokeai.backend.raw_model import RawModel @@ -28,8 +27,19 @@ class SegmentAnythingModel(RawModel): return calc_module_size(self._sam_model) - def segment(self, image: Image.Image, detection_results: list[DetectionResult]) -> torch.Tensor: - boxes = self._to_box_array(detection_results) + def segment(self, image: Image.Image, bounding_boxes: list[list[int]]) -> torch.Tensor: + """Run the SAM model. + + Args: + image (Image.Image): The image to segment. + bounding_boxes (list[list[int]]): The bounding box prompts. Each bounding box is in the format + [xmin, ymin, xmax, ymax]. + + Returns: + torch.Tensor: The segmentation masks. dtype: torch.bool. shape: [num_masks, channels, height, width]. + """ + # Add batch dimension of 1 to the bounding boxes. + boxes = [bounding_boxes] inputs = self._sam_processor(images=image, input_boxes=boxes, return_tensors="pt").to(self._sam_model.device) outputs = self._sam_model(**inputs) masks = self._sam_processor.post_process_masks( @@ -40,10 +50,4 @@ class SegmentAnythingModel(RawModel): # There should be only one batch. assert len(masks) == 1 - masks = masks[0] - return masks - - def _to_box_array(self, detection_results: list[DetectionResult]) -> list[list[list[int]]]: - """Convert a list of DetectionResults to the bbox format expected by the Segment Anything model.""" - boxes = [result.box.to_box() for result in detection_results] - return [boxes] + return masks[0] From fca119773b9ea2552debc1e41f50c1ba6ed5c0d9 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 31 Jul 2024 12:28:47 -0400 Subject: [PATCH 18/28] Split invokeai/backend/image_util/segment_anything/ dir into grounding_dino/ and segment_anything/ --- invokeai/app/invocations/grounding_dino.py | 4 ++-- invokeai/app/invocations/segment_anything_model.py | 4 ++-- .../image_util/{grounded_sam => grounding_dino}/__init__.py | 0 .../{grounded_sam => grounding_dino}/detection_result.py | 0 .../grounding_dino_pipeline.py | 2 +- invokeai/backend/image_util/segment_anything/__init__.py | 0 .../{grounded_sam => segment_anything}/mask_refinement.py | 0 .../segment_anything_model.py | 0 invokeai/backend/model_manager/load/model_util.py | 4 ++-- 9 files changed, 7 insertions(+), 7 deletions(-) rename invokeai/backend/image_util/{grounded_sam => grounding_dino}/__init__.py (100%) rename invokeai/backend/image_util/{grounded_sam => grounding_dino}/detection_result.py (100%) rename invokeai/backend/image_util/{grounded_sam => grounding_dino}/grounding_dino_pipeline.py (94%) create mode 100644 invokeai/backend/image_util/segment_anything/__init__.py rename invokeai/backend/image_util/{grounded_sam => segment_anything}/mask_refinement.py (100%) rename invokeai/backend/image_util/{grounded_sam => segment_anything}/segment_anything_model.py (100%) diff --git a/invokeai/app/invocations/grounding_dino.py b/invokeai/app/invocations/grounding_dino.py index ade5b1456e..779c946d40 100644 --- a/invokeai/app/invocations/grounding_dino.py +++ b/invokeai/app/invocations/grounding_dino.py @@ -9,8 +9,8 @@ from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation from invokeai.app.invocations.fields import BoundingBoxField, ImageField, InputField from invokeai.app.invocations.primitives import BoundingBoxCollectionOutput from invokeai.app.services.shared.invocation_context import InvocationContext -from invokeai.backend.image_util.grounded_sam.detection_result import DetectionResult -from invokeai.backend.image_util.grounded_sam.grounding_dino_pipeline import GroundingDinoPipeline +from invokeai.backend.image_util.grounding_dino.detection_result import DetectionResult +from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline GROUNDING_DINO_MODEL_ID = "IDEA-Research/grounding-dino-tiny" diff --git a/invokeai/app/invocations/segment_anything_model.py b/invokeai/app/invocations/segment_anything_model.py index 39e2ee6d7f..ad264c9584 100644 --- a/invokeai/app/invocations/segment_anything_model.py +++ b/invokeai/app/invocations/segment_anything_model.py @@ -13,8 +13,8 @@ from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation from invokeai.app.invocations.fields import BoundingBoxField, ImageField, InputField from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.services.shared.invocation_context import InvocationContext -from invokeai.backend.image_util.grounded_sam.mask_refinement import mask_to_polygon, polygon_to_mask -from invokeai.backend.image_util.grounded_sam.segment_anything_model import SegmentAnythingModel +from invokeai.backend.image_util.segment_anything.mask_refinement import mask_to_polygon, polygon_to_mask +from invokeai.backend.image_util.segment_anything.segment_anything_model import SegmentAnythingModel SEGMENT_ANYTHING_MODEL_ID = "facebook/sam-vit-base" diff --git a/invokeai/backend/image_util/grounded_sam/__init__.py b/invokeai/backend/image_util/grounding_dino/__init__.py similarity index 100% rename from invokeai/backend/image_util/grounded_sam/__init__.py rename to invokeai/backend/image_util/grounding_dino/__init__.py diff --git a/invokeai/backend/image_util/grounded_sam/detection_result.py b/invokeai/backend/image_util/grounding_dino/detection_result.py similarity index 100% rename from invokeai/backend/image_util/grounded_sam/detection_result.py rename to invokeai/backend/image_util/grounding_dino/detection_result.py diff --git a/invokeai/backend/image_util/grounded_sam/grounding_dino_pipeline.py b/invokeai/backend/image_util/grounding_dino/grounding_dino_pipeline.py similarity index 94% rename from invokeai/backend/image_util/grounded_sam/grounding_dino_pipeline.py rename to invokeai/backend/image_util/grounding_dino/grounding_dino_pipeline.py index d998eb1d57..68c2b8decc 100644 --- a/invokeai/backend/image_util/grounded_sam/grounding_dino_pipeline.py +++ b/invokeai/backend/image_util/grounding_dino/grounding_dino_pipeline.py @@ -4,7 +4,7 @@ import torch from PIL import Image from transformers.pipelines import ZeroShotObjectDetectionPipeline -from invokeai.backend.image_util.grounded_sam.detection_result import DetectionResult +from invokeai.backend.image_util.grounding_dino.detection_result import DetectionResult from invokeai.backend.raw_model import RawModel diff --git a/invokeai/backend/image_util/segment_anything/__init__.py b/invokeai/backend/image_util/segment_anything/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/invokeai/backend/image_util/grounded_sam/mask_refinement.py b/invokeai/backend/image_util/segment_anything/mask_refinement.py similarity index 100% rename from invokeai/backend/image_util/grounded_sam/mask_refinement.py rename to invokeai/backend/image_util/segment_anything/mask_refinement.py diff --git a/invokeai/backend/image_util/grounded_sam/segment_anything_model.py b/invokeai/backend/image_util/segment_anything/segment_anything_model.py similarity index 100% rename from invokeai/backend/image_util/grounded_sam/segment_anything_model.py rename to invokeai/backend/image_util/segment_anything/segment_anything_model.py diff --git a/invokeai/backend/model_manager/load/model_util.py b/invokeai/backend/model_manager/load/model_util.py index 351331176c..9c2570978a 100644 --- a/invokeai/backend/model_manager/load/model_util.py +++ b/invokeai/backend/model_manager/load/model_util.py @@ -11,8 +11,8 @@ from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.schedulers.scheduling_utils import SchedulerMixin from transformers import CLIPTokenizer -from invokeai.backend.image_util.grounded_sam.grounding_dino_pipeline import GroundingDinoPipeline -from invokeai.backend.image_util.grounded_sam.segment_anything_model import SegmentAnythingModel +from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline +from invokeai.backend.image_util.segment_anything.segment_anything_model import SegmentAnythingModel from invokeai.backend.ip_adapter.ip_adapter import IPAdapter from invokeai.backend.lora import LoRAModelRaw from invokeai.backend.model_manager.config import AnyModel From b5832768dc84702b95bcefaa247b6382750b22f2 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 31 Jul 2024 17:15:48 -0400 Subject: [PATCH 19/28] Return a MaskOutput from SegmentAnythingModelInvocation. And add a MaskTensorToImageInvocation. --- invokeai/app/invocations/mask.py | 30 +++++++++- .../app/invocations/segment_anything_model.py | 59 ++++++++++--------- 2 files changed, 59 insertions(+), 30 deletions(-) diff --git a/invokeai/app/invocations/mask.py b/invokeai/app/invocations/mask.py index 6f54660847..2ebefeacff 100644 --- a/invokeai/app/invocations/mask.py +++ b/invokeai/app/invocations/mask.py @@ -1,9 +1,10 @@ import numpy as np import torch +from PIL import Image from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, InvocationContext, invocation -from invokeai.app.invocations.fields import ImageField, InputField, TensorField, WithMetadata -from invokeai.app.invocations.primitives import MaskOutput +from invokeai.app.invocations.fields import ImageField, InputField, TensorField, WithBoard, WithMetadata +from invokeai.app.invocations.primitives import ImageOutput, MaskOutput @invocation( @@ -118,3 +119,28 @@ class ImageMaskToTensorInvocation(BaseInvocation, WithMetadata): height=mask.shape[1], width=mask.shape[2], ) + + +@invocation( + "tensor_mask_to_image", + title="Tensor Mask to Image", + tags=["mask"], + category="mask", + version="1.0.0", +) +class MaskTensorToImageInvocation(BaseInvocation, WithMetadata, WithBoard): + """Convert a mask tensor to an image.""" + + mask: TensorField = InputField(description="The mask tensor to convert.") + + def invoke(self, context: InvocationContext) -> ImageOutput: + mask = context.tensors.load(self.mask.tensor_name) + # Ensure that the mask is binary. + if mask.dtype != torch.bool: + mask = mask > 0.5 + mask_np = mask.float().cpu().detach().numpy() * 255 + mask_np = mask_np.astype(np.uint8) + + mask_pil = Image.fromarray(mask_np, mode="L") + image_dto = context.images.save(image=mask_pil) + return ImageOutput.build(image_dto) diff --git a/invokeai/app/invocations/segment_anything_model.py b/invokeai/app/invocations/segment_anything_model.py index ad264c9584..a652e68338 100644 --- a/invokeai/app/invocations/segment_anything_model.py +++ b/invokeai/app/invocations/segment_anything_model.py @@ -2,7 +2,6 @@ from pathlib import Path from typing import Literal import numpy as np -import numpy.typing as npt import torch from PIL import Image from transformers import AutoModelForMaskGeneration, AutoProcessor @@ -10,8 +9,8 @@ from transformers.models.sam import SamModel from transformers.models.sam.processing_sam import SamProcessor from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation -from invokeai.app.invocations.fields import BoundingBoxField, ImageField, InputField -from invokeai.app.invocations.primitives import ImageOutput +from invokeai.app.invocations.fields import BoundingBoxField, ImageField, InputField, TensorField +from invokeai.app.invocations.primitives import MaskOutput from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.image_util.segment_anything.mask_refinement import mask_to_polygon, polygon_to_mask from invokeai.backend.image_util.segment_anything.segment_anything_model import SegmentAnythingModel @@ -46,24 +45,22 @@ class SegmentAnythingModelInvocation(BaseInvocation): ) @torch.no_grad() - def invoke(self, context: InvocationContext) -> ImageOutput: + def invoke(self, context: InvocationContext) -> MaskOutput: # The models expect a 3-channel RGB image. image_pil = context.images.get_pil(self.image.image_name, mode="RGB") if len(self.bounding_boxes) == 0: - combined_mask = np.zeros(image_pil.size[::-1], dtype=np.uint8) + combined_mask = torch.zeros(image_pil.size[::-1], dtype=torch.bool) else: masks = self._segment(context=context, image=image_pil) masks = self._filter_masks(masks=masks, bounding_boxes=self.bounding_boxes) - # masks contains binary values of 0 or 1, so we merge them via max-reduce. - combined_mask = np.maximum.reduce(masks) - # Map [0, 1] to [0, 255]. - mask_np = combined_mask * 255 - mask_pil = Image.fromarray(mask_np) + # masks contains bool values, so we merge them via max-reduce. + combined_mask, _ = torch.stack(masks).max(dim=0) - image_dto = context.images.save(image=mask_pil) - return ImageOutput.build(image_dto) + mask_tensor_name = context.tensors.save(combined_mask) + height, width = combined_mask.shape + return MaskOutput(mask=TensorField(tensor_name=mask_tensor_name), width=width, height=height) @staticmethod def _load_sam_model(model_path: Path): @@ -84,7 +81,7 @@ class SegmentAnythingModelInvocation(BaseInvocation): self, context: InvocationContext, image: Image.Image, - ) -> list[npt.NDArray[np.uint8]]: + ) -> list[torch.Tensor]: """Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes.""" # Convert the bounding boxes to the SAM input format. sam_bounding_boxes = [[bb.x_min, bb.y_min, bb.x_max, bb.y_max] for bb in self.bounding_boxes] @@ -97,22 +94,23 @@ class SegmentAnythingModelInvocation(BaseInvocation): assert isinstance(sam_pipeline, SegmentAnythingModel) masks = sam_pipeline.segment(image=image, bounding_boxes=sam_bounding_boxes) - masks = self._to_numpy_masks(masks) + masks = self._process_masks(masks) if self.apply_polygon_refinement: masks = self._apply_polygon_refinement(masks) return masks - def _to_numpy_masks(self, masks: torch.Tensor) -> list[npt.NDArray[np.uint8]]: - """Convert the tensor output from the Segment Anything model to a list of numpy masks.""" - eps = 0.0001 + def _process_masks(self, masks: torch.Tensor) -> list[torch.Tensor]: + """Convert the tensor output from the Segment Anything model from a tensor of shape + [num_masks, channels, height, width] to a list of tensors of shape [height, width]. + """ + assert masks.dtype == torch.bool # [num_masks, channels, height, width] -> [num_masks, height, width] - masks = masks.permute(0, 2, 3, 1).float().mean(dim=-1) - masks = masks > eps - np_masks = masks.cpu().numpy().astype(np.uint8) - return list(np_masks) + masks, _ = masks.max(dim=1) + # Split the first dimension into a list of masks. + return list(masks.cpu().unbind(dim=0)) - def _apply_polygon_refinement(self, masks: list[npt.NDArray[np.uint8]]) -> list[npt.NDArray[np.uint8]]: + def _apply_polygon_refinement(self, masks: list[torch.Tensor]) -> list[torch.Tensor]: """Apply polygon refinement to the masks. Convert each mask to a polygon, then back to a mask. This has the following effect: @@ -121,18 +119,23 @@ class SegmentAnythingModelInvocation(BaseInvocation): - Removes small mask pieces. - Removes holes from the mask. """ - for idx, mask in enumerate(masks): + # Convert tensor masks to np masks. + np_masks = [mask.cpu().numpy().astype(np.uint8) for mask in masks] + + # Apply polygon refinement. + for idx, mask in enumerate(np_masks): shape = mask.shape assert len(shape) == 2 # Assert length to satisfy type checker. polygon = mask_to_polygon(mask) mask = polygon_to_mask(polygon, shape) - masks[idx] = mask + np_masks[idx] = mask + + # Convert np masks back to tensor masks. + masks = [torch.tensor(mask, dtype=torch.bool) for mask in np_masks] return masks - def _filter_masks( - self, masks: list[npt.NDArray[np.uint8]], bounding_boxes: list[BoundingBoxField] - ) -> list[npt.NDArray[np.uint8]]: + def _filter_masks(self, masks: list[torch.Tensor], bounding_boxes: list[BoundingBoxField]) -> list[torch.Tensor]: """Filter the detected masks based on the specified mask filter.""" assert len(masks) == len(bounding_boxes) @@ -140,7 +143,7 @@ class SegmentAnythingModelInvocation(BaseInvocation): return masks elif self.mask_filter == "largest": # Find the largest mask. - return [max(masks, key=lambda x: x.sum())] + return [max(masks, key=lambda x: float(x.sum()))] elif self.mask_filter == "highest_box_score": # Find the index of the bounding box with the highest score. # Note that we fallback to -1.0 if the score is None. This is mainly to satisfy the type checker. In most From f5cfdcf32d3db1da0ef1bae92f852cfd8802b283 Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Thu, 1 Aug 2024 04:09:08 +0530 Subject: [PATCH 20/28] feat: Add BoundingBox Primitive Node --- invokeai/app/invocations/primitives.py | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index 4e1e00bfd8..c6ffc8c2b9 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -485,9 +485,27 @@ class BoundingBoxOutput(BaseInvocationOutput): class BoundingBoxCollectionOutput(BaseInvocationOutput): """Base class for nodes that output a collection of bounding boxes""" - collection: list[BoundingBoxField] = OutputField( - description="The output bounding boxes.", - ) + collection: list[BoundingBoxField] = OutputField(description="The output bounding boxes.", title="Bounding Boxes") + + +@invocation( + "bounding_box", + title="Bounding Box", + tags=["primitives", "segmentation", "collection", "bounding box"], + category="primitives", + version="1.0.0", +) +class BoundingBoxInvocation(BaseInvocation): + """Create a bounding box manually by supplying box coordinates""" + + x_min: int = InputField(default=0, description="x-coordinate of the bounding box's top left vertex", title="X1") + y_min: int = InputField(default=0, description="y-coordinate of the bounding box's top left vertex", title="Y1") + x_max: int = InputField(default=0, description="x-coordinate of the bounding box's bottom right vertex", title="X2") + y_max: int = InputField(default=0, description="y-coordinate of the bounding box's bottom right vertex", title="Y2") + + def invoke(self, context: InvocationContext) -> BoundingBoxOutput: + bounding_box = BoundingBoxField(x_min=self.x_min, y_min=self.y_min, x_max=self.x_max, y_max=self.y_max) + return BoundingBoxOutput(bounding_box=bounding_box) # endregion From 63581ec980da0e96b89dfad884c78b3597ed5388 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 1 Aug 2024 09:51:53 -0400 Subject: [PATCH 21/28] (minor) Add None check to fix static type checking error. --- .../backend/image_util/grounding_dino/grounding_dino_pipeline.py | 1 + 1 file changed, 1 insertion(+) diff --git a/invokeai/backend/image_util/grounding_dino/grounding_dino_pipeline.py b/invokeai/backend/image_util/grounding_dino/grounding_dino_pipeline.py index 68c2b8decc..772e8c0dd8 100644 --- a/invokeai/backend/image_util/grounding_dino/grounding_dino_pipeline.py +++ b/invokeai/backend/image_util/grounding_dino/grounding_dino_pipeline.py @@ -18,6 +18,7 @@ class GroundingDinoPipeline(RawModel): def detect(self, image: Image.Image, candidate_labels: list[str], threshold: float = 0.1) -> list[DetectionResult]: results = self._pipeline(image=image, candidate_labels=candidate_labels, threshold=threshold) + assert results is not None results = [DetectionResult.model_validate(result) for result in results] return results From b9dc3460ba502810ab283c9a548f923636979d68 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 1 Aug 2024 09:57:47 -0400 Subject: [PATCH 22/28] Rename SegmentAnythingModel -> SegmentAnythingPipeline. --- invokeai/app/invocations/segment_anything_model.py | 6 +++--- ...gment_anything_model.py => segment_anything_pipeline.py} | 2 +- invokeai/backend/model_manager/load/model_util.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) rename invokeai/backend/image_util/segment_anything/{segment_anything_model.py => segment_anything_pipeline.py} (98%) diff --git a/invokeai/app/invocations/segment_anything_model.py b/invokeai/app/invocations/segment_anything_model.py index a652e68338..13fc71f231 100644 --- a/invokeai/app/invocations/segment_anything_model.py +++ b/invokeai/app/invocations/segment_anything_model.py @@ -13,7 +13,7 @@ from invokeai.app.invocations.fields import BoundingBoxField, ImageField, InputF from invokeai.app.invocations.primitives import MaskOutput from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.image_util.segment_anything.mask_refinement import mask_to_polygon, polygon_to_mask -from invokeai.backend.image_util.segment_anything.segment_anything_model import SegmentAnythingModel +from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline SEGMENT_ANYTHING_MODEL_ID = "facebook/sam-vit-base" @@ -75,7 +75,7 @@ class SegmentAnythingModelInvocation(BaseInvocation): sam_processor = AutoProcessor.from_pretrained(model_path, local_files_only=True) assert isinstance(sam_processor, SamProcessor) - return SegmentAnythingModel(sam_model=sam_model, sam_processor=sam_processor) + return SegmentAnythingPipeline(sam_model=sam_model, sam_processor=sam_processor) def _segment( self, @@ -91,7 +91,7 @@ class SegmentAnythingModelInvocation(BaseInvocation): source=SEGMENT_ANYTHING_MODEL_ID, loader=SegmentAnythingModelInvocation._load_sam_model ) as sam_pipeline, ): - assert isinstance(sam_pipeline, SegmentAnythingModel) + assert isinstance(sam_pipeline, SegmentAnythingPipeline) masks = sam_pipeline.segment(image=image, bounding_boxes=sam_bounding_boxes) masks = self._process_masks(masks) diff --git a/invokeai/backend/image_util/segment_anything/segment_anything_model.py b/invokeai/backend/image_util/segment_anything/segment_anything_pipeline.py similarity index 98% rename from invokeai/backend/image_util/segment_anything/segment_anything_model.py rename to invokeai/backend/image_util/segment_anything/segment_anything_pipeline.py index d358db5c00..0818fb2fbd 100644 --- a/invokeai/backend/image_util/segment_anything/segment_anything_model.py +++ b/invokeai/backend/image_util/segment_anything/segment_anything_pipeline.py @@ -8,7 +8,7 @@ from transformers.models.sam.processing_sam import SamProcessor from invokeai.backend.raw_model import RawModel -class SegmentAnythingModel(RawModel): +class SegmentAnythingPipeline(RawModel): """A wrapper class for the transformers SAM model and processor that makes it compatible with the model manager.""" def __init__(self, sam_model: SamModel, sam_processor: SamProcessor): diff --git a/invokeai/backend/model_manager/load/model_util.py b/invokeai/backend/model_manager/load/model_util.py index 9c2570978a..2656a8b6ef 100644 --- a/invokeai/backend/model_manager/load/model_util.py +++ b/invokeai/backend/model_manager/load/model_util.py @@ -12,7 +12,7 @@ from diffusers.schedulers.scheduling_utils import SchedulerMixin from transformers import CLIPTokenizer from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline -from invokeai.backend.image_util.segment_anything.segment_anything_model import SegmentAnythingModel +from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline from invokeai.backend.ip_adapter.ip_adapter import IPAdapter from invokeai.backend.lora import LoRAModelRaw from invokeai.backend.model_manager.config import AnyModel @@ -44,7 +44,7 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int: LoRAModelRaw, SpandrelImageToImageModel, GroundingDinoPipeline, - SegmentAnythingModel, + SegmentAnythingPipeline, ), ): return model.calc_size() From c3a6a6fb22127e4add43e09691ba942032d1d922 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 1 Aug 2024 10:00:36 -0400 Subject: [PATCH 23/28] Rename SegmentAnythingModelInvocation -> SegmentAnythingInvocation. --- .../{segment_anything_model.py => segment_anything.py} | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) rename invokeai/app/invocations/{segment_anything_model.py => segment_anything.py} (97%) diff --git a/invokeai/app/invocations/segment_anything_model.py b/invokeai/app/invocations/segment_anything.py similarity index 97% rename from invokeai/app/invocations/segment_anything_model.py rename to invokeai/app/invocations/segment_anything.py index 13fc71f231..5b766f405a 100644 --- a/invokeai/app/invocations/segment_anything_model.py +++ b/invokeai/app/invocations/segment_anything.py @@ -19,13 +19,13 @@ SEGMENT_ANYTHING_MODEL_ID = "facebook/sam-vit-base" @invocation( - "segment_anything_model", - title="Segment Anything Model", + "segment_anything", + title="Segment Anything", tags=["prompt", "segmentation"], category="segmentation", version="1.0.0", ) -class SegmentAnythingModelInvocation(BaseInvocation): +class SegmentAnythingInvocation(BaseInvocation): """Runs a Segment Anything Model (https://arxiv.org/pdf/2304.02643). Reference: @@ -88,7 +88,7 @@ class SegmentAnythingModelInvocation(BaseInvocation): with ( context.models.load_remote_model( - source=SEGMENT_ANYTHING_MODEL_ID, loader=SegmentAnythingModelInvocation._load_sam_model + source=SEGMENT_ANYTHING_MODEL_ID, loader=SegmentAnythingInvocation._load_sam_model ) as sam_pipeline, ): assert isinstance(sam_pipeline, SegmentAnythingPipeline) From e6a512aa867d88558df1dde1162c9f1fe60a49c2 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 1 Aug 2024 10:12:24 -0400 Subject: [PATCH 24/28] (minor) Tweak order of mask operations. --- invokeai/app/invocations/mask.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/invokeai/app/invocations/mask.py b/invokeai/app/invocations/mask.py index 2ebefeacff..64d1b48e38 100644 --- a/invokeai/app/invocations/mask.py +++ b/invokeai/app/invocations/mask.py @@ -138,8 +138,7 @@ class MaskTensorToImageInvocation(BaseInvocation, WithMetadata, WithBoard): # Ensure that the mask is binary. if mask.dtype != torch.bool: mask = mask > 0.5 - mask_np = mask.float().cpu().detach().numpy() * 255 - mask_np = mask_np.astype(np.uint8) + mask_np = (mask.float() * 255).byte().cpu().numpy() mask_pil = Image.fromarray(mask_np, mode="L") image_dto = context.images.save(image=mask_pil) From c6d49e8b1fc2a4e1d41c128e3a5123719424b581 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 1 Aug 2024 10:17:42 -0400 Subject: [PATCH 25/28] Shorten SegmentAnythingInvocation and GroundingDinoInvocatino docstrings, since they are used as the invocation descriptions in the UI. --- invokeai/app/invocations/grounding_dino.py | 11 +++++------ invokeai/app/invocations/segment_anything.py | 10 +++++----- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/invokeai/app/invocations/grounding_dino.py b/invokeai/app/invocations/grounding_dino.py index 779c946d40..a7d522e784 100644 --- a/invokeai/app/invocations/grounding_dino.py +++ b/invokeai/app/invocations/grounding_dino.py @@ -23,13 +23,12 @@ GROUNDING_DINO_MODEL_ID = "IDEA-Research/grounding-dino-tiny" version="1.0.0", ) class GroundingDinoInvocation(BaseInvocation): - """Runs a Grounding DINO model (https://arxiv.org/pdf/2303.05499). Performs zero-shot bounding-box object detection - from a text prompt. + """Runs a Grounding DINO model. Performs zero-shot bounding-box object detection from a text prompt.""" - Reference: - - https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam - - https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb - """ + # Reference: + # - https://arxiv.org/pdf/2303.05499 + # - https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam + # - https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb prompt: str = InputField(description="The prompt describing the object to segment.") image: ImageField = InputField(description="The image to segment.") diff --git a/invokeai/app/invocations/segment_anything.py b/invokeai/app/invocations/segment_anything.py index 5b766f405a..6240dd6946 100644 --- a/invokeai/app/invocations/segment_anything.py +++ b/invokeai/app/invocations/segment_anything.py @@ -26,12 +26,12 @@ SEGMENT_ANYTHING_MODEL_ID = "facebook/sam-vit-base" version="1.0.0", ) class SegmentAnythingInvocation(BaseInvocation): - """Runs a Segment Anything Model (https://arxiv.org/pdf/2304.02643). + """Runs a Segment Anything Model.""" - Reference: - - https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam - - https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb - """ + # Reference: + # - https://arxiv.org/pdf/2304.02643 + # - https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam + # - https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb image: ImageField = InputField(description="The image to segment.") bounding_boxes: list[BoundingBoxField] = InputField(description="The bounding boxes to prompt the SAM model with.") From 44b21f10f153b176aa8af62bc3ccc33a1ec4c8bb Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 1 Aug 2024 14:00:57 -0400 Subject: [PATCH 26/28] Add a pydantic model_validator to BoundingBoxField to check the validity of the coords. --- invokeai/app/invocations/fields.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/invokeai/app/invocations/fields.py b/invokeai/app/invocations/fields.py index fe97123fef..9efcf2148f 100644 --- a/invokeai/app/invocations/fields.py +++ b/invokeai/app/invocations/fields.py @@ -1,7 +1,7 @@ from enum import Enum from typing import Any, Callable, Optional, Tuple -from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter +from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter, model_validator from pydantic.fields import _Unset from pydantic_core import PydanticUndefined @@ -258,6 +258,14 @@ class BoundingBoxField(BaseModel): "when the bounding box was produced by a detector and has an associated confidence score.", ) + @model_validator(mode="after") + def check_coords(self): + if self.x_min > self.x_max: + raise ValueError(f"x_min ({self.x_min}) is greater than x_max ({self.x_max}).") + if self.y_min > self.y_max: + raise ValueError(f"y_min ({self.y_min}) is greater than y_max ({self.y_max}).") + return self + class MetadataField(RootModel[dict[str, Any]]): """ From 675ffc2757ed5f3f0c615d8d898136a11c5bd894 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 1 Aug 2024 14:05:44 -0400 Subject: [PATCH 27/28] Remove BoundingBoxInvocation field name overrides. --- invokeai/app/invocations/primitives.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index c6ffc8c2b9..3655554f3b 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -498,10 +498,10 @@ class BoundingBoxCollectionOutput(BaseInvocationOutput): class BoundingBoxInvocation(BaseInvocation): """Create a bounding box manually by supplying box coordinates""" - x_min: int = InputField(default=0, description="x-coordinate of the bounding box's top left vertex", title="X1") - y_min: int = InputField(default=0, description="y-coordinate of the bounding box's top left vertex", title="Y1") - x_max: int = InputField(default=0, description="x-coordinate of the bounding box's bottom right vertex", title="X2") - y_max: int = InputField(default=0, description="y-coordinate of the bounding box's bottom right vertex", title="Y2") + x_min: int = InputField(default=0, description="x-coordinate of the bounding box's top left vertex") + y_min: int = InputField(default=0, description="y-coordinate of the bounding box's top left vertex") + x_max: int = InputField(default=0, description="x-coordinate of the bounding box's bottom right vertex") + y_max: int = InputField(default=0, description="y-coordinate of the bounding box's bottom right vertex") def invoke(self, context: InvocationContext) -> BoundingBoxOutput: bounding_box = BoundingBoxField(x_min=self.x_min, y_min=self.y_min, x_max=self.x_max, y_max=self.y_max) From 27ac61a4fb8ea7fb5a6f3482473ddf080fbb363f Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 1 Aug 2024 14:23:32 -0400 Subject: [PATCH 28/28] Expose all model options in the GroundingDinoInvocation and the SegmentAnythingInvocation. --- invokeai/app/invocations/grounding_dino.py | 10 ++++++++-- invokeai/app/invocations/segment_anything.py | 10 ++++++++-- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/invokeai/app/invocations/grounding_dino.py b/invokeai/app/invocations/grounding_dino.py index a7d522e784..1e3d5cea0c 100644 --- a/invokeai/app/invocations/grounding_dino.py +++ b/invokeai/app/invocations/grounding_dino.py @@ -1,4 +1,5 @@ from pathlib import Path +from typing import Literal import torch from PIL import Image @@ -12,7 +13,11 @@ from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.image_util.grounding_dino.detection_result import DetectionResult from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline -GROUNDING_DINO_MODEL_ID = "IDEA-Research/grounding-dino-tiny" +GroundingDinoModelKey = Literal["grounding-dino-tiny", "grounding-dino-base"] +GROUNDING_DINO_MODEL_IDS: dict[GroundingDinoModelKey, str] = { + "grounding-dino-tiny": "IDEA-Research/grounding-dino-tiny", + "grounding-dino-base": "IDEA-Research/grounding-dino-base", +} @invocation( @@ -30,6 +35,7 @@ class GroundingDinoInvocation(BaseInvocation): # - https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam # - https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb + model: GroundingDinoModelKey = InputField(description="The Grounding DINO model to use.") prompt: str = InputField(description="The prompt describing the object to segment.") image: ImageField = InputField(description="The image to segment.") detection_threshold: float = InputField( @@ -88,7 +94,7 @@ class GroundingDinoInvocation(BaseInvocation): labels = [label if label.endswith(".") else label + "." for label in labels] with context.models.load_remote_model( - source=GROUNDING_DINO_MODEL_ID, loader=GroundingDinoInvocation._load_grounding_dino + source=GROUNDING_DINO_MODEL_IDS[self.model], loader=GroundingDinoInvocation._load_grounding_dino ) as detector: assert isinstance(detector, GroundingDinoPipeline) return detector.detect(image=image, candidate_labels=labels, threshold=threshold) diff --git a/invokeai/app/invocations/segment_anything.py b/invokeai/app/invocations/segment_anything.py index 6240dd6946..b49b1a39e3 100644 --- a/invokeai/app/invocations/segment_anything.py +++ b/invokeai/app/invocations/segment_anything.py @@ -15,7 +15,12 @@ from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.image_util.segment_anything.mask_refinement import mask_to_polygon, polygon_to_mask from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline -SEGMENT_ANYTHING_MODEL_ID = "facebook/sam-vit-base" +SegmentAnythingModelKey = Literal["segment-anything-base", "segment-anything-large", "segment-anything-huge"] +SEGMENT_ANYTHING_MODEL_IDS: dict[SegmentAnythingModelKey, str] = { + "segment-anything-base": "facebook/sam-vit-base", + "segment-anything-large": "facebook/sam-vit-large", + "segment-anything-huge": "facebook/sam-vit-huge", +} @invocation( @@ -33,6 +38,7 @@ class SegmentAnythingInvocation(BaseInvocation): # - https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam # - https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb + model: SegmentAnythingModelKey = InputField(description="The Segment Anything model to use.") image: ImageField = InputField(description="The image to segment.") bounding_boxes: list[BoundingBoxField] = InputField(description="The bounding boxes to prompt the SAM model with.") apply_polygon_refinement: bool = InputField( @@ -88,7 +94,7 @@ class SegmentAnythingInvocation(BaseInvocation): with ( context.models.load_remote_model( - source=SEGMENT_ANYTHING_MODEL_ID, loader=SegmentAnythingInvocation._load_sam_model + source=SEGMENT_ANYTHING_MODEL_IDS[self.model], loader=SegmentAnythingInvocation._load_sam_model ) as sam_pipeline, ): assert isinstance(sam_pipeline, SegmentAnythingPipeline)