ported code refactor changes from PR #1221

- pass a PIL.Image to img2img and inpaint rather than tensor
- To support clipseg, inpaint needs to accept an "L" or "1" format
  mask. Made the appropriate change.
This commit is contained in:
Lincoln Stein 2022-10-22 20:06:45 -04:00
parent f25c1f900f
commit 9472945299
3 changed files with 88 additions and 82 deletions

View File

@ -393,7 +393,7 @@ class Generate:
log_tokens =self.log_tokenization log_tokens =self.log_tokenization
) )
init_image,mask_image,pil_image,pil_mask = self._make_images( init_image, mask_image = self._make_images(
init_img, init_img,
init_mask, init_mask,
width, width,
@ -433,8 +433,6 @@ class Generate:
height=height, height=height,
init_img=init_img, # embiggen needs to manipulate from the unmodified init_img init_img=init_img, # embiggen needs to manipulate from the unmodified init_img
init_image=init_image, # notice that init_image is different from init_img init_image=init_image, # notice that init_image is different from init_img
pil_image=pil_image,
pil_mask=pil_mask,
mask_image=mask_image, mask_image=mask_image,
strength=strength, strength=strength,
threshold=threshold, threshold=threshold,
@ -626,7 +624,7 @@ class Generate:
init_image = None init_image = None
init_mask = None init_mask = None
if not img: if not img:
return None, None, None, None return None, None
image = self._load_img(img) image = self._load_img(img)
@ -636,23 +634,22 @@ class Generate:
# if image has a transparent area and no mask was provided, then try to generate mask # if image has a transparent area and no mask was provided, then try to generate mask
if self._has_transparency(image): if self._has_transparency(image):
self._transparency_check_and_warning(image, mask) self._transparency_check_and_warning(image, mask)
# this returns a torch tensor
init_mask = self._create_init_mask(image, width, height, fit=fit) init_mask = self._create_init_mask(image, width, height, fit=fit)
if (image.width * image.height) > (self.width * self.height) and self.size_matters: if (image.width * image.height) > (self.width * self.height) and self.size_matters:
print(">> This input is larger than your defaults. If you run out of memory, please use a smaller image.") print(">> This input is larger than your defaults. If you run out of memory, please use a smaller image.")
self.size_matters = False self.size_matters = False
init_image = self._create_init_image(image,width,height,fit=fit) # this returns a torch tensor init_image = self._create_init_image(image,width,height,fit=fit)
if mask: if mask:
mask_image = self._load_img(mask) # this returns an Image mask_image = self._load_img(mask)
init_mask = self._create_init_mask(mask_image,width,height,fit=fit) init_mask = self._create_init_mask(mask_image,width,height,fit=fit)
elif text_mask: elif text_mask:
init_mask = self._txt2mask(image, text_mask, width, height, fit=fit) init_mask = self._txt2mask(image, text_mask, width, height, fit=fit)
return init_image, init_mask, image, mask_image return init_image,init_mask
def _make_base(self): def _make_base(self):
if not self.generators.get('base'): if not self.generators.get('base'):
@ -869,33 +866,15 @@ class Generate:
def _create_init_image(self, image, width, height, fit=True): def _create_init_image(self, image, width, height, fit=True):
image = image.convert('RGB') image = image.convert('RGB')
if fit: image = self._fit_image(image, (width, height)) if fit else self._squeeze_image(image)
image = self._fit_image(image, (width, height)) return image
else:
image = self._squeeze_image(image)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
image = 2.0 * image - 1.0
return image.to(self.device)
def _create_init_mask(self, image, width, height, fit=True): def _create_init_mask(self, image, width, height, fit=True):
# convert into a black/white mask # convert into a black/white mask
image = self._image_to_mask(image) image = self._image_to_mask(image)
image = image.convert('RGB') image = image.convert('RGB')
image = self._fit_image(image, (width, height)) if fit else self._squeeze_image(image)
# now we adjust the size return image
if fit:
image = self._fit_image(image, (width, height))
else:
image = self._squeeze_image(image)
image = image.resize((image.width//downsampling, image.height //
downsampling), resample=Image.Resampling.NEAREST)
image = np.array(image)
image = image.astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return image.to(self.device)
# The mask is expected to have the region to be inpainted # The mask is expected to have the region to be inpainted
# with alpha transparency. It converts it into a black/white # with alpha transparency. It converts it into a black/white
@ -912,7 +891,6 @@ class Generate:
mask = ImageOps.invert(mask) mask = ImageOps.invert(mask)
return mask return mask
# TODO: The latter part of this method repeats code from _create_init_mask()
def _txt2mask(self, image:Image, text_mask:list, width, height, fit=True) -> Image: def _txt2mask(self, image:Image, text_mask:list, width, height, fit=True) -> Image:
prompt = text_mask[0] prompt = text_mask[0]
confidence_level = text_mask[1] if len(text_mask)>1 else 0.5 confidence_level = text_mask[1] if len(text_mask)>1 else 0.5
@ -922,18 +900,8 @@ class Generate:
segmented = self.txt2mask.segment(image, prompt) segmented = self.txt2mask.segment(image, prompt)
mask = segmented.to_mask(float(confidence_level)) mask = segmented.to_mask(float(confidence_level))
mask = mask.convert('RGB') mask = mask.convert('RGB')
# now we adjust the size mask = self._fit_image(mask, (width, height)) if fit else self._squeeze_image(mask)
if fit: return mask
mask = self._fit_image(mask, (width, height))
else:
mask = self._squeeze_image(mask)
mask = mask.resize((mask.width//downsampling, mask.height //
downsampling), resample=Image.Resampling.NEAREST)
mask = np.array(mask)
mask = mask.astype(np.float32) / 255.0
mask = mask[None].transpose(0, 3, 1, 2)
mask = torch.from_numpy(mask)
return mask.to(self.device)
def _has_transparency(self, image): def _has_transparency(self, image):
if image.info.get("transparency", None) is not None: if image.info.get("transparency", None) is not None:

View File

@ -4,9 +4,12 @@ ldm.invoke.generator.img2img descends from ldm.invoke.generator
import torch import torch
import numpy as np import numpy as np
from ldm.invoke.devices import choose_autocast import PIL
from ldm.invoke.generator.base import Generator from torch import Tensor
from ldm.models.diffusion.ddim import DDIMSampler from PIL import Image
from ldm.invoke.devices import choose_autocast
from ldm.invoke.generator.base import Generator
from ldm.models.diffusion.ddim import DDIMSampler
class Img2Img(Generator): class Img2Img(Generator):
def __init__(self, model, precision): def __init__(self, model, precision):
@ -25,6 +28,9 @@ class Img2Img(Generator):
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
) )
if isinstance(init_image, PIL.Image.Image):
init_image = self._image_to_tensor(init_image)
scope = choose_autocast(self.precision) scope = choose_autocast(self.precision)
with scope(self.model.device.type): with scope(self.model.device.type):
self.init_latent = self.model.get_first_stage_encoding( self.init_latent = self.model.get_first_stage_encoding(
@ -68,3 +74,11 @@ class Img2Img(Generator):
shape = init_latent.shape shape = init_latent.shape
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2]) x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2])
return x return x
def _image_to_tensor(self, image:Image, normalize:bool=True)->Tensor:
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
if normalize:
image = 2.0 * image - 1.0
return image.to(self.model.device)

View File

@ -6,6 +6,7 @@ import torch
import torchvision.transforms as T import torchvision.transforms as T
import numpy as np import numpy as np
import cv2 as cv import cv2 as cv
import PIL
from PIL import Image, ImageFilter from PIL import Image, ImageFilter
from skimage.exposure.histogram_matching import match_histograms from skimage.exposure.histogram_matching import match_histograms
from einops import rearrange, repeat from einops import rearrange, repeat
@ -13,16 +14,19 @@ from ldm.invoke.devices import choose_autocast
from ldm.invoke.generator.img2img import Img2Img from ldm.invoke.generator.img2img import Img2Img
from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.ksampler import KSampler from ldm.models.diffusion.ksampler import KSampler
from ldm.invoke.generator.base import downsampling
class Inpaint(Img2Img): class Inpaint(Img2Img):
def __init__(self, model, precision): def __init__(self, model, precision):
self.init_latent = None self.init_latent = None
self.pil_image = None
self.pil_mask = None
self.mask_blur_radius = 0
super().__init__(model, precision) super().__init__(model, precision)
@torch.no_grad() @torch.no_grad()
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
conditioning,init_image,mask_image,strength, conditioning,init_image,mask_image,strength,
pil_image: Image.Image, pil_mask: Image.Image,
mask_blur_radius: int = 8, mask_blur_radius: int = 8,
step_callback=None,inpaint_replace=False, **kwargs): step_callback=None,inpaint_replace=False, **kwargs):
""" """
@ -31,17 +35,21 @@ class Inpaint(Img2Img):
the time you call it. kwargs are 'init_latent' and 'strength' the time you call it. kwargs are 'init_latent' and 'strength'
""" """
# Get the alpha channel of the mask if isinstance(init_image, PIL.Image.Image):
pil_init_mask = pil_mask.getchannel('A') self.pil_image = init_image
pil_init_image = pil_image.convert('RGBA') # Add an alpha channel if one doesn't exist init_image = self._image_to_tensor(init_image)
# Build an image with only visible pixels from source to use as reference for color-matching. if isinstance(mask_image, PIL.Image.Image):
# Note that this doesn't use the mask, which would exclude some source image pixels from the self.pil_mask = mask_image
# histogram and cause slight color changes. mask_image = mask_image.resize(
init_rgb_pixels = np.asarray(pil_image.convert('RGB'), dtype=np.uint8).reshape(pil_image.width * pil_image.height, 3) (
init_a_pixels = np.asarray(pil_init_image.getchannel('A'), dtype=np.uint8).reshape(pil_init_mask.width * pil_init_mask.height) mask_image.width // downsampling,
init_rgb_pixels = init_rgb_pixels[init_a_pixels > 0] mask_image.height // downsampling
init_rgb_pixels = init_rgb_pixels.reshape(1, init_rgb_pixels.shape[0], init_rgb_pixels.shape[1]) # Filter to just pixels that have any alpha, this is now our histogram ),
resample=Image.Resampling.NEAREST
)
mask_image = self._image_to_tensor(mask_image,normalize=False)
self.mask_blur_radius = mask_blur_radius
# klms samplers not supported yet, so ignore previous sampler # klms samplers not supported yet, so ignore previous sampler
if isinstance(sampler,KSampler): if isinstance(sampler,KSampler):
@ -96,30 +104,46 @@ class Inpaint(Img2Img):
mask = mask_image, mask = mask_image,
init_latent = self.init_latent init_latent = self.init_latent
) )
return self.sample_to_image(samples)
# Get PIL result
gen_result = self.sample_to_image(samples).convert('RGB')
# Get numpy version
np_gen_result = np.asarray(gen_result, dtype=np.uint8)
# Color correct
np_matched_result = match_histograms(np_gen_result, init_rgb_pixels, channel_axis=-1)
matched_result = Image.fromarray(np_matched_result, mode='RGB')
# Blur the mask out (into init image) by specified amount
if mask_blur_radius > 0:
nm = np.asarray(pil_init_mask, dtype=np.uint8)
nmd = cv.dilate(nm, kernel=np.ones((3,3), dtype=np.uint8), iterations=int(mask_blur_radius / 2))
pmd = Image.fromarray(nmd, mode='L')
blurred_init_mask = pmd.filter(ImageFilter.BoxBlur(mask_blur_radius))
else:
blurred_init_mask = pil_init_mask
# Paste original on color-corrected generation (using blurred mask)
matched_result.paste(pil_image, (0,0), mask = blurred_init_mask)
return matched_result
return make_image return make_image
def sample_to_image(self, samples)->Image:
gen_result = super().sample_to_image(samples).convert('RGB')
pil_mask = self.pil_mask
pil_image = self.pil_image
mask_blur_radius = self.mask_blur_radius
# Get the original alpha channel of the mask
pil_init_mask = pil_mask.convert('L')
pil_init_image = pil_image.convert('RGBA') # Add an alpha channel if one doesn't exist
# Build an image with only visible pixels from source to use as reference for color-matching.
# Note that this doesn't use the mask, which would exclude some source image pixels from the
# histogram and cause slight color changes.
init_rgb_pixels = np.asarray(pil_image.convert('RGB'), dtype=np.uint8).reshape(pil_image.width * pil_image.height, 3)
init_a_pixels = np.asarray(pil_init_image.getchannel('A'), dtype=np.uint8).reshape(pil_init_mask.width * pil_init_mask.height)
init_rgb_pixels = init_rgb_pixels[init_a_pixels > 0]
init_rgb_pixels = init_rgb_pixels.reshape(1, init_rgb_pixels.shape[0], init_rgb_pixels.shape[1]) # Filter to just pixels that have any alpha, this is now our histogram
# Get numpy version
np_gen_result = np.asarray(gen_result, dtype=np.uint8)
# Color correct
np_matched_result = match_histograms(np_gen_result, init_rgb_pixels, channel_axis=-1)
matched_result = Image.fromarray(np_matched_result, mode='RGB')
# Blur the mask out (into init image) by specified amount
if mask_blur_radius > 0:
nm = np.asarray(pil_init_mask, dtype=np.uint8)
nmd = cv.dilate(nm, kernel=np.ones((3,3), dtype=np.uint8), iterations=int(mask_blur_radius / 2))
pmd = Image.fromarray(nmd, mode='L')
blurred_init_mask = pmd.filter(ImageFilter.BoxBlur(mask_blur_radius))
else:
blurred_init_mask = pil_init_mask
# Paste original on color-corrected generation (using blurred mask)
matched_result.paste(pil_image, (0,0), mask = blurred_init_mask)
return matched_result