mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Use staticmethods rather than inner functions for the Grounding DINO and SAM model loaders.
This commit is contained in:
parent
0a7048f650
commit
e206890e25
@ -84,19 +84,8 @@ class GroundedSAMInvocation(BaseInvocation):
|
||||
image_dto = context.images.save(image=mask_pil)
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
def _detect(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
image: Image.Image,
|
||||
labels: list[str],
|
||||
threshold: float = 0.3,
|
||||
) -> list[DetectionResult]:
|
||||
"""Use Grounding DINO to detect bounding boxes for a set of labels in an image."""
|
||||
# TODO(ryand): I copied this "."-handling logic from the transformers example code. Test it and see if it
|
||||
# actually makes a difference.
|
||||
labels = [label if label.endswith(".") else label + "." for label in labels]
|
||||
|
||||
def load_grounding_dino(model_path: Path):
|
||||
@staticmethod
|
||||
def _load_grounding_dino(model_path: Path):
|
||||
grounding_dino_pipeline = pipeline(
|
||||
model=str(model_path),
|
||||
task="zero-shot-object-detection",
|
||||
@ -108,19 +97,8 @@ class GroundedSAMInvocation(BaseInvocation):
|
||||
assert isinstance(grounding_dino_pipeline, ZeroShotObjectDetectionPipeline)
|
||||
return GroundingDinoPipeline(grounding_dino_pipeline)
|
||||
|
||||
with context.models.load_remote_model(source=GROUNDING_DINO_MODEL_ID, loader=load_grounding_dino) as detector:
|
||||
assert isinstance(detector, GroundingDinoPipeline)
|
||||
return detector.detect(image=image, candidate_labels=labels, threshold=threshold)
|
||||
|
||||
def _segment(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
image: Image.Image,
|
||||
detection_results: list[DetectionResult],
|
||||
) -> list[DetectionResult]:
|
||||
"""Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes."""
|
||||
|
||||
def load_sam_model(model_path: Path):
|
||||
@staticmethod
|
||||
def _load_sam_model(model_path: Path):
|
||||
sam_model = AutoModelForMaskGeneration.from_pretrained(
|
||||
model_path,
|
||||
local_files_only=True,
|
||||
@ -134,11 +112,37 @@ class GroundedSAMInvocation(BaseInvocation):
|
||||
assert isinstance(sam_processor, SamProcessor)
|
||||
return SegmentAnythingModel(sam_model=sam_model, sam_processor=sam_processor)
|
||||
|
||||
def _detect(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
image: Image.Image,
|
||||
labels: list[str],
|
||||
threshold: float = 0.3,
|
||||
) -> list[DetectionResult]:
|
||||
"""Use Grounding DINO to detect bounding boxes for a set of labels in an image."""
|
||||
# TODO(ryand): I copied this "."-handling logic from the transformers example code. Test it and see if it
|
||||
# actually makes a difference.
|
||||
labels = [label if label.endswith(".") else label + "." for label in labels]
|
||||
|
||||
with context.models.load_remote_model(
|
||||
source=GROUNDING_DINO_MODEL_ID, loader=GroundedSAMInvocation._load_grounding_dino
|
||||
) as detector:
|
||||
assert isinstance(detector, GroundingDinoPipeline)
|
||||
return detector.detect(image=image, candidate_labels=labels, threshold=threshold)
|
||||
|
||||
def _segment(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
image: Image.Image,
|
||||
detection_results: list[DetectionResult],
|
||||
) -> list[DetectionResult]:
|
||||
"""Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes."""
|
||||
with (
|
||||
context.models.load_remote_model(source=SEGMENT_ANYTHING_MODEL_ID, loader=load_sam_model) as sam_pipeline,
|
||||
context.models.load_remote_model(
|
||||
source=SEGMENT_ANYTHING_MODEL_ID, loader=GroundedSAMInvocation._load_sam_model
|
||||
) as sam_pipeline,
|
||||
):
|
||||
assert isinstance(sam_pipeline, SegmentAnythingModel)
|
||||
|
||||
masks = sam_pipeline.segment(image=image, detection_results=detection_results)
|
||||
|
||||
masks = self._to_numpy_masks(masks)
|
||||
|
Loading…
x
Reference in New Issue
Block a user