From b5832768dc84702b95bcefaa247b6382750b22f2 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 31 Jul 2024 17:15:48 -0400 Subject: [PATCH] Return a MaskOutput from SegmentAnythingModelInvocation. And add a MaskTensorToImageInvocation. --- invokeai/app/invocations/mask.py | 30 +++++++++- .../app/invocations/segment_anything_model.py | 59 ++++++++++--------- 2 files changed, 59 insertions(+), 30 deletions(-) diff --git a/invokeai/app/invocations/mask.py b/invokeai/app/invocations/mask.py index 6f54660847..2ebefeacff 100644 --- a/invokeai/app/invocations/mask.py +++ b/invokeai/app/invocations/mask.py @@ -1,9 +1,10 @@ import numpy as np import torch +from PIL import Image from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, InvocationContext, invocation -from invokeai.app.invocations.fields import ImageField, InputField, TensorField, WithMetadata -from invokeai.app.invocations.primitives import MaskOutput +from invokeai.app.invocations.fields import ImageField, InputField, TensorField, WithBoard, WithMetadata +from invokeai.app.invocations.primitives import ImageOutput, MaskOutput @invocation( @@ -118,3 +119,28 @@ class ImageMaskToTensorInvocation(BaseInvocation, WithMetadata): height=mask.shape[1], width=mask.shape[2], ) + + +@invocation( + "tensor_mask_to_image", + title="Tensor Mask to Image", + tags=["mask"], + category="mask", + version="1.0.0", +) +class MaskTensorToImageInvocation(BaseInvocation, WithMetadata, WithBoard): + """Convert a mask tensor to an image.""" + + mask: TensorField = InputField(description="The mask tensor to convert.") + + def invoke(self, context: InvocationContext) -> ImageOutput: + mask = context.tensors.load(self.mask.tensor_name) + # Ensure that the mask is binary. + if mask.dtype != torch.bool: + mask = mask > 0.5 + mask_np = mask.float().cpu().detach().numpy() * 255 + mask_np = mask_np.astype(np.uint8) + + mask_pil = Image.fromarray(mask_np, mode="L") + image_dto = context.images.save(image=mask_pil) + return ImageOutput.build(image_dto) diff --git a/invokeai/app/invocations/segment_anything_model.py b/invokeai/app/invocations/segment_anything_model.py index ad264c9584..a652e68338 100644 --- a/invokeai/app/invocations/segment_anything_model.py +++ b/invokeai/app/invocations/segment_anything_model.py @@ -2,7 +2,6 @@ from pathlib import Path from typing import Literal import numpy as np -import numpy.typing as npt import torch from PIL import Image from transformers import AutoModelForMaskGeneration, AutoProcessor @@ -10,8 +9,8 @@ from transformers.models.sam import SamModel from transformers.models.sam.processing_sam import SamProcessor from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation -from invokeai.app.invocations.fields import BoundingBoxField, ImageField, InputField -from invokeai.app.invocations.primitives import ImageOutput +from invokeai.app.invocations.fields import BoundingBoxField, ImageField, InputField, TensorField +from invokeai.app.invocations.primitives import MaskOutput from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.image_util.segment_anything.mask_refinement import mask_to_polygon, polygon_to_mask from invokeai.backend.image_util.segment_anything.segment_anything_model import SegmentAnythingModel @@ -46,24 +45,22 @@ class SegmentAnythingModelInvocation(BaseInvocation): ) @torch.no_grad() - def invoke(self, context: InvocationContext) -> ImageOutput: + def invoke(self, context: InvocationContext) -> MaskOutput: # The models expect a 3-channel RGB image. image_pil = context.images.get_pil(self.image.image_name, mode="RGB") if len(self.bounding_boxes) == 0: - combined_mask = np.zeros(image_pil.size[::-1], dtype=np.uint8) + combined_mask = torch.zeros(image_pil.size[::-1], dtype=torch.bool) else: masks = self._segment(context=context, image=image_pil) masks = self._filter_masks(masks=masks, bounding_boxes=self.bounding_boxes) - # masks contains binary values of 0 or 1, so we merge them via max-reduce. - combined_mask = np.maximum.reduce(masks) - # Map [0, 1] to [0, 255]. - mask_np = combined_mask * 255 - mask_pil = Image.fromarray(mask_np) + # masks contains bool values, so we merge them via max-reduce. + combined_mask, _ = torch.stack(masks).max(dim=0) - image_dto = context.images.save(image=mask_pil) - return ImageOutput.build(image_dto) + mask_tensor_name = context.tensors.save(combined_mask) + height, width = combined_mask.shape + return MaskOutput(mask=TensorField(tensor_name=mask_tensor_name), width=width, height=height) @staticmethod def _load_sam_model(model_path: Path): @@ -84,7 +81,7 @@ class SegmentAnythingModelInvocation(BaseInvocation): self, context: InvocationContext, image: Image.Image, - ) -> list[npt.NDArray[np.uint8]]: + ) -> list[torch.Tensor]: """Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes.""" # Convert the bounding boxes to the SAM input format. sam_bounding_boxes = [[bb.x_min, bb.y_min, bb.x_max, bb.y_max] for bb in self.bounding_boxes] @@ -97,22 +94,23 @@ class SegmentAnythingModelInvocation(BaseInvocation): assert isinstance(sam_pipeline, SegmentAnythingModel) masks = sam_pipeline.segment(image=image, bounding_boxes=sam_bounding_boxes) - masks = self._to_numpy_masks(masks) + masks = self._process_masks(masks) if self.apply_polygon_refinement: masks = self._apply_polygon_refinement(masks) return masks - def _to_numpy_masks(self, masks: torch.Tensor) -> list[npt.NDArray[np.uint8]]: - """Convert the tensor output from the Segment Anything model to a list of numpy masks.""" - eps = 0.0001 + def _process_masks(self, masks: torch.Tensor) -> list[torch.Tensor]: + """Convert the tensor output from the Segment Anything model from a tensor of shape + [num_masks, channels, height, width] to a list of tensors of shape [height, width]. + """ + assert masks.dtype == torch.bool # [num_masks, channels, height, width] -> [num_masks, height, width] - masks = masks.permute(0, 2, 3, 1).float().mean(dim=-1) - masks = masks > eps - np_masks = masks.cpu().numpy().astype(np.uint8) - return list(np_masks) + masks, _ = masks.max(dim=1) + # Split the first dimension into a list of masks. + return list(masks.cpu().unbind(dim=0)) - def _apply_polygon_refinement(self, masks: list[npt.NDArray[np.uint8]]) -> list[npt.NDArray[np.uint8]]: + def _apply_polygon_refinement(self, masks: list[torch.Tensor]) -> list[torch.Tensor]: """Apply polygon refinement to the masks. Convert each mask to a polygon, then back to a mask. This has the following effect: @@ -121,18 +119,23 @@ class SegmentAnythingModelInvocation(BaseInvocation): - Removes small mask pieces. - Removes holes from the mask. """ - for idx, mask in enumerate(masks): + # Convert tensor masks to np masks. + np_masks = [mask.cpu().numpy().astype(np.uint8) for mask in masks] + + # Apply polygon refinement. + for idx, mask in enumerate(np_masks): shape = mask.shape assert len(shape) == 2 # Assert length to satisfy type checker. polygon = mask_to_polygon(mask) mask = polygon_to_mask(polygon, shape) - masks[idx] = mask + np_masks[idx] = mask + + # Convert np masks back to tensor masks. + masks = [torch.tensor(mask, dtype=torch.bool) for mask in np_masks] return masks - def _filter_masks( - self, masks: list[npt.NDArray[np.uint8]], bounding_boxes: list[BoundingBoxField] - ) -> list[npt.NDArray[np.uint8]]: + def _filter_masks(self, masks: list[torch.Tensor], bounding_boxes: list[BoundingBoxField]) -> list[torch.Tensor]: """Filter the detected masks based on the specified mask filter.""" assert len(masks) == len(bounding_boxes) @@ -140,7 +143,7 @@ class SegmentAnythingModelInvocation(BaseInvocation): return masks elif self.mask_filter == "largest": # Find the largest mask. - return [max(masks, key=lambda x: x.sum())] + return [max(masks, key=lambda x: float(x.sum()))] elif self.mask_filter == "highest_box_score": # Find the index of the bounding box with the highest score. # Note that we fallback to -1.0 if the score is None. This is mainly to satisfy the type checker. In most