Move some logic from GroundedSAMInvocation to the backend classes.

This commit is contained in:
Ryan Dick
2024-07-30 15:34:33 -04:00
parent aca2a2fa13
commit 918f77bce0
3 changed files with 22 additions and 59 deletions

View File

@ -1,6 +1,5 @@
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Literal, Optional
from typing import Literal
import numpy as np
import numpy.typing as npt
@ -15,6 +14,7 @@ from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import ImageField, InputField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.grounded_sam.detection_result import DetectionResult
from invokeai.backend.grounded_sam.grounding_dino_pipeline import GroundingDinoPipeline
from invokeai.backend.grounded_sam.mask_refinement import mask_to_polygon, polygon_to_mask
from invokeai.backend.grounded_sam.segment_anything_model import SegmentAnythingModel
@ -23,43 +23,6 @@ GROUNDING_DINO_MODEL_ID = "IDEA-Research/grounding-dino-tiny"
SEGMENT_ANYTHING_MODEL_ID = "facebook/sam-vit-base"
@dataclass
class BoundingBox:
"""Bounding box helper class used locally for the Grounding DINO outputs."""
xmin: int
ymin: int
xmax: int
ymax: int
def to_box(self) -> list[int]:
"""Convert to the array notation expected by SAM."""
return [self.xmin, self.ymin, self.xmax, self.ymax]
@dataclass
class DetectionResult:
"""Detection result from Grounding DINO or Grounded SAM."""
score: float
label: str
box: BoundingBox
mask: Optional[npt.NDArray[Any]] = None
@classmethod
def from_dict(cls, detection_dict: dict[str, Any]):
return cls(
score=detection_dict["score"],
label=detection_dict["label"],
box=BoundingBox(
xmin=detection_dict["box"]["xmin"],
ymin=detection_dict["box"]["ymin"],
xmax=detection_dict["box"]["xmax"],
ymax=detection_dict["box"]["ymax"],
),
)
@invocation(
"grounded_segment_anything",
title="Segment Anything (Text Prompt)",
@ -92,9 +55,10 @@ class GroundedSAMInvocation(BaseInvocation):
description="The detection threshold for the Grounding DINO model. All detected bounding boxes with scores above this threshold will be used.",
ge=0.0,
le=1.0,
default=0.1,
default=0.3,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
image_pil = context.images.get_pil(self.image.image_name)
@ -118,15 +82,6 @@ class GroundedSAMInvocation(BaseInvocation):
image_dto = context.images.save(image=mask_pil)
return ImageOutput.build(image_dto)
def _to_box_array(self, detection_results: list[DetectionResult]) -> list[list[list[int]]]:
"""Convert a list of DetectionResults to the format expected by the Segment Anything model.
Args:
detection_results (list[DetectionResult]): The Grounding DINO detection results.
"""
boxes = [result.box.to_box() for result in detection_results]
return [boxes]
def _detect(
self,
context: InvocationContext,
@ -153,10 +108,7 @@ class GroundedSAMInvocation(BaseInvocation):
with context.models.load_remote_model(source=GROUNDING_DINO_MODEL_ID, loader=load_grounding_dino) as detector:
assert isinstance(detector, GroundingDinoPipeline)
results = detector.detect(image=image, candidate_labels=labels, threshold=threshold)
results = [DetectionResult.from_dict(result) for result in results]
return results
return detector.detect(image=image, candidate_labels=labels, threshold=threshold)
def _segment(
self,
@ -185,8 +137,7 @@ class GroundedSAMInvocation(BaseInvocation):
):
assert isinstance(sam_pipeline, SegmentAnythingModel)
boxes = self._to_box_array(detection_results)
masks = sam_pipeline.segment(image=image, boxes=boxes)
masks = sam_pipeline.segment(image=image, detection_results=detection_results)
masks = self._to_numpy_masks(masks)
masks = self._apply_polygon_refinement(masks)
@ -200,7 +151,7 @@ class GroundedSAMInvocation(BaseInvocation):
"""Convert the tensor output from the Segment Anything model to a list of numpy masks."""
masks = masks.cpu().float()
masks = masks.permute(0, 2, 3, 1)
masks = masks.mean(axis=-1)
masks = masks.mean(dim=-1)
masks = (masks > 0).int()
masks = masks.numpy().astype(np.uint8)
masks = list(masks)