Re-order GroundedSAMInvocation._to_numpy_masks(...) to do slightly more work on the GPU.

This commit is contained in:
Ryan Dick 2024-07-31 09:51:14 -04:00
parent e206890e25
commit bcd1483a14

View File

@ -156,11 +156,11 @@ class GroundedSAMInvocation(BaseInvocation):
def _to_numpy_masks(self, masks: torch.Tensor) -> list[npt.NDArray[np.uint8]]:
"""Convert the tensor output from the Segment Anything model to a list of numpy masks."""
masks = masks.cpu().float()
masks = masks.permute(0, 2, 3, 1)
masks = masks.mean(dim=-1)
masks = (masks > 0).int()
np_masks = masks.numpy().astype(np.uint8)
eps = 0.0001
# [num_masks, channels, height, width] -> [num_masks, height, width]
masks = masks.permute(0, 2, 3, 1).float().mean(dim=-1)
masks = masks > eps
np_masks = masks.cpu().numpy().astype(np.uint8)
return list(np_masks)
def _apply_polygon_refinement(self, masks: list[npt.NDArray[np.uint8]]) -> list[npt.NDArray[np.uint8]]: