mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Use staticmethods rather than inner functions for the Grounding DINO and SAM model loaders.
This commit is contained in:
parent
0a7048f650
commit
e206890e25
@ -84,19 +84,8 @@ class GroundedSAMInvocation(BaseInvocation):
|
|||||||
image_dto = context.images.save(image=mask_pil)
|
image_dto = context.images.save(image=mask_pil)
|
||||||
return ImageOutput.build(image_dto)
|
return ImageOutput.build(image_dto)
|
||||||
|
|
||||||
def _detect(
|
@staticmethod
|
||||||
self,
|
def _load_grounding_dino(model_path: Path):
|
||||||
context: InvocationContext,
|
|
||||||
image: Image.Image,
|
|
||||||
labels: list[str],
|
|
||||||
threshold: float = 0.3,
|
|
||||||
) -> list[DetectionResult]:
|
|
||||||
"""Use Grounding DINO to detect bounding boxes for a set of labels in an image."""
|
|
||||||
# TODO(ryand): I copied this "."-handling logic from the transformers example code. Test it and see if it
|
|
||||||
# actually makes a difference.
|
|
||||||
labels = [label if label.endswith(".") else label + "." for label in labels]
|
|
||||||
|
|
||||||
def load_grounding_dino(model_path: Path):
|
|
||||||
grounding_dino_pipeline = pipeline(
|
grounding_dino_pipeline = pipeline(
|
||||||
model=str(model_path),
|
model=str(model_path),
|
||||||
task="zero-shot-object-detection",
|
task="zero-shot-object-detection",
|
||||||
@ -108,19 +97,8 @@ class GroundedSAMInvocation(BaseInvocation):
|
|||||||
assert isinstance(grounding_dino_pipeline, ZeroShotObjectDetectionPipeline)
|
assert isinstance(grounding_dino_pipeline, ZeroShotObjectDetectionPipeline)
|
||||||
return GroundingDinoPipeline(grounding_dino_pipeline)
|
return GroundingDinoPipeline(grounding_dino_pipeline)
|
||||||
|
|
||||||
with context.models.load_remote_model(source=GROUNDING_DINO_MODEL_ID, loader=load_grounding_dino) as detector:
|
@staticmethod
|
||||||
assert isinstance(detector, GroundingDinoPipeline)
|
def _load_sam_model(model_path: Path):
|
||||||
return detector.detect(image=image, candidate_labels=labels, threshold=threshold)
|
|
||||||
|
|
||||||
def _segment(
|
|
||||||
self,
|
|
||||||
context: InvocationContext,
|
|
||||||
image: Image.Image,
|
|
||||||
detection_results: list[DetectionResult],
|
|
||||||
) -> list[DetectionResult]:
|
|
||||||
"""Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes."""
|
|
||||||
|
|
||||||
def load_sam_model(model_path: Path):
|
|
||||||
sam_model = AutoModelForMaskGeneration.from_pretrained(
|
sam_model = AutoModelForMaskGeneration.from_pretrained(
|
||||||
model_path,
|
model_path,
|
||||||
local_files_only=True,
|
local_files_only=True,
|
||||||
@ -134,11 +112,37 @@ class GroundedSAMInvocation(BaseInvocation):
|
|||||||
assert isinstance(sam_processor, SamProcessor)
|
assert isinstance(sam_processor, SamProcessor)
|
||||||
return SegmentAnythingModel(sam_model=sam_model, sam_processor=sam_processor)
|
return SegmentAnythingModel(sam_model=sam_model, sam_processor=sam_processor)
|
||||||
|
|
||||||
|
def _detect(
|
||||||
|
self,
|
||||||
|
context: InvocationContext,
|
||||||
|
image: Image.Image,
|
||||||
|
labels: list[str],
|
||||||
|
threshold: float = 0.3,
|
||||||
|
) -> list[DetectionResult]:
|
||||||
|
"""Use Grounding DINO to detect bounding boxes for a set of labels in an image."""
|
||||||
|
# TODO(ryand): I copied this "."-handling logic from the transformers example code. Test it and see if it
|
||||||
|
# actually makes a difference.
|
||||||
|
labels = [label if label.endswith(".") else label + "." for label in labels]
|
||||||
|
|
||||||
|
with context.models.load_remote_model(
|
||||||
|
source=GROUNDING_DINO_MODEL_ID, loader=GroundedSAMInvocation._load_grounding_dino
|
||||||
|
) as detector:
|
||||||
|
assert isinstance(detector, GroundingDinoPipeline)
|
||||||
|
return detector.detect(image=image, candidate_labels=labels, threshold=threshold)
|
||||||
|
|
||||||
|
def _segment(
|
||||||
|
self,
|
||||||
|
context: InvocationContext,
|
||||||
|
image: Image.Image,
|
||||||
|
detection_results: list[DetectionResult],
|
||||||
|
) -> list[DetectionResult]:
|
||||||
|
"""Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes."""
|
||||||
with (
|
with (
|
||||||
context.models.load_remote_model(source=SEGMENT_ANYTHING_MODEL_ID, loader=load_sam_model) as sam_pipeline,
|
context.models.load_remote_model(
|
||||||
|
source=SEGMENT_ANYTHING_MODEL_ID, loader=GroundedSAMInvocation._load_sam_model
|
||||||
|
) as sam_pipeline,
|
||||||
):
|
):
|
||||||
assert isinstance(sam_pipeline, SegmentAnythingModel)
|
assert isinstance(sam_pipeline, SegmentAnythingModel)
|
||||||
|
|
||||||
masks = sam_pipeline.segment(image=image, detection_results=detection_results)
|
masks = sam_pipeline.segment(image=image, detection_results=detection_results)
|
||||||
|
|
||||||
masks = self._to_numpy_masks(masks)
|
masks = self._to_numpy_masks(masks)
|
||||||
|
Loading…
Reference in New Issue
Block a user