(minor) Simplify GroundedSAMInvocation._filter_detections()

This commit is contained in:
Ryan Dick 2024-07-31 08:25:19 -04:00
parent 67c32f3d6c
commit bdae81e429

View File

@ -182,16 +182,10 @@ class GroundedSAMInvocation(BaseInvocation):
return detections
elif self.mask_filter == "largest":
# Find the largest mask.
mask_areas = [detection.mask.sum() for detection in detections]
largest_mask_idx = mask_areas.index(max(mask_areas))
return [detections[largest_mask_idx]]
return [max(detections, key=lambda x: x.mask.sum())]
elif self.mask_filter == "highest_box_score":
# Find the detection with the highest box score.
max_score_detection = detections[0]
for detection in detections:
if detection.score > max_score_detection.score:
max_score_detection = detection
return [max_score_detection]
return [max(detections, key=lambda x: x.score)]
else:
raise ValueError(f"Invalid mask filter: {self.mask_filter}")