Add mask_filter and detection_threshold options to the GroundedSAMInvocation.

This commit is contained in:
Ryan Dick
2024-07-30 14:22:40 -04:00
parent ff6398f7d8
commit aca2a2fa13
3 changed files with 83 additions and 25 deletions

View File

@ -1,6 +1,7 @@
from typing import Optional
import torch
from PIL import Image
from transformers.pipelines import ZeroShotObjectDetectionPipeline
@ -12,8 +13,8 @@ class GroundingDinoPipeline:
def __init__(self, pipeline: ZeroShotObjectDetectionPipeline):
self._pipeline = pipeline
def __call__(self, *args, **kwargs):
return self._pipeline(*args, **kwargs)
def detect(self, image: Image.Image, candidate_labels: list[str], threshold: float = 0.1):
return self._pipeline(image=image, candidate_labels=candidate_labels, threshold=threshold)
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> "GroundingDinoPipeline":
self._pipeline.model.to(device=device, dtype=dtype)

View File

@ -30,6 +30,9 @@ class SegmentAnythingModel:
masks=outputs.pred_masks,
original_sizes=inputs.original_sizes,
reshaped_input_sizes=inputs.reshaped_input_sizes,
)[0]
)
# There should be only one batch.
assert len(masks) == 1
masks = masks[0]
return masks