Optional refined model for Txt2Mask

Don't merge right now, just wanted to show the necessary changes
This commit is contained in:
spezialspezial 2022-11-02 00:33:46 +01:00 committed by GitHub
parent 2bdd738f03
commit 6c9a2761f5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -36,6 +36,7 @@ from torchvision import transforms
CLIP_VERSION = 'ViT-B/16' CLIP_VERSION = 'ViT-B/16'
CLIPSEG_WEIGHTS = 'src/clipseg/weights/rd64-uni.pth' CLIPSEG_WEIGHTS = 'src/clipseg/weights/rd64-uni.pth'
CLIPSEG_WEIGHTS_REFINED = 'src/clipseg/weights/rd64-uni-refined.pth'
CLIPSEG_SIZE = 352 CLIPSEG_SIZE = 352
class SegmentedGrayscale(object): class SegmentedGrayscale(object):
@ -72,14 +73,14 @@ class Txt2Mask(object):
Create new Txt2Mask object. The optional device argument can be one of Create new Txt2Mask object. The optional device argument can be one of
'cuda', 'mps' or 'cpu'. 'cuda', 'mps' or 'cpu'.
''' '''
def __init__(self,device='cpu'): def __init__(self,device='cpu',refined=False):
print('>> Initializing clipseg model for text to mask inference') print('>> Initializing clipseg model for text to mask inference')
self.device = device self.device = device
self.model = CLIPDensePredT(version=CLIP_VERSION, reduce_dim=64, ) self.model = CLIPDensePredT(version=CLIP_VERSION, reduce_dim=64, complex_trans_conv=refined)
self.model.eval() self.model.eval()
# initially we keep everything in cpu to conserve space # initially we keep everything in cpu to conserve space
self.model.to('cpu') self.model.to('cpu')
self.model.load_state_dict(torch.load(CLIPSEG_WEIGHTS, map_location=torch.device('cpu')), strict=False) self.model.load_state_dict(torch.load(CLIPSEG_WEIGHTS_REFINED if refined else CLIPSEG_WEIGHTS, map_location=torch.device('cpu')), strict=False)
@torch.no_grad() @torch.no_grad()
def segment(self, image, prompt:str) -> SegmentedGrayscale: def segment(self, image, prompt:str) -> SegmentedGrayscale: