mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Improve inpainting by color-correcting result and pasting init image over result using mask
This commit is contained in:
parent
ce6d618e3b
commit
493eaa7389
@ -271,6 +271,8 @@ class Generate:
|
|||||||
upscale = None,
|
upscale = None,
|
||||||
# this is specific to inpainting and causes more extreme inpainting
|
# this is specific to inpainting and causes more extreme inpainting
|
||||||
inpaint_replace = 0.0,
|
inpaint_replace = 0.0,
|
||||||
|
# This will help match inpainted areas to the original image more smoothly
|
||||||
|
mask_blur_radius: int = 8,
|
||||||
# Set this True to handle KeyboardInterrupt internally
|
# Set this True to handle KeyboardInterrupt internally
|
||||||
catch_interrupts = False,
|
catch_interrupts = False,
|
||||||
hires_fix = False,
|
hires_fix = False,
|
||||||
@ -391,7 +393,7 @@ class Generate:
|
|||||||
log_tokens =self.log_tokenization
|
log_tokens =self.log_tokenization
|
||||||
)
|
)
|
||||||
|
|
||||||
init_image,mask_image = self._make_images(
|
init_image,mask_image,pil_image,pil_mask = self._make_images(
|
||||||
init_img,
|
init_img,
|
||||||
init_mask,
|
init_mask,
|
||||||
width,
|
width,
|
||||||
@ -431,6 +433,8 @@ class Generate:
|
|||||||
height=height,
|
height=height,
|
||||||
init_img=init_img, # embiggen needs to manipulate from the unmodified init_img
|
init_img=init_img, # embiggen needs to manipulate from the unmodified init_img
|
||||||
init_image=init_image, # notice that init_image is different from init_img
|
init_image=init_image, # notice that init_image is different from init_img
|
||||||
|
pil_image=pil_image,
|
||||||
|
pil_mask=pil_mask,
|
||||||
mask_image=mask_image,
|
mask_image=mask_image,
|
||||||
strength=strength,
|
strength=strength,
|
||||||
threshold=threshold,
|
threshold=threshold,
|
||||||
@ -438,6 +442,7 @@ class Generate:
|
|||||||
embiggen=embiggen,
|
embiggen=embiggen,
|
||||||
embiggen_tiles=embiggen_tiles,
|
embiggen_tiles=embiggen_tiles,
|
||||||
inpaint_replace=inpaint_replace,
|
inpaint_replace=inpaint_replace,
|
||||||
|
mask_blur_radius=mask_blur_radius
|
||||||
)
|
)
|
||||||
|
|
||||||
if init_color:
|
if init_color:
|
||||||
@ -621,7 +626,7 @@ class Generate:
|
|||||||
init_image = None
|
init_image = None
|
||||||
init_mask = None
|
init_mask = None
|
||||||
if not img:
|
if not img:
|
||||||
return None, None
|
return None, None, None, None
|
||||||
|
|
||||||
image = self._load_img(img)
|
image = self._load_img(img)
|
||||||
|
|
||||||
@ -647,7 +652,7 @@ class Generate:
|
|||||||
elif text_mask:
|
elif text_mask:
|
||||||
init_mask = self._txt2mask(image, text_mask, width, height, fit=fit)
|
init_mask = self._txt2mask(image, text_mask, width, height, fit=fit)
|
||||||
|
|
||||||
return init_image, init_mask
|
return init_image, init_mask, image, mask_image
|
||||||
|
|
||||||
def _make_base(self):
|
def _make_base(self):
|
||||||
if not self.generators.get('base'):
|
if not self.generators.get('base'):
|
||||||
@ -895,8 +900,9 @@ class Generate:
|
|||||||
# The mask is expected to have the region to be inpainted
|
# The mask is expected to have the region to be inpainted
|
||||||
# with alpha transparency. It converts it into a black/white
|
# with alpha transparency. It converts it into a black/white
|
||||||
# image with the transparent part black.
|
# image with the transparent part black.
|
||||||
def _image_to_mask(self, mask_image, invert=False) -> Image:
|
def _image_to_mask(self, mask_image: Image.Image, invert=False) -> Image:
|
||||||
if mask_image.mode in ('L','RGB'):
|
# Obtain the mask from the transparency channel
|
||||||
|
if mask_image.mode == 'L':
|
||||||
mask = mask_image
|
mask = mask_image
|
||||||
else:
|
else:
|
||||||
# Obtain the mask from the transparency channel
|
# Obtain the mask from the transparency channel
|
||||||
|
@ -3,7 +3,11 @@ ldm.invoke.generator.inpaint descends from ldm.invoke.generator
|
|||||||
'''
|
'''
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torchvision.transforms as T
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import cv2 as cv
|
||||||
|
from PIL import Image, ImageFilter
|
||||||
|
from skimage.exposure.histogram_matching import match_histograms
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from ldm.invoke.devices import choose_autocast
|
from ldm.invoke.devices import choose_autocast
|
||||||
from ldm.invoke.generator.img2img import Img2Img
|
from ldm.invoke.generator.img2img import Img2Img
|
||||||
@ -18,12 +22,27 @@ class Inpaint(Img2Img):
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||||
conditioning,init_image,mask_image,strength,
|
conditioning,init_image,mask_image,strength,
|
||||||
step_callback=None,inpaint_replace=False,**kwargs):
|
pil_image: Image.Image, pil_mask: Image.Image,
|
||||||
|
mask_blur_radius: int = 8,
|
||||||
|
step_callback=None,inpaint_replace=False, **kwargs):
|
||||||
"""
|
"""
|
||||||
Returns a function returning an image derived from the prompt and
|
Returns a function returning an image derived from the prompt and
|
||||||
the initial image + mask. Return value depends on the seed at
|
the initial image + mask. Return value depends on the seed at
|
||||||
the time you call it. kwargs are 'init_latent' and 'strength'
|
the time you call it. kwargs are 'init_latent' and 'strength'
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Get the alpha channel of the mask
|
||||||
|
pil_init_mask = pil_mask.getchannel('A')
|
||||||
|
pil_init_image = pil_image.convert('RGBA') # Add an alpha channel if one doesn't exist
|
||||||
|
|
||||||
|
# Build an image with only visible pixels from source to use as reference for color-matching.
|
||||||
|
# Note that this doesn't use the mask, which would exclude some source image pixels from the
|
||||||
|
# histogram and cause slight color changes.
|
||||||
|
init_rgb_pixels = np.asarray(pil_image.convert('RGB'), dtype=np.uint8).reshape(pil_image.width * pil_image.height, 3)
|
||||||
|
init_a_pixels = np.asarray(pil_init_image.getchannel('A'), dtype=np.uint8).reshape(pil_init_mask.width * pil_init_mask.height)
|
||||||
|
init_rgb_pixels = init_rgb_pixels[init_a_pixels > 0]
|
||||||
|
init_rgb_pixels = init_rgb_pixels.reshape(1, init_rgb_pixels.shape[0], init_rgb_pixels.shape[1]) # Filter to just pixels that have any alpha, this is now our histogram
|
||||||
|
|
||||||
# klms samplers not supported yet, so ignore previous sampler
|
# klms samplers not supported yet, so ignore previous sampler
|
||||||
if isinstance(sampler,KSampler):
|
if isinstance(sampler,KSampler):
|
||||||
print(
|
print(
|
||||||
@ -78,9 +97,29 @@ class Inpaint(Img2Img):
|
|||||||
init_latent = self.init_latent
|
init_latent = self.init_latent
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.sample_to_image(samples)
|
# Get PIL result
|
||||||
|
gen_result = self.sample_to_image(samples).convert('RGB')
|
||||||
|
|
||||||
|
# Get numpy version
|
||||||
|
np_gen_result = np.asarray(gen_result, dtype=np.uint8)
|
||||||
|
|
||||||
|
# Color correct
|
||||||
|
np_matched_result = match_histograms(np_gen_result, init_rgb_pixels, channel_axis=-1)
|
||||||
|
matched_result = Image.fromarray(np_matched_result, mode='RGB')
|
||||||
|
|
||||||
|
|
||||||
|
# Blur the mask out (into init image) by specified amount
|
||||||
|
if mask_blur_radius > 0:
|
||||||
|
nm = np.asarray(pil_init_mask, dtype=np.uint8)
|
||||||
|
nmd = cv.dilate(nm, kernel=np.ones((3,3), dtype=np.uint8), iterations=int(mask_blur_radius / 2))
|
||||||
|
pmd = Image.fromarray(nmd, mode='L')
|
||||||
|
blurred_init_mask = pmd.filter(ImageFilter.BoxBlur(mask_blur_radius))
|
||||||
|
else:
|
||||||
|
blurred_init_mask = pil_init_mask
|
||||||
|
|
||||||
|
# Paste original on color-corrected generation (using blurred mask)
|
||||||
|
matched_result.paste(pil_image, (0,0), mask = blurred_init_mask)
|
||||||
|
|
||||||
|
return matched_result
|
||||||
|
|
||||||
return make_image
|
return make_image
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user