fix: pass device to grounding dino initializer

This commit is contained in:
blessedcoolant 2024-07-26 00:55:22 +05:30
parent b20c70c588
commit 737356fd14
2 changed files with 14 additions and 7 deletions

View File

@ -55,7 +55,7 @@ class SegmentAnythingInvocation(BaseInvocation):
raise RuntimeError("Unable to load segmentation models")
grounding_dino = GroundingSegmentAnythingDetector.build_grounding_dino(
cast(Dict[str, torch.Tensor], grounding_dino_state_dict)
cast(Dict[str, torch.Tensor], grounding_dino_state_dict), TorchDevice.choose_torch_device()
)
segment_anything = GroundingSegmentAnythingDetector.build_segment_anything(
cast(Dict[str, torch.Tensor], segment_anything_state_dict), TorchDevice.choose_torch_device()

View File

@ -1,5 +1,5 @@
import pathlib
from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional
import numpy as np
import supervision as sv
@ -18,13 +18,14 @@ class GroundingSegmentAnythingDetector:
self.segment_anything_model: Optional[SamPredictor] = segment_anything_model
@staticmethod
def build_grounding_dino(grounding_dino_state_dict: Dict[str, torch.Tensor]):
def build_grounding_dino(grounding_dino_state_dict: Dict[str, torch.Tensor], device: torch.device):
grounding_dino_config = pathlib.Path(
"./invokeai/backend/image_util/grounding_segment_anything/groundingdino/config/GroundingDINO_SwinT_OGC.py"
)
return Model(
model_state_dict=grounding_dino_state_dict,
model_config_path=grounding_dino_config.as_posix(),
device=device.type,
)
@staticmethod
@ -35,12 +36,15 @@ class GroundingSegmentAnythingDetector:
def detect_objects(
self,
image: np.ndarray,
image: np.ndarray[Any, Any],
prompts: List[str],
box_threshold: float = 0.5,
text_threshold: float = 0.5,
nms_threshold: float = 0.8,
):
if not self.grounding_dino_model:
raise RuntimeError("GroundingDINO model could not load.")
detections = self.grounding_dino_model.predict_with_classes(
image=image, classes=prompts, box_threshold=box_threshold, text_threshold=text_threshold
)
@ -52,15 +56,18 @@ class GroundingSegmentAnythingDetector:
.numpy()
.tolist()
)
detections.xyxy = detections.xyxy[nms_idx]
detections.confidence = detections.confidence[nms_idx]
detections.class_id = detections.class_id[nms_idx]
return detections
def segment_detections(
self, image: np.ndarray, detections: sv.Detections, prompts: List[str]
) -> Dict[str, np.ndarray]:
self, image: np.ndarray[Any, Any], detections: sv.Detections, prompts: List[str]
) -> Dict[str, np.ndarray[Any, Any]]:
if not self.segment_anything_model:
raise RuntimeError("Segment Anything model could not be loaded")
self.segment_anything_model.set_image(image)
result_masks = {}
for box, class_id in zip(detections.xyxy, detections.class_id):