mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Split GroundedSamInvocation into GroundingDinoInvocation and SegmentAnythingModelInvocation.
This commit is contained in:
parent
73386826d6
commit
0193267a53
@ -242,6 +242,23 @@ class ConditioningField(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BoundingBoxField(BaseModel):
|
||||||
|
"""A bounding box primitive value."""
|
||||||
|
|
||||||
|
x_min: int = Field(ge=0, description="The minimum x-coordinate of the bounding box (inclusive).")
|
||||||
|
x_max: int = Field(ge=0, description="The maximum x-coordinate of the bounding box (exclusive).")
|
||||||
|
y_min: int = Field(ge=0, description="The minimum y-coordinate of the bounding box (inclusive).")
|
||||||
|
y_max: int = Field(ge=0, description="The maximum y-coordinate of the bounding box (exclusive).")
|
||||||
|
|
||||||
|
score: Optional[float] = Field(
|
||||||
|
default=None,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
description="The score associated with the bounding box. In the range [0, 1]. This value is typically set "
|
||||||
|
"when the bounding box was produced by a detector and has an associated confidence score.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MetadataField(RootModel[dict[str, Any]]):
|
class MetadataField(RootModel[dict[str, Any]]):
|
||||||
"""
|
"""
|
||||||
Pydantic model for metadata with custom root of type dict[str, Any].
|
Pydantic model for metadata with custom root of type dict[str, Any].
|
||||||
|
95
invokeai/app/invocations/grounding_dino.py
Normal file
95
invokeai/app/invocations/grounding_dino.py
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from transformers import pipeline
|
||||||
|
from transformers.pipelines import ZeroShotObjectDetectionPipeline
|
||||||
|
|
||||||
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||||
|
from invokeai.app.invocations.fields import BoundingBoxField, ImageField, InputField
|
||||||
|
from invokeai.app.invocations.primitives import BoundingBoxCollectionOutput
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
|
from invokeai.backend.image_util.grounded_sam.detection_result import DetectionResult
|
||||||
|
from invokeai.backend.image_util.grounded_sam.grounding_dino_pipeline import GroundingDinoPipeline
|
||||||
|
|
||||||
|
GROUNDING_DINO_MODEL_ID = "IDEA-Research/grounding-dino-tiny"
|
||||||
|
|
||||||
|
|
||||||
|
@invocation(
|
||||||
|
"grounding_dino",
|
||||||
|
title="Grounding DINO (Text Prompt Object Detection)",
|
||||||
|
tags=["prompt", "object detection"],
|
||||||
|
category="image",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
|
class GroundingDinoInvocation(BaseInvocation):
|
||||||
|
"""Runs a Grounding DINO model (https://arxiv.org/pdf/2303.05499). Performs zero-shot bounding-box object detection
|
||||||
|
from a text prompt.
|
||||||
|
|
||||||
|
Reference:
|
||||||
|
- https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam
|
||||||
|
- https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb
|
||||||
|
"""
|
||||||
|
|
||||||
|
prompt: str = InputField(description="The prompt describing the object to segment.")
|
||||||
|
image: ImageField = InputField(description="The image to segment.")
|
||||||
|
detection_threshold: float = InputField(
|
||||||
|
description="The detection threshold for the Grounding DINO model. All detected bounding boxes with scores above this threshold will be returned.",
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
default=0.3,
|
||||||
|
)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def invoke(self, context: InvocationContext) -> BoundingBoxCollectionOutput:
|
||||||
|
# The model expects a 3-channel RGB image.
|
||||||
|
image_pil = context.images.get_pil(self.image.image_name, mode="RGB")
|
||||||
|
|
||||||
|
detections = self._detect(
|
||||||
|
context=context, image=image_pil, labels=[self.prompt], threshold=self.detection_threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert detections to BoundingBoxCollectionOutput.
|
||||||
|
bounding_boxes: list[BoundingBoxField] = []
|
||||||
|
for detection in detections:
|
||||||
|
bounding_boxes.append(
|
||||||
|
BoundingBoxField(
|
||||||
|
x_min=detection.box.xmin,
|
||||||
|
x_max=detection.box.xmax,
|
||||||
|
y_min=detection.box.ymin,
|
||||||
|
y_max=detection.box.ymax,
|
||||||
|
score=detection.score,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return BoundingBoxCollectionOutput(collection=bounding_boxes)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _load_grounding_dino(model_path: Path):
|
||||||
|
grounding_dino_pipeline = pipeline(
|
||||||
|
model=str(model_path),
|
||||||
|
task="zero-shot-object-detection",
|
||||||
|
local_files_only=True,
|
||||||
|
# TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
|
||||||
|
# model, and figure out how to make it work in the pipeline.
|
||||||
|
# torch_dtype=TorchDevice.choose_torch_dtype(),
|
||||||
|
)
|
||||||
|
assert isinstance(grounding_dino_pipeline, ZeroShotObjectDetectionPipeline)
|
||||||
|
return GroundingDinoPipeline(grounding_dino_pipeline)
|
||||||
|
|
||||||
|
def _detect(
|
||||||
|
self,
|
||||||
|
context: InvocationContext,
|
||||||
|
image: Image.Image,
|
||||||
|
labels: list[str],
|
||||||
|
threshold: float = 0.3,
|
||||||
|
) -> list[DetectionResult]:
|
||||||
|
"""Use Grounding DINO to detect bounding boxes for a set of labels in an image."""
|
||||||
|
# TODO(ryand): I copied this "."-handling logic from the transformers example code. Test it and see if it
|
||||||
|
# actually makes a difference.
|
||||||
|
labels = [label if label.endswith(".") else label + "." for label in labels]
|
||||||
|
|
||||||
|
with context.models.load_remote_model(
|
||||||
|
source=GROUNDING_DINO_MODEL_ID, loader=GroundingDinoInvocation._load_grounding_dino
|
||||||
|
) as detector:
|
||||||
|
assert isinstance(detector, GroundingDinoPipeline)
|
||||||
|
return detector.detect(image=image, candidate_labels=labels, threshold=threshold)
|
@ -7,6 +7,7 @@ import torch
|
|||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||||
from invokeai.app.invocations.fields import (
|
from invokeai.app.invocations.fields import (
|
||||||
|
BoundingBoxField,
|
||||||
ColorField,
|
ColorField,
|
||||||
ConditioningField,
|
ConditioningField,
|
||||||
DenoiseMaskField,
|
DenoiseMaskField,
|
||||||
@ -469,3 +470,24 @@ class ConditioningCollectionInvocation(BaseInvocation):
|
|||||||
|
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
# region BoundingBox
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("bounding_box_output")
|
||||||
|
class BoundingBoxOutput(BaseInvocationOutput):
|
||||||
|
"""Base class for nodes that output a single bounding box"""
|
||||||
|
|
||||||
|
bounding_box: BoundingBoxField = OutputField(description="The output bounding box.")
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("bounding_box_collection_output")
|
||||||
|
class BoundingBoxCollectionOutput(BaseInvocationOutput):
|
||||||
|
"""Base class for nodes that output a collection of bounding boxes"""
|
||||||
|
|
||||||
|
collection: list[BoundingBoxField] = OutputField(
|
||||||
|
description="The output bounding boxes.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# endregion
|
||||||
|
@ -5,75 +5,56 @@ import numpy as np
|
|||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import AutoModelForMaskGeneration, AutoProcessor, pipeline
|
from transformers import AutoModelForMaskGeneration, AutoProcessor
|
||||||
from transformers.models.sam import SamModel
|
from transformers.models.sam import SamModel
|
||||||
from transformers.models.sam.processing_sam import SamProcessor
|
from transformers.models.sam.processing_sam import SamProcessor
|
||||||
from transformers.pipelines import ZeroShotObjectDetectionPipeline
|
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||||
from invokeai.app.invocations.fields import ImageField, InputField
|
from invokeai.app.invocations.fields import BoundingBoxField, ImageField, InputField
|
||||||
from invokeai.app.invocations.primitives import ImageOutput
|
from invokeai.app.invocations.primitives import ImageOutput
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.backend.image_util.grounded_sam.detection_result import DetectionResult
|
|
||||||
from invokeai.backend.image_util.grounded_sam.grounding_dino_pipeline import GroundingDinoPipeline
|
|
||||||
from invokeai.backend.image_util.grounded_sam.mask_refinement import mask_to_polygon, polygon_to_mask
|
from invokeai.backend.image_util.grounded_sam.mask_refinement import mask_to_polygon, polygon_to_mask
|
||||||
from invokeai.backend.image_util.grounded_sam.segment_anything_model import SegmentAnythingModel
|
from invokeai.backend.image_util.grounded_sam.segment_anything_model import SegmentAnythingModel
|
||||||
|
|
||||||
GROUNDING_DINO_MODEL_ID = "IDEA-Research/grounding-dino-tiny"
|
|
||||||
SEGMENT_ANYTHING_MODEL_ID = "facebook/sam-vit-base"
|
SEGMENT_ANYTHING_MODEL_ID = "facebook/sam-vit-base"
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
"grounded_segment_anything",
|
"segment_anything_model",
|
||||||
title="Segment Anything (Text Prompt)",
|
title="Segment Anything Model",
|
||||||
tags=["prompt", "segmentation"],
|
tags=["prompt", "segmentation"],
|
||||||
category="segmentation",
|
category="segmentation",
|
||||||
version="1.0.0",
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class GroundedSAMInvocation(BaseInvocation):
|
class SegmentAnythingModelInvocation(BaseInvocation):
|
||||||
"""Runs Grounded-SAM, as proposed in https://arxiv.org/pdf/2401.14159.
|
"""Runs a Segment Anything Model (https://arxiv.org/pdf/2304.02643).
|
||||||
|
|
||||||
More specifically, a Grounding DINO model is run to obtain bounding boxes for a text prompt, then the bounding boxes
|
|
||||||
are passed as a prompt to a Segment Anything model to obtain a segmentation mask.
|
|
||||||
|
|
||||||
Reference:
|
Reference:
|
||||||
- https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam
|
- https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam
|
||||||
- https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb
|
- https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb
|
||||||
"""
|
"""
|
||||||
|
|
||||||
prompt: str = InputField(description="The prompt describing the object to segment.")
|
|
||||||
image: ImageField = InputField(description="The image to segment.")
|
image: ImageField = InputField(description="The image to segment.")
|
||||||
|
bounding_boxes: list[BoundingBoxField] = InputField(description="The bounding boxes to prompt the SAM model with.")
|
||||||
apply_polygon_refinement: bool = InputField(
|
apply_polygon_refinement: bool = InputField(
|
||||||
description="Whether to apply polygon refinement to the masks. This will smooth the edges of the mask slightly and ensure that each mask consists of a single closed polygon (before merging).",
|
description="Whether to apply polygon refinement to the masks. This will smooth the edges of the masks slightly and ensure that each mask consists of a single closed polygon (before merging).",
|
||||||
default=True,
|
default=True,
|
||||||
)
|
)
|
||||||
mask_filter: Literal["all", "largest", "highest_box_score"] = InputField(
|
mask_filter: Literal["all", "largest", "highest_box_score"] = InputField(
|
||||||
description="The filtering to apply to the detected masks before merging them into a final output.",
|
description="The filtering to apply to the detected masks before merging them into a final output.",
|
||||||
default="all",
|
default="all",
|
||||||
)
|
)
|
||||||
detection_threshold: float = InputField(
|
|
||||||
description="The detection threshold for the Grounding DINO model. All detected bounding boxes with scores above this threshold will be used.",
|
|
||||||
ge=0.0,
|
|
||||||
le=1.0,
|
|
||||||
default=0.3,
|
|
||||||
)
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
# The models expect a 3-channel RGB image.
|
# The models expect a 3-channel RGB image.
|
||||||
image_pil = context.images.get_pil(self.image.image_name, mode="RGB")
|
image_pil = context.images.get_pil(self.image.image_name, mode="RGB")
|
||||||
|
|
||||||
detections = self._detect(
|
if len(self.bounding_boxes) == 0:
|
||||||
context=context, image=image_pil, labels=[self.prompt], threshold=self.detection_threshold
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(detections) == 0:
|
|
||||||
combined_mask = np.zeros(image_pil.size[::-1], dtype=np.uint8)
|
combined_mask = np.zeros(image_pil.size[::-1], dtype=np.uint8)
|
||||||
else:
|
else:
|
||||||
detections = self._segment(context=context, image=image_pil, detection_results=detections)
|
masks = self._segment(context=context, image=image_pil)
|
||||||
|
masks = self._filter_masks(masks=masks, bounding_boxes=self.bounding_boxes)
|
||||||
detections = self._filter_detections(detections)
|
|
||||||
masks = [detection.mask for detection in detections]
|
|
||||||
# masks contains binary values of 0 or 1, so we merge them via max-reduce.
|
# masks contains binary values of 0 or 1, so we merge them via max-reduce.
|
||||||
combined_mask = np.maximum.reduce(masks)
|
combined_mask = np.maximum.reduce(masks)
|
||||||
|
|
||||||
@ -84,19 +65,6 @@ class GroundedSAMInvocation(BaseInvocation):
|
|||||||
image_dto = context.images.save(image=mask_pil)
|
image_dto = context.images.save(image=mask_pil)
|
||||||
return ImageOutput.build(image_dto)
|
return ImageOutput.build(image_dto)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _load_grounding_dino(model_path: Path):
|
|
||||||
grounding_dino_pipeline = pipeline(
|
|
||||||
model=str(model_path),
|
|
||||||
task="zero-shot-object-detection",
|
|
||||||
local_files_only=True,
|
|
||||||
# TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
|
|
||||||
# model, and figure out how to make it work in the pipeline.
|
|
||||||
# torch_dtype=TorchDevice.choose_torch_dtype(),
|
|
||||||
)
|
|
||||||
assert isinstance(grounding_dino_pipeline, ZeroShotObjectDetectionPipeline)
|
|
||||||
return GroundingDinoPipeline(grounding_dino_pipeline)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _load_sam_model(model_path: Path):
|
def _load_sam_model(model_path: Path):
|
||||||
sam_model = AutoModelForMaskGeneration.from_pretrained(
|
sam_model = AutoModelForMaskGeneration.from_pretrained(
|
||||||
@ -112,47 +80,28 @@ class GroundedSAMInvocation(BaseInvocation):
|
|||||||
assert isinstance(sam_processor, SamProcessor)
|
assert isinstance(sam_processor, SamProcessor)
|
||||||
return SegmentAnythingModel(sam_model=sam_model, sam_processor=sam_processor)
|
return SegmentAnythingModel(sam_model=sam_model, sam_processor=sam_processor)
|
||||||
|
|
||||||
def _detect(
|
|
||||||
self,
|
|
||||||
context: InvocationContext,
|
|
||||||
image: Image.Image,
|
|
||||||
labels: list[str],
|
|
||||||
threshold: float = 0.3,
|
|
||||||
) -> list[DetectionResult]:
|
|
||||||
"""Use Grounding DINO to detect bounding boxes for a set of labels in an image."""
|
|
||||||
# TODO(ryand): I copied this "."-handling logic from the transformers example code. Test it and see if it
|
|
||||||
# actually makes a difference.
|
|
||||||
labels = [label if label.endswith(".") else label + "." for label in labels]
|
|
||||||
|
|
||||||
with context.models.load_remote_model(
|
|
||||||
source=GROUNDING_DINO_MODEL_ID, loader=GroundedSAMInvocation._load_grounding_dino
|
|
||||||
) as detector:
|
|
||||||
assert isinstance(detector, GroundingDinoPipeline)
|
|
||||||
return detector.detect(image=image, candidate_labels=labels, threshold=threshold)
|
|
||||||
|
|
||||||
def _segment(
|
def _segment(
|
||||||
self,
|
self,
|
||||||
context: InvocationContext,
|
context: InvocationContext,
|
||||||
image: Image.Image,
|
image: Image.Image,
|
||||||
detection_results: list[DetectionResult],
|
) -> list[npt.NDArray[np.uint8]]:
|
||||||
) -> list[DetectionResult]:
|
|
||||||
"""Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes."""
|
"""Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes."""
|
||||||
|
# Convert the bounding boxes to the SAM input format.
|
||||||
|
sam_bounding_boxes = [[bb.x_min, bb.y_min, bb.x_max, bb.y_max] for bb in self.bounding_boxes]
|
||||||
|
|
||||||
with (
|
with (
|
||||||
context.models.load_remote_model(
|
context.models.load_remote_model(
|
||||||
source=SEGMENT_ANYTHING_MODEL_ID, loader=GroundedSAMInvocation._load_sam_model
|
source=SEGMENT_ANYTHING_MODEL_ID, loader=SegmentAnythingModelInvocation._load_sam_model
|
||||||
) as sam_pipeline,
|
) as sam_pipeline,
|
||||||
):
|
):
|
||||||
assert isinstance(sam_pipeline, SegmentAnythingModel)
|
assert isinstance(sam_pipeline, SegmentAnythingModel)
|
||||||
masks = sam_pipeline.segment(image=image, detection_results=detection_results)
|
masks = sam_pipeline.segment(image=image, bounding_boxes=sam_bounding_boxes)
|
||||||
|
|
||||||
masks = self._to_numpy_masks(masks)
|
masks = self._to_numpy_masks(masks)
|
||||||
if self.apply_polygon_refinement:
|
if self.apply_polygon_refinement:
|
||||||
masks = self._apply_polygon_refinement(masks)
|
masks = self._apply_polygon_refinement(masks)
|
||||||
|
|
||||||
for detection_result, mask in zip(detection_results, masks, strict=True):
|
return masks
|
||||||
detection_result.mask = mask
|
|
||||||
|
|
||||||
return detection_results
|
|
||||||
|
|
||||||
def _to_numpy_masks(self, masks: torch.Tensor) -> list[npt.NDArray[np.uint8]]:
|
def _to_numpy_masks(self, masks: torch.Tensor) -> list[npt.NDArray[np.uint8]]:
|
||||||
"""Convert the tensor output from the Segment Anything model to a list of numpy masks."""
|
"""Convert the tensor output from the Segment Anything model to a list of numpy masks."""
|
||||||
@ -181,15 +130,23 @@ class GroundedSAMInvocation(BaseInvocation):
|
|||||||
|
|
||||||
return masks
|
return masks
|
||||||
|
|
||||||
def _filter_detections(self, detections: list[DetectionResult]) -> list[DetectionResult]:
|
def _filter_masks(
|
||||||
|
self, masks: list[npt.NDArray[np.uint8]], bounding_boxes: list[BoundingBoxField]
|
||||||
|
) -> list[npt.NDArray[np.uint8]]:
|
||||||
"""Filter the detected masks based on the specified mask filter."""
|
"""Filter the detected masks based on the specified mask filter."""
|
||||||
|
assert len(masks) == len(bounding_boxes)
|
||||||
|
|
||||||
if self.mask_filter == "all":
|
if self.mask_filter == "all":
|
||||||
return detections
|
return masks
|
||||||
elif self.mask_filter == "largest":
|
elif self.mask_filter == "largest":
|
||||||
# Find the largest mask.
|
# Find the largest mask.
|
||||||
return [max(detections, key=lambda x: x.mask.sum())]
|
return [max(masks, key=lambda x: x.sum())]
|
||||||
elif self.mask_filter == "highest_box_score":
|
elif self.mask_filter == "highest_box_score":
|
||||||
# Find the detection with the highest box score.
|
# Find the index of the bounding box with the highest score.
|
||||||
return [max(detections, key=lambda x: x.score)]
|
# Note that we fallback to -1.0 if the score is None. This is mainly to satisfy the type checker. In most
|
||||||
|
# cases the scores should all be non-None when using this filtering mode. That being said, -1.0 is a
|
||||||
|
# reasonable fallback since the expected score range is [0.0, 1.0].
|
||||||
|
max_score_idx = max(range(len(bounding_boxes)), key=lambda i: bounding_boxes[i].score or -1.0)
|
||||||
|
return [masks[max_score_idx]]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid mask filter: {self.mask_filter}")
|
raise ValueError(f"Invalid mask filter: {self.mask_filter}")
|
@ -1,6 +1,3 @@
|
|||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
import numpy.typing as npt
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
|
|
||||||
@ -12,18 +9,13 @@ class BoundingBox(BaseModel):
|
|||||||
xmax: int
|
xmax: int
|
||||||
ymax: int
|
ymax: int
|
||||||
|
|
||||||
def to_box(self) -> list[int]:
|
|
||||||
"""Convert to the array notation expected by SAM."""
|
|
||||||
return [self.xmin, self.ymin, self.xmax, self.ymax]
|
|
||||||
|
|
||||||
|
|
||||||
class DetectionResult(BaseModel):
|
class DetectionResult(BaseModel):
|
||||||
"""Detection result from Grounding DINO or Grounded SAM."""
|
"""Detection result from Grounding DINO."""
|
||||||
|
|
||||||
score: float
|
score: float
|
||||||
label: str
|
label: str
|
||||||
box: BoundingBox
|
box: BoundingBox
|
||||||
mask: Optional[npt.NDArray[Any]] = None
|
|
||||||
model_config = ConfigDict(
|
model_config = ConfigDict(
|
||||||
# Allow arbitrary types for mask, since it will be a numpy array.
|
# Allow arbitrary types for mask, since it will be a numpy array.
|
||||||
arbitrary_types_allowed=True
|
arbitrary_types_allowed=True
|
||||||
|
@ -5,7 +5,6 @@ from PIL import Image
|
|||||||
from transformers.models.sam import SamModel
|
from transformers.models.sam import SamModel
|
||||||
from transformers.models.sam.processing_sam import SamProcessor
|
from transformers.models.sam.processing_sam import SamProcessor
|
||||||
|
|
||||||
from invokeai.backend.image_util.grounded_sam.detection_result import DetectionResult
|
|
||||||
from invokeai.backend.raw_model import RawModel
|
from invokeai.backend.raw_model import RawModel
|
||||||
|
|
||||||
|
|
||||||
@ -28,8 +27,19 @@ class SegmentAnythingModel(RawModel):
|
|||||||
|
|
||||||
return calc_module_size(self._sam_model)
|
return calc_module_size(self._sam_model)
|
||||||
|
|
||||||
def segment(self, image: Image.Image, detection_results: list[DetectionResult]) -> torch.Tensor:
|
def segment(self, image: Image.Image, bounding_boxes: list[list[int]]) -> torch.Tensor:
|
||||||
boxes = self._to_box_array(detection_results)
|
"""Run the SAM model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (Image.Image): The image to segment.
|
||||||
|
bounding_boxes (list[list[int]]): The bounding box prompts. Each bounding box is in the format
|
||||||
|
[xmin, ymin, xmax, ymax].
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The segmentation masks. dtype: torch.bool. shape: [num_masks, channels, height, width].
|
||||||
|
"""
|
||||||
|
# Add batch dimension of 1 to the bounding boxes.
|
||||||
|
boxes = [bounding_boxes]
|
||||||
inputs = self._sam_processor(images=image, input_boxes=boxes, return_tensors="pt").to(self._sam_model.device)
|
inputs = self._sam_processor(images=image, input_boxes=boxes, return_tensors="pt").to(self._sam_model.device)
|
||||||
outputs = self._sam_model(**inputs)
|
outputs = self._sam_model(**inputs)
|
||||||
masks = self._sam_processor.post_process_masks(
|
masks = self._sam_processor.post_process_masks(
|
||||||
@ -40,10 +50,4 @@ class SegmentAnythingModel(RawModel):
|
|||||||
|
|
||||||
# There should be only one batch.
|
# There should be only one batch.
|
||||||
assert len(masks) == 1
|
assert len(masks) == 1
|
||||||
masks = masks[0]
|
return masks[0]
|
||||||
return masks
|
|
||||||
|
|
||||||
def _to_box_array(self, detection_results: list[DetectionResult]) -> list[list[list[int]]]:
|
|
||||||
"""Convert a list of DetectionResults to the bbox format expected by the Segment Anything model."""
|
|
||||||
boxes = [result.box.to_box() for result in detection_results]
|
|
||||||
return [boxes]
|
|
||||||
|
Loading…
Reference in New Issue
Block a user