Prevent Grounding DINO and Segment Anything from being moved to MPS - they don't work on MPS devices.

This commit is contained in:
Ryan Dick 2024-07-30 23:04:15 +02:00
parent 2da9f913f3
commit 5701c79fab
2 changed files with 7 additions and 0 deletions

View File

@ -21,6 +21,10 @@ class GroundingDinoPipeline:
return results
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> "GroundingDinoPipeline":
# HACK(ryand): The GroundingDinoPipeline does not work on MPS devices. We only allow it to be moved to CPU or
# CUDA.
if device is not None and device.type not in {"cpu", "cuda"}:
device = None
self._pipeline.model.to(device=device, dtype=dtype)
self._pipeline.device = self._pipeline.model.device
return self

View File

@ -16,6 +16,9 @@ class SegmentAnythingModel:
self._sam_processor = sam_processor
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> "SegmentAnythingModel":
# HACK(ryand): The SAM pipeline does not work on MPS devices. We only allow it to be moved to CPU or CUDA.
if device is not None and device.type not in {"cpu", "cuda"}:
device = None
self._sam_model.to(device=device, dtype=dtype)
return self