Merge branch 'patch-9' of https://github.com/spezialspezial/stable-diffusion into spezialspezial-patch-9

This commit is contained in:
Lincoln Stein 2022-11-02 18:24:55 -04:00
commit 895c47fd11

View File

@ -36,6 +36,7 @@ from torchvision import transforms
CLIP_VERSION = 'ViT-B/16'
CLIPSEG_WEIGHTS = 'src/clipseg/weights/rd64-uni.pth'
CLIPSEG_WEIGHTS_REFINED = 'src/clipseg/weights/rd64-uni-refined.pth'
CLIPSEG_SIZE = 352
class SegmentedGrayscale(object):
@ -72,14 +73,14 @@ class Txt2Mask(object):
Create new Txt2Mask object. The optional device argument can be one of
'cuda', 'mps' or 'cpu'.
'''
def __init__(self,device='cpu'):
def __init__(self,device='cpu',refined=False):
print('>> Initializing clipseg model for text to mask inference')
self.device = device
self.model = CLIPDensePredT(version=CLIP_VERSION, reduce_dim=64, )
self.model = CLIPDensePredT(version=CLIP_VERSION, reduce_dim=64, complex_trans_conv=refined)
self.model.eval()
# initially we keep everything in cpu to conserve space
self.model.to('cpu')
self.model.load_state_dict(torch.load(CLIPSEG_WEIGHTS, map_location=torch.device('cpu')), strict=False)
self.model.load_state_dict(torch.load(CLIPSEG_WEIGHTS_REFINED if refined else CLIPSEG_WEIGHTS, map_location=torch.device('cpu')), strict=False)
@torch.no_grad()
def segment(self, image, prompt:str) -> SegmentedGrayscale: