mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Prevent Grounding DINO and Segment Anything from being moved to MPS - they don't work on MPS devices.
This commit is contained in:
@ -21,6 +21,10 @@ class GroundingDinoPipeline:
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> "GroundingDinoPipeline":
|
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> "GroundingDinoPipeline":
|
||||||
|
# HACK(ryand): The GroundingDinoPipeline does not work on MPS devices. We only allow it to be moved to CPU or
|
||||||
|
# CUDA.
|
||||||
|
if device is not None and device.type not in {"cpu", "cuda"}:
|
||||||
|
device = None
|
||||||
self._pipeline.model.to(device=device, dtype=dtype)
|
self._pipeline.model.to(device=device, dtype=dtype)
|
||||||
self._pipeline.device = self._pipeline.model.device
|
self._pipeline.device = self._pipeline.model.device
|
||||||
return self
|
return self
|
||||||
|
@ -16,6 +16,9 @@ class SegmentAnythingModel:
|
|||||||
self._sam_processor = sam_processor
|
self._sam_processor = sam_processor
|
||||||
|
|
||||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> "SegmentAnythingModel":
|
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> "SegmentAnythingModel":
|
||||||
|
# HACK(ryand): The SAM pipeline does not work on MPS devices. We only allow it to be moved to CPU or CUDA.
|
||||||
|
if device is not None and device.type not in {"cpu", "cuda"}:
|
||||||
|
device = None
|
||||||
self._sam_model.to(device=device, dtype=dtype)
|
self._sam_model.to(device=device, dtype=dtype)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user