Return a MaskOutput from SegmentAnythingModelInvocation. And add a MaskTensorToImageInvocation.

This commit is contained in:
Ryan Dick 2024-07-31 17:15:48 -04:00
parent fca119773b
commit b5832768dc
2 changed files with 59 additions and 30 deletions

View File

@ -1,9 +1,10 @@
import numpy as np
import torch
from PIL import Image
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, InvocationContext, invocation
from invokeai.app.invocations.fields import ImageField, InputField, TensorField, WithMetadata
from invokeai.app.invocations.primitives import MaskOutput
from invokeai.app.invocations.fields import ImageField, InputField, TensorField, WithBoard, WithMetadata
from invokeai.app.invocations.primitives import ImageOutput, MaskOutput
@invocation(
@ -118,3 +119,28 @@ class ImageMaskToTensorInvocation(BaseInvocation, WithMetadata):
height=mask.shape[1],
width=mask.shape[2],
)
@invocation(
"tensor_mask_to_image",
title="Tensor Mask to Image",
tags=["mask"],
category="mask",
version="1.0.0",
)
class MaskTensorToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Convert a mask tensor to an image."""
mask: TensorField = InputField(description="The mask tensor to convert.")
def invoke(self, context: InvocationContext) -> ImageOutput:
mask = context.tensors.load(self.mask.tensor_name)
# Ensure that the mask is binary.
if mask.dtype != torch.bool:
mask = mask > 0.5
mask_np = mask.float().cpu().detach().numpy() * 255
mask_np = mask_np.astype(np.uint8)
mask_pil = Image.fromarray(mask_np, mode="L")
image_dto = context.images.save(image=mask_pil)
return ImageOutput.build(image_dto)

View File

@ -2,7 +2,6 @@ from pathlib import Path
from typing import Literal
import numpy as np
import numpy.typing as npt
import torch
from PIL import Image
from transformers import AutoModelForMaskGeneration, AutoProcessor
@ -10,8 +9,8 @@ from transformers.models.sam import SamModel
from transformers.models.sam.processing_sam import SamProcessor
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import BoundingBoxField, ImageField, InputField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.invocations.fields import BoundingBoxField, ImageField, InputField, TensorField
from invokeai.app.invocations.primitives import MaskOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.segment_anything.mask_refinement import mask_to_polygon, polygon_to_mask
from invokeai.backend.image_util.segment_anything.segment_anything_model import SegmentAnythingModel
@ -46,24 +45,22 @@ class SegmentAnythingModelInvocation(BaseInvocation):
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
def invoke(self, context: InvocationContext) -> MaskOutput:
# The models expect a 3-channel RGB image.
image_pil = context.images.get_pil(self.image.image_name, mode="RGB")
if len(self.bounding_boxes) == 0:
combined_mask = np.zeros(image_pil.size[::-1], dtype=np.uint8)
combined_mask = torch.zeros(image_pil.size[::-1], dtype=torch.bool)
else:
masks = self._segment(context=context, image=image_pil)
masks = self._filter_masks(masks=masks, bounding_boxes=self.bounding_boxes)
# masks contains binary values of 0 or 1, so we merge them via max-reduce.
combined_mask = np.maximum.reduce(masks)
# Map [0, 1] to [0, 255].
mask_np = combined_mask * 255
mask_pil = Image.fromarray(mask_np)
# masks contains bool values, so we merge them via max-reduce.
combined_mask, _ = torch.stack(masks).max(dim=0)
image_dto = context.images.save(image=mask_pil)
return ImageOutput.build(image_dto)
mask_tensor_name = context.tensors.save(combined_mask)
height, width = combined_mask.shape
return MaskOutput(mask=TensorField(tensor_name=mask_tensor_name), width=width, height=height)
@staticmethod
def _load_sam_model(model_path: Path):
@ -84,7 +81,7 @@ class SegmentAnythingModelInvocation(BaseInvocation):
self,
context: InvocationContext,
image: Image.Image,
) -> list[npt.NDArray[np.uint8]]:
) -> list[torch.Tensor]:
"""Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes."""
# Convert the bounding boxes to the SAM input format.
sam_bounding_boxes = [[bb.x_min, bb.y_min, bb.x_max, bb.y_max] for bb in self.bounding_boxes]
@ -97,22 +94,23 @@ class SegmentAnythingModelInvocation(BaseInvocation):
assert isinstance(sam_pipeline, SegmentAnythingModel)
masks = sam_pipeline.segment(image=image, bounding_boxes=sam_bounding_boxes)
masks = self._to_numpy_masks(masks)
masks = self._process_masks(masks)
if self.apply_polygon_refinement:
masks = self._apply_polygon_refinement(masks)
return masks
def _to_numpy_masks(self, masks: torch.Tensor) -> list[npt.NDArray[np.uint8]]:
"""Convert the tensor output from the Segment Anything model to a list of numpy masks."""
eps = 0.0001
def _process_masks(self, masks: torch.Tensor) -> list[torch.Tensor]:
"""Convert the tensor output from the Segment Anything model from a tensor of shape
[num_masks, channels, height, width] to a list of tensors of shape [height, width].
"""
assert masks.dtype == torch.bool
# [num_masks, channels, height, width] -> [num_masks, height, width]
masks = masks.permute(0, 2, 3, 1).float().mean(dim=-1)
masks = masks > eps
np_masks = masks.cpu().numpy().astype(np.uint8)
return list(np_masks)
masks, _ = masks.max(dim=1)
# Split the first dimension into a list of masks.
return list(masks.cpu().unbind(dim=0))
def _apply_polygon_refinement(self, masks: list[npt.NDArray[np.uint8]]) -> list[npt.NDArray[np.uint8]]:
def _apply_polygon_refinement(self, masks: list[torch.Tensor]) -> list[torch.Tensor]:
"""Apply polygon refinement to the masks.
Convert each mask to a polygon, then back to a mask. This has the following effect:
@ -121,18 +119,23 @@ class SegmentAnythingModelInvocation(BaseInvocation):
- Removes small mask pieces.
- Removes holes from the mask.
"""
for idx, mask in enumerate(masks):
# Convert tensor masks to np masks.
np_masks = [mask.cpu().numpy().astype(np.uint8) for mask in masks]
# Apply polygon refinement.
for idx, mask in enumerate(np_masks):
shape = mask.shape
assert len(shape) == 2 # Assert length to satisfy type checker.
polygon = mask_to_polygon(mask)
mask = polygon_to_mask(polygon, shape)
masks[idx] = mask
np_masks[idx] = mask
# Convert np masks back to tensor masks.
masks = [torch.tensor(mask, dtype=torch.bool) for mask in np_masks]
return masks
def _filter_masks(
self, masks: list[npt.NDArray[np.uint8]], bounding_boxes: list[BoundingBoxField]
) -> list[npt.NDArray[np.uint8]]:
def _filter_masks(self, masks: list[torch.Tensor], bounding_boxes: list[BoundingBoxField]) -> list[torch.Tensor]:
"""Filter the detected masks based on the specified mask filter."""
assert len(masks) == len(bounding_boxes)
@ -140,7 +143,7 @@ class SegmentAnythingModelInvocation(BaseInvocation):
return masks
elif self.mask_filter == "largest":
# Find the largest mask.
return [max(masks, key=lambda x: x.sum())]
return [max(masks, key=lambda x: float(x.sum()))]
elif self.mask_filter == "highest_box_score":
# Find the index of the bounding box with the highest score.
# Note that we fallback to -1.0 if the score is None. This is mainly to satisfy the type checker. In most