Make Grounding DINO DetectionResult a Pydantic model.

This commit is contained in:
Ryan Dick 2024-07-31 08:47:00 -04:00
parent cec7399366
commit 33e8604b57
2 changed files with 8 additions and 19 deletions

View File

@ -1,11 +1,10 @@
from dataclasses import dataclass
from typing import Any, Optional from typing import Any, Optional
import numpy.typing as npt import numpy.typing as npt
from pydantic import BaseModel, ConfigDict
@dataclass class BoundingBox(BaseModel):
class BoundingBox:
"""Bounding box helper class.""" """Bounding box helper class."""
xmin: int xmin: int
@ -18,24 +17,14 @@ class BoundingBox:
return [self.xmin, self.ymin, self.xmax, self.ymax] return [self.xmin, self.ymin, self.xmax, self.ymax]
@dataclass class DetectionResult(BaseModel):
class DetectionResult:
"""Detection result from Grounding DINO or Grounded SAM.""" """Detection result from Grounding DINO or Grounded SAM."""
score: float score: float
label: str label: str
box: BoundingBox box: BoundingBox
mask: Optional[npt.NDArray[Any]] = None mask: Optional[npt.NDArray[Any]] = None
model_config = ConfigDict(
@classmethod # Allow arbitrary types for mask, since it will be a numpy array.
def from_dict(cls, detection_dict: dict[str, Any]): arbitrary_types_allowed=True
return cls( )
score=detection_dict["score"],
label=detection_dict["label"],
box=BoundingBox(
xmin=detection_dict["box"]["xmin"],
ymin=detection_dict["box"]["ymin"],
xmax=detection_dict["box"]["xmax"],
ymax=detection_dict["box"]["ymax"],
),
)

View File

@ -17,7 +17,7 @@ class GroundingDinoPipeline:
def detect(self, image: Image.Image, candidate_labels: list[str], threshold: float = 0.1) -> list[DetectionResult]: def detect(self, image: Image.Image, candidate_labels: list[str], threshold: float = 0.1) -> list[DetectionResult]:
results = self._pipeline(image=image, candidate_labels=candidate_labels, threshold=threshold) results = self._pipeline(image=image, candidate_labels=candidate_labels, threshold=threshold)
results = [DetectionResult.from_dict(result) for result in results] results = [DetectionResult.model_validate(result) for result in results]
return results return results
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> "GroundingDinoPipeline": def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> "GroundingDinoPipeline":