Return a MaskOutput from SegmentAnythingModelInvocation. And add a MaskTensorToImageInvocation.

This commit is contained in:
Ryan Dick 2024-07-31 17:15:48 -04:00
parent fca119773b
commit b5832768dc
2 changed files with 59 additions and 30 deletions

View File

@ -1,9 +1,10 @@
import numpy as np import numpy as np
import torch import torch
from PIL import Image
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, InvocationContext, invocation from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, InvocationContext, invocation
from invokeai.app.invocations.fields import ImageField, InputField, TensorField, WithMetadata from invokeai.app.invocations.fields import ImageField, InputField, TensorField, WithBoard, WithMetadata
from invokeai.app.invocations.primitives import MaskOutput from invokeai.app.invocations.primitives import ImageOutput, MaskOutput
@invocation( @invocation(
@ -118,3 +119,28 @@ class ImageMaskToTensorInvocation(BaseInvocation, WithMetadata):
height=mask.shape[1], height=mask.shape[1],
width=mask.shape[2], width=mask.shape[2],
) )
@invocation(
"tensor_mask_to_image",
title="Tensor Mask to Image",
tags=["mask"],
category="mask",
version="1.0.0",
)
class MaskTensorToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Convert a mask tensor to an image."""
mask: TensorField = InputField(description="The mask tensor to convert.")
def invoke(self, context: InvocationContext) -> ImageOutput:
mask = context.tensors.load(self.mask.tensor_name)
# Ensure that the mask is binary.
if mask.dtype != torch.bool:
mask = mask > 0.5
mask_np = mask.float().cpu().detach().numpy() * 255
mask_np = mask_np.astype(np.uint8)
mask_pil = Image.fromarray(mask_np, mode="L")
image_dto = context.images.save(image=mask_pil)
return ImageOutput.build(image_dto)

View File

@ -2,7 +2,6 @@ from pathlib import Path
from typing import Literal from typing import Literal
import numpy as np import numpy as np
import numpy.typing as npt
import torch import torch
from PIL import Image from PIL import Image
from transformers import AutoModelForMaskGeneration, AutoProcessor from transformers import AutoModelForMaskGeneration, AutoProcessor
@ -10,8 +9,8 @@ from transformers.models.sam import SamModel
from transformers.models.sam.processing_sam import SamProcessor from transformers.models.sam.processing_sam import SamProcessor
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import BoundingBoxField, ImageField, InputField from invokeai.app.invocations.fields import BoundingBoxField, ImageField, InputField, TensorField
from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.invocations.primitives import MaskOutput
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.segment_anything.mask_refinement import mask_to_polygon, polygon_to_mask from invokeai.backend.image_util.segment_anything.mask_refinement import mask_to_polygon, polygon_to_mask
from invokeai.backend.image_util.segment_anything.segment_anything_model import SegmentAnythingModel from invokeai.backend.image_util.segment_anything.segment_anything_model import SegmentAnythingModel
@ -46,24 +45,22 @@ class SegmentAnythingModelInvocation(BaseInvocation):
) )
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> MaskOutput:
# The models expect a 3-channel RGB image. # The models expect a 3-channel RGB image.
image_pil = context.images.get_pil(self.image.image_name, mode="RGB") image_pil = context.images.get_pil(self.image.image_name, mode="RGB")
if len(self.bounding_boxes) == 0: if len(self.bounding_boxes) == 0:
combined_mask = np.zeros(image_pil.size[::-1], dtype=np.uint8) combined_mask = torch.zeros(image_pil.size[::-1], dtype=torch.bool)
else: else:
masks = self._segment(context=context, image=image_pil) masks = self._segment(context=context, image=image_pil)
masks = self._filter_masks(masks=masks, bounding_boxes=self.bounding_boxes) masks = self._filter_masks(masks=masks, bounding_boxes=self.bounding_boxes)
# masks contains binary values of 0 or 1, so we merge them via max-reduce.
combined_mask = np.maximum.reduce(masks)
# Map [0, 1] to [0, 255]. # masks contains bool values, so we merge them via max-reduce.
mask_np = combined_mask * 255 combined_mask, _ = torch.stack(masks).max(dim=0)
mask_pil = Image.fromarray(mask_np)
image_dto = context.images.save(image=mask_pil) mask_tensor_name = context.tensors.save(combined_mask)
return ImageOutput.build(image_dto) height, width = combined_mask.shape
return MaskOutput(mask=TensorField(tensor_name=mask_tensor_name), width=width, height=height)
@staticmethod @staticmethod
def _load_sam_model(model_path: Path): def _load_sam_model(model_path: Path):
@ -84,7 +81,7 @@ class SegmentAnythingModelInvocation(BaseInvocation):
self, self,
context: InvocationContext, context: InvocationContext,
image: Image.Image, image: Image.Image,
) -> list[npt.NDArray[np.uint8]]: ) -> list[torch.Tensor]:
"""Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes.""" """Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes."""
# Convert the bounding boxes to the SAM input format. # Convert the bounding boxes to the SAM input format.
sam_bounding_boxes = [[bb.x_min, bb.y_min, bb.x_max, bb.y_max] for bb in self.bounding_boxes] sam_bounding_boxes = [[bb.x_min, bb.y_min, bb.x_max, bb.y_max] for bb in self.bounding_boxes]
@ -97,22 +94,23 @@ class SegmentAnythingModelInvocation(BaseInvocation):
assert isinstance(sam_pipeline, SegmentAnythingModel) assert isinstance(sam_pipeline, SegmentAnythingModel)
masks = sam_pipeline.segment(image=image, bounding_boxes=sam_bounding_boxes) masks = sam_pipeline.segment(image=image, bounding_boxes=sam_bounding_boxes)
masks = self._to_numpy_masks(masks) masks = self._process_masks(masks)
if self.apply_polygon_refinement: if self.apply_polygon_refinement:
masks = self._apply_polygon_refinement(masks) masks = self._apply_polygon_refinement(masks)
return masks return masks
def _to_numpy_masks(self, masks: torch.Tensor) -> list[npt.NDArray[np.uint8]]: def _process_masks(self, masks: torch.Tensor) -> list[torch.Tensor]:
"""Convert the tensor output from the Segment Anything model to a list of numpy masks.""" """Convert the tensor output from the Segment Anything model from a tensor of shape
eps = 0.0001 [num_masks, channels, height, width] to a list of tensors of shape [height, width].
"""
assert masks.dtype == torch.bool
# [num_masks, channels, height, width] -> [num_masks, height, width] # [num_masks, channels, height, width] -> [num_masks, height, width]
masks = masks.permute(0, 2, 3, 1).float().mean(dim=-1) masks, _ = masks.max(dim=1)
masks = masks > eps # Split the first dimension into a list of masks.
np_masks = masks.cpu().numpy().astype(np.uint8) return list(masks.cpu().unbind(dim=0))
return list(np_masks)
def _apply_polygon_refinement(self, masks: list[npt.NDArray[np.uint8]]) -> list[npt.NDArray[np.uint8]]: def _apply_polygon_refinement(self, masks: list[torch.Tensor]) -> list[torch.Tensor]:
"""Apply polygon refinement to the masks. """Apply polygon refinement to the masks.
Convert each mask to a polygon, then back to a mask. This has the following effect: Convert each mask to a polygon, then back to a mask. This has the following effect:
@ -121,18 +119,23 @@ class SegmentAnythingModelInvocation(BaseInvocation):
- Removes small mask pieces. - Removes small mask pieces.
- Removes holes from the mask. - Removes holes from the mask.
""" """
for idx, mask in enumerate(masks): # Convert tensor masks to np masks.
np_masks = [mask.cpu().numpy().astype(np.uint8) for mask in masks]
# Apply polygon refinement.
for idx, mask in enumerate(np_masks):
shape = mask.shape shape = mask.shape
assert len(shape) == 2 # Assert length to satisfy type checker. assert len(shape) == 2 # Assert length to satisfy type checker.
polygon = mask_to_polygon(mask) polygon = mask_to_polygon(mask)
mask = polygon_to_mask(polygon, shape) mask = polygon_to_mask(polygon, shape)
masks[idx] = mask np_masks[idx] = mask
# Convert np masks back to tensor masks.
masks = [torch.tensor(mask, dtype=torch.bool) for mask in np_masks]
return masks return masks
def _filter_masks( def _filter_masks(self, masks: list[torch.Tensor], bounding_boxes: list[BoundingBoxField]) -> list[torch.Tensor]:
self, masks: list[npt.NDArray[np.uint8]], bounding_boxes: list[BoundingBoxField]
) -> list[npt.NDArray[np.uint8]]:
"""Filter the detected masks based on the specified mask filter.""" """Filter the detected masks based on the specified mask filter."""
assert len(masks) == len(bounding_boxes) assert len(masks) == len(bounding_boxes)
@ -140,7 +143,7 @@ class SegmentAnythingModelInvocation(BaseInvocation):
return masks return masks
elif self.mask_filter == "largest": elif self.mask_filter == "largest":
# Find the largest mask. # Find the largest mask.
return [max(masks, key=lambda x: x.sum())] return [max(masks, key=lambda x: float(x.sum()))]
elif self.mask_filter == "highest_box_score": elif self.mask_filter == "highest_box_score":
# Find the index of the bounding box with the highest score. # Find the index of the bounding box with the highest score.
# Note that we fallback to -1.0 if the score is None. This is mainly to satisfy the type checker. In most # Note that we fallback to -1.0 if the score is None. This is mainly to satisfy the type checker. In most