diff --git a/environment-AMD.yml b/environment-AMD.yml index 8b5fb07e55..5a0e46998d 100644 --- a/environment-AMD.yml +++ b/environment-AMD.yml @@ -42,4 +42,5 @@ dependencies: - git+https://github.com/invoke-ai/Real-ESRGAN.git#egg=realesrgan - git+https://github.com/invoke-ai/GFPGAN.git#egg=gfpgan - git+https://github.com/invoke-ai/clipseg.git@relaxed-python-requirement#egg=clipseg + - git+https://github.com/invoke-ai/PyPatchMatch@0.1.1#egg=pypatchmatch - -e . diff --git a/ldm/generate.py b/ldm/generate.py index 12127e69ea..bbc2cc5078 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -263,6 +263,8 @@ class Generate: ), 'call to img2img() must include the init_img argument' return self.prompt2png(prompt, outdir, **kwargs) + from ldm.invoke.generator.inpaint import infill_methods + def prompt2image( self, # these are common @@ -323,8 +325,10 @@ class Generate: seam_strength: float = 0.7, seam_steps: int = 10, tile_size: int = 32, + infill_method = infill_methods[0], # The infill method to use force_outpaint: bool = False, enable_image_debugging = False, + **args, ): # eat up additional cruft """ @@ -505,6 +509,7 @@ class Generate: seam_strength = seam_strength, seam_steps = seam_steps, tile_size = tile_size, + infill_method = infill_method, force_outpaint = force_outpaint, inpaint_height = inpaint_height, inpaint_width = inpaint_width, diff --git a/ldm/invoke/generator/inpaint.py b/ldm/invoke/generator/inpaint.py index e6c8dc6517..1443dedc09 100644 --- a/ldm/invoke/generator/inpaint.py +++ b/ldm/invoke/generator/inpaint.py @@ -17,6 +17,16 @@ from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.ksampler import KSampler from ldm.invoke.generator.base import downsampling from ldm.util import debug_image +from patchmatch import patch_match + + +infill_methods: list[str] = list() + +if patch_match.patchmatch_available: + infill_methods.append('patchmatch') + +infill_methods.append('tile') + class Inpaint(Img2Img): def __init__(self, model, precision): @@ -43,18 +53,24 @@ class Inpaint(Img2Img): writeable=False ) + def infill_patchmatch(self, im: Image.Image) -> Image: + if im.mode != 'RGBA': + return im + + # Skip patchmatch if patchmatch isn't available + if not patch_match.patchmatch_available: + return im + + # Patchmatch (note, we may want to expose patch_size? Increasing it significantly impacts performance though) + im_patched_np = patch_match.inpaint(im.convert('RGB'), ImageOps.invert(im.split()[-1]), patch_size = 3) + im_patched = Image.fromarray(im_patched_np, mode = 'RGB') + return im_patched + def tile_fill_missing(self, im: Image.Image, tile_size: int = 16, seed: int = None) -> Image: # Only fill if there's an alpha layer if im.mode != 'RGBA': return im - # # HACK PATCH MATCH - # from src.PyPatchMatch import patch_match - # im_patched_np = patch_match.inpaint(im.convert('RGB'), ImageOps.invert(im.split()[-1]), patch_size = 3) - # im_patched = Image.fromarray(im_patched_np, mode = 'RGB') - # return im_patched - # # /HACK - a = np.asarray(im, dtype=np.uint8) tile_size = (tile_size, tile_size) @@ -161,6 +177,7 @@ class Inpaint(Img2Img): tile_size: int = 32, step_callback=None, inpaint_replace=False, enable_image_debugging=False, + infill_method = infill_methods[0], # The infill method to use **kwargs): """ Returns a function returning an image derived from the prompt and @@ -173,13 +190,15 @@ class Inpaint(Img2Img): if isinstance(init_image, PIL.Image.Image): self.pil_image = init_image.copy() - # Fill missing areas of original image - init_filled = self.tile_fill_missing( - self.pil_image.copy(), - seed = self.seed if (self.seed is not None - and self.seed >= 0) else self.new_seed(), - tile_size = tile_size - ) + # Do infill + if infill_method == 'patchmatch' and patch_match.patchmatch_available: + init_filled = self.infill_patchmatch(self.pil_image.copy()) + else: # if infill_method == 'tile': # Only two methods right now, so always use 'tile' if not patchmatch + init_filled = self.tile_fill_missing( + self.pil_image.copy(), + seed = self.seed, + tile_size = tile_size + ) init_filled.paste(init_image, (0,0), init_image.split()[-1]) debug_image(init_filled, "init_filled", debug_status=self.enable_image_debugging)