From 493eaa7389cf1c63da15cf961fb35dd79bcda355 Mon Sep 17 00:00:00 2001 From: Kyle Schouviller Date: Sat, 22 Oct 2022 14:56:33 -0700 Subject: [PATCH] Improve inpainting by color-correcting result and pasting init image over result using mask --- ldm/generate.py | 16 +++++++---- ldm/invoke/generator/inpaint.py | 49 +++++++++++++++++++++++++++++---- 2 files changed, 55 insertions(+), 10 deletions(-) diff --git a/ldm/generate.py b/ldm/generate.py index 1604dabd65..964863ce6e 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -271,6 +271,8 @@ class Generate: upscale = None, # this is specific to inpainting and causes more extreme inpainting inpaint_replace = 0.0, + # This will help match inpainted areas to the original image more smoothly + mask_blur_radius: int = 8, # Set this True to handle KeyboardInterrupt internally catch_interrupts = False, hires_fix = False, @@ -391,7 +393,7 @@ class Generate: log_tokens =self.log_tokenization ) - init_image,mask_image = self._make_images( + init_image,mask_image,pil_image,pil_mask = self._make_images( init_img, init_mask, width, @@ -431,6 +433,8 @@ class Generate: height=height, init_img=init_img, # embiggen needs to manipulate from the unmodified init_img init_image=init_image, # notice that init_image is different from init_img + pil_image=pil_image, + pil_mask=pil_mask, mask_image=mask_image, strength=strength, threshold=threshold, @@ -438,6 +442,7 @@ class Generate: embiggen=embiggen, embiggen_tiles=embiggen_tiles, inpaint_replace=inpaint_replace, + mask_blur_radius=mask_blur_radius ) if init_color: @@ -621,7 +626,7 @@ class Generate: init_image = None init_mask = None if not img: - return None, None + return None, None, None, None image = self._load_img(img) @@ -647,7 +652,7 @@ class Generate: elif text_mask: init_mask = self._txt2mask(image, text_mask, width, height, fit=fit) - return init_image, init_mask + return init_image, init_mask, image, mask_image def _make_base(self): if not self.generators.get('base'): @@ -895,8 +900,9 @@ class Generate: # The mask is expected to have the region to be inpainted # with alpha transparency. It converts it into a black/white # image with the transparent part black. - def _image_to_mask(self, mask_image, invert=False) -> Image: - if mask_image.mode in ('L','RGB'): + def _image_to_mask(self, mask_image: Image.Image, invert=False) -> Image: + # Obtain the mask from the transparency channel + if mask_image.mode == 'L': mask = mask_image else: # Obtain the mask from the transparency channel diff --git a/ldm/invoke/generator/inpaint.py b/ldm/invoke/generator/inpaint.py index bc4b6133b3..ee67b90c46 100644 --- a/ldm/invoke/generator/inpaint.py +++ b/ldm/invoke/generator/inpaint.py @@ -3,7 +3,11 @@ ldm.invoke.generator.inpaint descends from ldm.invoke.generator ''' import torch +import torchvision.transforms as T import numpy as np +import cv2 as cv +from PIL import Image, ImageFilter +from skimage.exposure.histogram_matching import match_histograms from einops import rearrange, repeat from ldm.invoke.devices import choose_autocast from ldm.invoke.generator.img2img import Img2Img @@ -18,12 +22,27 @@ class Inpaint(Img2Img): @torch.no_grad() def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, conditioning,init_image,mask_image,strength, - step_callback=None,inpaint_replace=False,**kwargs): + pil_image: Image.Image, pil_mask: Image.Image, + mask_blur_radius: int = 8, + step_callback=None,inpaint_replace=False, **kwargs): """ Returns a function returning an image derived from the prompt and the initial image + mask. Return value depends on the seed at the time you call it. kwargs are 'init_latent' and 'strength' """ + + # Get the alpha channel of the mask + pil_init_mask = pil_mask.getchannel('A') + pil_init_image = pil_image.convert('RGBA') # Add an alpha channel if one doesn't exist + + # Build an image with only visible pixels from source to use as reference for color-matching. + # Note that this doesn't use the mask, which would exclude some source image pixels from the + # histogram and cause slight color changes. + init_rgb_pixels = np.asarray(pil_image.convert('RGB'), dtype=np.uint8).reshape(pil_image.width * pil_image.height, 3) + init_a_pixels = np.asarray(pil_init_image.getchannel('A'), dtype=np.uint8).reshape(pil_init_mask.width * pil_init_mask.height) + init_rgb_pixels = init_rgb_pixels[init_a_pixels > 0] + init_rgb_pixels = init_rgb_pixels.reshape(1, init_rgb_pixels.shape[0], init_rgb_pixels.shape[1]) # Filter to just pixels that have any alpha, this is now our histogram + # klms samplers not supported yet, so ignore previous sampler if isinstance(sampler,KSampler): print( @@ -78,9 +97,29 @@ class Inpaint(Img2Img): init_latent = self.init_latent ) - return self.sample_to_image(samples) + # Get PIL result + gen_result = self.sample_to_image(samples).convert('RGB') + + # Get numpy version + np_gen_result = np.asarray(gen_result, dtype=np.uint8) + + # Color correct + np_matched_result = match_histograms(np_gen_result, init_rgb_pixels, channel_axis=-1) + matched_result = Image.fromarray(np_matched_result, mode='RGB') + + + # Blur the mask out (into init image) by specified amount + if mask_blur_radius > 0: + nm = np.asarray(pil_init_mask, dtype=np.uint8) + nmd = cv.dilate(nm, kernel=np.ones((3,3), dtype=np.uint8), iterations=int(mask_blur_radius / 2)) + pmd = Image.fromarray(nmd, mode='L') + blurred_init_mask = pmd.filter(ImageFilter.BoxBlur(mask_blur_radius)) + else: + blurred_init_mask = pil_init_mask + + # Paste original on color-corrected generation (using blurred mask) + matched_result.paste(pil_image, (0,0), mask = blurred_init_mask) + + return matched_result return make_image - - -