From b9dc3460ba502810ab283c9a548f923636979d68 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 1 Aug 2024 09:57:47 -0400 Subject: [PATCH] Rename SegmentAnythingModel -> SegmentAnythingPipeline. --- invokeai/app/invocations/segment_anything_model.py | 6 +++--- ...gment_anything_model.py => segment_anything_pipeline.py} | 2 +- invokeai/backend/model_manager/load/model_util.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) rename invokeai/backend/image_util/segment_anything/{segment_anything_model.py => segment_anything_pipeline.py} (98%) diff --git a/invokeai/app/invocations/segment_anything_model.py b/invokeai/app/invocations/segment_anything_model.py index a652e68338..13fc71f231 100644 --- a/invokeai/app/invocations/segment_anything_model.py +++ b/invokeai/app/invocations/segment_anything_model.py @@ -13,7 +13,7 @@ from invokeai.app.invocations.fields import BoundingBoxField, ImageField, InputF from invokeai.app.invocations.primitives import MaskOutput from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.image_util.segment_anything.mask_refinement import mask_to_polygon, polygon_to_mask -from invokeai.backend.image_util.segment_anything.segment_anything_model import SegmentAnythingModel +from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline SEGMENT_ANYTHING_MODEL_ID = "facebook/sam-vit-base" @@ -75,7 +75,7 @@ class SegmentAnythingModelInvocation(BaseInvocation): sam_processor = AutoProcessor.from_pretrained(model_path, local_files_only=True) assert isinstance(sam_processor, SamProcessor) - return SegmentAnythingModel(sam_model=sam_model, sam_processor=sam_processor) + return SegmentAnythingPipeline(sam_model=sam_model, sam_processor=sam_processor) def _segment( self, @@ -91,7 +91,7 @@ class SegmentAnythingModelInvocation(BaseInvocation): source=SEGMENT_ANYTHING_MODEL_ID, loader=SegmentAnythingModelInvocation._load_sam_model ) as sam_pipeline, ): - assert isinstance(sam_pipeline, SegmentAnythingModel) + assert isinstance(sam_pipeline, SegmentAnythingPipeline) masks = sam_pipeline.segment(image=image, bounding_boxes=sam_bounding_boxes) masks = self._process_masks(masks) diff --git a/invokeai/backend/image_util/segment_anything/segment_anything_model.py b/invokeai/backend/image_util/segment_anything/segment_anything_pipeline.py similarity index 98% rename from invokeai/backend/image_util/segment_anything/segment_anything_model.py rename to invokeai/backend/image_util/segment_anything/segment_anything_pipeline.py index d358db5c00..0818fb2fbd 100644 --- a/invokeai/backend/image_util/segment_anything/segment_anything_model.py +++ b/invokeai/backend/image_util/segment_anything/segment_anything_pipeline.py @@ -8,7 +8,7 @@ from transformers.models.sam.processing_sam import SamProcessor from invokeai.backend.raw_model import RawModel -class SegmentAnythingModel(RawModel): +class SegmentAnythingPipeline(RawModel): """A wrapper class for the transformers SAM model and processor that makes it compatible with the model manager.""" def __init__(self, sam_model: SamModel, sam_processor: SamProcessor): diff --git a/invokeai/backend/model_manager/load/model_util.py b/invokeai/backend/model_manager/load/model_util.py index 9c2570978a..2656a8b6ef 100644 --- a/invokeai/backend/model_manager/load/model_util.py +++ b/invokeai/backend/model_manager/load/model_util.py @@ -12,7 +12,7 @@ from diffusers.schedulers.scheduling_utils import SchedulerMixin from transformers import CLIPTokenizer from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline -from invokeai.backend.image_util.segment_anything.segment_anything_model import SegmentAnythingModel +from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline from invokeai.backend.ip_adapter.ip_adapter import IPAdapter from invokeai.backend.lora import LoRAModelRaw from invokeai.backend.model_manager.config import AnyModel @@ -44,7 +44,7 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int: LoRAModelRaw, SpandrelImageToImageModel, GroundingDinoPipeline, - SegmentAnythingModel, + SegmentAnythingPipeline, ), ): return model.calc_size()