reorder mask creation and scaling to avoid deleting colors under transparency

This commit is contained in:
Lincoln Stein 2022-10-03 21:08:32 -04:00
parent 2c8806341f
commit feb405f19a

View File

@ -19,6 +19,7 @@ import cv2
import skimage
from omegaconf import OmegaConf
from ldm.dream.generator.base import downsampling
from PIL import Image, ImageOps
from torch import nn
from pytorch_lightning import seed_everything, logging
@ -612,11 +613,7 @@ class Generate:
img,
width,
height,
fit=fit
) # this returns an Image
if out_direction:
image = self._create_outpaint_image(image, out_direction)
init_image = self._create_init_image(image) # this returns a torch tensor
) # this returns an Image
# if image has a transparent area and no mask was provided, then try to generate mask
if self._has_transparency(image) and not mask:
@ -629,12 +626,17 @@ class Generate:
'>> a transparency mask, or provide mask explicitly using --init_mask (-M).'
)
# this returns a torch tensor
init_mask = self._create_init_mask(image)
init_mask = self._create_init_mask(image,width,height,fit=fit)
if (image.width * image.height) > (self.width * self.height):
print(">> This input is larger than your defaults. If you run out of memory, please use a smaller image.")
init_image = self._create_init_image(image,width,height,fit=fit) # this returns a torch tensor
if mask:
mask_image = self._load_img(
mask, width, height, fit=fit) # this returns an Image
init_mask = self._create_init_mask(mask_image)
mask, width, height) # this returns an Image
init_mask = self._create_init_mask(mask_image,width,height,fit=fit)
return init_image, init_mask
@ -850,7 +852,7 @@ class Generate:
return model
def _load_img(self, img, width, height, fit=False):
def _load_img(self, img, width, height)->Image:
if isinstance(img, Image.Image):
image = img
print(
@ -868,14 +870,15 @@ class Generate:
print(
f'>> loaded input image of size {image.width}x{image.height}'
)
return image
def _create_init_image(self, image, width, height, fit=True):
image = image.convert('RGB')
if fit:
image = self._fit_image(image, (width, height))
else:
image = self._squeeze_image(image)
return image
def _create_init_image(self, image):
image = image.convert('RGB')
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
@ -942,18 +945,20 @@ class Generate:
return new_img
def _create_init_mask(self, image):
def _create_init_mask(self, image, width, height, fit=True):
# convert into a black/white mask
image = self._image_to_mask(image)
image = image.convert('RGB')
# BUG: We need to use the model's downsample factor rather than hardcoding "8"
from ldm.dream.generator.base import downsampling
# now we adjust the size
if fit:
image = self._fit_image(image, (width, height))
else:
image = self._squeeze_image(image)
image = image.resize((image.width//downsampling, image.height //
downsampling), resample=Image.Resampling.NEAREST)
# print(
# f'>> DEBUG: writing the mask to mask.png'
# )
# image.save('mask.png')
image = np.array(image)
image = image.astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
@ -1035,10 +1040,6 @@ class Generate:
height = h
width = w
resize_needed = True
if (width * height) > (self.width * self.height):
print(">> This input is larger than your defaults. If you run out of memory, please use a smaller image.")
return width, height, resize_needed