Rename SegmentAnythingModel -> SegmentAnythingPipeline.

This commit is contained in:
Ryan Dick 2024-08-01 09:57:47 -04:00
parent 63581ec980
commit b9dc3460ba
3 changed files with 6 additions and 6 deletions

View File

@ -13,7 +13,7 @@ from invokeai.app.invocations.fields import BoundingBoxField, ImageField, InputF
from invokeai.app.invocations.primitives import MaskOutput from invokeai.app.invocations.primitives import MaskOutput
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.segment_anything.mask_refinement import mask_to_polygon, polygon_to_mask from invokeai.backend.image_util.segment_anything.mask_refinement import mask_to_polygon, polygon_to_mask
from invokeai.backend.image_util.segment_anything.segment_anything_model import SegmentAnythingModel from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline
SEGMENT_ANYTHING_MODEL_ID = "facebook/sam-vit-base" SEGMENT_ANYTHING_MODEL_ID = "facebook/sam-vit-base"
@ -75,7 +75,7 @@ class SegmentAnythingModelInvocation(BaseInvocation):
sam_processor = AutoProcessor.from_pretrained(model_path, local_files_only=True) sam_processor = AutoProcessor.from_pretrained(model_path, local_files_only=True)
assert isinstance(sam_processor, SamProcessor) assert isinstance(sam_processor, SamProcessor)
return SegmentAnythingModel(sam_model=sam_model, sam_processor=sam_processor) return SegmentAnythingPipeline(sam_model=sam_model, sam_processor=sam_processor)
def _segment( def _segment(
self, self,
@ -91,7 +91,7 @@ class SegmentAnythingModelInvocation(BaseInvocation):
source=SEGMENT_ANYTHING_MODEL_ID, loader=SegmentAnythingModelInvocation._load_sam_model source=SEGMENT_ANYTHING_MODEL_ID, loader=SegmentAnythingModelInvocation._load_sam_model
) as sam_pipeline, ) as sam_pipeline,
): ):
assert isinstance(sam_pipeline, SegmentAnythingModel) assert isinstance(sam_pipeline, SegmentAnythingPipeline)
masks = sam_pipeline.segment(image=image, bounding_boxes=sam_bounding_boxes) masks = sam_pipeline.segment(image=image, bounding_boxes=sam_bounding_boxes)
masks = self._process_masks(masks) masks = self._process_masks(masks)

View File

@ -8,7 +8,7 @@ from transformers.models.sam.processing_sam import SamProcessor
from invokeai.backend.raw_model import RawModel from invokeai.backend.raw_model import RawModel
class SegmentAnythingModel(RawModel): class SegmentAnythingPipeline(RawModel):
"""A wrapper class for the transformers SAM model and processor that makes it compatible with the model manager.""" """A wrapper class for the transformers SAM model and processor that makes it compatible with the model manager."""
def __init__(self, sam_model: SamModel, sam_processor: SamProcessor): def __init__(self, sam_model: SamModel, sam_processor: SamProcessor):

View File

@ -12,7 +12,7 @@ from diffusers.schedulers.scheduling_utils import SchedulerMixin
from transformers import CLIPTokenizer from transformers import CLIPTokenizer
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
from invokeai.backend.image_util.segment_anything.segment_anything_model import SegmentAnythingModel from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.lora import LoRAModelRaw from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_manager.config import AnyModel from invokeai.backend.model_manager.config import AnyModel
@ -44,7 +44,7 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
LoRAModelRaw, LoRAModelRaw,
SpandrelImageToImageModel, SpandrelImageToImageModel,
GroundingDinoPipeline, GroundingDinoPipeline,
SegmentAnythingModel, SegmentAnythingPipeline,
), ),
): ):
return model.calc_size() return model.calc_size()