mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
28 lines
968 B
Python
28 lines
968 B
Python
|
from typing import Optional
|
||
|
|
||
|
import torch
|
||
|
from transformers.pipelines import ZeroShotObjectDetectionPipeline
|
||
|
|
||
|
|
||
|
class GroundingDinoPipeline:
|
||
|
"""A wrapper class for a ZeroShotObjectDetectionPipeline that makes it compatible with the model manager's memory
|
||
|
management system.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, pipeline: ZeroShotObjectDetectionPipeline):
|
||
|
self._pipeline = pipeline
|
||
|
|
||
|
def __call__(self, *args, **kwargs):
|
||
|
return self._pipeline(*args, **kwargs)
|
||
|
|
||
|
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> "GroundingDinoPipeline":
|
||
|
self._pipeline.model.to(device=device, dtype=dtype)
|
||
|
self._pipeline.device = self._pipeline.model.device
|
||
|
return self
|
||
|
|
||
|
def calc_size(self) -> int:
|
||
|
# HACK(ryand): Fix the circular import issue.
|
||
|
from invokeai.backend.model_manager.load.model_util import calc_module_size
|
||
|
|
||
|
return calc_module_size(self._pipeline.model)
|