diff --git a/ldm/invoke/generator/inpaint.py b/ldm/invoke/generator/inpaint.py index 8fbcf249aa..a6843599a9 100644 --- a/ldm/invoke/generator/inpaint.py +++ b/ldm/invoke/generator/inpaint.py @@ -125,18 +125,26 @@ class Inpaint(Img2Img): pil_init_image = pil_image.convert('RGBA') # Add an alpha channel if one doesn't exist # Build an image with only visible pixels from source to use as reference for color-matching. - # Note that this doesn't use the mask, which would exclude some source image pixels from the - # histogram and cause slight color changes. - init_rgb_pixels = np.asarray(pil_image.convert('RGB'), dtype=np.uint8).reshape(pil_image.width * pil_image.height, 3) - init_a_pixels = np.asarray(pil_init_image.getchannel('A'), dtype=np.uint8).reshape(pil_init_mask.width * pil_init_mask.height) - init_rgb_pixels = init_rgb_pixels[init_a_pixels > 0] - init_rgb_pixels = init_rgb_pixels.reshape(1, init_rgb_pixels.shape[0], init_rgb_pixels.shape[1]) # Filter to just pixels that have any alpha, this is now our histogram + init_rgb_pixels = np.asarray(pil_image.convert('RGB'), dtype=np.uint8) + init_a_pixels = np.asarray(pil_init_image.getchannel('A'), dtype=np.uint8) + init_mask_pixels = np.asarray(pil_init_mask, dtype=np.uint8) - # Get numpy version + # Get numpy version of result np_gen_result = np.asarray(gen_result, dtype=np.uint8) + # Mask and calculate mean and standard deviation + mask_pixels = init_a_pixels * init_mask_pixels > 0 + np_init_rgb_pixels_masked = init_rgb_pixels[mask_pixels, :] + np_gen_result_masked = np_gen_result[mask_pixels, :] + + init_means = np_init_rgb_pixels_masked.mean(axis=0) + init_std = np_init_rgb_pixels_masked.std(axis=0) + gen_means = np_gen_result_masked.mean(axis=0) + gen_std = np_gen_result_masked.std(axis=0) + # Color correct - np_matched_result = match_histograms(np_gen_result, init_rgb_pixels, channel_axis=-1) + np_matched_result = np_gen_result.copy() + np_matched_result[:,:,:] = (((np_matched_result[:,:,:].astype(np.float32) - gen_means[None,None,:]) / gen_std[None,None,:]) * init_std[None,None,:] + init_means[None,None,:]).clip(0, 255).astype(np.uint8) matched_result = Image.fromarray(np_matched_result, mode='RGB') # Blur the mask out (into init image) by specified amount @@ -151,4 +159,3 @@ class Inpaint(Img2Img): # Paste original on color-corrected generation (using blurred mask) matched_result.paste(pil_image, (0,0), mask = blurred_init_mask) return matched_result -