mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Split GroundedSamInvocation into GroundingDinoInvocation and SegmentAnythingModelInvocation.
This commit is contained in:
@ -1,6 +1,3 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy.typing as npt
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
@ -12,18 +9,13 @@ class BoundingBox(BaseModel):
|
||||
xmax: int
|
||||
ymax: int
|
||||
|
||||
def to_box(self) -> list[int]:
|
||||
"""Convert to the array notation expected by SAM."""
|
||||
return [self.xmin, self.ymin, self.xmax, self.ymax]
|
||||
|
||||
|
||||
class DetectionResult(BaseModel):
|
||||
"""Detection result from Grounding DINO or Grounded SAM."""
|
||||
"""Detection result from Grounding DINO."""
|
||||
|
||||
score: float
|
||||
label: str
|
||||
box: BoundingBox
|
||||
mask: Optional[npt.NDArray[Any]] = None
|
||||
model_config = ConfigDict(
|
||||
# Allow arbitrary types for mask, since it will be a numpy array.
|
||||
arbitrary_types_allowed=True
|
||||
|
@ -5,7 +5,6 @@ from PIL import Image
|
||||
from transformers.models.sam import SamModel
|
||||
from transformers.models.sam.processing_sam import SamProcessor
|
||||
|
||||
from invokeai.backend.image_util.grounded_sam.detection_result import DetectionResult
|
||||
from invokeai.backend.raw_model import RawModel
|
||||
|
||||
|
||||
@ -28,8 +27,19 @@ class SegmentAnythingModel(RawModel):
|
||||
|
||||
return calc_module_size(self._sam_model)
|
||||
|
||||
def segment(self, image: Image.Image, detection_results: list[DetectionResult]) -> torch.Tensor:
|
||||
boxes = self._to_box_array(detection_results)
|
||||
def segment(self, image: Image.Image, bounding_boxes: list[list[int]]) -> torch.Tensor:
|
||||
"""Run the SAM model.
|
||||
|
||||
Args:
|
||||
image (Image.Image): The image to segment.
|
||||
bounding_boxes (list[list[int]]): The bounding box prompts. Each bounding box is in the format
|
||||
[xmin, ymin, xmax, ymax].
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The segmentation masks. dtype: torch.bool. shape: [num_masks, channels, height, width].
|
||||
"""
|
||||
# Add batch dimension of 1 to the bounding boxes.
|
||||
boxes = [bounding_boxes]
|
||||
inputs = self._sam_processor(images=image, input_boxes=boxes, return_tensors="pt").to(self._sam_model.device)
|
||||
outputs = self._sam_model(**inputs)
|
||||
masks = self._sam_processor.post_process_masks(
|
||||
@ -40,10 +50,4 @@ class SegmentAnythingModel(RawModel):
|
||||
|
||||
# There should be only one batch.
|
||||
assert len(masks) == 1
|
||||
masks = masks[0]
|
||||
return masks
|
||||
|
||||
def _to_box_array(self, detection_results: list[DetectionResult]) -> list[list[list[int]]]:
|
||||
"""Convert a list of DetectionResults to the bbox format expected by the Segment Anything model."""
|
||||
boxes = [result.box.to_box() for result in detection_results]
|
||||
return [boxes]
|
||||
return masks[0]
|
||||
|
Reference in New Issue
Block a user