mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Move some logic from GroundedSAMInvocation to the backend classes.
This commit is contained in:
parent
aca2a2fa13
commit
918f77bce0
@ -1,6 +1,5 @@
|
|||||||
from dataclasses import dataclass
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Literal, Optional
|
from typing import Literal
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
@ -15,6 +14,7 @@ from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
|||||||
from invokeai.app.invocations.fields import ImageField, InputField
|
from invokeai.app.invocations.fields import ImageField, InputField
|
||||||
from invokeai.app.invocations.primitives import ImageOutput
|
from invokeai.app.invocations.primitives import ImageOutput
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
|
from invokeai.backend.grounded_sam.detection_result import DetectionResult
|
||||||
from invokeai.backend.grounded_sam.grounding_dino_pipeline import GroundingDinoPipeline
|
from invokeai.backend.grounded_sam.grounding_dino_pipeline import GroundingDinoPipeline
|
||||||
from invokeai.backend.grounded_sam.mask_refinement import mask_to_polygon, polygon_to_mask
|
from invokeai.backend.grounded_sam.mask_refinement import mask_to_polygon, polygon_to_mask
|
||||||
from invokeai.backend.grounded_sam.segment_anything_model import SegmentAnythingModel
|
from invokeai.backend.grounded_sam.segment_anything_model import SegmentAnythingModel
|
||||||
@ -23,43 +23,6 @@ GROUNDING_DINO_MODEL_ID = "IDEA-Research/grounding-dino-tiny"
|
|||||||
SEGMENT_ANYTHING_MODEL_ID = "facebook/sam-vit-base"
|
SEGMENT_ANYTHING_MODEL_ID = "facebook/sam-vit-base"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class BoundingBox:
|
|
||||||
"""Bounding box helper class used locally for the Grounding DINO outputs."""
|
|
||||||
|
|
||||||
xmin: int
|
|
||||||
ymin: int
|
|
||||||
xmax: int
|
|
||||||
ymax: int
|
|
||||||
|
|
||||||
def to_box(self) -> list[int]:
|
|
||||||
"""Convert to the array notation expected by SAM."""
|
|
||||||
return [self.xmin, self.ymin, self.xmax, self.ymax]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class DetectionResult:
|
|
||||||
"""Detection result from Grounding DINO or Grounded SAM."""
|
|
||||||
|
|
||||||
score: float
|
|
||||||
label: str
|
|
||||||
box: BoundingBox
|
|
||||||
mask: Optional[npt.NDArray[Any]] = None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_dict(cls, detection_dict: dict[str, Any]):
|
|
||||||
return cls(
|
|
||||||
score=detection_dict["score"],
|
|
||||||
label=detection_dict["label"],
|
|
||||||
box=BoundingBox(
|
|
||||||
xmin=detection_dict["box"]["xmin"],
|
|
||||||
ymin=detection_dict["box"]["ymin"],
|
|
||||||
xmax=detection_dict["box"]["xmax"],
|
|
||||||
ymax=detection_dict["box"]["ymax"],
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
"grounded_segment_anything",
|
"grounded_segment_anything",
|
||||||
title="Segment Anything (Text Prompt)",
|
title="Segment Anything (Text Prompt)",
|
||||||
@ -92,9 +55,10 @@ class GroundedSAMInvocation(BaseInvocation):
|
|||||||
description="The detection threshold for the Grounding DINO model. All detected bounding boxes with scores above this threshold will be used.",
|
description="The detection threshold for the Grounding DINO model. All detected bounding boxes with scores above this threshold will be used.",
|
||||||
ge=0.0,
|
ge=0.0,
|
||||||
le=1.0,
|
le=1.0,
|
||||||
default=0.1,
|
default=0.3,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image_pil = context.images.get_pil(self.image.image_name)
|
image_pil = context.images.get_pil(self.image.image_name)
|
||||||
|
|
||||||
@ -118,15 +82,6 @@ class GroundedSAMInvocation(BaseInvocation):
|
|||||||
image_dto = context.images.save(image=mask_pil)
|
image_dto = context.images.save(image=mask_pil)
|
||||||
return ImageOutput.build(image_dto)
|
return ImageOutput.build(image_dto)
|
||||||
|
|
||||||
def _to_box_array(self, detection_results: list[DetectionResult]) -> list[list[list[int]]]:
|
|
||||||
"""Convert a list of DetectionResults to the format expected by the Segment Anything model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
detection_results (list[DetectionResult]): The Grounding DINO detection results.
|
|
||||||
"""
|
|
||||||
boxes = [result.box.to_box() for result in detection_results]
|
|
||||||
return [boxes]
|
|
||||||
|
|
||||||
def _detect(
|
def _detect(
|
||||||
self,
|
self,
|
||||||
context: InvocationContext,
|
context: InvocationContext,
|
||||||
@ -153,10 +108,7 @@ class GroundedSAMInvocation(BaseInvocation):
|
|||||||
|
|
||||||
with context.models.load_remote_model(source=GROUNDING_DINO_MODEL_ID, loader=load_grounding_dino) as detector:
|
with context.models.load_remote_model(source=GROUNDING_DINO_MODEL_ID, loader=load_grounding_dino) as detector:
|
||||||
assert isinstance(detector, GroundingDinoPipeline)
|
assert isinstance(detector, GroundingDinoPipeline)
|
||||||
|
return detector.detect(image=image, candidate_labels=labels, threshold=threshold)
|
||||||
results = detector.detect(image=image, candidate_labels=labels, threshold=threshold)
|
|
||||||
results = [DetectionResult.from_dict(result) for result in results]
|
|
||||||
return results
|
|
||||||
|
|
||||||
def _segment(
|
def _segment(
|
||||||
self,
|
self,
|
||||||
@ -185,8 +137,7 @@ class GroundedSAMInvocation(BaseInvocation):
|
|||||||
):
|
):
|
||||||
assert isinstance(sam_pipeline, SegmentAnythingModel)
|
assert isinstance(sam_pipeline, SegmentAnythingModel)
|
||||||
|
|
||||||
boxes = self._to_box_array(detection_results)
|
masks = sam_pipeline.segment(image=image, detection_results=detection_results)
|
||||||
masks = sam_pipeline.segment(image=image, boxes=boxes)
|
|
||||||
|
|
||||||
masks = self._to_numpy_masks(masks)
|
masks = self._to_numpy_masks(masks)
|
||||||
masks = self._apply_polygon_refinement(masks)
|
masks = self._apply_polygon_refinement(masks)
|
||||||
@ -200,7 +151,7 @@ class GroundedSAMInvocation(BaseInvocation):
|
|||||||
"""Convert the tensor output from the Segment Anything model to a list of numpy masks."""
|
"""Convert the tensor output from the Segment Anything model to a list of numpy masks."""
|
||||||
masks = masks.cpu().float()
|
masks = masks.cpu().float()
|
||||||
masks = masks.permute(0, 2, 3, 1)
|
masks = masks.permute(0, 2, 3, 1)
|
||||||
masks = masks.mean(axis=-1)
|
masks = masks.mean(dim=-1)
|
||||||
masks = (masks > 0).int()
|
masks = (masks > 0).int()
|
||||||
masks = masks.numpy().astype(np.uint8)
|
masks = masks.numpy().astype(np.uint8)
|
||||||
masks = list(masks)
|
masks = list(masks)
|
||||||
|
@ -4,6 +4,8 @@ import torch
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers.pipelines import ZeroShotObjectDetectionPipeline
|
from transformers.pipelines import ZeroShotObjectDetectionPipeline
|
||||||
|
|
||||||
|
from invokeai.backend.grounded_sam.detection_result import DetectionResult
|
||||||
|
|
||||||
|
|
||||||
class GroundingDinoPipeline:
|
class GroundingDinoPipeline:
|
||||||
"""A wrapper class for a ZeroShotObjectDetectionPipeline that makes it compatible with the model manager's memory
|
"""A wrapper class for a ZeroShotObjectDetectionPipeline that makes it compatible with the model manager's memory
|
||||||
@ -13,8 +15,10 @@ class GroundingDinoPipeline:
|
|||||||
def __init__(self, pipeline: ZeroShotObjectDetectionPipeline):
|
def __init__(self, pipeline: ZeroShotObjectDetectionPipeline):
|
||||||
self._pipeline = pipeline
|
self._pipeline = pipeline
|
||||||
|
|
||||||
def detect(self, image: Image.Image, candidate_labels: list[str], threshold: float = 0.1):
|
def detect(self, image: Image.Image, candidate_labels: list[str], threshold: float = 0.1) -> list[DetectionResult]:
|
||||||
return self._pipeline(image=image, candidate_labels=candidate_labels, threshold=threshold)
|
results = self._pipeline(image=image, candidate_labels=candidate_labels, threshold=threshold)
|
||||||
|
results = [DetectionResult.from_dict(result) for result in results]
|
||||||
|
return results
|
||||||
|
|
||||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> "GroundingDinoPipeline":
|
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> "GroundingDinoPipeline":
|
||||||
self._pipeline.model.to(device=device, dtype=dtype)
|
self._pipeline.model.to(device=device, dtype=dtype)
|
||||||
|
@ -5,6 +5,8 @@ from PIL import Image
|
|||||||
from transformers.models.sam import SamModel
|
from transformers.models.sam import SamModel
|
||||||
from transformers.models.sam.processing_sam import SamProcessor
|
from transformers.models.sam.processing_sam import SamProcessor
|
||||||
|
|
||||||
|
from invokeai.backend.grounded_sam.detection_result import DetectionResult
|
||||||
|
|
||||||
|
|
||||||
class SegmentAnythingModel:
|
class SegmentAnythingModel:
|
||||||
"""A wrapper class for the transformers SAM model and processor that makes it compatible with the model manager."""
|
"""A wrapper class for the transformers SAM model and processor that makes it compatible with the model manager."""
|
||||||
@ -23,7 +25,8 @@ class SegmentAnythingModel:
|
|||||||
|
|
||||||
return calc_module_size(self._sam_model)
|
return calc_module_size(self._sam_model)
|
||||||
|
|
||||||
def segment(self, image: Image.Image, boxes: list[list[list[int]]]) -> torch.Tensor:
|
def segment(self, image: Image.Image, detection_results: list[DetectionResult]) -> torch.Tensor:
|
||||||
|
boxes = self._to_box_array(detection_results)
|
||||||
inputs = self._sam_processor(images=image, input_boxes=boxes, return_tensors="pt").to(self._sam_model.device)
|
inputs = self._sam_processor(images=image, input_boxes=boxes, return_tensors="pt").to(self._sam_model.device)
|
||||||
outputs = self._sam_model(**inputs)
|
outputs = self._sam_model(**inputs)
|
||||||
masks = self._sam_processor.post_process_masks(
|
masks = self._sam_processor.post_process_masks(
|
||||||
@ -36,3 +39,8 @@ class SegmentAnythingModel:
|
|||||||
assert len(masks) == 1
|
assert len(masks) == 1
|
||||||
masks = masks[0]
|
masks = masks[0]
|
||||||
return masks
|
return masks
|
||||||
|
|
||||||
|
def _to_box_array(self, detection_results: list[DetectionResult]) -> list[list[list[int]]]:
|
||||||
|
"""Convert a list of DetectionResults to the bbox format expected by the Segment Anything model."""
|
||||||
|
boxes = [result.box.to_box() for result in detection_results]
|
||||||
|
return [boxes]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user