Optional refined model for Txt2Mask

Don't merge right now, just wanted to show the necessary changes
This commit is contained in:
spezialspezial 2022-11-02 00:33:46 +01:00 committed by GitHub
parent 2bdd738f03
commit 6c9a2761f5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -36,6 +36,7 @@ from torchvision import transforms
CLIP_VERSION = 'ViT-B/16'
CLIPSEG_WEIGHTS = 'src/clipseg/weights/rd64-uni.pth'
CLIPSEG_WEIGHTS_REFINED = 'src/clipseg/weights/rd64-uni-refined.pth'
CLIPSEG_SIZE = 352
class SegmentedGrayscale(object):
@ -72,14 +73,14 @@ class Txt2Mask(object):
Create new Txt2Mask object. The optional device argument can be one of
'cuda', 'mps' or 'cpu'.
'''
def __init__(self,device='cpu'):
def __init__(self,device='cpu',refined=False):
print('>> Initializing clipseg model for text to mask inference')
self.device = device
self.model = CLIPDensePredT(version=CLIP_VERSION, reduce_dim=64, )
self.model = CLIPDensePredT(version=CLIP_VERSION, reduce_dim=64, complex_trans_conv=refined)
self.model.eval()
# initially we keep everything in cpu to conserve space
self.model.to('cpu')
self.model.load_state_dict(torch.load(CLIPSEG_WEIGHTS, map_location=torch.device('cpu')), strict=False)
self.model.load_state_dict(torch.load(CLIPSEG_WEIGHTS_REFINED if refined else CLIPSEG_WEIGHTS, map_location=torch.device('cpu')), strict=False)
@torch.no_grad()
def segment(self, image, prompt:str) -> SegmentedGrayscale: