From 6c9a2761f5a5f5b05b05fb1e353756acfbedc6c5 Mon Sep 17 00:00:00 2001 From: spezialspezial <75758219+spezialspezial@users.noreply.github.com> Date: Wed, 2 Nov 2022 00:33:46 +0100 Subject: [PATCH] Optional refined model for Txt2Mask Don't merge right now, just wanted to show the necessary changes --- ldm/invoke/txt2mask.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/ldm/invoke/txt2mask.py b/ldm/invoke/txt2mask.py index a5deec277a..4e77f02857 100644 --- a/ldm/invoke/txt2mask.py +++ b/ldm/invoke/txt2mask.py @@ -36,6 +36,7 @@ from torchvision import transforms CLIP_VERSION = 'ViT-B/16' CLIPSEG_WEIGHTS = 'src/clipseg/weights/rd64-uni.pth' +CLIPSEG_WEIGHTS_REFINED = 'src/clipseg/weights/rd64-uni-refined.pth' CLIPSEG_SIZE = 352 class SegmentedGrayscale(object): @@ -72,14 +73,14 @@ class Txt2Mask(object): Create new Txt2Mask object. The optional device argument can be one of 'cuda', 'mps' or 'cpu'. ''' - def __init__(self,device='cpu'): + def __init__(self,device='cpu',refined=False): print('>> Initializing clipseg model for text to mask inference') self.device = device - self.model = CLIPDensePredT(version=CLIP_VERSION, reduce_dim=64, ) + self.model = CLIPDensePredT(version=CLIP_VERSION, reduce_dim=64, complex_trans_conv=refined) self.model.eval() # initially we keep everything in cpu to conserve space self.model.to('cpu') - self.model.load_state_dict(torch.load(CLIPSEG_WEIGHTS, map_location=torch.device('cpu')), strict=False) + self.model.load_state_dict(torch.load(CLIPSEG_WEIGHTS_REFINED if refined else CLIPSEG_WEIGHTS, map_location=torch.device('cpu')), strict=False) @torch.no_grad() def segment(self, image, prompt:str) -> SegmentedGrayscale: