diff --git a/invokeai/app/invocations/segment_anything.py b/invokeai/app/invocations/segment_anything.py index 3516f22687..7b9adf584a 100644 --- a/invokeai/app/invocations/segment_anything.py +++ b/invokeai/app/invocations/segment_anything.py @@ -55,7 +55,7 @@ class SegmentAnythingInvocation(BaseInvocation): raise RuntimeError("Unable to load segmentation models") grounding_dino = GroundingSegmentAnythingDetector.build_grounding_dino( - cast(Dict[str, torch.Tensor], grounding_dino_state_dict) + cast(Dict[str, torch.Tensor], grounding_dino_state_dict), TorchDevice.choose_torch_device() ) segment_anything = GroundingSegmentAnythingDetector.build_segment_anything( cast(Dict[str, torch.Tensor], segment_anything_state_dict), TorchDevice.choose_torch_device() diff --git a/invokeai/backend/image_util/grounding_segment_anything/gsa.py b/invokeai/backend/image_util/grounding_segment_anything/gsa.py index 3102091bef..aa07e0218b 100644 --- a/invokeai/backend/image_util/grounding_segment_anything/gsa.py +++ b/invokeai/backend/image_util/grounding_segment_anything/gsa.py @@ -1,5 +1,5 @@ import pathlib -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional import numpy as np import supervision as sv @@ -18,13 +18,14 @@ class GroundingSegmentAnythingDetector: self.segment_anything_model: Optional[SamPredictor] = segment_anything_model @staticmethod - def build_grounding_dino(grounding_dino_state_dict: Dict[str, torch.Tensor]): + def build_grounding_dino(grounding_dino_state_dict: Dict[str, torch.Tensor], device: torch.device): grounding_dino_config = pathlib.Path( "./invokeai/backend/image_util/grounding_segment_anything/groundingdino/config/GroundingDINO_SwinT_OGC.py" ) return Model( model_state_dict=grounding_dino_state_dict, model_config_path=grounding_dino_config.as_posix(), + device=device.type, ) @staticmethod @@ -35,12 +36,15 @@ class GroundingSegmentAnythingDetector: def detect_objects( self, - image: np.ndarray, + image: np.ndarray[Any, Any], prompts: List[str], box_threshold: float = 0.5, text_threshold: float = 0.5, nms_threshold: float = 0.8, ): + if not self.grounding_dino_model: + raise RuntimeError("GroundingDINO model could not load.") + detections = self.grounding_dino_model.predict_with_classes( image=image, classes=prompts, box_threshold=box_threshold, text_threshold=text_threshold ) @@ -52,15 +56,18 @@ class GroundingSegmentAnythingDetector: .numpy() .tolist() ) - detections.xyxy = detections.xyxy[nms_idx] detections.confidence = detections.confidence[nms_idx] detections.class_id = detections.class_id[nms_idx] + return detections def segment_detections( - self, image: np.ndarray, detections: sv.Detections, prompts: List[str] - ) -> Dict[str, np.ndarray]: + self, image: np.ndarray[Any, Any], detections: sv.Detections, prompts: List[str] + ) -> Dict[str, np.ndarray[Any, Any]]: + if not self.segment_anything_model: + raise RuntimeError("Segment Anything model could not be loaded") + self.segment_anything_model.set_image(image) result_masks = {} for box, class_id in zip(detections.xyxy, detections.class_id):