Make Grounding DINO DetectionResult a Pydantic model.

This commit is contained in:
Ryan Dick 2024-07-31 08:47:00 -04:00
parent cec7399366
commit 33e8604b57
2 changed files with 8 additions and 19 deletions

View File

@ -1,11 +1,10 @@
from dataclasses import dataclass
from typing import Any, Optional
import numpy.typing as npt
from pydantic import BaseModel, ConfigDict
@dataclass
class BoundingBox:
class BoundingBox(BaseModel):
"""Bounding box helper class."""
xmin: int
@ -18,24 +17,14 @@ class BoundingBox:
return [self.xmin, self.ymin, self.xmax, self.ymax]
@dataclass
class DetectionResult:
class DetectionResult(BaseModel):
"""Detection result from Grounding DINO or Grounded SAM."""
score: float
label: str
box: BoundingBox
mask: Optional[npt.NDArray[Any]] = None
@classmethod
def from_dict(cls, detection_dict: dict[str, Any]):
return cls(
score=detection_dict["score"],
label=detection_dict["label"],
box=BoundingBox(
xmin=detection_dict["box"]["xmin"],
ymin=detection_dict["box"]["ymin"],
xmax=detection_dict["box"]["xmax"],
ymax=detection_dict["box"]["ymax"],
),
)
model_config = ConfigDict(
# Allow arbitrary types for mask, since it will be a numpy array.
arbitrary_types_allowed=True
)

View File

@ -17,7 +17,7 @@ class GroundingDinoPipeline:
def detect(self, image: Image.Image, candidate_labels: list[str], threshold: float = 0.1) -> list[DetectionResult]:
results = self._pipeline(image=image, candidate_labels=candidate_labels, threshold=threshold)
results = [DetectionResult.from_dict(result) for result in results]
results = [DetectionResult.model_validate(result) for result in results]
return results
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> "GroundingDinoPipeline":