mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Move some logic from GroundedSAMInvocation to the backend classes.
This commit is contained in:
@ -4,6 +4,8 @@ import torch
|
||||
from PIL import Image
|
||||
from transformers.pipelines import ZeroShotObjectDetectionPipeline
|
||||
|
||||
from invokeai.backend.grounded_sam.detection_result import DetectionResult
|
||||
|
||||
|
||||
class GroundingDinoPipeline:
|
||||
"""A wrapper class for a ZeroShotObjectDetectionPipeline that makes it compatible with the model manager's memory
|
||||
@ -13,8 +15,10 @@ class GroundingDinoPipeline:
|
||||
def __init__(self, pipeline: ZeroShotObjectDetectionPipeline):
|
||||
self._pipeline = pipeline
|
||||
|
||||
def detect(self, image: Image.Image, candidate_labels: list[str], threshold: float = 0.1):
|
||||
return self._pipeline(image=image, candidate_labels=candidate_labels, threshold=threshold)
|
||||
def detect(self, image: Image.Image, candidate_labels: list[str], threshold: float = 0.1) -> list[DetectionResult]:
|
||||
results = self._pipeline(image=image, candidate_labels=candidate_labels, threshold=threshold)
|
||||
results = [DetectionResult.from_dict(result) for result in results]
|
||||
return results
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> "GroundingDinoPipeline":
|
||||
self._pipeline.model.to(device=device, dtype=dtype)
|
||||
|
@ -5,6 +5,8 @@ from PIL import Image
|
||||
from transformers.models.sam import SamModel
|
||||
from transformers.models.sam.processing_sam import SamProcessor
|
||||
|
||||
from invokeai.backend.grounded_sam.detection_result import DetectionResult
|
||||
|
||||
|
||||
class SegmentAnythingModel:
|
||||
"""A wrapper class for the transformers SAM model and processor that makes it compatible with the model manager."""
|
||||
@ -23,7 +25,8 @@ class SegmentAnythingModel:
|
||||
|
||||
return calc_module_size(self._sam_model)
|
||||
|
||||
def segment(self, image: Image.Image, boxes: list[list[list[int]]]) -> torch.Tensor:
|
||||
def segment(self, image: Image.Image, detection_results: list[DetectionResult]) -> torch.Tensor:
|
||||
boxes = self._to_box_array(detection_results)
|
||||
inputs = self._sam_processor(images=image, input_boxes=boxes, return_tensors="pt").to(self._sam_model.device)
|
||||
outputs = self._sam_model(**inputs)
|
||||
masks = self._sam_processor.post_process_masks(
|
||||
@ -36,3 +39,8 @@ class SegmentAnythingModel:
|
||||
assert len(masks) == 1
|
||||
masks = masks[0]
|
||||
return masks
|
||||
|
||||
def _to_box_array(self, detection_results: list[DetectionResult]) -> list[list[list[int]]]:
|
||||
"""Convert a list of DetectionResults to the bbox format expected by the Segment Anything model."""
|
||||
boxes = [result.box.to_box() for result in detection_results]
|
||||
return [boxes]
|
||||
|
Reference in New Issue
Block a user