Split GroundedSamInvocation into GroundingDinoInvocation and SegmentAnythingModelInvocation.

This commit is contained in:
Ryan Dick
2024-07-31 12:20:23 -04:00
parent 73386826d6
commit 0193267a53
6 changed files with 180 additions and 93 deletions

View File

@ -1,6 +1,3 @@
from typing import Any, Optional
import numpy.typing as npt
from pydantic import BaseModel, ConfigDict
@ -12,18 +9,13 @@ class BoundingBox(BaseModel):
xmax: int
ymax: int
def to_box(self) -> list[int]:
"""Convert to the array notation expected by SAM."""
return [self.xmin, self.ymin, self.xmax, self.ymax]
class DetectionResult(BaseModel):
"""Detection result from Grounding DINO or Grounded SAM."""
"""Detection result from Grounding DINO."""
score: float
label: str
box: BoundingBox
mask: Optional[npt.NDArray[Any]] = None
model_config = ConfigDict(
# Allow arbitrary types for mask, since it will be a numpy array.
arbitrary_types_allowed=True

View File

@ -5,7 +5,6 @@ from PIL import Image
from transformers.models.sam import SamModel
from transformers.models.sam.processing_sam import SamProcessor
from invokeai.backend.image_util.grounded_sam.detection_result import DetectionResult
from invokeai.backend.raw_model import RawModel
@ -28,8 +27,19 @@ class SegmentAnythingModel(RawModel):
return calc_module_size(self._sam_model)
def segment(self, image: Image.Image, detection_results: list[DetectionResult]) -> torch.Tensor:
boxes = self._to_box_array(detection_results)
def segment(self, image: Image.Image, bounding_boxes: list[list[int]]) -> torch.Tensor:
"""Run the SAM model.
Args:
image (Image.Image): The image to segment.
bounding_boxes (list[list[int]]): The bounding box prompts. Each bounding box is in the format
[xmin, ymin, xmax, ymax].
Returns:
torch.Tensor: The segmentation masks. dtype: torch.bool. shape: [num_masks, channels, height, width].
"""
# Add batch dimension of 1 to the bounding boxes.
boxes = [bounding_boxes]
inputs = self._sam_processor(images=image, input_boxes=boxes, return_tensors="pt").to(self._sam_model.device)
outputs = self._sam_model(**inputs)
masks = self._sam_processor.post_process_masks(
@ -40,10 +50,4 @@ class SegmentAnythingModel(RawModel):
# There should be only one batch.
assert len(masks) == 1
masks = masks[0]
return masks
def _to_box_array(self, detection_results: list[DetectionResult]) -> list[list[list[int]]]:
"""Convert a list of DetectionResults to the bbox format expected by the Segment Anything model."""
boxes = [result.box.to_box() for result in detection_results]
return [boxes]
return masks[0]