InvokeAI/invokeai/app/invocations/segment_anything.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

77 lines
3.4 KiB
Python
Raw Normal View History

2024-07-25 19:13:26 +00:00
from typing import Dict, cast
import torch
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import ImageField, InputField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.grounding_segment_anything.gsa import GroundingSegmentAnythingDetector
from invokeai.backend.util.devices import TorchDevice
GROUNDING_SEGMENT_ANYTHING_MODELS = {
"groundingdino_swint_ogc": "https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth",
"segment_anything_vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
}
@invocation(
"segment_anything",
title="Segment Anything",
tags=["grounding_dino", "segment", "anything"],
category="image",
version="1.0.0",
)
class SegmentAnythingInvocation(BaseInvocation):
"""Automatically generate masks from an image using GroundingDINO & Segment Anything"""
image: ImageField = InputField(description="The image to process")
prompt: str = InputField(default="", description="Keywords to segment", title="Prompt")
box_threshold: float = InputField(
default=0.5, ge=0, le=1, description="Threshold of box detection", title="Box Threshold"
)
text_threshold: float = InputField(
default=0.5, ge=0, le=1, description="Threshold of text detection", title="Text Threshold"
)
nms_threshold: float = InputField(
default=0.8, ge=0, le=1, description="Threshold of nms detection", title="NMS Threshold"
)
def invoke(self, context: InvocationContext) -> ImageOutput:
input_image = context.images.get_pil(self.image.image_name)
grounding_dino_model = context.models.load_remote_model(
GROUNDING_SEGMENT_ANYTHING_MODELS["groundingdino_swint_ogc"]
)
segment_anything_model = context.models.load_remote_model(
GROUNDING_SEGMENT_ANYTHING_MODELS["segment_anything_vit_h"]
)
with (
grounding_dino_model.model_on_device() as (_, grounding_dino_state_dict),
segment_anything_model.model_on_device() as (_, segment_anything_state_dict),
):
if not grounding_dino_state_dict or not segment_anything_state_dict:
raise RuntimeError("Unable to load segmentation models")
grounding_dino = GroundingSegmentAnythingDetector.build_grounding_dino(
cast(Dict[str, torch.Tensor], grounding_dino_state_dict), TorchDevice.choose_torch_device()
2024-07-25 19:13:26 +00:00
)
segment_anything = GroundingSegmentAnythingDetector.build_segment_anything(
cast(Dict[str, torch.Tensor], segment_anything_state_dict), TorchDevice.choose_torch_device()
)
detector = GroundingSegmentAnythingDetector(grounding_dino, segment_anything)
mask = detector.predict(
input_image, self.prompt, self.box_threshold, self.text_threshold, self.nms_threshold
)
image_dto = context.images.save(mask)
"""Builds an ImageOutput and its ImageField"""
processed_image_field = ImageField(image_name=image_dto.image_name)
return ImageOutput(
image=processed_image_field,
width=input_image.width,
height=input_image.height,
)