mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
txt2mask.py now tracking development again
This commit is contained in:
parent
2bf9f1f0d8
commit
e574a1574f
@ -29,9 +29,9 @@ work fine.
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from clipseg_models.clipseg import CLIPDensePredT
|
||||
from models.clipseg import CLIPDensePredT
|
||||
from einops import rearrange, repeat
|
||||
from PIL import Image, ImageOps
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
|
||||
CLIP_VERSION = 'ViT-B/16'
|
||||
@ -50,14 +50,9 @@ class SegmentedGrayscale(object):
|
||||
discrete_heatmap = self.heatmap.lt(threshold).int()
|
||||
return self._rescale(Image.fromarray(np.uint8(discrete_heatmap*255),mode='L'))
|
||||
|
||||
def to_transparent(self,invert:bool=False)->Image:
|
||||
def to_transparent(self)->Image:
|
||||
transparent_image = self.image.copy()
|
||||
gs = self.to_grayscale()
|
||||
# The following line looks like a bug, but isn't.
|
||||
# For img2img, we want the selected regions to be transparent,
|
||||
# but to_grayscale() returns the opposite.
|
||||
gs = ImageOps.invert(gs) if not invert else gs
|
||||
transparent_image.putalpha(gs)
|
||||
transparent_image.putalpha(self.to_grayscale())
|
||||
return transparent_image
|
||||
|
||||
# unscales and uncrops the 352x352 heatmap so that it matches the image again
|
||||
@ -84,7 +79,7 @@ class Txt2Mask(object):
|
||||
self.model.load_state_dict(torch.load(CLIPSEG_WEIGHTS, map_location=torch.device('cpu')), strict=False)
|
||||
|
||||
@torch.no_grad()
|
||||
def segment(self, image, prompt:str) -> SegmentedGrayscale:
|
||||
def segment(self, image:Image, prompt:str) -> SegmentedGrayscale:
|
||||
'''
|
||||
Given a prompt string such as "a bagel", tries to identify the object in the
|
||||
provided image and returns a SegmentedGrayscale object in which the brighter
|
||||
@ -99,10 +94,6 @@ class Txt2Mask(object):
|
||||
transforms.Resize((CLIPSEG_SIZE, CLIPSEG_SIZE)), # must be multiple of 64...
|
||||
])
|
||||
|
||||
if type(image) is str:
|
||||
image = Image.open(image).convert('RGB')
|
||||
|
||||
image = ImageOps.exif_transpose(image)
|
||||
img = self._scale_and_crop(image)
|
||||
img = transform(img).unsqueeze(0)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user