mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
77 lines
3.4 KiB
Python
77 lines
3.4 KiB
Python
from typing import Dict, cast
|
|
|
|
import torch
|
|
|
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
|
from invokeai.app.invocations.fields import ImageField, InputField
|
|
from invokeai.app.invocations.primitives import ImageOutput
|
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
|
from invokeai.backend.image_util.grounding_segment_anything.gsa import GroundingSegmentAnythingDetector
|
|
from invokeai.backend.util.devices import TorchDevice
|
|
|
|
GROUNDING_SEGMENT_ANYTHING_MODELS = {
|
|
"groundingdino_swint_ogc": "https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth",
|
|
"segment_anything_vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
|
|
}
|
|
|
|
|
|
@invocation(
|
|
"segment_anything",
|
|
title="Segment Anything",
|
|
tags=["grounding_dino", "segment", "anything"],
|
|
category="image",
|
|
version="1.0.0",
|
|
)
|
|
class SegmentAnythingInvocation(BaseInvocation):
|
|
"""Automatically generate masks from an image using GroundingDINO & Segment Anything"""
|
|
|
|
image: ImageField = InputField(description="The image to process")
|
|
prompt: str = InputField(default="", description="Keywords to segment", title="Prompt")
|
|
box_threshold: float = InputField(
|
|
default=0.5, ge=0, le=1, description="Threshold of box detection", title="Box Threshold"
|
|
)
|
|
text_threshold: float = InputField(
|
|
default=0.5, ge=0, le=1, description="Threshold of text detection", title="Text Threshold"
|
|
)
|
|
nms_threshold: float = InputField(
|
|
default=0.8, ge=0, le=1, description="Threshold of nms detection", title="NMS Threshold"
|
|
)
|
|
|
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
input_image = context.images.get_pil(self.image.image_name)
|
|
|
|
grounding_dino_model = context.models.load_remote_model(
|
|
GROUNDING_SEGMENT_ANYTHING_MODELS["groundingdino_swint_ogc"]
|
|
)
|
|
segment_anything_model = context.models.load_remote_model(
|
|
GROUNDING_SEGMENT_ANYTHING_MODELS["segment_anything_vit_h"]
|
|
)
|
|
|
|
with (
|
|
grounding_dino_model.model_on_device() as (_, grounding_dino_state_dict),
|
|
segment_anything_model.model_on_device() as (_, segment_anything_state_dict),
|
|
):
|
|
if not grounding_dino_state_dict or not segment_anything_state_dict:
|
|
raise RuntimeError("Unable to load segmentation models")
|
|
|
|
grounding_dino = GroundingSegmentAnythingDetector.build_grounding_dino(
|
|
cast(Dict[str, torch.Tensor], grounding_dino_state_dict), TorchDevice.choose_torch_device()
|
|
)
|
|
segment_anything = GroundingSegmentAnythingDetector.build_segment_anything(
|
|
cast(Dict[str, torch.Tensor], segment_anything_state_dict), TorchDevice.choose_torch_device()
|
|
)
|
|
detector = GroundingSegmentAnythingDetector(grounding_dino, segment_anything)
|
|
|
|
mask = detector.predict(
|
|
input_image, self.prompt, self.box_threshold, self.text_threshold, self.nms_threshold
|
|
)
|
|
image_dto = context.images.save(mask)
|
|
|
|
"""Builds an ImageOutput and its ImageField"""
|
|
processed_image_field = ImageField(image_name=image_dto.image_name)
|
|
return ImageOutput(
|
|
image=processed_image_field,
|
|
width=input_image.width,
|
|
height=input_image.height,
|
|
)
|