(minor) Simplify GroundedSAMInvocation._filter_detections()

This commit is contained in:
Ryan Dick 2024-07-31 08:25:19 -04:00
parent 67c32f3d6c
commit bdae81e429

View File

@ -182,16 +182,10 @@ class GroundedSAMInvocation(BaseInvocation):
return detections return detections
elif self.mask_filter == "largest": elif self.mask_filter == "largest":
# Find the largest mask. # Find the largest mask.
mask_areas = [detection.mask.sum() for detection in detections] return [max(detections, key=lambda x: x.mask.sum())]
largest_mask_idx = mask_areas.index(max(mask_areas))
return [detections[largest_mask_idx]]
elif self.mask_filter == "highest_box_score": elif self.mask_filter == "highest_box_score":
# Find the detection with the highest box score. # Find the detection with the highest box score.
max_score_detection = detections[0] return [max(detections, key=lambda x: x.score)]
for detection in detections:
if detection.score > max_score_detection.score:
max_score_detection = detection
return [max_score_detection]
else: else:
raise ValueError(f"Invalid mask filter: {self.mask_filter}") raise ValueError(f"Invalid mask filter: {self.mask_filter}")