mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Rename SegmentAnythingModel -> SegmentAnythingPipeline.
This commit is contained in:
parent
63581ec980
commit
b9dc3460ba
@ -13,7 +13,7 @@ from invokeai.app.invocations.fields import BoundingBoxField, ImageField, InputF
|
|||||||
from invokeai.app.invocations.primitives import MaskOutput
|
from invokeai.app.invocations.primitives import MaskOutput
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.backend.image_util.segment_anything.mask_refinement import mask_to_polygon, polygon_to_mask
|
from invokeai.backend.image_util.segment_anything.mask_refinement import mask_to_polygon, polygon_to_mask
|
||||||
from invokeai.backend.image_util.segment_anything.segment_anything_model import SegmentAnythingModel
|
from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline
|
||||||
|
|
||||||
SEGMENT_ANYTHING_MODEL_ID = "facebook/sam-vit-base"
|
SEGMENT_ANYTHING_MODEL_ID = "facebook/sam-vit-base"
|
||||||
|
|
||||||
@ -75,7 +75,7 @@ class SegmentAnythingModelInvocation(BaseInvocation):
|
|||||||
|
|
||||||
sam_processor = AutoProcessor.from_pretrained(model_path, local_files_only=True)
|
sam_processor = AutoProcessor.from_pretrained(model_path, local_files_only=True)
|
||||||
assert isinstance(sam_processor, SamProcessor)
|
assert isinstance(sam_processor, SamProcessor)
|
||||||
return SegmentAnythingModel(sam_model=sam_model, sam_processor=sam_processor)
|
return SegmentAnythingPipeline(sam_model=sam_model, sam_processor=sam_processor)
|
||||||
|
|
||||||
def _segment(
|
def _segment(
|
||||||
self,
|
self,
|
||||||
@ -91,7 +91,7 @@ class SegmentAnythingModelInvocation(BaseInvocation):
|
|||||||
source=SEGMENT_ANYTHING_MODEL_ID, loader=SegmentAnythingModelInvocation._load_sam_model
|
source=SEGMENT_ANYTHING_MODEL_ID, loader=SegmentAnythingModelInvocation._load_sam_model
|
||||||
) as sam_pipeline,
|
) as sam_pipeline,
|
||||||
):
|
):
|
||||||
assert isinstance(sam_pipeline, SegmentAnythingModel)
|
assert isinstance(sam_pipeline, SegmentAnythingPipeline)
|
||||||
masks = sam_pipeline.segment(image=image, bounding_boxes=sam_bounding_boxes)
|
masks = sam_pipeline.segment(image=image, bounding_boxes=sam_bounding_boxes)
|
||||||
|
|
||||||
masks = self._process_masks(masks)
|
masks = self._process_masks(masks)
|
||||||
|
@ -8,7 +8,7 @@ from transformers.models.sam.processing_sam import SamProcessor
|
|||||||
from invokeai.backend.raw_model import RawModel
|
from invokeai.backend.raw_model import RawModel
|
||||||
|
|
||||||
|
|
||||||
class SegmentAnythingModel(RawModel):
|
class SegmentAnythingPipeline(RawModel):
|
||||||
"""A wrapper class for the transformers SAM model and processor that makes it compatible with the model manager."""
|
"""A wrapper class for the transformers SAM model and processor that makes it compatible with the model manager."""
|
||||||
|
|
||||||
def __init__(self, sam_model: SamModel, sam_processor: SamProcessor):
|
def __init__(self, sam_model: SamModel, sam_processor: SamProcessor):
|
@ -12,7 +12,7 @@ from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
|||||||
from transformers import CLIPTokenizer
|
from transformers import CLIPTokenizer
|
||||||
|
|
||||||
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
|
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
|
||||||
from invokeai.backend.image_util.segment_anything.segment_anything_model import SegmentAnythingModel
|
from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline
|
||||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||||
from invokeai.backend.lora import LoRAModelRaw
|
from invokeai.backend.lora import LoRAModelRaw
|
||||||
from invokeai.backend.model_manager.config import AnyModel
|
from invokeai.backend.model_manager.config import AnyModel
|
||||||
@ -44,7 +44,7 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
|
|||||||
LoRAModelRaw,
|
LoRAModelRaw,
|
||||||
SpandrelImageToImageModel,
|
SpandrelImageToImageModel,
|
||||||
GroundingDinoPipeline,
|
GroundingDinoPipeline,
|
||||||
SegmentAnythingModel,
|
SegmentAnythingPipeline,
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
return model.calc_size()
|
return model.calc_size()
|
||||||
|
Loading…
Reference in New Issue
Block a user