Improve inpainting by color-correcting result and pasting init image over result using mask

This commit is contained in:
Kyle Schouviller
2022-10-22 14:56:33 -07:00
committed by Lincoln Stein
parent ce6d618e3b
commit 493eaa7389
2 changed files with 55 additions and 10 deletions

View File

@ -271,6 +271,8 @@ class Generate:
upscale = None,
# this is specific to inpainting and causes more extreme inpainting
inpaint_replace = 0.0,
# This will help match inpainted areas to the original image more smoothly
mask_blur_radius: int = 8,
# Set this True to handle KeyboardInterrupt internally
catch_interrupts = False,
hires_fix = False,
@ -391,7 +393,7 @@ class Generate:
log_tokens =self.log_tokenization
)
init_image,mask_image = self._make_images(
init_image,mask_image,pil_image,pil_mask = self._make_images(
init_img,
init_mask,
width,
@ -431,6 +433,8 @@ class Generate:
height=height,
init_img=init_img, # embiggen needs to manipulate from the unmodified init_img
init_image=init_image, # notice that init_image is different from init_img
pil_image=pil_image,
pil_mask=pil_mask,
mask_image=mask_image,
strength=strength,
threshold=threshold,
@ -438,6 +442,7 @@ class Generate:
embiggen=embiggen,
embiggen_tiles=embiggen_tiles,
inpaint_replace=inpaint_replace,
mask_blur_radius=mask_blur_radius
)
if init_color:
@ -621,7 +626,7 @@ class Generate:
init_image = None
init_mask = None
if not img:
return None, None
return None, None, None, None
image = self._load_img(img)
@ -647,7 +652,7 @@ class Generate:
elif text_mask:
init_mask = self._txt2mask(image, text_mask, width, height, fit=fit)
return init_image, init_mask
return init_image, init_mask, image, mask_image
def _make_base(self):
if not self.generators.get('base'):
@ -895,8 +900,9 @@ class Generate:
# The mask is expected to have the region to be inpainted
# with alpha transparency. It converts it into a black/white
# image with the transparent part black.
def _image_to_mask(self, mask_image, invert=False) -> Image:
if mask_image.mode in ('L','RGB'):
def _image_to_mask(self, mask_image: Image.Image, invert=False) -> Image:
# Obtain the mask from the transparency channel
if mask_image.mode == 'L':
mask = mask_image
else:
# Obtain the mask from the transparency channel