diff --git a/invokeai/app/invocations/grounding_dino.py b/invokeai/app/invocations/grounding_dino.py index a7d522e784..1e3d5cea0c 100644 --- a/invokeai/app/invocations/grounding_dino.py +++ b/invokeai/app/invocations/grounding_dino.py @@ -1,4 +1,5 @@ from pathlib import Path +from typing import Literal import torch from PIL import Image @@ -12,7 +13,11 @@ from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.image_util.grounding_dino.detection_result import DetectionResult from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline -GROUNDING_DINO_MODEL_ID = "IDEA-Research/grounding-dino-tiny" +GroundingDinoModelKey = Literal["grounding-dino-tiny", "grounding-dino-base"] +GROUNDING_DINO_MODEL_IDS: dict[GroundingDinoModelKey, str] = { + "grounding-dino-tiny": "IDEA-Research/grounding-dino-tiny", + "grounding-dino-base": "IDEA-Research/grounding-dino-base", +} @invocation( @@ -30,6 +35,7 @@ class GroundingDinoInvocation(BaseInvocation): # - https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam # - https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb + model: GroundingDinoModelKey = InputField(description="The Grounding DINO model to use.") prompt: str = InputField(description="The prompt describing the object to segment.") image: ImageField = InputField(description="The image to segment.") detection_threshold: float = InputField( @@ -88,7 +94,7 @@ class GroundingDinoInvocation(BaseInvocation): labels = [label if label.endswith(".") else label + "." for label in labels] with context.models.load_remote_model( - source=GROUNDING_DINO_MODEL_ID, loader=GroundingDinoInvocation._load_grounding_dino + source=GROUNDING_DINO_MODEL_IDS[self.model], loader=GroundingDinoInvocation._load_grounding_dino ) as detector: assert isinstance(detector, GroundingDinoPipeline) return detector.detect(image=image, candidate_labels=labels, threshold=threshold) diff --git a/invokeai/app/invocations/segment_anything.py b/invokeai/app/invocations/segment_anything.py index 6240dd6946..b49b1a39e3 100644 --- a/invokeai/app/invocations/segment_anything.py +++ b/invokeai/app/invocations/segment_anything.py @@ -15,7 +15,12 @@ from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.image_util.segment_anything.mask_refinement import mask_to_polygon, polygon_to_mask from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline -SEGMENT_ANYTHING_MODEL_ID = "facebook/sam-vit-base" +SegmentAnythingModelKey = Literal["segment-anything-base", "segment-anything-large", "segment-anything-huge"] +SEGMENT_ANYTHING_MODEL_IDS: dict[SegmentAnythingModelKey, str] = { + "segment-anything-base": "facebook/sam-vit-base", + "segment-anything-large": "facebook/sam-vit-large", + "segment-anything-huge": "facebook/sam-vit-huge", +} @invocation( @@ -33,6 +38,7 @@ class SegmentAnythingInvocation(BaseInvocation): # - https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam # - https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb + model: SegmentAnythingModelKey = InputField(description="The Segment Anything model to use.") image: ImageField = InputField(description="The image to segment.") bounding_boxes: list[BoundingBoxField] = InputField(description="The bounding boxes to prompt the SAM model with.") apply_polygon_refinement: bool = InputField( @@ -88,7 +94,7 @@ class SegmentAnythingInvocation(BaseInvocation): with ( context.models.load_remote_model( - source=SEGMENT_ANYTHING_MODEL_ID, loader=SegmentAnythingInvocation._load_sam_model + source=SEGMENT_ANYTHING_MODEL_IDS[self.model], loader=SegmentAnythingInvocation._load_sam_model ) as sam_pipeline, ): assert isinstance(sam_pipeline, SegmentAnythingPipeline)