mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
(minor) Simplify GroundedSAMInvocation._filter_detections()
This commit is contained in:
parent
67c32f3d6c
commit
bdae81e429
@ -182,16 +182,10 @@ class GroundedSAMInvocation(BaseInvocation):
|
|||||||
return detections
|
return detections
|
||||||
elif self.mask_filter == "largest":
|
elif self.mask_filter == "largest":
|
||||||
# Find the largest mask.
|
# Find the largest mask.
|
||||||
mask_areas = [detection.mask.sum() for detection in detections]
|
return [max(detections, key=lambda x: x.mask.sum())]
|
||||||
largest_mask_idx = mask_areas.index(max(mask_areas))
|
|
||||||
return [detections[largest_mask_idx]]
|
|
||||||
elif self.mask_filter == "highest_box_score":
|
elif self.mask_filter == "highest_box_score":
|
||||||
# Find the detection with the highest box score.
|
# Find the detection with the highest box score.
|
||||||
max_score_detection = detections[0]
|
return [max(detections, key=lambda x: x.score)]
|
||||||
for detection in detections:
|
|
||||||
if detection.score > max_score_detection.score:
|
|
||||||
max_score_detection = detection
|
|
||||||
return [max_score_detection]
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid mask filter: {self.mask_filter}")
|
raise ValueError(f"Invalid mask filter: {self.mask_filter}")
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user