mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Use staticmethods rather than inner functions for the Grounding DINO and SAM model loaders.
This commit is contained in:
parent
0a7048f650
commit
e206890e25
@ -84,6 +84,34 @@ class GroundedSAMInvocation(BaseInvocation):
|
|||||||
image_dto = context.images.save(image=mask_pil)
|
image_dto = context.images.save(image=mask_pil)
|
||||||
return ImageOutput.build(image_dto)
|
return ImageOutput.build(image_dto)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _load_grounding_dino(model_path: Path):
|
||||||
|
grounding_dino_pipeline = pipeline(
|
||||||
|
model=str(model_path),
|
||||||
|
task="zero-shot-object-detection",
|
||||||
|
local_files_only=True,
|
||||||
|
# TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
|
||||||
|
# model, and figure out how to make it work in the pipeline.
|
||||||
|
# torch_dtype=TorchDevice.choose_torch_dtype(),
|
||||||
|
)
|
||||||
|
assert isinstance(grounding_dino_pipeline, ZeroShotObjectDetectionPipeline)
|
||||||
|
return GroundingDinoPipeline(grounding_dino_pipeline)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _load_sam_model(model_path: Path):
|
||||||
|
sam_model = AutoModelForMaskGeneration.from_pretrained(
|
||||||
|
model_path,
|
||||||
|
local_files_only=True,
|
||||||
|
# TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
|
||||||
|
# model, and figure out how to make it work in the pipeline.
|
||||||
|
# torch_dtype=TorchDevice.choose_torch_dtype(),
|
||||||
|
)
|
||||||
|
assert isinstance(sam_model, SamModel)
|
||||||
|
|
||||||
|
sam_processor = AutoProcessor.from_pretrained(model_path, local_files_only=True)
|
||||||
|
assert isinstance(sam_processor, SamProcessor)
|
||||||
|
return SegmentAnythingModel(sam_model=sam_model, sam_processor=sam_processor)
|
||||||
|
|
||||||
def _detect(
|
def _detect(
|
||||||
self,
|
self,
|
||||||
context: InvocationContext,
|
context: InvocationContext,
|
||||||
@ -96,19 +124,9 @@ class GroundedSAMInvocation(BaseInvocation):
|
|||||||
# actually makes a difference.
|
# actually makes a difference.
|
||||||
labels = [label if label.endswith(".") else label + "." for label in labels]
|
labels = [label if label.endswith(".") else label + "." for label in labels]
|
||||||
|
|
||||||
def load_grounding_dino(model_path: Path):
|
with context.models.load_remote_model(
|
||||||
grounding_dino_pipeline = pipeline(
|
source=GROUNDING_DINO_MODEL_ID, loader=GroundedSAMInvocation._load_grounding_dino
|
||||||
model=str(model_path),
|
) as detector:
|
||||||
task="zero-shot-object-detection",
|
|
||||||
local_files_only=True,
|
|
||||||
# TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
|
|
||||||
# model, and figure out how to make it work in the pipeline.
|
|
||||||
# torch_dtype=TorchDevice.choose_torch_dtype(),
|
|
||||||
)
|
|
||||||
assert isinstance(grounding_dino_pipeline, ZeroShotObjectDetectionPipeline)
|
|
||||||
return GroundingDinoPipeline(grounding_dino_pipeline)
|
|
||||||
|
|
||||||
with context.models.load_remote_model(source=GROUNDING_DINO_MODEL_ID, loader=load_grounding_dino) as detector:
|
|
||||||
assert isinstance(detector, GroundingDinoPipeline)
|
assert isinstance(detector, GroundingDinoPipeline)
|
||||||
return detector.detect(image=image, candidate_labels=labels, threshold=threshold)
|
return detector.detect(image=image, candidate_labels=labels, threshold=threshold)
|
||||||
|
|
||||||
@ -119,26 +137,12 @@ class GroundedSAMInvocation(BaseInvocation):
|
|||||||
detection_results: list[DetectionResult],
|
detection_results: list[DetectionResult],
|
||||||
) -> list[DetectionResult]:
|
) -> list[DetectionResult]:
|
||||||
"""Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes."""
|
"""Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes."""
|
||||||
|
|
||||||
def load_sam_model(model_path: Path):
|
|
||||||
sam_model = AutoModelForMaskGeneration.from_pretrained(
|
|
||||||
model_path,
|
|
||||||
local_files_only=True,
|
|
||||||
# TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
|
|
||||||
# model, and figure out how to make it work in the pipeline.
|
|
||||||
# torch_dtype=TorchDevice.choose_torch_dtype(),
|
|
||||||
)
|
|
||||||
assert isinstance(sam_model, SamModel)
|
|
||||||
|
|
||||||
sam_processor = AutoProcessor.from_pretrained(model_path, local_files_only=True)
|
|
||||||
assert isinstance(sam_processor, SamProcessor)
|
|
||||||
return SegmentAnythingModel(sam_model=sam_model, sam_processor=sam_processor)
|
|
||||||
|
|
||||||
with (
|
with (
|
||||||
context.models.load_remote_model(source=SEGMENT_ANYTHING_MODEL_ID, loader=load_sam_model) as sam_pipeline,
|
context.models.load_remote_model(
|
||||||
|
source=SEGMENT_ANYTHING_MODEL_ID, loader=GroundedSAMInvocation._load_sam_model
|
||||||
|
) as sam_pipeline,
|
||||||
):
|
):
|
||||||
assert isinstance(sam_pipeline, SegmentAnythingModel)
|
assert isinstance(sam_pipeline, SegmentAnythingModel)
|
||||||
|
|
||||||
masks = sam_pipeline.segment(image=image, detection_results=detection_results)
|
masks = sam_pipeline.segment(image=image, detection_results=detection_results)
|
||||||
|
|
||||||
masks = self._to_numpy_masks(masks)
|
masks = self._to_numpy_masks(masks)
|
||||||
|
Loading…
Reference in New Issue
Block a user