diff --git a/invokeai/app/invocations/fields.py b/invokeai/app/invocations/fields.py index f9a483f84c..fe97123fef 100644 --- a/invokeai/app/invocations/fields.py +++ b/invokeai/app/invocations/fields.py @@ -242,6 +242,23 @@ class ConditioningField(BaseModel): ) +class BoundingBoxField(BaseModel): + """A bounding box primitive value.""" + + x_min: int = Field(ge=0, description="The minimum x-coordinate of the bounding box (inclusive).") + x_max: int = Field(ge=0, description="The maximum x-coordinate of the bounding box (exclusive).") + y_min: int = Field(ge=0, description="The minimum y-coordinate of the bounding box (inclusive).") + y_max: int = Field(ge=0, description="The maximum y-coordinate of the bounding box (exclusive).") + + score: Optional[float] = Field( + default=None, + ge=0.0, + le=1.0, + description="The score associated with the bounding box. In the range [0, 1]. This value is typically set " + "when the bounding box was produced by a detector and has an associated confidence score.", + ) + + class MetadataField(RootModel[dict[str, Any]]): """ Pydantic model for metadata with custom root of type dict[str, Any]. diff --git a/invokeai/app/invocations/grounding_dino.py b/invokeai/app/invocations/grounding_dino.py new file mode 100644 index 0000000000..ade5b1456e --- /dev/null +++ b/invokeai/app/invocations/grounding_dino.py @@ -0,0 +1,95 @@ +from pathlib import Path + +import torch +from PIL import Image +from transformers import pipeline +from transformers.pipelines import ZeroShotObjectDetectionPipeline + +from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation +from invokeai.app.invocations.fields import BoundingBoxField, ImageField, InputField +from invokeai.app.invocations.primitives import BoundingBoxCollectionOutput +from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.image_util.grounded_sam.detection_result import DetectionResult +from invokeai.backend.image_util.grounded_sam.grounding_dino_pipeline import GroundingDinoPipeline + +GROUNDING_DINO_MODEL_ID = "IDEA-Research/grounding-dino-tiny" + + +@invocation( + "grounding_dino", + title="Grounding DINO (Text Prompt Object Detection)", + tags=["prompt", "object detection"], + category="image", + version="1.0.0", +) +class GroundingDinoInvocation(BaseInvocation): + """Runs a Grounding DINO model (https://arxiv.org/pdf/2303.05499). Performs zero-shot bounding-box object detection + from a text prompt. + + Reference: + - https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam + - https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb + """ + + prompt: str = InputField(description="The prompt describing the object to segment.") + image: ImageField = InputField(description="The image to segment.") + detection_threshold: float = InputField( + description="The detection threshold for the Grounding DINO model. All detected bounding boxes with scores above this threshold will be returned.", + ge=0.0, + le=1.0, + default=0.3, + ) + + @torch.no_grad() + def invoke(self, context: InvocationContext) -> BoundingBoxCollectionOutput: + # The model expects a 3-channel RGB image. + image_pil = context.images.get_pil(self.image.image_name, mode="RGB") + + detections = self._detect( + context=context, image=image_pil, labels=[self.prompt], threshold=self.detection_threshold + ) + + # Convert detections to BoundingBoxCollectionOutput. + bounding_boxes: list[BoundingBoxField] = [] + for detection in detections: + bounding_boxes.append( + BoundingBoxField( + x_min=detection.box.xmin, + x_max=detection.box.xmax, + y_min=detection.box.ymin, + y_max=detection.box.ymax, + score=detection.score, + ) + ) + return BoundingBoxCollectionOutput(collection=bounding_boxes) + + @staticmethod + def _load_grounding_dino(model_path: Path): + grounding_dino_pipeline = pipeline( + model=str(model_path), + task="zero-shot-object-detection", + local_files_only=True, + # TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the + # model, and figure out how to make it work in the pipeline. + # torch_dtype=TorchDevice.choose_torch_dtype(), + ) + assert isinstance(grounding_dino_pipeline, ZeroShotObjectDetectionPipeline) + return GroundingDinoPipeline(grounding_dino_pipeline) + + def _detect( + self, + context: InvocationContext, + image: Image.Image, + labels: list[str], + threshold: float = 0.3, + ) -> list[DetectionResult]: + """Use Grounding DINO to detect bounding boxes for a set of labels in an image.""" + # TODO(ryand): I copied this "."-handling logic from the transformers example code. Test it and see if it + # actually makes a difference. + labels = [label if label.endswith(".") else label + "." for label in labels] + + with context.models.load_remote_model( + source=GROUNDING_DINO_MODEL_ID, loader=GroundingDinoInvocation._load_grounding_dino + ) as detector: + assert isinstance(detector, GroundingDinoPipeline) + return detector.detect(image=image, candidate_labels=labels, threshold=threshold) diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index e5056e3775..4e1e00bfd8 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -7,6 +7,7 @@ import torch from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR from invokeai.app.invocations.fields import ( + BoundingBoxField, ColorField, ConditioningField, DenoiseMaskField, @@ -469,3 +470,24 @@ class ConditioningCollectionInvocation(BaseInvocation): # endregion + +# region BoundingBox + + +@invocation_output("bounding_box_output") +class BoundingBoxOutput(BaseInvocationOutput): + """Base class for nodes that output a single bounding box""" + + bounding_box: BoundingBoxField = OutputField(description="The output bounding box.") + + +@invocation_output("bounding_box_collection_output") +class BoundingBoxCollectionOutput(BaseInvocationOutput): + """Base class for nodes that output a collection of bounding boxes""" + + collection: list[BoundingBoxField] = OutputField( + description="The output bounding boxes.", + ) + + +# endregion diff --git a/invokeai/app/invocations/grounded_sam.py b/invokeai/app/invocations/segment_anything_model.py similarity index 56% rename from invokeai/app/invocations/grounded_sam.py rename to invokeai/app/invocations/segment_anything_model.py index d741c3d6ad..39e2ee6d7f 100644 --- a/invokeai/app/invocations/grounded_sam.py +++ b/invokeai/app/invocations/segment_anything_model.py @@ -5,75 +5,56 @@ import numpy as np import numpy.typing as npt import torch from PIL import Image -from transformers import AutoModelForMaskGeneration, AutoProcessor, pipeline +from transformers import AutoModelForMaskGeneration, AutoProcessor from transformers.models.sam import SamModel from transformers.models.sam.processing_sam import SamProcessor -from transformers.pipelines import ZeroShotObjectDetectionPipeline from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation -from invokeai.app.invocations.fields import ImageField, InputField +from invokeai.app.invocations.fields import BoundingBoxField, ImageField, InputField from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.services.shared.invocation_context import InvocationContext -from invokeai.backend.image_util.grounded_sam.detection_result import DetectionResult -from invokeai.backend.image_util.grounded_sam.grounding_dino_pipeline import GroundingDinoPipeline from invokeai.backend.image_util.grounded_sam.mask_refinement import mask_to_polygon, polygon_to_mask from invokeai.backend.image_util.grounded_sam.segment_anything_model import SegmentAnythingModel -GROUNDING_DINO_MODEL_ID = "IDEA-Research/grounding-dino-tiny" SEGMENT_ANYTHING_MODEL_ID = "facebook/sam-vit-base" @invocation( - "grounded_segment_anything", - title="Segment Anything (Text Prompt)", + "segment_anything_model", + title="Segment Anything Model", tags=["prompt", "segmentation"], category="segmentation", version="1.0.0", ) -class GroundedSAMInvocation(BaseInvocation): - """Runs Grounded-SAM, as proposed in https://arxiv.org/pdf/2401.14159. - - More specifically, a Grounding DINO model is run to obtain bounding boxes for a text prompt, then the bounding boxes - are passed as a prompt to a Segment Anything model to obtain a segmentation mask. +class SegmentAnythingModelInvocation(BaseInvocation): + """Runs a Segment Anything Model (https://arxiv.org/pdf/2304.02643). Reference: - https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam - https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb """ - prompt: str = InputField(description="The prompt describing the object to segment.") image: ImageField = InputField(description="The image to segment.") + bounding_boxes: list[BoundingBoxField] = InputField(description="The bounding boxes to prompt the SAM model with.") apply_polygon_refinement: bool = InputField( - description="Whether to apply polygon refinement to the masks. This will smooth the edges of the mask slightly and ensure that each mask consists of a single closed polygon (before merging).", + description="Whether to apply polygon refinement to the masks. This will smooth the edges of the masks slightly and ensure that each mask consists of a single closed polygon (before merging).", default=True, ) mask_filter: Literal["all", "largest", "highest_box_score"] = InputField( description="The filtering to apply to the detected masks before merging them into a final output.", default="all", ) - detection_threshold: float = InputField( - description="The detection threshold for the Grounding DINO model. All detected bounding boxes with scores above this threshold will be used.", - ge=0.0, - le=1.0, - default=0.3, - ) @torch.no_grad() def invoke(self, context: InvocationContext) -> ImageOutput: # The models expect a 3-channel RGB image. image_pil = context.images.get_pil(self.image.image_name, mode="RGB") - detections = self._detect( - context=context, image=image_pil, labels=[self.prompt], threshold=self.detection_threshold - ) - - if len(detections) == 0: + if len(self.bounding_boxes) == 0: combined_mask = np.zeros(image_pil.size[::-1], dtype=np.uint8) else: - detections = self._segment(context=context, image=image_pil, detection_results=detections) - - detections = self._filter_detections(detections) - masks = [detection.mask for detection in detections] + masks = self._segment(context=context, image=image_pil) + masks = self._filter_masks(masks=masks, bounding_boxes=self.bounding_boxes) # masks contains binary values of 0 or 1, so we merge them via max-reduce. combined_mask = np.maximum.reduce(masks) @@ -84,19 +65,6 @@ class GroundedSAMInvocation(BaseInvocation): image_dto = context.images.save(image=mask_pil) return ImageOutput.build(image_dto) - @staticmethod - def _load_grounding_dino(model_path: Path): - grounding_dino_pipeline = pipeline( - model=str(model_path), - task="zero-shot-object-detection", - local_files_only=True, - # TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the - # model, and figure out how to make it work in the pipeline. - # torch_dtype=TorchDevice.choose_torch_dtype(), - ) - assert isinstance(grounding_dino_pipeline, ZeroShotObjectDetectionPipeline) - return GroundingDinoPipeline(grounding_dino_pipeline) - @staticmethod def _load_sam_model(model_path: Path): sam_model = AutoModelForMaskGeneration.from_pretrained( @@ -112,47 +80,28 @@ class GroundedSAMInvocation(BaseInvocation): assert isinstance(sam_processor, SamProcessor) return SegmentAnythingModel(sam_model=sam_model, sam_processor=sam_processor) - def _detect( - self, - context: InvocationContext, - image: Image.Image, - labels: list[str], - threshold: float = 0.3, - ) -> list[DetectionResult]: - """Use Grounding DINO to detect bounding boxes for a set of labels in an image.""" - # TODO(ryand): I copied this "."-handling logic from the transformers example code. Test it and see if it - # actually makes a difference. - labels = [label if label.endswith(".") else label + "." for label in labels] - - with context.models.load_remote_model( - source=GROUNDING_DINO_MODEL_ID, loader=GroundedSAMInvocation._load_grounding_dino - ) as detector: - assert isinstance(detector, GroundingDinoPipeline) - return detector.detect(image=image, candidate_labels=labels, threshold=threshold) - def _segment( self, context: InvocationContext, image: Image.Image, - detection_results: list[DetectionResult], - ) -> list[DetectionResult]: + ) -> list[npt.NDArray[np.uint8]]: """Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes.""" + # Convert the bounding boxes to the SAM input format. + sam_bounding_boxes = [[bb.x_min, bb.y_min, bb.x_max, bb.y_max] for bb in self.bounding_boxes] + with ( context.models.load_remote_model( - source=SEGMENT_ANYTHING_MODEL_ID, loader=GroundedSAMInvocation._load_sam_model + source=SEGMENT_ANYTHING_MODEL_ID, loader=SegmentAnythingModelInvocation._load_sam_model ) as sam_pipeline, ): assert isinstance(sam_pipeline, SegmentAnythingModel) - masks = sam_pipeline.segment(image=image, detection_results=detection_results) + masks = sam_pipeline.segment(image=image, bounding_boxes=sam_bounding_boxes) masks = self._to_numpy_masks(masks) if self.apply_polygon_refinement: masks = self._apply_polygon_refinement(masks) - for detection_result, mask in zip(detection_results, masks, strict=True): - detection_result.mask = mask - - return detection_results + return masks def _to_numpy_masks(self, masks: torch.Tensor) -> list[npt.NDArray[np.uint8]]: """Convert the tensor output from the Segment Anything model to a list of numpy masks.""" @@ -181,15 +130,23 @@ class GroundedSAMInvocation(BaseInvocation): return masks - def _filter_detections(self, detections: list[DetectionResult]) -> list[DetectionResult]: + def _filter_masks( + self, masks: list[npt.NDArray[np.uint8]], bounding_boxes: list[BoundingBoxField] + ) -> list[npt.NDArray[np.uint8]]: """Filter the detected masks based on the specified mask filter.""" + assert len(masks) == len(bounding_boxes) + if self.mask_filter == "all": - return detections + return masks elif self.mask_filter == "largest": # Find the largest mask. - return [max(detections, key=lambda x: x.mask.sum())] + return [max(masks, key=lambda x: x.sum())] elif self.mask_filter == "highest_box_score": - # Find the detection with the highest box score. - return [max(detections, key=lambda x: x.score)] + # Find the index of the bounding box with the highest score. + # Note that we fallback to -1.0 if the score is None. This is mainly to satisfy the type checker. In most + # cases the scores should all be non-None when using this filtering mode. That being said, -1.0 is a + # reasonable fallback since the expected score range is [0.0, 1.0]. + max_score_idx = max(range(len(bounding_boxes)), key=lambda i: bounding_boxes[i].score or -1.0) + return [masks[max_score_idx]] else: raise ValueError(f"Invalid mask filter: {self.mask_filter}") diff --git a/invokeai/backend/image_util/grounded_sam/detection_result.py b/invokeai/backend/image_util/grounded_sam/detection_result.py index a9f2bdd65f..2d0c78e681 100644 --- a/invokeai/backend/image_util/grounded_sam/detection_result.py +++ b/invokeai/backend/image_util/grounded_sam/detection_result.py @@ -1,6 +1,3 @@ -from typing import Any, Optional - -import numpy.typing as npt from pydantic import BaseModel, ConfigDict @@ -12,18 +9,13 @@ class BoundingBox(BaseModel): xmax: int ymax: int - def to_box(self) -> list[int]: - """Convert to the array notation expected by SAM.""" - return [self.xmin, self.ymin, self.xmax, self.ymax] - class DetectionResult(BaseModel): - """Detection result from Grounding DINO or Grounded SAM.""" + """Detection result from Grounding DINO.""" score: float label: str box: BoundingBox - mask: Optional[npt.NDArray[Any]] = None model_config = ConfigDict( # Allow arbitrary types for mask, since it will be a numpy array. arbitrary_types_allowed=True diff --git a/invokeai/backend/image_util/grounded_sam/segment_anything_model.py b/invokeai/backend/image_util/grounded_sam/segment_anything_model.py index 5147141303..d358db5c00 100644 --- a/invokeai/backend/image_util/grounded_sam/segment_anything_model.py +++ b/invokeai/backend/image_util/grounded_sam/segment_anything_model.py @@ -5,7 +5,6 @@ from PIL import Image from transformers.models.sam import SamModel from transformers.models.sam.processing_sam import SamProcessor -from invokeai.backend.image_util.grounded_sam.detection_result import DetectionResult from invokeai.backend.raw_model import RawModel @@ -28,8 +27,19 @@ class SegmentAnythingModel(RawModel): return calc_module_size(self._sam_model) - def segment(self, image: Image.Image, detection_results: list[DetectionResult]) -> torch.Tensor: - boxes = self._to_box_array(detection_results) + def segment(self, image: Image.Image, bounding_boxes: list[list[int]]) -> torch.Tensor: + """Run the SAM model. + + Args: + image (Image.Image): The image to segment. + bounding_boxes (list[list[int]]): The bounding box prompts. Each bounding box is in the format + [xmin, ymin, xmax, ymax]. + + Returns: + torch.Tensor: The segmentation masks. dtype: torch.bool. shape: [num_masks, channels, height, width]. + """ + # Add batch dimension of 1 to the bounding boxes. + boxes = [bounding_boxes] inputs = self._sam_processor(images=image, input_boxes=boxes, return_tensors="pt").to(self._sam_model.device) outputs = self._sam_model(**inputs) masks = self._sam_processor.post_process_masks( @@ -40,10 +50,4 @@ class SegmentAnythingModel(RawModel): # There should be only one batch. assert len(masks) == 1 - masks = masks[0] - return masks - - def _to_box_array(self, detection_results: list[DetectionResult]) -> list[list[list[int]]]: - """Convert a list of DetectionResults to the bbox format expected by the Segment Anything model.""" - boxes = [result.box.to_box() for result in detection_results] - return [boxes] + return masks[0]