txt2mask.py now tracking development again

This commit is contained in:
Damian at mba 2022-10-21 12:42:07 +02:00
parent 2bf9f1f0d8
commit e574a1574f

View File

@ -29,9 +29,9 @@ work fine.
import torch
import numpy as np
from clipseg_models.clipseg import CLIPDensePredT
from models.clipseg import CLIPDensePredT
from einops import rearrange, repeat
from PIL import Image, ImageOps
from PIL import Image
from torchvision import transforms
CLIP_VERSION = 'ViT-B/16'
@ -50,14 +50,9 @@ class SegmentedGrayscale(object):
discrete_heatmap = self.heatmap.lt(threshold).int()
return self._rescale(Image.fromarray(np.uint8(discrete_heatmap*255),mode='L'))
def to_transparent(self,invert:bool=False)->Image:
def to_transparent(self)->Image:
transparent_image = self.image.copy()
gs = self.to_grayscale()
# The following line looks like a bug, but isn't.
# For img2img, we want the selected regions to be transparent,
# but to_grayscale() returns the opposite.
gs = ImageOps.invert(gs) if not invert else gs
transparent_image.putalpha(gs)
transparent_image.putalpha(self.to_grayscale())
return transparent_image
# unscales and uncrops the 352x352 heatmap so that it matches the image again
@ -84,7 +79,7 @@ class Txt2Mask(object):
self.model.load_state_dict(torch.load(CLIPSEG_WEIGHTS, map_location=torch.device('cpu')), strict=False)
@torch.no_grad()
def segment(self, image, prompt:str) -> SegmentedGrayscale:
def segment(self, image:Image, prompt:str) -> SegmentedGrayscale:
'''
Given a prompt string such as "a bagel", tries to identify the object in the
provided image and returns a SegmentedGrayscale object in which the brighter
@ -99,10 +94,6 @@ class Txt2Mask(object):
transforms.Resize((CLIPSEG_SIZE, CLIPSEG_SIZE)), # must be multiple of 64...
])
if type(image) is str:
image = Image.open(image).convert('RGB')
image = ImageOps.exif_transpose(image)
img = self._scale_and_crop(image)
img = transform(img).unsqueeze(0)