diff --git a/invokeai/backend/grounded_sam/detection_result.py b/invokeai/backend/grounded_sam/detection_result.py index 40e4254385..a9f2bdd65f 100644 --- a/invokeai/backend/grounded_sam/detection_result.py +++ b/invokeai/backend/grounded_sam/detection_result.py @@ -1,11 +1,10 @@ -from dataclasses import dataclass from typing import Any, Optional import numpy.typing as npt +from pydantic import BaseModel, ConfigDict -@dataclass -class BoundingBox: +class BoundingBox(BaseModel): """Bounding box helper class.""" xmin: int @@ -18,24 +17,14 @@ class BoundingBox: return [self.xmin, self.ymin, self.xmax, self.ymax] -@dataclass -class DetectionResult: +class DetectionResult(BaseModel): """Detection result from Grounding DINO or Grounded SAM.""" score: float label: str box: BoundingBox mask: Optional[npt.NDArray[Any]] = None - - @classmethod - def from_dict(cls, detection_dict: dict[str, Any]): - return cls( - score=detection_dict["score"], - label=detection_dict["label"], - box=BoundingBox( - xmin=detection_dict["box"]["xmin"], - ymin=detection_dict["box"]["ymin"], - xmax=detection_dict["box"]["xmax"], - ymax=detection_dict["box"]["ymax"], - ), - ) + model_config = ConfigDict( + # Allow arbitrary types for mask, since it will be a numpy array. + arbitrary_types_allowed=True + ) diff --git a/invokeai/backend/grounded_sam/grounding_dino_pipeline.py b/invokeai/backend/grounded_sam/grounding_dino_pipeline.py index 1fc92b5e12..97c92f9249 100644 --- a/invokeai/backend/grounded_sam/grounding_dino_pipeline.py +++ b/invokeai/backend/grounded_sam/grounding_dino_pipeline.py @@ -17,7 +17,7 @@ class GroundingDinoPipeline: def detect(self, image: Image.Image, candidate_labels: list[str], threshold: float = 0.1) -> list[DetectionResult]: results = self._pipeline(image=image, candidate_labels=candidate_labels, threshold=threshold) - results = [DetectionResult.from_dict(result) for result in results] + results = [DetectionResult.model_validate(result) for result in results] return results def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> "GroundingDinoPipeline":