clipseg library and environment in place

This commit is contained in:
Lincoln Stein 2022-10-16 16:45:07 -04:00
parent c974c95e2b
commit 32122e0312
3 changed files with 122 additions and 11 deletions

View File

@ -37,4 +37,5 @@ dependencies:
- -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
- -e git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k_diffusion
- -e git+https://github.com/TencentARC/GFPGAN.git#egg=gfpgan
- -e git+https://github.com/invoke-ai/clipseg.git#egg=clipseg
- -e .

90
ldm/invoke/txt2mask.py Normal file
View File

@ -0,0 +1,90 @@
'''Makes available the Txt2Mask class, which assists in the automatic
assignment of masks via text prompt using clipseg.
Here is typical usage:
from ldm.invoke.txt2mask import Txt2Mask, SegmentedGrayscale
from PIL import Image
txt2mask = Txt2Mask(self.device)
segmented = txt2mask.segment(Image.open('/path/to/img.png'),'a bagel')
# this will return a grayscale Image of the segmented data
grayscale = segmented.to_grayscale()
# this will return a semi-transparent image in which the
# selected object(s) are opaque and the rest is at various
# levels of transparency
transparent = segmented.to_transparent()
# this will return a masked image suitable for use in inpainting:
mask = segmented.to_mask(threshold=0.5)
The threshold used in the call to to_mask() selects pixels for use in
the mask that exceed the indicated confidence threshold. Values range
from 0.0 to 1.0. The higher the threshold, the more confident the
algorithm is. In limited testing, I have found that values around 0.5
work fine.
'''
import torch
import numpy as np
from models.clipseg import CLIPDensePredT
from einops import rearrange, repeat
from PIL import Image
from torchvision import transforms
CLIP_VERSION = 'ViT-B/16'
CLIPSEG_WEIGHTS = 'src/clipseg/weights/rd64-uni.pth'
class SegmentedGrayscale(object):
def __init__(self, image:Image, heatmap:torch.Tensor):
self.heatmap = heatmap
self.image = image
def to_grayscale(self)->Image:
return Image.fromarray(np.uint8(self.heatmap*255))
def to_mask(self,threshold:float=0.5)->Image:
discrete_heatmap = self.heatmap.lt(threshold).int()
return Image.fromarray(np.uint8(discrete_heatmap*255),mode='L')
def to_transparent(self)->Image:
transparent_image = self.image.copy()
transparent_image.putalpha(self.to_image)
return transparent_image
class Txt2Mask(object):
'''
Create new Txt2Mask object. The optional device argument can be one of
'cuda', 'mps' or 'cpu'.
'''
def __init__(self,device='cpu'):
print('>> Initializing clipseg model')
self.model = CLIPDensePredT(version=CLIP_VERSION, reduce_dim=64, )
self.model.eval()
self.model.to(device)
self.model.load_state_dict(torch.load(CLIPSEG_WEIGHTS, map_location=torch.device(device)), strict=False)
@torch.no_grad()
def segment(self, image:Image, prompt:str) -> SegmentedGrayscale:
'''
Given a prompt string such as "a bagel", tries to identify the object in the
provided image and returns a SegmentedGrayscale object in which the brighter
pixels indicate where the object is inferred to be.
'''
prompts = [prompt] # right now we operate on just a single prompt at a time
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
transforms.Resize((image.width, image.height)), # must be multiple of 64...
])
img = transform(image).unsqueeze(0)
preds = self.model(img.repeat(len(prompts),1,1,1), prompts)[0]
heatmap = torch.sigmoid(preds[0][0]).cpu()
return SegmentedGrayscale(image, heatmap)

View File

@ -10,28 +10,31 @@ import sys
import transformers
import os
import warnings
import torch
import urllib.request
import zipfile
import traceback
transformers.logging.set_verbosity_error()
# this will preload the Bert tokenizer fles
print('preloading bert tokenizer...', end='')
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
print('Loading bert tokenizer (ignore deprecation errors)...', end='')
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning)
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
print('...success')
sys.stdout.flush()
# this will download requirements for Kornia
print('preloading Kornia requirements...', end='')
print('Loading Kornia requirements...', end='')
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning)
import kornia
print('...success')
version = 'openai/clip-vit-large-patch14'
print('preloading CLIP model...',end='')
sys.stdout.flush()
print('Loading CLIP model...',end='')
tokenizer = CLIPTokenizer.from_pretrained(version)
transformer = CLIPTextModel.from_pretrained(version)
print('...success')
@ -61,7 +64,6 @@ if gfpgan:
FaceRestoreHelper(1, det_model='retinaface_resnet50')
print('...success')
except Exception:
import traceback
print('Error loading ESRGAN:')
print(traceback.format_exc())
@ -89,13 +91,11 @@ if gfpgan:
urllib.request.urlretrieve(model_url,model_dest)
print('...success')
except Exception:
import traceback
print('Error loading GFPGAN:')
print(traceback.format_exc())
print('preloading CodeFormer model file...',end='')
try:
import urllib.request
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
model_dest = 'ldm/invoke/restoration/codeformer/weights/codeformer.pth'
if not os.path.exists(model_dest):
@ -103,7 +103,27 @@ try:
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
urllib.request.urlretrieve(model_url,model_dest)
except Exception:
import traceback
print('Error loading CodeFormer:')
print(traceback.format_exc())
print('...success')
print('Loading clipseq model for text-based masking...',end='')
try:
model_url = 'https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download'
model_dest = 'src/clipseg/clipseg_weights.zip'
if not os.path.exists(model_dest):
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
urllib.request.urlretrieve(model_url,model_dest)
with zipfile.ZipFile(model_dest,'r') as zip:
zip.extractall('src/clipseg')
os.rename('src/clipseg/clipseg_weights','src/clipseg/weights')
from models.clipseg import CLIPDensePredT
model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, )
model.eval()
model.load_state_dict(torch.load('src/clipseg/weights/rd64-uni-refined.pth'), strict=False)
except Exception:
print('Error installing clipseg model:')
print(traceback.format_exc())
print('...success')