mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Expose all model options in the GroundingDinoInvocation and the SegmentAnythingInvocation.
This commit is contained in:
parent
675ffc2757
commit
27ac61a4fb
@ -1,4 +1,5 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@ -12,7 +13,11 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
|
|||||||
from invokeai.backend.image_util.grounding_dino.detection_result import DetectionResult
|
from invokeai.backend.image_util.grounding_dino.detection_result import DetectionResult
|
||||||
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
|
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
|
||||||
|
|
||||||
GROUNDING_DINO_MODEL_ID = "IDEA-Research/grounding-dino-tiny"
|
GroundingDinoModelKey = Literal["grounding-dino-tiny", "grounding-dino-base"]
|
||||||
|
GROUNDING_DINO_MODEL_IDS: dict[GroundingDinoModelKey, str] = {
|
||||||
|
"grounding-dino-tiny": "IDEA-Research/grounding-dino-tiny",
|
||||||
|
"grounding-dino-base": "IDEA-Research/grounding-dino-base",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
@ -30,6 +35,7 @@ class GroundingDinoInvocation(BaseInvocation):
|
|||||||
# - https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam
|
# - https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam
|
||||||
# - https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb
|
# - https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb
|
||||||
|
|
||||||
|
model: GroundingDinoModelKey = InputField(description="The Grounding DINO model to use.")
|
||||||
prompt: str = InputField(description="The prompt describing the object to segment.")
|
prompt: str = InputField(description="The prompt describing the object to segment.")
|
||||||
image: ImageField = InputField(description="The image to segment.")
|
image: ImageField = InputField(description="The image to segment.")
|
||||||
detection_threshold: float = InputField(
|
detection_threshold: float = InputField(
|
||||||
@ -88,7 +94,7 @@ class GroundingDinoInvocation(BaseInvocation):
|
|||||||
labels = [label if label.endswith(".") else label + "." for label in labels]
|
labels = [label if label.endswith(".") else label + "." for label in labels]
|
||||||
|
|
||||||
with context.models.load_remote_model(
|
with context.models.load_remote_model(
|
||||||
source=GROUNDING_DINO_MODEL_ID, loader=GroundingDinoInvocation._load_grounding_dino
|
source=GROUNDING_DINO_MODEL_IDS[self.model], loader=GroundingDinoInvocation._load_grounding_dino
|
||||||
) as detector:
|
) as detector:
|
||||||
assert isinstance(detector, GroundingDinoPipeline)
|
assert isinstance(detector, GroundingDinoPipeline)
|
||||||
return detector.detect(image=image, candidate_labels=labels, threshold=threshold)
|
return detector.detect(image=image, candidate_labels=labels, threshold=threshold)
|
||||||
|
@ -15,7 +15,12 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
|
|||||||
from invokeai.backend.image_util.segment_anything.mask_refinement import mask_to_polygon, polygon_to_mask
|
from invokeai.backend.image_util.segment_anything.mask_refinement import mask_to_polygon, polygon_to_mask
|
||||||
from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline
|
from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline
|
||||||
|
|
||||||
SEGMENT_ANYTHING_MODEL_ID = "facebook/sam-vit-base"
|
SegmentAnythingModelKey = Literal["segment-anything-base", "segment-anything-large", "segment-anything-huge"]
|
||||||
|
SEGMENT_ANYTHING_MODEL_IDS: dict[SegmentAnythingModelKey, str] = {
|
||||||
|
"segment-anything-base": "facebook/sam-vit-base",
|
||||||
|
"segment-anything-large": "facebook/sam-vit-large",
|
||||||
|
"segment-anything-huge": "facebook/sam-vit-huge",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
@ -33,6 +38,7 @@ class SegmentAnythingInvocation(BaseInvocation):
|
|||||||
# - https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam
|
# - https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam
|
||||||
# - https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb
|
# - https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb
|
||||||
|
|
||||||
|
model: SegmentAnythingModelKey = InputField(description="The Segment Anything model to use.")
|
||||||
image: ImageField = InputField(description="The image to segment.")
|
image: ImageField = InputField(description="The image to segment.")
|
||||||
bounding_boxes: list[BoundingBoxField] = InputField(description="The bounding boxes to prompt the SAM model with.")
|
bounding_boxes: list[BoundingBoxField] = InputField(description="The bounding boxes to prompt the SAM model with.")
|
||||||
apply_polygon_refinement: bool = InputField(
|
apply_polygon_refinement: bool = InputField(
|
||||||
@ -88,7 +94,7 @@ class SegmentAnythingInvocation(BaseInvocation):
|
|||||||
|
|
||||||
with (
|
with (
|
||||||
context.models.load_remote_model(
|
context.models.load_remote_model(
|
||||||
source=SEGMENT_ANYTHING_MODEL_ID, loader=SegmentAnythingInvocation._load_sam_model
|
source=SEGMENT_ANYTHING_MODEL_IDS[self.model], loader=SegmentAnythingInvocation._load_sam_model
|
||||||
) as sam_pipeline,
|
) as sam_pipeline,
|
||||||
):
|
):
|
||||||
assert isinstance(sam_pipeline, SegmentAnythingPipeline)
|
assert isinstance(sam_pipeline, SegmentAnythingPipeline)
|
||||||
|
Loading…
Reference in New Issue
Block a user