From 57bff2a66364ceb606cd0eee50d4257951d1d4e9 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sun, 16 Oct 2022 16:45:07 -0400 Subject: [PATCH] clipseg library and environment in place --- environment.yml | 1 + ldm/invoke/txt2mask.py | 90 +++++++++++++++++++++++++++++++++++++++ scripts/preload_models.py | 42 +++++++++++++----- 3 files changed, 122 insertions(+), 11 deletions(-) create mode 100644 ldm/invoke/txt2mask.py diff --git a/environment.yml b/environment.yml index f387e722c3..14e599fa20 100644 --- a/environment.yml +++ b/environment.yml @@ -37,4 +37,5 @@ dependencies: - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers - -e git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k_diffusion - -e git+https://github.com/TencentARC/GFPGAN.git#egg=gfpgan + - -e git+https://github.com/invoke-ai/clipseg.git#egg=clipseg - -e . diff --git a/ldm/invoke/txt2mask.py b/ldm/invoke/txt2mask.py new file mode 100644 index 0000000000..3e2b84d591 --- /dev/null +++ b/ldm/invoke/txt2mask.py @@ -0,0 +1,90 @@ +'''Makes available the Txt2Mask class, which assists in the automatic +assignment of masks via text prompt using clipseg. + +Here is typical usage: + + from ldm.invoke.txt2mask import Txt2Mask, SegmentedGrayscale + from PIL import Image + + txt2mask = Txt2Mask(self.device) + segmented = txt2mask.segment(Image.open('/path/to/img.png'),'a bagel') + + # this will return a grayscale Image of the segmented data + grayscale = segmented.to_grayscale() + + # this will return a semi-transparent image in which the + # selected object(s) are opaque and the rest is at various + # levels of transparency + transparent = segmented.to_transparent() + + # this will return a masked image suitable for use in inpainting: + mask = segmented.to_mask(threshold=0.5) + +The threshold used in the call to to_mask() selects pixels for use in +the mask that exceed the indicated confidence threshold. Values range +from 0.0 to 1.0. The higher the threshold, the more confident the +algorithm is. In limited testing, I have found that values around 0.5 +work fine. +''' + +import torch +import numpy as np +from models.clipseg import CLIPDensePredT +from einops import rearrange, repeat +from PIL import Image +from torchvision import transforms + +CLIP_VERSION = 'ViT-B/16' +CLIPSEG_WEIGHTS = 'src/clipseg/weights/rd64-uni.pth' + +class SegmentedGrayscale(object): + def __init__(self, image:Image, heatmap:torch.Tensor): + self.heatmap = heatmap + self.image = image + + def to_grayscale(self)->Image: + return Image.fromarray(np.uint8(self.heatmap*255)) + + def to_mask(self,threshold:float=0.5)->Image: + discrete_heatmap = self.heatmap.lt(threshold).int() + return Image.fromarray(np.uint8(discrete_heatmap*255),mode='L') + + def to_transparent(self)->Image: + transparent_image = self.image.copy() + transparent_image.putalpha(self.to_image) + return transparent_image + +class Txt2Mask(object): + ''' + Create new Txt2Mask object. The optional device argument can be one of + 'cuda', 'mps' or 'cpu'. + ''' + def __init__(self,device='cpu'): + print('>> Initializing clipseg model') + self.model = CLIPDensePredT(version=CLIP_VERSION, reduce_dim=64, ) + self.model.eval() + self.model.to(device) + self.model.load_state_dict(torch.load(CLIPSEG_WEIGHTS, map_location=torch.device(device)), strict=False) + + @torch.no_grad() + def segment(self, image:Image, prompt:str) -> SegmentedGrayscale: + ''' + Given a prompt string such as "a bagel", tries to identify the object in the + provided image and returns a SegmentedGrayscale object in which the brighter + pixels indicate where the object is inferred to be. + ''' + prompts = [prompt] # right now we operate on just a single prompt at a time + + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + transforms.Resize((image.width, image.height)), # must be multiple of 64... + ]) + img = transform(image).unsqueeze(0) + preds = self.model(img.repeat(len(prompts),1,1,1), prompts)[0] + heatmap = torch.sigmoid(preds[0][0]).cpu() + return SegmentedGrayscale(image, heatmap) + + + + diff --git a/scripts/preload_models.py b/scripts/preload_models.py index db484517db..2ef344f8c3 100644 --- a/scripts/preload_models.py +++ b/scripts/preload_models.py @@ -10,28 +10,31 @@ import sys import transformers import os import warnings +import torch import urllib.request +import zipfile +import traceback transformers.logging.set_verbosity_error() # this will preload the Bert tokenizer fles -print('preloading bert tokenizer...', end='') - -tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') +print('Loading bert tokenizer (ignore deprecation errors)...', end='') +with warnings.catch_warnings(): + warnings.filterwarnings('ignore', category=DeprecationWarning) + tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') print('...success') +sys.stdout.flush() # this will download requirements for Kornia -print('preloading Kornia requirements...', end='') +print('Loading Kornia requirements...', end='') with warnings.catch_warnings(): warnings.filterwarnings('ignore', category=DeprecationWarning) import kornia print('...success') version = 'openai/clip-vit-large-patch14' - -print('preloading CLIP model...',end='') sys.stdout.flush() - +print('Loading CLIP model...',end='') tokenizer = CLIPTokenizer.from_pretrained(version) transformer = CLIPTextModel.from_pretrained(version) print('...success') @@ -61,7 +64,6 @@ if gfpgan: FaceRestoreHelper(1, det_model='retinaface_resnet50') print('...success') except Exception: - import traceback print('Error loading ESRGAN:') print(traceback.format_exc()) @@ -89,13 +91,11 @@ if gfpgan: urllib.request.urlretrieve(model_url,model_dest) print('...success') except Exception: - import traceback print('Error loading GFPGAN:') print(traceback.format_exc()) print('preloading CodeFormer model file...',end='') try: - import urllib.request model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth' model_dest = 'ldm/invoke/restoration/codeformer/weights/codeformer.pth' if not os.path.exists(model_dest): @@ -103,7 +103,27 @@ try: os.makedirs(os.path.dirname(model_dest), exist_ok=True) urllib.request.urlretrieve(model_url,model_dest) except Exception: - import traceback print('Error loading CodeFormer:') print(traceback.format_exc()) print('...success') + +print('Loading clipseq model for text-based masking...',end='') +try: + model_url = 'https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download' + model_dest = 'src/clipseg/clipseg_weights.zip' + if not os.path.exists(model_dest): + os.makedirs(os.path.dirname(model_dest), exist_ok=True) + urllib.request.urlretrieve(model_url,model_dest) + with zipfile.ZipFile(model_dest,'r') as zip: + zip.extractall('src/clipseg') + os.rename('src/clipseg/clipseg_weights','src/clipseg/weights') + from models.clipseg import CLIPDensePredT + model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, ) + model.eval() + model.load_state_dict(torch.load('src/clipseg/weights/rd64-uni-refined.pth'), strict=False) +except Exception: + print('Error installing clipseg model:') + print(traceback.format_exc()) +print('...success') + +