diff --git a/invokeai/backend/image_util/grounded_sam/grounding_dino_pipeline.py b/invokeai/backend/image_util/grounded_sam/grounding_dino_pipeline.py index c03a3d8cab..d998eb1d57 100644 --- a/invokeai/backend/image_util/grounded_sam/grounding_dino_pipeline.py +++ b/invokeai/backend/image_util/grounded_sam/grounding_dino_pipeline.py @@ -5,9 +5,10 @@ from PIL import Image from transformers.pipelines import ZeroShotObjectDetectionPipeline from invokeai.backend.image_util.grounded_sam.detection_result import DetectionResult +from invokeai.backend.raw_model import RawModel -class GroundingDinoPipeline: +class GroundingDinoPipeline(RawModel): """A wrapper class for a ZeroShotObjectDetectionPipeline that makes it compatible with the model manager's memory management system. """ @@ -20,14 +21,13 @@ class GroundingDinoPipeline: results = [DetectionResult.model_validate(result) for result in results] return results - def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> "GroundingDinoPipeline": + def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None): # HACK(ryand): The GroundingDinoPipeline does not work on MPS devices. We only allow it to be moved to CPU or # CUDA. if device is not None and device.type not in {"cpu", "cuda"}: device = None self._pipeline.model.to(device=device, dtype=dtype) self._pipeline.device = self._pipeline.model.device - return self def calc_size(self) -> int: # HACK(ryand): Fix the circular import issue. diff --git a/invokeai/backend/image_util/grounded_sam/segment_anything_model.py b/invokeai/backend/image_util/grounded_sam/segment_anything_model.py index c9959424f2..5147141303 100644 --- a/invokeai/backend/image_util/grounded_sam/segment_anything_model.py +++ b/invokeai/backend/image_util/grounded_sam/segment_anything_model.py @@ -6,21 +6,21 @@ from transformers.models.sam import SamModel from transformers.models.sam.processing_sam import SamProcessor from invokeai.backend.image_util.grounded_sam.detection_result import DetectionResult +from invokeai.backend.raw_model import RawModel -class SegmentAnythingModel: +class SegmentAnythingModel(RawModel): """A wrapper class for the transformers SAM model and processor that makes it compatible with the model manager.""" def __init__(self, sam_model: SamModel, sam_processor: SamProcessor): self._sam_model = sam_model self._sam_processor = sam_processor - def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> "SegmentAnythingModel": + def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None): # HACK(ryand): The SAM pipeline does not work on MPS devices. We only allow it to be moved to CPU or CUDA. if device is not None and device.type not in {"cpu", "cuda"}: device = None self._sam_model.to(device=device, dtype=dtype) - return self def calc_size(self) -> int: # HACK(ryand): Fix the circular import issue.