mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Make GroundingDinoPipeline and SegmentAnythingModel subclasses of RawModel for type checking purposes.
This commit is contained in:
parent
9f448fecb7
commit
73386826d6
@ -5,9 +5,10 @@ from PIL import Image
|
|||||||
from transformers.pipelines import ZeroShotObjectDetectionPipeline
|
from transformers.pipelines import ZeroShotObjectDetectionPipeline
|
||||||
|
|
||||||
from invokeai.backend.image_util.grounded_sam.detection_result import DetectionResult
|
from invokeai.backend.image_util.grounded_sam.detection_result import DetectionResult
|
||||||
|
from invokeai.backend.raw_model import RawModel
|
||||||
|
|
||||||
|
|
||||||
class GroundingDinoPipeline:
|
class GroundingDinoPipeline(RawModel):
|
||||||
"""A wrapper class for a ZeroShotObjectDetectionPipeline that makes it compatible with the model manager's memory
|
"""A wrapper class for a ZeroShotObjectDetectionPipeline that makes it compatible with the model manager's memory
|
||||||
management system.
|
management system.
|
||||||
"""
|
"""
|
||||||
@ -20,14 +21,13 @@ class GroundingDinoPipeline:
|
|||||||
results = [DetectionResult.model_validate(result) for result in results]
|
results = [DetectionResult.model_validate(result) for result in results]
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> "GroundingDinoPipeline":
|
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
|
||||||
# HACK(ryand): The GroundingDinoPipeline does not work on MPS devices. We only allow it to be moved to CPU or
|
# HACK(ryand): The GroundingDinoPipeline does not work on MPS devices. We only allow it to be moved to CPU or
|
||||||
# CUDA.
|
# CUDA.
|
||||||
if device is not None and device.type not in {"cpu", "cuda"}:
|
if device is not None and device.type not in {"cpu", "cuda"}:
|
||||||
device = None
|
device = None
|
||||||
self._pipeline.model.to(device=device, dtype=dtype)
|
self._pipeline.model.to(device=device, dtype=dtype)
|
||||||
self._pipeline.device = self._pipeline.model.device
|
self._pipeline.device = self._pipeline.model.device
|
||||||
return self
|
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
def calc_size(self) -> int:
|
||||||
# HACK(ryand): Fix the circular import issue.
|
# HACK(ryand): Fix the circular import issue.
|
||||||
|
@ -6,21 +6,21 @@ from transformers.models.sam import SamModel
|
|||||||
from transformers.models.sam.processing_sam import SamProcessor
|
from transformers.models.sam.processing_sam import SamProcessor
|
||||||
|
|
||||||
from invokeai.backend.image_util.grounded_sam.detection_result import DetectionResult
|
from invokeai.backend.image_util.grounded_sam.detection_result import DetectionResult
|
||||||
|
from invokeai.backend.raw_model import RawModel
|
||||||
|
|
||||||
|
|
||||||
class SegmentAnythingModel:
|
class SegmentAnythingModel(RawModel):
|
||||||
"""A wrapper class for the transformers SAM model and processor that makes it compatible with the model manager."""
|
"""A wrapper class for the transformers SAM model and processor that makes it compatible with the model manager."""
|
||||||
|
|
||||||
def __init__(self, sam_model: SamModel, sam_processor: SamProcessor):
|
def __init__(self, sam_model: SamModel, sam_processor: SamProcessor):
|
||||||
self._sam_model = sam_model
|
self._sam_model = sam_model
|
||||||
self._sam_processor = sam_processor
|
self._sam_processor = sam_processor
|
||||||
|
|
||||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> "SegmentAnythingModel":
|
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
|
||||||
# HACK(ryand): The SAM pipeline does not work on MPS devices. We only allow it to be moved to CPU or CUDA.
|
# HACK(ryand): The SAM pipeline does not work on MPS devices. We only allow it to be moved to CPU or CUDA.
|
||||||
if device is not None and device.type not in {"cpu", "cuda"}:
|
if device is not None and device.type not in {"cpu", "cuda"}:
|
||||||
device = None
|
device = None
|
||||||
self._sam_model.to(device=device, dtype=dtype)
|
self._sam_model.to(device=device, dtype=dtype)
|
||||||
return self
|
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
def calc_size(self) -> int:
|
||||||
# HACK(ryand): Fix the circular import issue.
|
# HACK(ryand): Fix the circular import issue.
|
||||||
|
Loading…
Reference in New Issue
Block a user