(minor) Simplify GroundedSAMInvocation._merge_masks(...).

This commit is contained in:
Ryan Dick 2024-07-31 08:58:51 -04:00
parent e8ecf5e155
commit 0a7048f650

View File

@ -74,7 +74,8 @@ class GroundedSAMInvocation(BaseInvocation):
detections = self._filter_detections(detections)
masks = [detection.mask for detection in detections]
combined_mask = self._merge_masks(masks)
# masks contains binary values of 0 or 1, so we merge them via max-reduce.
combined_mask = np.maximum.reduce(masks)
# Map [0, 1] to [0, 255].
mask_np = combined_mask * 255
@ -188,10 +189,3 @@ class GroundedSAMInvocation(BaseInvocation):
return [max(detections, key=lambda x: x.score)]
else:
raise ValueError(f"Invalid mask filter: {self.mask_filter}")
def _merge_masks(self, masks: list[npt.NDArray[np.uint8]]) -> npt.NDArray[np.uint8]:
"""Merge multiple masks into a single mask."""
# Merge all masks together.
stacked_mask = np.stack(masks, axis=0)
combined_mask = np.max(stacked_mask, axis=0)
return combined_mask