diff --git a/invokeai/app/invocations/grounded_sam.py b/invokeai/app/invocations/grounded_sam.py index 211016fac4..8eb8770e47 100644 --- a/invokeai/app/invocations/grounded_sam.py +++ b/invokeai/app/invocations/grounded_sam.py @@ -1,6 +1,5 @@ -from dataclasses import dataclass from pathlib import Path -from typing import Any, Literal, Optional +from typing import Literal import numpy as np import numpy.typing as npt @@ -15,6 +14,7 @@ from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation from invokeai.app.invocations.fields import ImageField, InputField from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.grounded_sam.detection_result import DetectionResult from invokeai.backend.grounded_sam.grounding_dino_pipeline import GroundingDinoPipeline from invokeai.backend.grounded_sam.mask_refinement import mask_to_polygon, polygon_to_mask from invokeai.backend.grounded_sam.segment_anything_model import SegmentAnythingModel @@ -23,43 +23,6 @@ GROUNDING_DINO_MODEL_ID = "IDEA-Research/grounding-dino-tiny" SEGMENT_ANYTHING_MODEL_ID = "facebook/sam-vit-base" -@dataclass -class BoundingBox: - """Bounding box helper class used locally for the Grounding DINO outputs.""" - - xmin: int - ymin: int - xmax: int - ymax: int - - def to_box(self) -> list[int]: - """Convert to the array notation expected by SAM.""" - return [self.xmin, self.ymin, self.xmax, self.ymax] - - -@dataclass -class DetectionResult: - """Detection result from Grounding DINO or Grounded SAM.""" - - score: float - label: str - box: BoundingBox - mask: Optional[npt.NDArray[Any]] = None - - @classmethod - def from_dict(cls, detection_dict: dict[str, Any]): - return cls( - score=detection_dict["score"], - label=detection_dict["label"], - box=BoundingBox( - xmin=detection_dict["box"]["xmin"], - ymin=detection_dict["box"]["ymin"], - xmax=detection_dict["box"]["xmax"], - ymax=detection_dict["box"]["ymax"], - ), - ) - - @invocation( "grounded_segment_anything", title="Segment Anything (Text Prompt)", @@ -92,9 +55,10 @@ class GroundedSAMInvocation(BaseInvocation): description="The detection threshold for the Grounding DINO model. All detected bounding boxes with scores above this threshold will be used.", ge=0.0, le=1.0, - default=0.1, + default=0.3, ) + @torch.no_grad() def invoke(self, context: InvocationContext) -> ImageOutput: image_pil = context.images.get_pil(self.image.image_name) @@ -118,15 +82,6 @@ class GroundedSAMInvocation(BaseInvocation): image_dto = context.images.save(image=mask_pil) return ImageOutput.build(image_dto) - def _to_box_array(self, detection_results: list[DetectionResult]) -> list[list[list[int]]]: - """Convert a list of DetectionResults to the format expected by the Segment Anything model. - - Args: - detection_results (list[DetectionResult]): The Grounding DINO detection results. - """ - boxes = [result.box.to_box() for result in detection_results] - return [boxes] - def _detect( self, context: InvocationContext, @@ -153,10 +108,7 @@ class GroundedSAMInvocation(BaseInvocation): with context.models.load_remote_model(source=GROUNDING_DINO_MODEL_ID, loader=load_grounding_dino) as detector: assert isinstance(detector, GroundingDinoPipeline) - - results = detector.detect(image=image, candidate_labels=labels, threshold=threshold) - results = [DetectionResult.from_dict(result) for result in results] - return results + return detector.detect(image=image, candidate_labels=labels, threshold=threshold) def _segment( self, @@ -185,8 +137,7 @@ class GroundedSAMInvocation(BaseInvocation): ): assert isinstance(sam_pipeline, SegmentAnythingModel) - boxes = self._to_box_array(detection_results) - masks = sam_pipeline.segment(image=image, boxes=boxes) + masks = sam_pipeline.segment(image=image, detection_results=detection_results) masks = self._to_numpy_masks(masks) masks = self._apply_polygon_refinement(masks) @@ -200,7 +151,7 @@ class GroundedSAMInvocation(BaseInvocation): """Convert the tensor output from the Segment Anything model to a list of numpy masks.""" masks = masks.cpu().float() masks = masks.permute(0, 2, 3, 1) - masks = masks.mean(axis=-1) + masks = masks.mean(dim=-1) masks = (masks > 0).int() masks = masks.numpy().astype(np.uint8) masks = list(masks) diff --git a/invokeai/backend/grounded_sam/grounding_dino_pipeline.py b/invokeai/backend/grounded_sam/grounding_dino_pipeline.py index 0143e300bc..028aeb283d 100644 --- a/invokeai/backend/grounded_sam/grounding_dino_pipeline.py +++ b/invokeai/backend/grounded_sam/grounding_dino_pipeline.py @@ -4,6 +4,8 @@ import torch from PIL import Image from transformers.pipelines import ZeroShotObjectDetectionPipeline +from invokeai.backend.grounded_sam.detection_result import DetectionResult + class GroundingDinoPipeline: """A wrapper class for a ZeroShotObjectDetectionPipeline that makes it compatible with the model manager's memory @@ -13,8 +15,10 @@ class GroundingDinoPipeline: def __init__(self, pipeline: ZeroShotObjectDetectionPipeline): self._pipeline = pipeline - def detect(self, image: Image.Image, candidate_labels: list[str], threshold: float = 0.1): - return self._pipeline(image=image, candidate_labels=candidate_labels, threshold=threshold) + def detect(self, image: Image.Image, candidate_labels: list[str], threshold: float = 0.1) -> list[DetectionResult]: + results = self._pipeline(image=image, candidate_labels=candidate_labels, threshold=threshold) + results = [DetectionResult.from_dict(result) for result in results] + return results def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> "GroundingDinoPipeline": self._pipeline.model.to(device=device, dtype=dtype) diff --git a/invokeai/backend/grounded_sam/segment_anything_model.py b/invokeai/backend/grounded_sam/segment_anything_model.py index 1c5b1cf456..86106b869c 100644 --- a/invokeai/backend/grounded_sam/segment_anything_model.py +++ b/invokeai/backend/grounded_sam/segment_anything_model.py @@ -5,6 +5,8 @@ from PIL import Image from transformers.models.sam import SamModel from transformers.models.sam.processing_sam import SamProcessor +from invokeai.backend.grounded_sam.detection_result import DetectionResult + class SegmentAnythingModel: """A wrapper class for the transformers SAM model and processor that makes it compatible with the model manager.""" @@ -23,7 +25,8 @@ class SegmentAnythingModel: return calc_module_size(self._sam_model) - def segment(self, image: Image.Image, boxes: list[list[list[int]]]) -> torch.Tensor: + def segment(self, image: Image.Image, detection_results: list[DetectionResult]) -> torch.Tensor: + boxes = self._to_box_array(detection_results) inputs = self._sam_processor(images=image, input_boxes=boxes, return_tensors="pt").to(self._sam_model.device) outputs = self._sam_model(**inputs) masks = self._sam_processor.post_process_masks( @@ -36,3 +39,8 @@ class SegmentAnythingModel: assert len(masks) == 1 masks = masks[0] return masks + + def _to_box_array(self, detection_results: list[DetectionResult]) -> list[list[list[int]]]: + """Convert a list of DetectionResults to the bbox format expected by the Segment Anything model.""" + boxes = [result.box.to_box() for result in detection_results] + return [boxes]