Use staticmethods rather than inner functions for the Grounding DINO and SAM model loaders.

This commit is contained in:
Ryan Dick 2024-07-31 09:28:52 -04:00
parent 0a7048f650
commit e206890e25

View File

@ -84,19 +84,8 @@ class GroundedSAMInvocation(BaseInvocation):
image_dto = context.images.save(image=mask_pil) image_dto = context.images.save(image=mask_pil)
return ImageOutput.build(image_dto) return ImageOutput.build(image_dto)
def _detect( @staticmethod
self, def _load_grounding_dino(model_path: Path):
context: InvocationContext,
image: Image.Image,
labels: list[str],
threshold: float = 0.3,
) -> list[DetectionResult]:
"""Use Grounding DINO to detect bounding boxes for a set of labels in an image."""
# TODO(ryand): I copied this "."-handling logic from the transformers example code. Test it and see if it
# actually makes a difference.
labels = [label if label.endswith(".") else label + "." for label in labels]
def load_grounding_dino(model_path: Path):
grounding_dino_pipeline = pipeline( grounding_dino_pipeline = pipeline(
model=str(model_path), model=str(model_path),
task="zero-shot-object-detection", task="zero-shot-object-detection",
@ -108,19 +97,8 @@ class GroundedSAMInvocation(BaseInvocation):
assert isinstance(grounding_dino_pipeline, ZeroShotObjectDetectionPipeline) assert isinstance(grounding_dino_pipeline, ZeroShotObjectDetectionPipeline)
return GroundingDinoPipeline(grounding_dino_pipeline) return GroundingDinoPipeline(grounding_dino_pipeline)
with context.models.load_remote_model(source=GROUNDING_DINO_MODEL_ID, loader=load_grounding_dino) as detector: @staticmethod
assert isinstance(detector, GroundingDinoPipeline) def _load_sam_model(model_path: Path):
return detector.detect(image=image, candidate_labels=labels, threshold=threshold)
def _segment(
self,
context: InvocationContext,
image: Image.Image,
detection_results: list[DetectionResult],
) -> list[DetectionResult]:
"""Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes."""
def load_sam_model(model_path: Path):
sam_model = AutoModelForMaskGeneration.from_pretrained( sam_model = AutoModelForMaskGeneration.from_pretrained(
model_path, model_path,
local_files_only=True, local_files_only=True,
@ -134,11 +112,37 @@ class GroundedSAMInvocation(BaseInvocation):
assert isinstance(sam_processor, SamProcessor) assert isinstance(sam_processor, SamProcessor)
return SegmentAnythingModel(sam_model=sam_model, sam_processor=sam_processor) return SegmentAnythingModel(sam_model=sam_model, sam_processor=sam_processor)
def _detect(
self,
context: InvocationContext,
image: Image.Image,
labels: list[str],
threshold: float = 0.3,
) -> list[DetectionResult]:
"""Use Grounding DINO to detect bounding boxes for a set of labels in an image."""
# TODO(ryand): I copied this "."-handling logic from the transformers example code. Test it and see if it
# actually makes a difference.
labels = [label if label.endswith(".") else label + "." for label in labels]
with context.models.load_remote_model(
source=GROUNDING_DINO_MODEL_ID, loader=GroundedSAMInvocation._load_grounding_dino
) as detector:
assert isinstance(detector, GroundingDinoPipeline)
return detector.detect(image=image, candidate_labels=labels, threshold=threshold)
def _segment(
self,
context: InvocationContext,
image: Image.Image,
detection_results: list[DetectionResult],
) -> list[DetectionResult]:
"""Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes."""
with ( with (
context.models.load_remote_model(source=SEGMENT_ANYTHING_MODEL_ID, loader=load_sam_model) as sam_pipeline, context.models.load_remote_model(
source=SEGMENT_ANYTHING_MODEL_ID, loader=GroundedSAMInvocation._load_sam_model
) as sam_pipeline,
): ):
assert isinstance(sam_pipeline, SegmentAnythingModel) assert isinstance(sam_pipeline, SegmentAnythingModel)
masks = sam_pipeline.segment(image=image, detection_results=detection_results) masks = sam_pipeline.segment(image=image, detection_results=detection_results)
masks = self._to_numpy_masks(masks) masks = self._to_numpy_masks(masks)