mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Return a MaskOutput from SegmentAnythingModelInvocation. And add a MaskTensorToImageInvocation.
This commit is contained in:
parent
fca119773b
commit
b5832768dc
@ -1,9 +1,10 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, InvocationContext, invocation
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, InvocationContext, invocation
|
||||||
from invokeai.app.invocations.fields import ImageField, InputField, TensorField, WithMetadata
|
from invokeai.app.invocations.fields import ImageField, InputField, TensorField, WithBoard, WithMetadata
|
||||||
from invokeai.app.invocations.primitives import MaskOutput
|
from invokeai.app.invocations.primitives import ImageOutput, MaskOutput
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
@ -118,3 +119,28 @@ class ImageMaskToTensorInvocation(BaseInvocation, WithMetadata):
|
|||||||
height=mask.shape[1],
|
height=mask.shape[1],
|
||||||
width=mask.shape[2],
|
width=mask.shape[2],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@invocation(
|
||||||
|
"tensor_mask_to_image",
|
||||||
|
title="Tensor Mask to Image",
|
||||||
|
tags=["mask"],
|
||||||
|
category="mask",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
|
class MaskTensorToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||||
|
"""Convert a mask tensor to an image."""
|
||||||
|
|
||||||
|
mask: TensorField = InputField(description="The mask tensor to convert.")
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
|
mask = context.tensors.load(self.mask.tensor_name)
|
||||||
|
# Ensure that the mask is binary.
|
||||||
|
if mask.dtype != torch.bool:
|
||||||
|
mask = mask > 0.5
|
||||||
|
mask_np = mask.float().cpu().detach().numpy() * 255
|
||||||
|
mask_np = mask_np.astype(np.uint8)
|
||||||
|
|
||||||
|
mask_pil = Image.fromarray(mask_np, mode="L")
|
||||||
|
image_dto = context.images.save(image=mask_pil)
|
||||||
|
return ImageOutput.build(image_dto)
|
||||||
|
@ -2,7 +2,6 @@ from pathlib import Path
|
|||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import AutoModelForMaskGeneration, AutoProcessor
|
from transformers import AutoModelForMaskGeneration, AutoProcessor
|
||||||
@ -10,8 +9,8 @@ from transformers.models.sam import SamModel
|
|||||||
from transformers.models.sam.processing_sam import SamProcessor
|
from transformers.models.sam.processing_sam import SamProcessor
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||||
from invokeai.app.invocations.fields import BoundingBoxField, ImageField, InputField
|
from invokeai.app.invocations.fields import BoundingBoxField, ImageField, InputField, TensorField
|
||||||
from invokeai.app.invocations.primitives import ImageOutput
|
from invokeai.app.invocations.primitives import MaskOutput
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.backend.image_util.segment_anything.mask_refinement import mask_to_polygon, polygon_to_mask
|
from invokeai.backend.image_util.segment_anything.mask_refinement import mask_to_polygon, polygon_to_mask
|
||||||
from invokeai.backend.image_util.segment_anything.segment_anything_model import SegmentAnythingModel
|
from invokeai.backend.image_util.segment_anything.segment_anything_model import SegmentAnythingModel
|
||||||
@ -46,24 +45,22 @@ class SegmentAnythingModelInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||||
# The models expect a 3-channel RGB image.
|
# The models expect a 3-channel RGB image.
|
||||||
image_pil = context.images.get_pil(self.image.image_name, mode="RGB")
|
image_pil = context.images.get_pil(self.image.image_name, mode="RGB")
|
||||||
|
|
||||||
if len(self.bounding_boxes) == 0:
|
if len(self.bounding_boxes) == 0:
|
||||||
combined_mask = np.zeros(image_pil.size[::-1], dtype=np.uint8)
|
combined_mask = torch.zeros(image_pil.size[::-1], dtype=torch.bool)
|
||||||
else:
|
else:
|
||||||
masks = self._segment(context=context, image=image_pil)
|
masks = self._segment(context=context, image=image_pil)
|
||||||
masks = self._filter_masks(masks=masks, bounding_boxes=self.bounding_boxes)
|
masks = self._filter_masks(masks=masks, bounding_boxes=self.bounding_boxes)
|
||||||
# masks contains binary values of 0 or 1, so we merge them via max-reduce.
|
|
||||||
combined_mask = np.maximum.reduce(masks)
|
|
||||||
|
|
||||||
# Map [0, 1] to [0, 255].
|
# masks contains bool values, so we merge them via max-reduce.
|
||||||
mask_np = combined_mask * 255
|
combined_mask, _ = torch.stack(masks).max(dim=0)
|
||||||
mask_pil = Image.fromarray(mask_np)
|
|
||||||
|
|
||||||
image_dto = context.images.save(image=mask_pil)
|
mask_tensor_name = context.tensors.save(combined_mask)
|
||||||
return ImageOutput.build(image_dto)
|
height, width = combined_mask.shape
|
||||||
|
return MaskOutput(mask=TensorField(tensor_name=mask_tensor_name), width=width, height=height)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _load_sam_model(model_path: Path):
|
def _load_sam_model(model_path: Path):
|
||||||
@ -84,7 +81,7 @@ class SegmentAnythingModelInvocation(BaseInvocation):
|
|||||||
self,
|
self,
|
||||||
context: InvocationContext,
|
context: InvocationContext,
|
||||||
image: Image.Image,
|
image: Image.Image,
|
||||||
) -> list[npt.NDArray[np.uint8]]:
|
) -> list[torch.Tensor]:
|
||||||
"""Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes."""
|
"""Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes."""
|
||||||
# Convert the bounding boxes to the SAM input format.
|
# Convert the bounding boxes to the SAM input format.
|
||||||
sam_bounding_boxes = [[bb.x_min, bb.y_min, bb.x_max, bb.y_max] for bb in self.bounding_boxes]
|
sam_bounding_boxes = [[bb.x_min, bb.y_min, bb.x_max, bb.y_max] for bb in self.bounding_boxes]
|
||||||
@ -97,22 +94,23 @@ class SegmentAnythingModelInvocation(BaseInvocation):
|
|||||||
assert isinstance(sam_pipeline, SegmentAnythingModel)
|
assert isinstance(sam_pipeline, SegmentAnythingModel)
|
||||||
masks = sam_pipeline.segment(image=image, bounding_boxes=sam_bounding_boxes)
|
masks = sam_pipeline.segment(image=image, bounding_boxes=sam_bounding_boxes)
|
||||||
|
|
||||||
masks = self._to_numpy_masks(masks)
|
masks = self._process_masks(masks)
|
||||||
if self.apply_polygon_refinement:
|
if self.apply_polygon_refinement:
|
||||||
masks = self._apply_polygon_refinement(masks)
|
masks = self._apply_polygon_refinement(masks)
|
||||||
|
|
||||||
return masks
|
return masks
|
||||||
|
|
||||||
def _to_numpy_masks(self, masks: torch.Tensor) -> list[npt.NDArray[np.uint8]]:
|
def _process_masks(self, masks: torch.Tensor) -> list[torch.Tensor]:
|
||||||
"""Convert the tensor output from the Segment Anything model to a list of numpy masks."""
|
"""Convert the tensor output from the Segment Anything model from a tensor of shape
|
||||||
eps = 0.0001
|
[num_masks, channels, height, width] to a list of tensors of shape [height, width].
|
||||||
|
"""
|
||||||
|
assert masks.dtype == torch.bool
|
||||||
# [num_masks, channels, height, width] -> [num_masks, height, width]
|
# [num_masks, channels, height, width] -> [num_masks, height, width]
|
||||||
masks = masks.permute(0, 2, 3, 1).float().mean(dim=-1)
|
masks, _ = masks.max(dim=1)
|
||||||
masks = masks > eps
|
# Split the first dimension into a list of masks.
|
||||||
np_masks = masks.cpu().numpy().astype(np.uint8)
|
return list(masks.cpu().unbind(dim=0))
|
||||||
return list(np_masks)
|
|
||||||
|
|
||||||
def _apply_polygon_refinement(self, masks: list[npt.NDArray[np.uint8]]) -> list[npt.NDArray[np.uint8]]:
|
def _apply_polygon_refinement(self, masks: list[torch.Tensor]) -> list[torch.Tensor]:
|
||||||
"""Apply polygon refinement to the masks.
|
"""Apply polygon refinement to the masks.
|
||||||
|
|
||||||
Convert each mask to a polygon, then back to a mask. This has the following effect:
|
Convert each mask to a polygon, then back to a mask. This has the following effect:
|
||||||
@ -121,18 +119,23 @@ class SegmentAnythingModelInvocation(BaseInvocation):
|
|||||||
- Removes small mask pieces.
|
- Removes small mask pieces.
|
||||||
- Removes holes from the mask.
|
- Removes holes from the mask.
|
||||||
"""
|
"""
|
||||||
for idx, mask in enumerate(masks):
|
# Convert tensor masks to np masks.
|
||||||
|
np_masks = [mask.cpu().numpy().astype(np.uint8) for mask in masks]
|
||||||
|
|
||||||
|
# Apply polygon refinement.
|
||||||
|
for idx, mask in enumerate(np_masks):
|
||||||
shape = mask.shape
|
shape = mask.shape
|
||||||
assert len(shape) == 2 # Assert length to satisfy type checker.
|
assert len(shape) == 2 # Assert length to satisfy type checker.
|
||||||
polygon = mask_to_polygon(mask)
|
polygon = mask_to_polygon(mask)
|
||||||
mask = polygon_to_mask(polygon, shape)
|
mask = polygon_to_mask(polygon, shape)
|
||||||
masks[idx] = mask
|
np_masks[idx] = mask
|
||||||
|
|
||||||
|
# Convert np masks back to tensor masks.
|
||||||
|
masks = [torch.tensor(mask, dtype=torch.bool) for mask in np_masks]
|
||||||
|
|
||||||
return masks
|
return masks
|
||||||
|
|
||||||
def _filter_masks(
|
def _filter_masks(self, masks: list[torch.Tensor], bounding_boxes: list[BoundingBoxField]) -> list[torch.Tensor]:
|
||||||
self, masks: list[npt.NDArray[np.uint8]], bounding_boxes: list[BoundingBoxField]
|
|
||||||
) -> list[npt.NDArray[np.uint8]]:
|
|
||||||
"""Filter the detected masks based on the specified mask filter."""
|
"""Filter the detected masks based on the specified mask filter."""
|
||||||
assert len(masks) == len(bounding_boxes)
|
assert len(masks) == len(bounding_boxes)
|
||||||
|
|
||||||
@ -140,7 +143,7 @@ class SegmentAnythingModelInvocation(BaseInvocation):
|
|||||||
return masks
|
return masks
|
||||||
elif self.mask_filter == "largest":
|
elif self.mask_filter == "largest":
|
||||||
# Find the largest mask.
|
# Find the largest mask.
|
||||||
return [max(masks, key=lambda x: x.sum())]
|
return [max(masks, key=lambda x: float(x.sum()))]
|
||||||
elif self.mask_filter == "highest_box_score":
|
elif self.mask_filter == "highest_box_score":
|
||||||
# Find the index of the bounding box with the highest score.
|
# Find the index of the bounding box with the highest score.
|
||||||
# Note that we fallback to -1.0 if the score is None. This is mainly to satisfy the type checker. In most
|
# Note that we fallback to -1.0 if the score is None. This is mainly to satisfy the type checker. In most
|
||||||
|
Loading…
Reference in New Issue
Block a user