txt2mask.py now tracking development again

This commit is contained in:
Damian at mba 2022-10-21 12:42:07 +02:00
parent 2bf9f1f0d8
commit e574a1574f

View File

@ -29,9 +29,9 @@ work fine.
import torch import torch
import numpy as np import numpy as np
from clipseg_models.clipseg import CLIPDensePredT from models.clipseg import CLIPDensePredT
from einops import rearrange, repeat from einops import rearrange, repeat
from PIL import Image, ImageOps from PIL import Image
from torchvision import transforms from torchvision import transforms
CLIP_VERSION = 'ViT-B/16' CLIP_VERSION = 'ViT-B/16'
@ -50,14 +50,9 @@ class SegmentedGrayscale(object):
discrete_heatmap = self.heatmap.lt(threshold).int() discrete_heatmap = self.heatmap.lt(threshold).int()
return self._rescale(Image.fromarray(np.uint8(discrete_heatmap*255),mode='L')) return self._rescale(Image.fromarray(np.uint8(discrete_heatmap*255),mode='L'))
def to_transparent(self,invert:bool=False)->Image: def to_transparent(self)->Image:
transparent_image = self.image.copy() transparent_image = self.image.copy()
gs = self.to_grayscale() transparent_image.putalpha(self.to_grayscale())
# The following line looks like a bug, but isn't.
# For img2img, we want the selected regions to be transparent,
# but to_grayscale() returns the opposite.
gs = ImageOps.invert(gs) if not invert else gs
transparent_image.putalpha(gs)
return transparent_image return transparent_image
# unscales and uncrops the 352x352 heatmap so that it matches the image again # unscales and uncrops the 352x352 heatmap so that it matches the image again
@ -84,7 +79,7 @@ class Txt2Mask(object):
self.model.load_state_dict(torch.load(CLIPSEG_WEIGHTS, map_location=torch.device('cpu')), strict=False) self.model.load_state_dict(torch.load(CLIPSEG_WEIGHTS, map_location=torch.device('cpu')), strict=False)
@torch.no_grad() @torch.no_grad()
def segment(self, image, prompt:str) -> SegmentedGrayscale: def segment(self, image:Image, prompt:str) -> SegmentedGrayscale:
''' '''
Given a prompt string such as "a bagel", tries to identify the object in the Given a prompt string such as "a bagel", tries to identify the object in the
provided image and returns a SegmentedGrayscale object in which the brighter provided image and returns a SegmentedGrayscale object in which the brighter
@ -99,10 +94,6 @@ class Txt2Mask(object):
transforms.Resize((CLIPSEG_SIZE, CLIPSEG_SIZE)), # must be multiple of 64... transforms.Resize((CLIPSEG_SIZE, CLIPSEG_SIZE)), # must be multiple of 64...
]) ])
if type(image) is str:
image = Image.open(image).convert('RGB')
image = ImageOps.exif_transpose(image)
img = self._scale_and_crop(image) img = self._scale_and_crop(image)
img = transform(img).unsqueeze(0) img = transform(img).unsqueeze(0)