mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Make Grounding DINO DetectionResult a Pydantic model.
This commit is contained in:
parent
cec7399366
commit
33e8604b57
@ -1,11 +1,10 @@
|
|||||||
from dataclasses import dataclass
|
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
class BoundingBox(BaseModel):
|
||||||
class BoundingBox:
|
|
||||||
"""Bounding box helper class."""
|
"""Bounding box helper class."""
|
||||||
|
|
||||||
xmin: int
|
xmin: int
|
||||||
@ -18,24 +17,14 @@ class BoundingBox:
|
|||||||
return [self.xmin, self.ymin, self.xmax, self.ymax]
|
return [self.xmin, self.ymin, self.xmax, self.ymax]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
class DetectionResult(BaseModel):
|
||||||
class DetectionResult:
|
|
||||||
"""Detection result from Grounding DINO or Grounded SAM."""
|
"""Detection result from Grounding DINO or Grounded SAM."""
|
||||||
|
|
||||||
score: float
|
score: float
|
||||||
label: str
|
label: str
|
||||||
box: BoundingBox
|
box: BoundingBox
|
||||||
mask: Optional[npt.NDArray[Any]] = None
|
mask: Optional[npt.NDArray[Any]] = None
|
||||||
|
model_config = ConfigDict(
|
||||||
@classmethod
|
# Allow arbitrary types for mask, since it will be a numpy array.
|
||||||
def from_dict(cls, detection_dict: dict[str, Any]):
|
arbitrary_types_allowed=True
|
||||||
return cls(
|
)
|
||||||
score=detection_dict["score"],
|
|
||||||
label=detection_dict["label"],
|
|
||||||
box=BoundingBox(
|
|
||||||
xmin=detection_dict["box"]["xmin"],
|
|
||||||
ymin=detection_dict["box"]["ymin"],
|
|
||||||
xmax=detection_dict["box"]["xmax"],
|
|
||||||
ymax=detection_dict["box"]["ymax"],
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
@ -17,7 +17,7 @@ class GroundingDinoPipeline:
|
|||||||
|
|
||||||
def detect(self, image: Image.Image, candidate_labels: list[str], threshold: float = 0.1) -> list[DetectionResult]:
|
def detect(self, image: Image.Image, candidate_labels: list[str], threshold: float = 0.1) -> list[DetectionResult]:
|
||||||
results = self._pipeline(image=image, candidate_labels=candidate_labels, threshold=threshold)
|
results = self._pipeline(image=image, candidate_labels=candidate_labels, threshold=threshold)
|
||||||
results = [DetectionResult.from_dict(result) for result in results]
|
results = [DetectionResult.model_validate(result) for result in results]
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> "GroundingDinoPipeline":
|
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> "GroundingDinoPipeline":
|
||||||
|
Loading…
Reference in New Issue
Block a user