InvokeAI/invokeai/backend/grounded_sam/segment_anything_model.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

36 lines
1.4 KiB
Python
Raw Normal View History

from typing import Optional
import torch
from PIL import Image
from transformers.models.sam import SamModel
from transformers.models.sam.processing_sam import SamProcessor
class SegmentAnythingModel:
"""A wrapper class for the transformers SAM model and processor that makes it compatible with the model manager."""
def __init__(self, sam_model: SamModel, sam_processor: SamProcessor):
self._sam_model = sam_model
self._sam_processor = sam_processor
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> "SegmentAnythingModel":
self._sam_model.to(device=device, dtype=dtype)
return self
def calc_size(self) -> int:
# HACK(ryand): Fix the circular import issue.
from invokeai.backend.model_manager.load.model_util import calc_module_size
return calc_module_size(self._sam_model)
def segment(self, image: Image.Image, boxes: list[list[list[int]]]) -> torch.Tensor:
inputs = self._sam_processor(images=image, input_boxes=boxes, return_tensors="pt").to(self._sam_model.device)
outputs = self._sam_model(**inputs)
masks = self._sam_processor.post_process_masks(
masks=outputs.pred_masks,
original_sizes=inputs.original_sizes,
reshaped_input_sizes=inputs.reshaped_input_sizes,
)[0]
return masks