From f1ca78909722b7544529c62fc064d105a2951e97 Mon Sep 17 00:00:00 2001 From: Kyle Schouviller Date: Tue, 25 Oct 2022 17:10:28 -0700 Subject: [PATCH 1/6] Better inpainting color-correction --- ldm/invoke/generator/inpaint.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/ldm/invoke/generator/inpaint.py b/ldm/invoke/generator/inpaint.py index 8fbcf249aa..a6843599a9 100644 --- a/ldm/invoke/generator/inpaint.py +++ b/ldm/invoke/generator/inpaint.py @@ -125,18 +125,26 @@ class Inpaint(Img2Img): pil_init_image = pil_image.convert('RGBA') # Add an alpha channel if one doesn't exist # Build an image with only visible pixels from source to use as reference for color-matching. - # Note that this doesn't use the mask, which would exclude some source image pixels from the - # histogram and cause slight color changes. - init_rgb_pixels = np.asarray(pil_image.convert('RGB'), dtype=np.uint8).reshape(pil_image.width * pil_image.height, 3) - init_a_pixels = np.asarray(pil_init_image.getchannel('A'), dtype=np.uint8).reshape(pil_init_mask.width * pil_init_mask.height) - init_rgb_pixels = init_rgb_pixels[init_a_pixels > 0] - init_rgb_pixels = init_rgb_pixels.reshape(1, init_rgb_pixels.shape[0], init_rgb_pixels.shape[1]) # Filter to just pixels that have any alpha, this is now our histogram + init_rgb_pixels = np.asarray(pil_image.convert('RGB'), dtype=np.uint8) + init_a_pixels = np.asarray(pil_init_image.getchannel('A'), dtype=np.uint8) + init_mask_pixels = np.asarray(pil_init_mask, dtype=np.uint8) - # Get numpy version + # Get numpy version of result np_gen_result = np.asarray(gen_result, dtype=np.uint8) + # Mask and calculate mean and standard deviation + mask_pixels = init_a_pixels * init_mask_pixels > 0 + np_init_rgb_pixels_masked = init_rgb_pixels[mask_pixels, :] + np_gen_result_masked = np_gen_result[mask_pixels, :] + + init_means = np_init_rgb_pixels_masked.mean(axis=0) + init_std = np_init_rgb_pixels_masked.std(axis=0) + gen_means = np_gen_result_masked.mean(axis=0) + gen_std = np_gen_result_masked.std(axis=0) + # Color correct - np_matched_result = match_histograms(np_gen_result, init_rgb_pixels, channel_axis=-1) + np_matched_result = np_gen_result.copy() + np_matched_result[:,:,:] = (((np_matched_result[:,:,:].astype(np.float32) - gen_means[None,None,:]) / gen_std[None,None,:]) * init_std[None,None,:] + init_means[None,None,:]).clip(0, 255).astype(np.uint8) matched_result = Image.fromarray(np_matched_result, mode='RGB') # Blur the mask out (into init image) by specified amount @@ -151,4 +159,3 @@ class Inpaint(Img2Img): # Paste original on color-corrected generation (using blurred mask) matched_result.paste(pil_image, (0,0), mask = blurred_init_mask) return matched_result - From eaf6d46a7befe8a7da835df1e6145d2585b29aa1 Mon Sep 17 00:00:00 2001 From: Kyle Schouviller Date: Wed, 26 Oct 2022 00:39:36 -0700 Subject: [PATCH 2/6] Adding outpainting implementation (as part of inpaint). --- ldm/generate.py | 22 +++- ldm/invoke/generator/inpaint.py | 193 ++++++++++++++++++++++++++++---- 2 files changed, 189 insertions(+), 26 deletions(-) diff --git a/ldm/generate.py b/ldm/generate.py index 3ede1710e1..d6f402ab06 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -293,6 +293,13 @@ class Generate: catch_interrupts = False, hires_fix = False, use_mps_noise = False, + # Seam settings for outpainting + seam_size: int = 0, + seam_blur: int = 0, + seam_strength: float = 0.7, + seam_steps: int = 10, + tile_size: int = 32, + force_outpaint: bool = False, **args, ): # eat up additional cruft """ @@ -420,7 +427,7 @@ class Generate: ) # TODO: Hacky selection of operation to perform. Needs to be refactored. - if (init_image is not None) and (mask_image is not None): + if ((init_image is not None) and (mask_image is not None)) or force_outpaint: generator = self._make_inpaint() elif (embiggen != None or embiggen_tiles != None): generator = self._make_embiggen() @@ -464,7 +471,13 @@ class Generate: embiggen_tiles=embiggen_tiles, inpaint_replace=inpaint_replace, mask_blur_radius=mask_blur_radius, - safety_checker=checker + safety_checker=checker, + seam_size = seam_size, + seam_blur = seam_blur, + seam_strength = seam_strength, + seam_steps = seam_steps, + tile_size = tile_size, + force_outpaint = force_outpaint ) if init_color: @@ -888,8 +901,9 @@ class Generate: image = ImageOps.exif_transpose(image) return image - def _create_init_image(self, image, width, height, fit=True): - image = image.convert('RGB') + def _create_init_image(self, image: Image.Image, width, height, fit=True): + if image.mode != 'RGBA': + image = image.convert('RGB') image = self._fit_image(image, (width, height)) if fit else self._squeeze_image(image) return image diff --git a/ldm/invoke/generator/inpaint.py b/ldm/invoke/generator/inpaint.py index a6843599a9..be8f9d1b53 100644 --- a/ldm/invoke/generator/inpaint.py +++ b/ldm/invoke/generator/inpaint.py @@ -2,12 +2,13 @@ ldm.invoke.generator.inpaint descends from ldm.invoke.generator ''' +import math import torch import torchvision.transforms as T import numpy as np import cv2 as cv import PIL -from PIL import Image, ImageFilter +from PIL import Image, ImageFilter, ImageOps from skimage.exposure.histogram_matching import match_histograms from einops import rearrange, repeat from ldm.invoke.devices import choose_autocast @@ -24,11 +25,128 @@ class Inpaint(Img2Img): self.mask_blur_radius = 0 super().__init__(model, precision) + # Outpaint support code + def get_tile_images(self, image: np.ndarray, width=8, height=8): + _nrows, _ncols, depth = image.shape + _strides = image.strides + + nrows, _m = divmod(_nrows, height) + ncols, _n = divmod(_ncols, width) + if _m != 0 or _n != 0: + return None + + return np.lib.stride_tricks.as_strided( + np.ravel(image), + shape=(nrows, ncols, height, width, depth), + strides=(height * _strides[0], width * _strides[1], *_strides), + writeable=False + ) + + def tile_fill_missing(self, im: Image.Image, tile_size: int = 16, seed: int = None) -> Image: + a = np.asarray(im, dtype=np.uint8) + + tile_size = (tile_size, tile_size) + + # Get the image as tiles of a specified size + tiles = self.get_tile_images(a,*tile_size).copy() + + # Get the mask as tiles + tiles_mask = tiles[:,:,:,:,3] + + # Find any mask tiles with any fully transparent pixels (we will be replacing these later) + tmask_shape = tiles_mask.shape + tiles_mask = tiles_mask.reshape(math.prod(tiles_mask.shape)) + n,ny = (math.prod(tmask_shape[0:2])), math.prod(tmask_shape[2:]) + tiles_mask = (tiles_mask > 0) + tiles_mask = tiles_mask.reshape((n,ny)).all(axis = 1) + + # Get RGB tiles in single array and filter by the mask + tshape = tiles.shape + tiles_all = tiles.reshape((math.prod(tiles.shape[0:2]), * tiles.shape[2:])) + filtered_tiles = tiles_all[tiles_mask] + + if len(filtered_tiles) == 0: + return im + + # Find all invalid tiles and replace with a random valid tile + replace_count = (tiles_mask == False).sum() + rng = np.random.default_rng(seed = seed) + tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[rng.choice(filtered_tiles.shape[0], replace_count),:,:,:] + + # Convert back to an image + tiles_all = tiles_all.reshape(tshape) + tiles_all = tiles_all.swapaxes(1,2) + st = tiles_all.reshape((math.prod(tiles_all.shape[0:2]), math.prod(tiles_all.shape[2:4]), tiles_all.shape[4])) + si = Image.fromarray(st, mode='RGBA') + + return si + + + def mask_edge(self, mask: Image, edge_size: int, edge_blur: int) -> Image: + npimg = np.asarray(mask, dtype=np.uint8) + + # Detect any partially transparent regions + npgradient = np.uint8(255 * (1.0 - np.floor(np.abs(0.5 - np.float32(npimg) / 255.0) * 2.0))) + + # Detect hard edges + npedge = cv.Canny(npimg, threshold1=100, threshold2=200) + + # Combine + npmask = npgradient + npedge + + # Expand + npmask = cv.dilate(npmask, np.ones((3,3), np.uint8), iterations = int(edge_size / 2)) + + new_mask = Image.fromarray(npmask) + + if edge_blur > 0: + new_mask = new_mask.filter(ImageFilter.BoxBlur(edge_blur)) + + return ImageOps.invert(new_mask) + + + def seam_paint(self, + im: Image.Image, + seam_size: int, + seam_blur: int, + prompt,sampler,steps,cfg_scale,ddim_eta, + conditioning,strength, + noise + ) -> Image.Image: + hard_mask = self.pil_image.split()[-1].copy() + mask = self.mask_edge(hard_mask, seam_size, seam_blur) + + make_image = self.get_make_image( + prompt, + sampler, + steps, + cfg_scale, + ddim_eta, + conditioning, + init_image = im.copy().convert('RGBA'), + mask_image = mask.convert('RGB'), # Code currently requires an RGB mask + strength = strength, + mask_blur_radius = 0, + seam_size = 0 + ) + + result = make_image(noise) + + return result + + @torch.no_grad() def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, conditioning,init_image,mask_image,strength, mask_blur_radius: int = 8, - step_callback=None,inpaint_replace=False, **kwargs): + # Seam settings - when 0, doesn't fill seam + seam_size: int = 0, + seam_blur: int = 0, + seam_strength: float = 0.7, + seam_steps: int = 10, + tile_size: int = 32, + step_callback=None, + inpaint_replace=False, **kwargs): """ Returns a function returning an image derived from the prompt and the initial image + mask. Return value depends on the seed at @@ -37,7 +155,17 @@ class Inpaint(Img2Img): if isinstance(init_image, PIL.Image.Image): self.pil_image = init_image - init_image = self._image_to_tensor(init_image) + + # Fill missing areas of original image + init_filled = self.tile_fill_missing( + self.pil_image.copy(), + seed = self.seed, + tile_size = tile_size + ) + init_filled.paste(init_image, (0,0), init_image.split()[-1]) + + # Create init tensor + init_image = self._image_to_tensor(init_filled.convert('RGB')) if isinstance(mask_image, PIL.Image.Image): self.pil_mask = mask_image @@ -105,45 +233,55 @@ class Inpaint(Img2Img): mask = mask_image, init_latent = self.init_latent ) - return self.sample_to_image(samples) + + result = self.sample_to_image(samples) + + # Seam paint if this is our first pass (seam_size set to 0 during seam painting) + if seam_size > 0: + result = self.seam_paint( + result, + seam_size, + seam_blur, + prompt, + sampler, + seam_steps, + cfg_scale, + ddim_eta, + conditioning, + seam_strength, + x_T) + + return result return make_image - def sample_to_image(self, samples)->Image.Image: - gen_result = super().sample_to_image(samples).convert('RGB') - - if self.pil_image is None or self.pil_mask is None: - return gen_result - - pil_mask = self.pil_mask - pil_image = self.pil_image - mask_blur_radius = self.mask_blur_radius + def color_correct(self, image: Image.Image, base_image: Image.Image, mask: Image.Image, mask_blur_radius: int) -> Image.Image: # Get the original alpha channel of the mask if there is one. # Otherwise it is some other black/white image format ('1', 'L' or 'RGB') - pil_init_mask = pil_mask.getchannel('A') if pil_mask.mode == 'RGBA' else pil_mask.convert('L') - pil_init_image = pil_image.convert('RGBA') # Add an alpha channel if one doesn't exist + pil_init_mask = mask.getchannel('A') if mask.mode == 'RGBA' else mask.convert('L') + pil_init_image = base_image.convert('RGBA') # Add an alpha channel if one doesn't exist # Build an image with only visible pixels from source to use as reference for color-matching. - init_rgb_pixels = np.asarray(pil_image.convert('RGB'), dtype=np.uint8) + init_rgb_pixels = np.asarray(base_image.convert('RGB'), dtype=np.uint8) init_a_pixels = np.asarray(pil_init_image.getchannel('A'), dtype=np.uint8) init_mask_pixels = np.asarray(pil_init_mask, dtype=np.uint8) # Get numpy version of result - np_gen_result = np.asarray(gen_result, dtype=np.uint8) + np_image = np.asarray(image, dtype=np.uint8) # Mask and calculate mean and standard deviation mask_pixels = init_a_pixels * init_mask_pixels > 0 np_init_rgb_pixels_masked = init_rgb_pixels[mask_pixels, :] - np_gen_result_masked = np_gen_result[mask_pixels, :] + np_image_masked = np_image[mask_pixels, :] init_means = np_init_rgb_pixels_masked.mean(axis=0) init_std = np_init_rgb_pixels_masked.std(axis=0) - gen_means = np_gen_result_masked.mean(axis=0) - gen_std = np_gen_result_masked.std(axis=0) + gen_means = np_image_masked.mean(axis=0) + gen_std = np_image_masked.std(axis=0) # Color correct - np_matched_result = np_gen_result.copy() + np_matched_result = np_image.copy() np_matched_result[:,:,:] = (((np_matched_result[:,:,:].astype(np.float32) - gen_means[None,None,:]) / gen_std[None,None,:]) * init_std[None,None,:] + init_means[None,None,:]).clip(0, 255).astype(np.uint8) matched_result = Image.fromarray(np_matched_result, mode='RGB') @@ -157,5 +295,16 @@ class Inpaint(Img2Img): blurred_init_mask = pil_init_mask # Paste original on color-corrected generation (using blurred mask) - matched_result.paste(pil_image, (0,0), mask = blurred_init_mask) + matched_result.paste(base_image, (0,0), mask = blurred_init_mask) return matched_result + + + def sample_to_image(self, samples)->Image.Image: + gen_result = super().sample_to_image(samples).convert('RGB') + + if self.pil_image is None or self.pil_mask is None: + return gen_result + + corrected_result = self.color_correct(gen_result, self.pil_image, self.pil_mask, self.mask_blur_radius) + + return corrected_result From 2a44411f5b9faaaa603f07be0dceb364f7815825 Mon Sep 17 00:00:00 2001 From: Kyle Schouviller Date: Wed, 26 Oct 2022 12:09:38 -0700 Subject: [PATCH 3/6] Force RGB for img2img --- ldm/invoke/generator/img2img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ldm/invoke/generator/img2img.py b/ldm/invoke/generator/img2img.py index 613f1aca31..b9a721fcbb 100644 --- a/ldm/invoke/generator/img2img.py +++ b/ldm/invoke/generator/img2img.py @@ -29,7 +29,7 @@ class Img2Img(Generator): ) if isinstance(init_image, PIL.Image.Image): - init_image = self._image_to_tensor(init_image) + init_image = self._image_to_tensor(init_image.convert('RGB')) scope = choose_autocast(self.precision) with scope(self.model.device.type): From dac1ab0a05cb31d7a965f6577b273dc1ed479bb3 Mon Sep 17 00:00:00 2001 From: Kyle Schouviller Date: Tue, 25 Oct 2022 17:10:28 -0700 Subject: [PATCH 4/6] Better inpainting color-correction --- ldm/invoke/generator/inpaint.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/ldm/invoke/generator/inpaint.py b/ldm/invoke/generator/inpaint.py index 8fbcf249aa..a6843599a9 100644 --- a/ldm/invoke/generator/inpaint.py +++ b/ldm/invoke/generator/inpaint.py @@ -125,18 +125,26 @@ class Inpaint(Img2Img): pil_init_image = pil_image.convert('RGBA') # Add an alpha channel if one doesn't exist # Build an image with only visible pixels from source to use as reference for color-matching. - # Note that this doesn't use the mask, which would exclude some source image pixels from the - # histogram and cause slight color changes. - init_rgb_pixels = np.asarray(pil_image.convert('RGB'), dtype=np.uint8).reshape(pil_image.width * pil_image.height, 3) - init_a_pixels = np.asarray(pil_init_image.getchannel('A'), dtype=np.uint8).reshape(pil_init_mask.width * pil_init_mask.height) - init_rgb_pixels = init_rgb_pixels[init_a_pixels > 0] - init_rgb_pixels = init_rgb_pixels.reshape(1, init_rgb_pixels.shape[0], init_rgb_pixels.shape[1]) # Filter to just pixels that have any alpha, this is now our histogram + init_rgb_pixels = np.asarray(pil_image.convert('RGB'), dtype=np.uint8) + init_a_pixels = np.asarray(pil_init_image.getchannel('A'), dtype=np.uint8) + init_mask_pixels = np.asarray(pil_init_mask, dtype=np.uint8) - # Get numpy version + # Get numpy version of result np_gen_result = np.asarray(gen_result, dtype=np.uint8) + # Mask and calculate mean and standard deviation + mask_pixels = init_a_pixels * init_mask_pixels > 0 + np_init_rgb_pixels_masked = init_rgb_pixels[mask_pixels, :] + np_gen_result_masked = np_gen_result[mask_pixels, :] + + init_means = np_init_rgb_pixels_masked.mean(axis=0) + init_std = np_init_rgb_pixels_masked.std(axis=0) + gen_means = np_gen_result_masked.mean(axis=0) + gen_std = np_gen_result_masked.std(axis=0) + # Color correct - np_matched_result = match_histograms(np_gen_result, init_rgb_pixels, channel_axis=-1) + np_matched_result = np_gen_result.copy() + np_matched_result[:,:,:] = (((np_matched_result[:,:,:].astype(np.float32) - gen_means[None,None,:]) / gen_std[None,None,:]) * init_std[None,None,:] + init_means[None,None,:]).clip(0, 255).astype(np.uint8) matched_result = Image.fromarray(np_matched_result, mode='RGB') # Blur the mask out (into init image) by specified amount @@ -151,4 +159,3 @@ class Inpaint(Img2Img): # Paste original on color-corrected generation (using blurred mask) matched_result.paste(pil_image, (0,0), mask = blurred_init_mask) return matched_result - From bd8bb8c80b28e9df670ecd33b684764722e717b8 Mon Sep 17 00:00:00 2001 From: Kyle Schouviller Date: Wed, 26 Oct 2022 00:39:36 -0700 Subject: [PATCH 5/6] Adding outpainting implementation (as part of inpaint). --- ldm/generate.py | 22 +++- ldm/invoke/generator/inpaint.py | 193 ++++++++++++++++++++++++++++---- 2 files changed, 189 insertions(+), 26 deletions(-) diff --git a/ldm/generate.py b/ldm/generate.py index 3ede1710e1..d6f402ab06 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -293,6 +293,13 @@ class Generate: catch_interrupts = False, hires_fix = False, use_mps_noise = False, + # Seam settings for outpainting + seam_size: int = 0, + seam_blur: int = 0, + seam_strength: float = 0.7, + seam_steps: int = 10, + tile_size: int = 32, + force_outpaint: bool = False, **args, ): # eat up additional cruft """ @@ -420,7 +427,7 @@ class Generate: ) # TODO: Hacky selection of operation to perform. Needs to be refactored. - if (init_image is not None) and (mask_image is not None): + if ((init_image is not None) and (mask_image is not None)) or force_outpaint: generator = self._make_inpaint() elif (embiggen != None or embiggen_tiles != None): generator = self._make_embiggen() @@ -464,7 +471,13 @@ class Generate: embiggen_tiles=embiggen_tiles, inpaint_replace=inpaint_replace, mask_blur_radius=mask_blur_radius, - safety_checker=checker + safety_checker=checker, + seam_size = seam_size, + seam_blur = seam_blur, + seam_strength = seam_strength, + seam_steps = seam_steps, + tile_size = tile_size, + force_outpaint = force_outpaint ) if init_color: @@ -888,8 +901,9 @@ class Generate: image = ImageOps.exif_transpose(image) return image - def _create_init_image(self, image, width, height, fit=True): - image = image.convert('RGB') + def _create_init_image(self, image: Image.Image, width, height, fit=True): + if image.mode != 'RGBA': + image = image.convert('RGB') image = self._fit_image(image, (width, height)) if fit else self._squeeze_image(image) return image diff --git a/ldm/invoke/generator/inpaint.py b/ldm/invoke/generator/inpaint.py index a6843599a9..be8f9d1b53 100644 --- a/ldm/invoke/generator/inpaint.py +++ b/ldm/invoke/generator/inpaint.py @@ -2,12 +2,13 @@ ldm.invoke.generator.inpaint descends from ldm.invoke.generator ''' +import math import torch import torchvision.transforms as T import numpy as np import cv2 as cv import PIL -from PIL import Image, ImageFilter +from PIL import Image, ImageFilter, ImageOps from skimage.exposure.histogram_matching import match_histograms from einops import rearrange, repeat from ldm.invoke.devices import choose_autocast @@ -24,11 +25,128 @@ class Inpaint(Img2Img): self.mask_blur_radius = 0 super().__init__(model, precision) + # Outpaint support code + def get_tile_images(self, image: np.ndarray, width=8, height=8): + _nrows, _ncols, depth = image.shape + _strides = image.strides + + nrows, _m = divmod(_nrows, height) + ncols, _n = divmod(_ncols, width) + if _m != 0 or _n != 0: + return None + + return np.lib.stride_tricks.as_strided( + np.ravel(image), + shape=(nrows, ncols, height, width, depth), + strides=(height * _strides[0], width * _strides[1], *_strides), + writeable=False + ) + + def tile_fill_missing(self, im: Image.Image, tile_size: int = 16, seed: int = None) -> Image: + a = np.asarray(im, dtype=np.uint8) + + tile_size = (tile_size, tile_size) + + # Get the image as tiles of a specified size + tiles = self.get_tile_images(a,*tile_size).copy() + + # Get the mask as tiles + tiles_mask = tiles[:,:,:,:,3] + + # Find any mask tiles with any fully transparent pixels (we will be replacing these later) + tmask_shape = tiles_mask.shape + tiles_mask = tiles_mask.reshape(math.prod(tiles_mask.shape)) + n,ny = (math.prod(tmask_shape[0:2])), math.prod(tmask_shape[2:]) + tiles_mask = (tiles_mask > 0) + tiles_mask = tiles_mask.reshape((n,ny)).all(axis = 1) + + # Get RGB tiles in single array and filter by the mask + tshape = tiles.shape + tiles_all = tiles.reshape((math.prod(tiles.shape[0:2]), * tiles.shape[2:])) + filtered_tiles = tiles_all[tiles_mask] + + if len(filtered_tiles) == 0: + return im + + # Find all invalid tiles and replace with a random valid tile + replace_count = (tiles_mask == False).sum() + rng = np.random.default_rng(seed = seed) + tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[rng.choice(filtered_tiles.shape[0], replace_count),:,:,:] + + # Convert back to an image + tiles_all = tiles_all.reshape(tshape) + tiles_all = tiles_all.swapaxes(1,2) + st = tiles_all.reshape((math.prod(tiles_all.shape[0:2]), math.prod(tiles_all.shape[2:4]), tiles_all.shape[4])) + si = Image.fromarray(st, mode='RGBA') + + return si + + + def mask_edge(self, mask: Image, edge_size: int, edge_blur: int) -> Image: + npimg = np.asarray(mask, dtype=np.uint8) + + # Detect any partially transparent regions + npgradient = np.uint8(255 * (1.0 - np.floor(np.abs(0.5 - np.float32(npimg) / 255.0) * 2.0))) + + # Detect hard edges + npedge = cv.Canny(npimg, threshold1=100, threshold2=200) + + # Combine + npmask = npgradient + npedge + + # Expand + npmask = cv.dilate(npmask, np.ones((3,3), np.uint8), iterations = int(edge_size / 2)) + + new_mask = Image.fromarray(npmask) + + if edge_blur > 0: + new_mask = new_mask.filter(ImageFilter.BoxBlur(edge_blur)) + + return ImageOps.invert(new_mask) + + + def seam_paint(self, + im: Image.Image, + seam_size: int, + seam_blur: int, + prompt,sampler,steps,cfg_scale,ddim_eta, + conditioning,strength, + noise + ) -> Image.Image: + hard_mask = self.pil_image.split()[-1].copy() + mask = self.mask_edge(hard_mask, seam_size, seam_blur) + + make_image = self.get_make_image( + prompt, + sampler, + steps, + cfg_scale, + ddim_eta, + conditioning, + init_image = im.copy().convert('RGBA'), + mask_image = mask.convert('RGB'), # Code currently requires an RGB mask + strength = strength, + mask_blur_radius = 0, + seam_size = 0 + ) + + result = make_image(noise) + + return result + + @torch.no_grad() def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, conditioning,init_image,mask_image,strength, mask_blur_radius: int = 8, - step_callback=None,inpaint_replace=False, **kwargs): + # Seam settings - when 0, doesn't fill seam + seam_size: int = 0, + seam_blur: int = 0, + seam_strength: float = 0.7, + seam_steps: int = 10, + tile_size: int = 32, + step_callback=None, + inpaint_replace=False, **kwargs): """ Returns a function returning an image derived from the prompt and the initial image + mask. Return value depends on the seed at @@ -37,7 +155,17 @@ class Inpaint(Img2Img): if isinstance(init_image, PIL.Image.Image): self.pil_image = init_image - init_image = self._image_to_tensor(init_image) + + # Fill missing areas of original image + init_filled = self.tile_fill_missing( + self.pil_image.copy(), + seed = self.seed, + tile_size = tile_size + ) + init_filled.paste(init_image, (0,0), init_image.split()[-1]) + + # Create init tensor + init_image = self._image_to_tensor(init_filled.convert('RGB')) if isinstance(mask_image, PIL.Image.Image): self.pil_mask = mask_image @@ -105,45 +233,55 @@ class Inpaint(Img2Img): mask = mask_image, init_latent = self.init_latent ) - return self.sample_to_image(samples) + + result = self.sample_to_image(samples) + + # Seam paint if this is our first pass (seam_size set to 0 during seam painting) + if seam_size > 0: + result = self.seam_paint( + result, + seam_size, + seam_blur, + prompt, + sampler, + seam_steps, + cfg_scale, + ddim_eta, + conditioning, + seam_strength, + x_T) + + return result return make_image - def sample_to_image(self, samples)->Image.Image: - gen_result = super().sample_to_image(samples).convert('RGB') - - if self.pil_image is None or self.pil_mask is None: - return gen_result - - pil_mask = self.pil_mask - pil_image = self.pil_image - mask_blur_radius = self.mask_blur_radius + def color_correct(self, image: Image.Image, base_image: Image.Image, mask: Image.Image, mask_blur_radius: int) -> Image.Image: # Get the original alpha channel of the mask if there is one. # Otherwise it is some other black/white image format ('1', 'L' or 'RGB') - pil_init_mask = pil_mask.getchannel('A') if pil_mask.mode == 'RGBA' else pil_mask.convert('L') - pil_init_image = pil_image.convert('RGBA') # Add an alpha channel if one doesn't exist + pil_init_mask = mask.getchannel('A') if mask.mode == 'RGBA' else mask.convert('L') + pil_init_image = base_image.convert('RGBA') # Add an alpha channel if one doesn't exist # Build an image with only visible pixels from source to use as reference for color-matching. - init_rgb_pixels = np.asarray(pil_image.convert('RGB'), dtype=np.uint8) + init_rgb_pixels = np.asarray(base_image.convert('RGB'), dtype=np.uint8) init_a_pixels = np.asarray(pil_init_image.getchannel('A'), dtype=np.uint8) init_mask_pixels = np.asarray(pil_init_mask, dtype=np.uint8) # Get numpy version of result - np_gen_result = np.asarray(gen_result, dtype=np.uint8) + np_image = np.asarray(image, dtype=np.uint8) # Mask and calculate mean and standard deviation mask_pixels = init_a_pixels * init_mask_pixels > 0 np_init_rgb_pixels_masked = init_rgb_pixels[mask_pixels, :] - np_gen_result_masked = np_gen_result[mask_pixels, :] + np_image_masked = np_image[mask_pixels, :] init_means = np_init_rgb_pixels_masked.mean(axis=0) init_std = np_init_rgb_pixels_masked.std(axis=0) - gen_means = np_gen_result_masked.mean(axis=0) - gen_std = np_gen_result_masked.std(axis=0) + gen_means = np_image_masked.mean(axis=0) + gen_std = np_image_masked.std(axis=0) # Color correct - np_matched_result = np_gen_result.copy() + np_matched_result = np_image.copy() np_matched_result[:,:,:] = (((np_matched_result[:,:,:].astype(np.float32) - gen_means[None,None,:]) / gen_std[None,None,:]) * init_std[None,None,:] + init_means[None,None,:]).clip(0, 255).astype(np.uint8) matched_result = Image.fromarray(np_matched_result, mode='RGB') @@ -157,5 +295,16 @@ class Inpaint(Img2Img): blurred_init_mask = pil_init_mask # Paste original on color-corrected generation (using blurred mask) - matched_result.paste(pil_image, (0,0), mask = blurred_init_mask) + matched_result.paste(base_image, (0,0), mask = blurred_init_mask) return matched_result + + + def sample_to_image(self, samples)->Image.Image: + gen_result = super().sample_to_image(samples).convert('RGB') + + if self.pil_image is None or self.pil_mask is None: + return gen_result + + corrected_result = self.color_correct(gen_result, self.pil_image, self.pil_mask, self.mask_blur_radius) + + return corrected_result From d05373d35a03c58b8370d8ec3d8752b99bbde24c Mon Sep 17 00:00:00 2001 From: Kyle Schouviller Date: Wed, 26 Oct 2022 12:09:38 -0700 Subject: [PATCH 6/6] Force RGB for img2img --- ldm/invoke/generator/img2img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ldm/invoke/generator/img2img.py b/ldm/invoke/generator/img2img.py index 613f1aca31..b9a721fcbb 100644 --- a/ldm/invoke/generator/img2img.py +++ b/ldm/invoke/generator/img2img.py @@ -29,7 +29,7 @@ class Img2Img(Generator): ) if isinstance(init_image, PIL.Image.Image): - init_image = self._image_to_tensor(init_image) + init_image = self._image_to_tensor(init_image.convert('RGB')) scope = choose_autocast(self.precision) with scope(self.model.device.type):