mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add mask_filter and detection_threshold options to the GroundedSAMInvocation.
This commit is contained in:
@ -1,6 +1,7 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers.pipelines import ZeroShotObjectDetectionPipeline
|
||||
|
||||
|
||||
@ -12,8 +13,8 @@ class GroundingDinoPipeline:
|
||||
def __init__(self, pipeline: ZeroShotObjectDetectionPipeline):
|
||||
self._pipeline = pipeline
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self._pipeline(*args, **kwargs)
|
||||
def detect(self, image: Image.Image, candidate_labels: list[str], threshold: float = 0.1):
|
||||
return self._pipeline(image=image, candidate_labels=candidate_labels, threshold=threshold)
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> "GroundingDinoPipeline":
|
||||
self._pipeline.model.to(device=device, dtype=dtype)
|
||||
|
@ -30,6 +30,9 @@ class SegmentAnythingModel:
|
||||
masks=outputs.pred_masks,
|
||||
original_sizes=inputs.original_sizes,
|
||||
reshaped_input_sizes=inputs.reshaped_input_sizes,
|
||||
)[0]
|
||||
)
|
||||
|
||||
# There should be only one batch.
|
||||
assert len(masks) == 1
|
||||
masks = masks[0]
|
||||
return masks
|
||||
|
Reference in New Issue
Block a user