mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
clipseg library and environment in place
This commit is contained in:
parent
c974c95e2b
commit
32122e0312
@ -37,4 +37,5 @@ dependencies:
|
|||||||
- -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
|
- -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
|
||||||
- -e git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k_diffusion
|
- -e git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k_diffusion
|
||||||
- -e git+https://github.com/TencentARC/GFPGAN.git#egg=gfpgan
|
- -e git+https://github.com/TencentARC/GFPGAN.git#egg=gfpgan
|
||||||
|
- -e git+https://github.com/invoke-ai/clipseg.git#egg=clipseg
|
||||||
- -e .
|
- -e .
|
||||||
|
90
ldm/invoke/txt2mask.py
Normal file
90
ldm/invoke/txt2mask.py
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
'''Makes available the Txt2Mask class, which assists in the automatic
|
||||||
|
assignment of masks via text prompt using clipseg.
|
||||||
|
|
||||||
|
Here is typical usage:
|
||||||
|
|
||||||
|
from ldm.invoke.txt2mask import Txt2Mask, SegmentedGrayscale
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
txt2mask = Txt2Mask(self.device)
|
||||||
|
segmented = txt2mask.segment(Image.open('/path/to/img.png'),'a bagel')
|
||||||
|
|
||||||
|
# this will return a grayscale Image of the segmented data
|
||||||
|
grayscale = segmented.to_grayscale()
|
||||||
|
|
||||||
|
# this will return a semi-transparent image in which the
|
||||||
|
# selected object(s) are opaque and the rest is at various
|
||||||
|
# levels of transparency
|
||||||
|
transparent = segmented.to_transparent()
|
||||||
|
|
||||||
|
# this will return a masked image suitable for use in inpainting:
|
||||||
|
mask = segmented.to_mask(threshold=0.5)
|
||||||
|
|
||||||
|
The threshold used in the call to to_mask() selects pixels for use in
|
||||||
|
the mask that exceed the indicated confidence threshold. Values range
|
||||||
|
from 0.0 to 1.0. The higher the threshold, the more confident the
|
||||||
|
algorithm is. In limited testing, I have found that values around 0.5
|
||||||
|
work fine.
|
||||||
|
'''
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from models.clipseg import CLIPDensePredT
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from PIL import Image
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
CLIP_VERSION = 'ViT-B/16'
|
||||||
|
CLIPSEG_WEIGHTS = 'src/clipseg/weights/rd64-uni.pth'
|
||||||
|
|
||||||
|
class SegmentedGrayscale(object):
|
||||||
|
def __init__(self, image:Image, heatmap:torch.Tensor):
|
||||||
|
self.heatmap = heatmap
|
||||||
|
self.image = image
|
||||||
|
|
||||||
|
def to_grayscale(self)->Image:
|
||||||
|
return Image.fromarray(np.uint8(self.heatmap*255))
|
||||||
|
|
||||||
|
def to_mask(self,threshold:float=0.5)->Image:
|
||||||
|
discrete_heatmap = self.heatmap.lt(threshold).int()
|
||||||
|
return Image.fromarray(np.uint8(discrete_heatmap*255),mode='L')
|
||||||
|
|
||||||
|
def to_transparent(self)->Image:
|
||||||
|
transparent_image = self.image.copy()
|
||||||
|
transparent_image.putalpha(self.to_image)
|
||||||
|
return transparent_image
|
||||||
|
|
||||||
|
class Txt2Mask(object):
|
||||||
|
'''
|
||||||
|
Create new Txt2Mask object. The optional device argument can be one of
|
||||||
|
'cuda', 'mps' or 'cpu'.
|
||||||
|
'''
|
||||||
|
def __init__(self,device='cpu'):
|
||||||
|
print('>> Initializing clipseg model')
|
||||||
|
self.model = CLIPDensePredT(version=CLIP_VERSION, reduce_dim=64, )
|
||||||
|
self.model.eval()
|
||||||
|
self.model.to(device)
|
||||||
|
self.model.load_state_dict(torch.load(CLIPSEG_WEIGHTS, map_location=torch.device(device)), strict=False)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def segment(self, image:Image, prompt:str) -> SegmentedGrayscale:
|
||||||
|
'''
|
||||||
|
Given a prompt string such as "a bagel", tries to identify the object in the
|
||||||
|
provided image and returns a SegmentedGrayscale object in which the brighter
|
||||||
|
pixels indicate where the object is inferred to be.
|
||||||
|
'''
|
||||||
|
prompts = [prompt] # right now we operate on just a single prompt at a time
|
||||||
|
|
||||||
|
transform = transforms.Compose([
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||||
|
transforms.Resize((image.width, image.height)), # must be multiple of 64...
|
||||||
|
])
|
||||||
|
img = transform(image).unsqueeze(0)
|
||||||
|
preds = self.model(img.repeat(len(prompts),1,1,1), prompts)[0]
|
||||||
|
heatmap = torch.sigmoid(preds[0][0]).cpu()
|
||||||
|
return SegmentedGrayscale(image, heatmap)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -10,28 +10,31 @@ import sys
|
|||||||
import transformers
|
import transformers
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
|
import torch
|
||||||
import urllib.request
|
import urllib.request
|
||||||
|
import zipfile
|
||||||
|
import traceback
|
||||||
|
|
||||||
transformers.logging.set_verbosity_error()
|
transformers.logging.set_verbosity_error()
|
||||||
|
|
||||||
# this will preload the Bert tokenizer fles
|
# this will preload the Bert tokenizer fles
|
||||||
print('preloading bert tokenizer...', end='')
|
print('Loading bert tokenizer (ignore deprecation errors)...', end='')
|
||||||
|
with warnings.catch_warnings():
|
||||||
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
|
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
||||||
|
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
|
||||||
print('...success')
|
print('...success')
|
||||||
|
sys.stdout.flush()
|
||||||
|
|
||||||
# this will download requirements for Kornia
|
# this will download requirements for Kornia
|
||||||
print('preloading Kornia requirements...', end='')
|
print('Loading Kornia requirements...', end='')
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
||||||
import kornia
|
import kornia
|
||||||
print('...success')
|
print('...success')
|
||||||
|
|
||||||
version = 'openai/clip-vit-large-patch14'
|
version = 'openai/clip-vit-large-patch14'
|
||||||
|
|
||||||
print('preloading CLIP model...',end='')
|
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
|
print('Loading CLIP model...',end='')
|
||||||
tokenizer = CLIPTokenizer.from_pretrained(version)
|
tokenizer = CLIPTokenizer.from_pretrained(version)
|
||||||
transformer = CLIPTextModel.from_pretrained(version)
|
transformer = CLIPTextModel.from_pretrained(version)
|
||||||
print('...success')
|
print('...success')
|
||||||
@ -61,7 +64,6 @@ if gfpgan:
|
|||||||
FaceRestoreHelper(1, det_model='retinaface_resnet50')
|
FaceRestoreHelper(1, det_model='retinaface_resnet50')
|
||||||
print('...success')
|
print('...success')
|
||||||
except Exception:
|
except Exception:
|
||||||
import traceback
|
|
||||||
print('Error loading ESRGAN:')
|
print('Error loading ESRGAN:')
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
|
|
||||||
@ -89,13 +91,11 @@ if gfpgan:
|
|||||||
urllib.request.urlretrieve(model_url,model_dest)
|
urllib.request.urlretrieve(model_url,model_dest)
|
||||||
print('...success')
|
print('...success')
|
||||||
except Exception:
|
except Exception:
|
||||||
import traceback
|
|
||||||
print('Error loading GFPGAN:')
|
print('Error loading GFPGAN:')
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
|
|
||||||
print('preloading CodeFormer model file...',end='')
|
print('preloading CodeFormer model file...',end='')
|
||||||
try:
|
try:
|
||||||
import urllib.request
|
|
||||||
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
|
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
|
||||||
model_dest = 'ldm/invoke/restoration/codeformer/weights/codeformer.pth'
|
model_dest = 'ldm/invoke/restoration/codeformer/weights/codeformer.pth'
|
||||||
if not os.path.exists(model_dest):
|
if not os.path.exists(model_dest):
|
||||||
@ -103,7 +103,27 @@ try:
|
|||||||
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
|
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
|
||||||
urllib.request.urlretrieve(model_url,model_dest)
|
urllib.request.urlretrieve(model_url,model_dest)
|
||||||
except Exception:
|
except Exception:
|
||||||
import traceback
|
|
||||||
print('Error loading CodeFormer:')
|
print('Error loading CodeFormer:')
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
print('...success')
|
print('...success')
|
||||||
|
|
||||||
|
print('Loading clipseq model for text-based masking...',end='')
|
||||||
|
try:
|
||||||
|
model_url = 'https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download'
|
||||||
|
model_dest = 'src/clipseg/clipseg_weights.zip'
|
||||||
|
if not os.path.exists(model_dest):
|
||||||
|
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
|
||||||
|
urllib.request.urlretrieve(model_url,model_dest)
|
||||||
|
with zipfile.ZipFile(model_dest,'r') as zip:
|
||||||
|
zip.extractall('src/clipseg')
|
||||||
|
os.rename('src/clipseg/clipseg_weights','src/clipseg/weights')
|
||||||
|
from models.clipseg import CLIPDensePredT
|
||||||
|
model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, )
|
||||||
|
model.eval()
|
||||||
|
model.load_state_dict(torch.load('src/clipseg/weights/rd64-uni-refined.pth'), strict=False)
|
||||||
|
except Exception:
|
||||||
|
print('Error installing clipseg model:')
|
||||||
|
print(traceback.format_exc())
|
||||||
|
print('...success')
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user