mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix: pass device to grounding dino initializer
This commit is contained in:
parent
b20c70c588
commit
737356fd14
@ -55,7 +55,7 @@ class SegmentAnythingInvocation(BaseInvocation):
|
||||
raise RuntimeError("Unable to load segmentation models")
|
||||
|
||||
grounding_dino = GroundingSegmentAnythingDetector.build_grounding_dino(
|
||||
cast(Dict[str, torch.Tensor], grounding_dino_state_dict)
|
||||
cast(Dict[str, torch.Tensor], grounding_dino_state_dict), TorchDevice.choose_torch_device()
|
||||
)
|
||||
segment_anything = GroundingSegmentAnythingDetector.build_segment_anything(
|
||||
cast(Dict[str, torch.Tensor], segment_anything_state_dict), TorchDevice.choose_torch_device()
|
||||
|
@ -1,5 +1,5 @@
|
||||
import pathlib
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import supervision as sv
|
||||
@ -18,13 +18,14 @@ class GroundingSegmentAnythingDetector:
|
||||
self.segment_anything_model: Optional[SamPredictor] = segment_anything_model
|
||||
|
||||
@staticmethod
|
||||
def build_grounding_dino(grounding_dino_state_dict: Dict[str, torch.Tensor]):
|
||||
def build_grounding_dino(grounding_dino_state_dict: Dict[str, torch.Tensor], device: torch.device):
|
||||
grounding_dino_config = pathlib.Path(
|
||||
"./invokeai/backend/image_util/grounding_segment_anything/groundingdino/config/GroundingDINO_SwinT_OGC.py"
|
||||
)
|
||||
return Model(
|
||||
model_state_dict=grounding_dino_state_dict,
|
||||
model_config_path=grounding_dino_config.as_posix(),
|
||||
device=device.type,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -35,12 +36,15 @@ class GroundingSegmentAnythingDetector:
|
||||
|
||||
def detect_objects(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
image: np.ndarray[Any, Any],
|
||||
prompts: List[str],
|
||||
box_threshold: float = 0.5,
|
||||
text_threshold: float = 0.5,
|
||||
nms_threshold: float = 0.8,
|
||||
):
|
||||
if not self.grounding_dino_model:
|
||||
raise RuntimeError("GroundingDINO model could not load.")
|
||||
|
||||
detections = self.grounding_dino_model.predict_with_classes(
|
||||
image=image, classes=prompts, box_threshold=box_threshold, text_threshold=text_threshold
|
||||
)
|
||||
@ -52,15 +56,18 @@ class GroundingSegmentAnythingDetector:
|
||||
.numpy()
|
||||
.tolist()
|
||||
)
|
||||
|
||||
detections.xyxy = detections.xyxy[nms_idx]
|
||||
detections.confidence = detections.confidence[nms_idx]
|
||||
detections.class_id = detections.class_id[nms_idx]
|
||||
|
||||
return detections
|
||||
|
||||
def segment_detections(
|
||||
self, image: np.ndarray, detections: sv.Detections, prompts: List[str]
|
||||
) -> Dict[str, np.ndarray]:
|
||||
self, image: np.ndarray[Any, Any], detections: sv.Detections, prompts: List[str]
|
||||
) -> Dict[str, np.ndarray[Any, Any]]:
|
||||
if not self.segment_anything_model:
|
||||
raise RuntimeError("Segment Anything model could not be loaded")
|
||||
|
||||
self.segment_anything_model.set_image(image)
|
||||
result_masks = {}
|
||||
for box, class_id in zip(detections.xyxy, detections.class_id):
|
||||
|
Loading…
Reference in New Issue
Block a user