Make GroundingDinoPipeline and SegmentAnythingModel subclasses of RawModel for type checking purposes.

This commit is contained in:
Ryan Dick 2024-07-31 10:25:34 -04:00
parent 9f448fecb7
commit 73386826d6
2 changed files with 6 additions and 6 deletions

View File

@ -5,9 +5,10 @@ from PIL import Image
from transformers.pipelines import ZeroShotObjectDetectionPipeline from transformers.pipelines import ZeroShotObjectDetectionPipeline
from invokeai.backend.image_util.grounded_sam.detection_result import DetectionResult from invokeai.backend.image_util.grounded_sam.detection_result import DetectionResult
from invokeai.backend.raw_model import RawModel
class GroundingDinoPipeline: class GroundingDinoPipeline(RawModel):
"""A wrapper class for a ZeroShotObjectDetectionPipeline that makes it compatible with the model manager's memory """A wrapper class for a ZeroShotObjectDetectionPipeline that makes it compatible with the model manager's memory
management system. management system.
""" """
@ -20,14 +21,13 @@ class GroundingDinoPipeline:
results = [DetectionResult.model_validate(result) for result in results] results = [DetectionResult.model_validate(result) for result in results]
return results return results
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> "GroundingDinoPipeline": def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
# HACK(ryand): The GroundingDinoPipeline does not work on MPS devices. We only allow it to be moved to CPU or # HACK(ryand): The GroundingDinoPipeline does not work on MPS devices. We only allow it to be moved to CPU or
# CUDA. # CUDA.
if device is not None and device.type not in {"cpu", "cuda"}: if device is not None and device.type not in {"cpu", "cuda"}:
device = None device = None
self._pipeline.model.to(device=device, dtype=dtype) self._pipeline.model.to(device=device, dtype=dtype)
self._pipeline.device = self._pipeline.model.device self._pipeline.device = self._pipeline.model.device
return self
def calc_size(self) -> int: def calc_size(self) -> int:
# HACK(ryand): Fix the circular import issue. # HACK(ryand): Fix the circular import issue.

View File

@ -6,21 +6,21 @@ from transformers.models.sam import SamModel
from transformers.models.sam.processing_sam import SamProcessor from transformers.models.sam.processing_sam import SamProcessor
from invokeai.backend.image_util.grounded_sam.detection_result import DetectionResult from invokeai.backend.image_util.grounded_sam.detection_result import DetectionResult
from invokeai.backend.raw_model import RawModel
class SegmentAnythingModel: class SegmentAnythingModel(RawModel):
"""A wrapper class for the transformers SAM model and processor that makes it compatible with the model manager.""" """A wrapper class for the transformers SAM model and processor that makes it compatible with the model manager."""
def __init__(self, sam_model: SamModel, sam_processor: SamProcessor): def __init__(self, sam_model: SamModel, sam_processor: SamProcessor):
self._sam_model = sam_model self._sam_model = sam_model
self._sam_processor = sam_processor self._sam_processor = sam_processor
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> "SegmentAnythingModel": def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
# HACK(ryand): The SAM pipeline does not work on MPS devices. We only allow it to be moved to CPU or CUDA. # HACK(ryand): The SAM pipeline does not work on MPS devices. We only allow it to be moved to CPU or CUDA.
if device is not None and device.type not in {"cpu", "cuda"}: if device is not None and device.type not in {"cpu", "cuda"}:
device = None device = None
self._sam_model.to(device=device, dtype=dtype) self._sam_model.to(device=device, dtype=dtype)
return self
def calc_size(self) -> int: def calc_size(self) -> int:
# HACK(ryand): Fix the circular import issue. # HACK(ryand): Fix the circular import issue.