mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Use staticmethods rather than inner functions for the Grounding DINO and SAM model loaders.
This commit is contained in:
parent
0a7048f650
commit
e206890e25
@ -84,6 +84,34 @@ class GroundedSAMInvocation(BaseInvocation):
|
||||
image_dto = context.images.save(image=mask_pil)
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
@staticmethod
|
||||
def _load_grounding_dino(model_path: Path):
|
||||
grounding_dino_pipeline = pipeline(
|
||||
model=str(model_path),
|
||||
task="zero-shot-object-detection",
|
||||
local_files_only=True,
|
||||
# TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
|
||||
# model, and figure out how to make it work in the pipeline.
|
||||
# torch_dtype=TorchDevice.choose_torch_dtype(),
|
||||
)
|
||||
assert isinstance(grounding_dino_pipeline, ZeroShotObjectDetectionPipeline)
|
||||
return GroundingDinoPipeline(grounding_dino_pipeline)
|
||||
|
||||
@staticmethod
|
||||
def _load_sam_model(model_path: Path):
|
||||
sam_model = AutoModelForMaskGeneration.from_pretrained(
|
||||
model_path,
|
||||
local_files_only=True,
|
||||
# TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
|
||||
# model, and figure out how to make it work in the pipeline.
|
||||
# torch_dtype=TorchDevice.choose_torch_dtype(),
|
||||
)
|
||||
assert isinstance(sam_model, SamModel)
|
||||
|
||||
sam_processor = AutoProcessor.from_pretrained(model_path, local_files_only=True)
|
||||
assert isinstance(sam_processor, SamProcessor)
|
||||
return SegmentAnythingModel(sam_model=sam_model, sam_processor=sam_processor)
|
||||
|
||||
def _detect(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
@ -96,19 +124,9 @@ class GroundedSAMInvocation(BaseInvocation):
|
||||
# actually makes a difference.
|
||||
labels = [label if label.endswith(".") else label + "." for label in labels]
|
||||
|
||||
def load_grounding_dino(model_path: Path):
|
||||
grounding_dino_pipeline = pipeline(
|
||||
model=str(model_path),
|
||||
task="zero-shot-object-detection",
|
||||
local_files_only=True,
|
||||
# TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
|
||||
# model, and figure out how to make it work in the pipeline.
|
||||
# torch_dtype=TorchDevice.choose_torch_dtype(),
|
||||
)
|
||||
assert isinstance(grounding_dino_pipeline, ZeroShotObjectDetectionPipeline)
|
||||
return GroundingDinoPipeline(grounding_dino_pipeline)
|
||||
|
||||
with context.models.load_remote_model(source=GROUNDING_DINO_MODEL_ID, loader=load_grounding_dino) as detector:
|
||||
with context.models.load_remote_model(
|
||||
source=GROUNDING_DINO_MODEL_ID, loader=GroundedSAMInvocation._load_grounding_dino
|
||||
) as detector:
|
||||
assert isinstance(detector, GroundingDinoPipeline)
|
||||
return detector.detect(image=image, candidate_labels=labels, threshold=threshold)
|
||||
|
||||
@ -119,26 +137,12 @@ class GroundedSAMInvocation(BaseInvocation):
|
||||
detection_results: list[DetectionResult],
|
||||
) -> list[DetectionResult]:
|
||||
"""Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes."""
|
||||
|
||||
def load_sam_model(model_path: Path):
|
||||
sam_model = AutoModelForMaskGeneration.from_pretrained(
|
||||
model_path,
|
||||
local_files_only=True,
|
||||
# TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
|
||||
# model, and figure out how to make it work in the pipeline.
|
||||
# torch_dtype=TorchDevice.choose_torch_dtype(),
|
||||
)
|
||||
assert isinstance(sam_model, SamModel)
|
||||
|
||||
sam_processor = AutoProcessor.from_pretrained(model_path, local_files_only=True)
|
||||
assert isinstance(sam_processor, SamProcessor)
|
||||
return SegmentAnythingModel(sam_model=sam_model, sam_processor=sam_processor)
|
||||
|
||||
with (
|
||||
context.models.load_remote_model(source=SEGMENT_ANYTHING_MODEL_ID, loader=load_sam_model) as sam_pipeline,
|
||||
context.models.load_remote_model(
|
||||
source=SEGMENT_ANYTHING_MODEL_ID, loader=GroundedSAMInvocation._load_sam_model
|
||||
) as sam_pipeline,
|
||||
):
|
||||
assert isinstance(sam_pipeline, SegmentAnythingModel)
|
||||
|
||||
masks = sam_pipeline.segment(image=image, detection_results=detection_results)
|
||||
|
||||
masks = self._to_numpy_masks(masks)
|
||||
|
Loading…
Reference in New Issue
Block a user