Move some logic from GroundedSAMInvocation to the backend classes.

This commit is contained in:
Ryan Dick 2024-07-30 15:34:33 -04:00
parent aca2a2fa13
commit 918f77bce0
3 changed files with 22 additions and 59 deletions

View File

@ -1,6 +1,5 @@
from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Any, Literal, Optional from typing import Literal
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
@ -15,6 +14,7 @@ from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import ImageField, InputField from invokeai.app.invocations.fields import ImageField, InputField
from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.grounded_sam.detection_result import DetectionResult
from invokeai.backend.grounded_sam.grounding_dino_pipeline import GroundingDinoPipeline from invokeai.backend.grounded_sam.grounding_dino_pipeline import GroundingDinoPipeline
from invokeai.backend.grounded_sam.mask_refinement import mask_to_polygon, polygon_to_mask from invokeai.backend.grounded_sam.mask_refinement import mask_to_polygon, polygon_to_mask
from invokeai.backend.grounded_sam.segment_anything_model import SegmentAnythingModel from invokeai.backend.grounded_sam.segment_anything_model import SegmentAnythingModel
@ -23,43 +23,6 @@ GROUNDING_DINO_MODEL_ID = "IDEA-Research/grounding-dino-tiny"
SEGMENT_ANYTHING_MODEL_ID = "facebook/sam-vit-base" SEGMENT_ANYTHING_MODEL_ID = "facebook/sam-vit-base"
@dataclass
class BoundingBox:
"""Bounding box helper class used locally for the Grounding DINO outputs."""
xmin: int
ymin: int
xmax: int
ymax: int
def to_box(self) -> list[int]:
"""Convert to the array notation expected by SAM."""
return [self.xmin, self.ymin, self.xmax, self.ymax]
@dataclass
class DetectionResult:
"""Detection result from Grounding DINO or Grounded SAM."""
score: float
label: str
box: BoundingBox
mask: Optional[npt.NDArray[Any]] = None
@classmethod
def from_dict(cls, detection_dict: dict[str, Any]):
return cls(
score=detection_dict["score"],
label=detection_dict["label"],
box=BoundingBox(
xmin=detection_dict["box"]["xmin"],
ymin=detection_dict["box"]["ymin"],
xmax=detection_dict["box"]["xmax"],
ymax=detection_dict["box"]["ymax"],
),
)
@invocation( @invocation(
"grounded_segment_anything", "grounded_segment_anything",
title="Segment Anything (Text Prompt)", title="Segment Anything (Text Prompt)",
@ -92,9 +55,10 @@ class GroundedSAMInvocation(BaseInvocation):
description="The detection threshold for the Grounding DINO model. All detected bounding boxes with scores above this threshold will be used.", description="The detection threshold for the Grounding DINO model. All detected bounding boxes with scores above this threshold will be used.",
ge=0.0, ge=0.0,
le=1.0, le=1.0,
default=0.1, default=0.3,
) )
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image_pil = context.images.get_pil(self.image.image_name) image_pil = context.images.get_pil(self.image.image_name)
@ -118,15 +82,6 @@ class GroundedSAMInvocation(BaseInvocation):
image_dto = context.images.save(image=mask_pil) image_dto = context.images.save(image=mask_pil)
return ImageOutput.build(image_dto) return ImageOutput.build(image_dto)
def _to_box_array(self, detection_results: list[DetectionResult]) -> list[list[list[int]]]:
"""Convert a list of DetectionResults to the format expected by the Segment Anything model.
Args:
detection_results (list[DetectionResult]): The Grounding DINO detection results.
"""
boxes = [result.box.to_box() for result in detection_results]
return [boxes]
def _detect( def _detect(
self, self,
context: InvocationContext, context: InvocationContext,
@ -153,10 +108,7 @@ class GroundedSAMInvocation(BaseInvocation):
with context.models.load_remote_model(source=GROUNDING_DINO_MODEL_ID, loader=load_grounding_dino) as detector: with context.models.load_remote_model(source=GROUNDING_DINO_MODEL_ID, loader=load_grounding_dino) as detector:
assert isinstance(detector, GroundingDinoPipeline) assert isinstance(detector, GroundingDinoPipeline)
return detector.detect(image=image, candidate_labels=labels, threshold=threshold)
results = detector.detect(image=image, candidate_labels=labels, threshold=threshold)
results = [DetectionResult.from_dict(result) for result in results]
return results
def _segment( def _segment(
self, self,
@ -185,8 +137,7 @@ class GroundedSAMInvocation(BaseInvocation):
): ):
assert isinstance(sam_pipeline, SegmentAnythingModel) assert isinstance(sam_pipeline, SegmentAnythingModel)
boxes = self._to_box_array(detection_results) masks = sam_pipeline.segment(image=image, detection_results=detection_results)
masks = sam_pipeline.segment(image=image, boxes=boxes)
masks = self._to_numpy_masks(masks) masks = self._to_numpy_masks(masks)
masks = self._apply_polygon_refinement(masks) masks = self._apply_polygon_refinement(masks)
@ -200,7 +151,7 @@ class GroundedSAMInvocation(BaseInvocation):
"""Convert the tensor output from the Segment Anything model to a list of numpy masks.""" """Convert the tensor output from the Segment Anything model to a list of numpy masks."""
masks = masks.cpu().float() masks = masks.cpu().float()
masks = masks.permute(0, 2, 3, 1) masks = masks.permute(0, 2, 3, 1)
masks = masks.mean(axis=-1) masks = masks.mean(dim=-1)
masks = (masks > 0).int() masks = (masks > 0).int()
masks = masks.numpy().astype(np.uint8) masks = masks.numpy().astype(np.uint8)
masks = list(masks) masks = list(masks)

View File

@ -4,6 +4,8 @@ import torch
from PIL import Image from PIL import Image
from transformers.pipelines import ZeroShotObjectDetectionPipeline from transformers.pipelines import ZeroShotObjectDetectionPipeline
from invokeai.backend.grounded_sam.detection_result import DetectionResult
class GroundingDinoPipeline: class GroundingDinoPipeline:
"""A wrapper class for a ZeroShotObjectDetectionPipeline that makes it compatible with the model manager's memory """A wrapper class for a ZeroShotObjectDetectionPipeline that makes it compatible with the model manager's memory
@ -13,8 +15,10 @@ class GroundingDinoPipeline:
def __init__(self, pipeline: ZeroShotObjectDetectionPipeline): def __init__(self, pipeline: ZeroShotObjectDetectionPipeline):
self._pipeline = pipeline self._pipeline = pipeline
def detect(self, image: Image.Image, candidate_labels: list[str], threshold: float = 0.1): def detect(self, image: Image.Image, candidate_labels: list[str], threshold: float = 0.1) -> list[DetectionResult]:
return self._pipeline(image=image, candidate_labels=candidate_labels, threshold=threshold) results = self._pipeline(image=image, candidate_labels=candidate_labels, threshold=threshold)
results = [DetectionResult.from_dict(result) for result in results]
return results
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> "GroundingDinoPipeline": def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> "GroundingDinoPipeline":
self._pipeline.model.to(device=device, dtype=dtype) self._pipeline.model.to(device=device, dtype=dtype)

View File

@ -5,6 +5,8 @@ from PIL import Image
from transformers.models.sam import SamModel from transformers.models.sam import SamModel
from transformers.models.sam.processing_sam import SamProcessor from transformers.models.sam.processing_sam import SamProcessor
from invokeai.backend.grounded_sam.detection_result import DetectionResult
class SegmentAnythingModel: class SegmentAnythingModel:
"""A wrapper class for the transformers SAM model and processor that makes it compatible with the model manager.""" """A wrapper class for the transformers SAM model and processor that makes it compatible with the model manager."""
@ -23,7 +25,8 @@ class SegmentAnythingModel:
return calc_module_size(self._sam_model) return calc_module_size(self._sam_model)
def segment(self, image: Image.Image, boxes: list[list[list[int]]]) -> torch.Tensor: def segment(self, image: Image.Image, detection_results: list[DetectionResult]) -> torch.Tensor:
boxes = self._to_box_array(detection_results)
inputs = self._sam_processor(images=image, input_boxes=boxes, return_tensors="pt").to(self._sam_model.device) inputs = self._sam_processor(images=image, input_boxes=boxes, return_tensors="pt").to(self._sam_model.device)
outputs = self._sam_model(**inputs) outputs = self._sam_model(**inputs)
masks = self._sam_processor.post_process_masks( masks = self._sam_processor.post_process_masks(
@ -36,3 +39,8 @@ class SegmentAnythingModel:
assert len(masks) == 1 assert len(masks) == 1
masks = masks[0] masks = masks[0]
return masks return masks
def _to_box_array(self, detection_results: list[DetectionResult]) -> list[list[list[int]]]:
"""Convert a list of DetectionResults to the bbox format expected by the Segment Anything model."""
boxes = [result.box.to_box() for result in detection_results]
return [boxes]