Use staticmethods rather than inner functions for the Grounding DINO and SAM model loaders.

This commit is contained in:
Ryan Dick 2024-07-31 09:28:52 -04:00
parent 0a7048f650
commit e206890e25

View File

@ -84,6 +84,34 @@ class GroundedSAMInvocation(BaseInvocation):
image_dto = context.images.save(image=mask_pil)
return ImageOutput.build(image_dto)
@staticmethod
def _load_grounding_dino(model_path: Path):
grounding_dino_pipeline = pipeline(
model=str(model_path),
task="zero-shot-object-detection",
local_files_only=True,
# TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
# model, and figure out how to make it work in the pipeline.
# torch_dtype=TorchDevice.choose_torch_dtype(),
)
assert isinstance(grounding_dino_pipeline, ZeroShotObjectDetectionPipeline)
return GroundingDinoPipeline(grounding_dino_pipeline)
@staticmethod
def _load_sam_model(model_path: Path):
sam_model = AutoModelForMaskGeneration.from_pretrained(
model_path,
local_files_only=True,
# TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
# model, and figure out how to make it work in the pipeline.
# torch_dtype=TorchDevice.choose_torch_dtype(),
)
assert isinstance(sam_model, SamModel)
sam_processor = AutoProcessor.from_pretrained(model_path, local_files_only=True)
assert isinstance(sam_processor, SamProcessor)
return SegmentAnythingModel(sam_model=sam_model, sam_processor=sam_processor)
def _detect(
self,
context: InvocationContext,
@ -96,19 +124,9 @@ class GroundedSAMInvocation(BaseInvocation):
# actually makes a difference.
labels = [label if label.endswith(".") else label + "." for label in labels]
def load_grounding_dino(model_path: Path):
grounding_dino_pipeline = pipeline(
model=str(model_path),
task="zero-shot-object-detection",
local_files_only=True,
# TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
# model, and figure out how to make it work in the pipeline.
# torch_dtype=TorchDevice.choose_torch_dtype(),
)
assert isinstance(grounding_dino_pipeline, ZeroShotObjectDetectionPipeline)
return GroundingDinoPipeline(grounding_dino_pipeline)
with context.models.load_remote_model(source=GROUNDING_DINO_MODEL_ID, loader=load_grounding_dino) as detector:
with context.models.load_remote_model(
source=GROUNDING_DINO_MODEL_ID, loader=GroundedSAMInvocation._load_grounding_dino
) as detector:
assert isinstance(detector, GroundingDinoPipeline)
return detector.detect(image=image, candidate_labels=labels, threshold=threshold)
@ -119,26 +137,12 @@ class GroundedSAMInvocation(BaseInvocation):
detection_results: list[DetectionResult],
) -> list[DetectionResult]:
"""Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes."""
def load_sam_model(model_path: Path):
sam_model = AutoModelForMaskGeneration.from_pretrained(
model_path,
local_files_only=True,
# TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
# model, and figure out how to make it work in the pipeline.
# torch_dtype=TorchDevice.choose_torch_dtype(),
)
assert isinstance(sam_model, SamModel)
sam_processor = AutoProcessor.from_pretrained(model_path, local_files_only=True)
assert isinstance(sam_processor, SamProcessor)
return SegmentAnythingModel(sam_model=sam_model, sam_processor=sam_processor)
with (
context.models.load_remote_model(source=SEGMENT_ANYTHING_MODEL_ID, loader=load_sam_model) as sam_pipeline,
context.models.load_remote_model(
source=SEGMENT_ANYTHING_MODEL_ID, loader=GroundedSAMInvocation._load_sam_model
) as sam_pipeline,
):
assert isinstance(sam_pipeline, SegmentAnythingModel)
masks = sam_pipeline.segment(image=image, detection_results=detection_results)
masks = self._to_numpy_masks(masks)