mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
txt2mask.py now tracking development again
This commit is contained in:
parent
2bf9f1f0d8
commit
e574a1574f
@ -29,9 +29,9 @@ work fine.
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from clipseg_models.clipseg import CLIPDensePredT
|
from models.clipseg import CLIPDensePredT
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from PIL import Image, ImageOps
|
from PIL import Image
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
|
|
||||||
CLIP_VERSION = 'ViT-B/16'
|
CLIP_VERSION = 'ViT-B/16'
|
||||||
@ -50,14 +50,9 @@ class SegmentedGrayscale(object):
|
|||||||
discrete_heatmap = self.heatmap.lt(threshold).int()
|
discrete_heatmap = self.heatmap.lt(threshold).int()
|
||||||
return self._rescale(Image.fromarray(np.uint8(discrete_heatmap*255),mode='L'))
|
return self._rescale(Image.fromarray(np.uint8(discrete_heatmap*255),mode='L'))
|
||||||
|
|
||||||
def to_transparent(self,invert:bool=False)->Image:
|
def to_transparent(self)->Image:
|
||||||
transparent_image = self.image.copy()
|
transparent_image = self.image.copy()
|
||||||
gs = self.to_grayscale()
|
transparent_image.putalpha(self.to_grayscale())
|
||||||
# The following line looks like a bug, but isn't.
|
|
||||||
# For img2img, we want the selected regions to be transparent,
|
|
||||||
# but to_grayscale() returns the opposite.
|
|
||||||
gs = ImageOps.invert(gs) if not invert else gs
|
|
||||||
transparent_image.putalpha(gs)
|
|
||||||
return transparent_image
|
return transparent_image
|
||||||
|
|
||||||
# unscales and uncrops the 352x352 heatmap so that it matches the image again
|
# unscales and uncrops the 352x352 heatmap so that it matches the image again
|
||||||
@ -84,7 +79,7 @@ class Txt2Mask(object):
|
|||||||
self.model.load_state_dict(torch.load(CLIPSEG_WEIGHTS, map_location=torch.device('cpu')), strict=False)
|
self.model.load_state_dict(torch.load(CLIPSEG_WEIGHTS, map_location=torch.device('cpu')), strict=False)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def segment(self, image, prompt:str) -> SegmentedGrayscale:
|
def segment(self, image:Image, prompt:str) -> SegmentedGrayscale:
|
||||||
'''
|
'''
|
||||||
Given a prompt string such as "a bagel", tries to identify the object in the
|
Given a prompt string such as "a bagel", tries to identify the object in the
|
||||||
provided image and returns a SegmentedGrayscale object in which the brighter
|
provided image and returns a SegmentedGrayscale object in which the brighter
|
||||||
@ -99,10 +94,6 @@ class Txt2Mask(object):
|
|||||||
transforms.Resize((CLIPSEG_SIZE, CLIPSEG_SIZE)), # must be multiple of 64...
|
transforms.Resize((CLIPSEG_SIZE, CLIPSEG_SIZE)), # must be multiple of 64...
|
||||||
])
|
])
|
||||||
|
|
||||||
if type(image) is str:
|
|
||||||
image = Image.open(image).convert('RGB')
|
|
||||||
|
|
||||||
image = ImageOps.exif_transpose(image)
|
|
||||||
img = self._scale_and_crop(image)
|
img = self._scale_and_crop(image)
|
||||||
img = transform(img).unsqueeze(0)
|
img = transform(img).unsqueeze(0)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user