Rename SegmentAnythingModel -> SegmentAnythingPipeline.

This commit is contained in:
Ryan Dick 2024-08-01 09:57:47 -04:00
parent 63581ec980
commit b9dc3460ba
3 changed files with 6 additions and 6 deletions

View File

@ -13,7 +13,7 @@ from invokeai.app.invocations.fields import BoundingBoxField, ImageField, InputF
from invokeai.app.invocations.primitives import MaskOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.segment_anything.mask_refinement import mask_to_polygon, polygon_to_mask
from invokeai.backend.image_util.segment_anything.segment_anything_model import SegmentAnythingModel
from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline
SEGMENT_ANYTHING_MODEL_ID = "facebook/sam-vit-base"
@ -75,7 +75,7 @@ class SegmentAnythingModelInvocation(BaseInvocation):
sam_processor = AutoProcessor.from_pretrained(model_path, local_files_only=True)
assert isinstance(sam_processor, SamProcessor)
return SegmentAnythingModel(sam_model=sam_model, sam_processor=sam_processor)
return SegmentAnythingPipeline(sam_model=sam_model, sam_processor=sam_processor)
def _segment(
self,
@ -91,7 +91,7 @@ class SegmentAnythingModelInvocation(BaseInvocation):
source=SEGMENT_ANYTHING_MODEL_ID, loader=SegmentAnythingModelInvocation._load_sam_model
) as sam_pipeline,
):
assert isinstance(sam_pipeline, SegmentAnythingModel)
assert isinstance(sam_pipeline, SegmentAnythingPipeline)
masks = sam_pipeline.segment(image=image, bounding_boxes=sam_bounding_boxes)
masks = self._process_masks(masks)

View File

@ -8,7 +8,7 @@ from transformers.models.sam.processing_sam import SamProcessor
from invokeai.backend.raw_model import RawModel
class SegmentAnythingModel(RawModel):
class SegmentAnythingPipeline(RawModel):
"""A wrapper class for the transformers SAM model and processor that makes it compatible with the model manager."""
def __init__(self, sam_model: SamModel, sam_processor: SamProcessor):

View File

@ -12,7 +12,7 @@ from diffusers.schedulers.scheduling_utils import SchedulerMixin
from transformers import CLIPTokenizer
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
from invokeai.backend.image_util.segment_anything.segment_anything_model import SegmentAnythingModel
from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_manager.config import AnyModel
@ -44,7 +44,7 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
LoRAModelRaw,
SpandrelImageToImageModel,
GroundingDinoPipeline,
SegmentAnythingModel,
SegmentAnythingPipeline,
),
):
return model.calc_size()