''' ldm.invoke.generator.inpaint descends from ldm.invoke.generator ''' import math import PIL import cv2 as cv import numpy as np import torch from PIL import Image, ImageFilter, ImageOps, ImageChops from einops import repeat from ldm.invoke.devices import choose_autocast from ldm.invoke.generator.base import downsampling from ldm.invoke.generator.img2img import Img2Img from ldm.invoke.globals import Globals from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.ksampler import KSampler from ldm.util import debug_image infill_methods: list[str] = list() if Globals.try_patchmatch: from patchmatch import patch_match if patch_match.patchmatch_available: print('>> Patchmatch initialized') infill_methods.append('patchmatch') else: print('>> Patchmatch not loaded, please see https://github.com/invoke-ai/InvokeAI/blob/patchmatch-install-docs/docs/installation/INSTALL_PATCHMATCH.md') else: print('>> Patchmatch loading disabled') infill_methods.append('tile') class Inpaint(Img2Img): def __init__(self, model, precision): self.init_latent = None self.pil_image = None self.pil_mask = None self.mask_blur_radius = 0 super().__init__(model, precision) # Outpaint support code def get_tile_images(self, image: np.ndarray, width=8, height=8): _nrows, _ncols, depth = image.shape _strides = image.strides nrows, _m = divmod(_nrows, height) ncols, _n = divmod(_ncols, width) if _m != 0 or _n != 0: return None return np.lib.stride_tricks.as_strided( np.ravel(image), shape=(nrows, ncols, height, width, depth), strides=(height * _strides[0], width * _strides[1], *_strides), writeable=False ) def infill_patchmatch(self, im: Image.Image) -> Image: if im.mode != 'RGBA': return im # Skip patchmatch if patchmatch isn't available if not patch_match.patchmatch_available: return im # Patchmatch (note, we may want to expose patch_size? Increasing it significantly impacts performance though) im_patched_np = patch_match.inpaint(im.convert('RGB'), ImageOps.invert(im.split()[-1]), patch_size = 3) im_patched = Image.fromarray(im_patched_np, mode = 'RGB') return im_patched def tile_fill_missing(self, im: Image.Image, tile_size: int = 16, seed: int = None) -> Image: # Only fill if there's an alpha layer if im.mode != 'RGBA': return im a = np.asarray(im, dtype=np.uint8) tile_size = (tile_size, tile_size) # Get the image as tiles of a specified size tiles = self.get_tile_images(a,*tile_size).copy() # Get the mask as tiles tiles_mask = tiles[:,:,:,:,3] # Find any mask tiles with any fully transparent pixels (we will be replacing these later) tmask_shape = tiles_mask.shape tiles_mask = tiles_mask.reshape(math.prod(tiles_mask.shape)) n,ny = (math.prod(tmask_shape[0:2])), math.prod(tmask_shape[2:]) tiles_mask = (tiles_mask > 0) tiles_mask = tiles_mask.reshape((n,ny)).all(axis = 1) # Get RGB tiles in single array and filter by the mask tshape = tiles.shape tiles_all = tiles.reshape((math.prod(tiles.shape[0:2]), * tiles.shape[2:])) filtered_tiles = tiles_all[tiles_mask] if len(filtered_tiles) == 0: return im # Find all invalid tiles and replace with a random valid tile replace_count = (tiles_mask == False).sum() rng = np.random.default_rng(seed = seed) tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[rng.choice(filtered_tiles.shape[0], replace_count),:,:,:] # Convert back to an image tiles_all = tiles_all.reshape(tshape) tiles_all = tiles_all.swapaxes(1,2) st = tiles_all.reshape((math.prod(tiles_all.shape[0:2]), math.prod(tiles_all.shape[2:4]), tiles_all.shape[4])) si = Image.fromarray(st, mode='RGBA') return si def mask_edge(self, mask: Image, edge_size: int, edge_blur: int) -> Image: npimg = np.asarray(mask, dtype=np.uint8) # Detect any partially transparent regions npgradient = np.uint8(255 * (1.0 - np.floor(np.abs(0.5 - np.float32(npimg) / 255.0) * 2.0))) # Detect hard edges npedge = cv.Canny(npimg, threshold1=100, threshold2=200) # Combine npmask = npgradient + npedge # Expand npmask = cv.dilate(npmask, np.ones((3,3), np.uint8), iterations = int(edge_size / 2)) new_mask = Image.fromarray(npmask) if edge_blur > 0: new_mask = new_mask.filter(ImageFilter.BoxBlur(edge_blur)) return ImageOps.invert(new_mask) def seam_paint(self, im: Image.Image, seam_size: int, seam_blur: int, prompt,sampler,steps,cfg_scale,ddim_eta, conditioning,strength, noise, step_callback ) -> Image.Image: hard_mask = self.pil_image.split()[-1].copy() mask = self.mask_edge(hard_mask, seam_size, seam_blur) make_image = self.get_make_image( prompt, sampler, steps, cfg_scale, ddim_eta, conditioning, init_image = im.copy().convert('RGBA'), mask_image = mask.convert('RGB'), # Code currently requires an RGB mask strength = strength, mask_blur_radius = 0, seam_size = 0, step_callback = step_callback, inpaint_width = im.width, inpaint_height = im.height ) seam_noise = self.get_noise(im.width, im.height) result = make_image(seam_noise) return result @torch.no_grad() def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, conditioning,init_image,mask_image,strength, mask_blur_radius: int = 8, # Seam settings - when 0, doesn't fill seam seam_size: int = 0, seam_blur: int = 0, seam_strength: float = 0.7, seam_steps: int = 10, tile_size: int = 32, step_callback=None, inpaint_replace=False, enable_image_debugging=False, infill_method = infill_methods[0], # The infill method to use inpaint_width=None, inpaint_height=None, **kwargs): """ Returns a function returning an image derived from the prompt and the initial image + mask. Return value depends on the seed at the time you call it. kwargs are 'init_latent' and 'strength' """ self.enable_image_debugging = enable_image_debugging self.inpaint_width = inpaint_width self.inpaint_height = inpaint_height if isinstance(init_image, PIL.Image.Image): self.pil_image = init_image.copy() # Do infill if infill_method == 'patchmatch' and patch_match.patchmatch_available: init_filled = self.infill_patchmatch(self.pil_image.copy()) else: # if infill_method == 'tile': # Only two methods right now, so always use 'tile' if not patchmatch init_filled = self.tile_fill_missing( self.pil_image.copy(), seed = self.seed, tile_size = tile_size ) init_filled.paste(init_image, (0,0), init_image.split()[-1]) # Resize if requested for inpainting if inpaint_width and inpaint_height: init_filled = init_filled.resize((inpaint_width, inpaint_height)) debug_image(init_filled, "init_filled", debug_status=self.enable_image_debugging) # Create init tensor init_image = self._image_to_tensor(init_filled.convert('RGB')) if isinstance(mask_image, PIL.Image.Image): self.pil_mask = mask_image.copy() debug_image(mask_image, "mask_image BEFORE multiply with pil_image", debug_status=self.enable_image_debugging) mask_image = ImageChops.multiply(mask_image, self.pil_image.split()[-1].convert('RGB')) self.pil_mask = mask_image # Resize if requested for inpainting if inpaint_width and inpaint_height: mask_image = mask_image.resize((inpaint_width, inpaint_height)) debug_image(mask_image, "mask_image AFTER multiply with pil_image", debug_status=self.enable_image_debugging) mask_image = mask_image.resize( ( mask_image.width // downsampling, mask_image.height // downsampling ), resample=Image.Resampling.NEAREST ) mask_image = self._image_to_tensor(mask_image,normalize=False) self.mask_blur_radius = mask_blur_radius # klms samplers not supported yet, so ignore previous sampler if isinstance(sampler,KSampler): print( ">> Using recommended DDIM sampler for inpainting." ) sampler = DDIMSampler(self.model, device=self.model.device) sampler.make_schedule( ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False ) mask_image = mask_image[0][0].unsqueeze(0).repeat(4,1,1).unsqueeze(0) mask_image = repeat(mask_image, '1 ... -> b ...', b=1) scope = choose_autocast(self.precision) with scope(self.model.device.type): self.init_latent = self.model.get_first_stage_encoding( self.model.encode_first_stage(init_image) ) # move to latent space t_enc = int(strength * steps) # todo: support cross-attention control uc, c, _ = conditioning print(f">> target t_enc is {t_enc} steps") @torch.no_grad() def make_image(x_T): # encode (scaled latent) z_enc = sampler.stochastic_encode( self.init_latent, torch.tensor([t_enc]).to(self.model.device), noise=x_T ) # to replace masked area with latent noise, weighted by inpaint_replace strength if inpaint_replace > 0.0: print(f'>> inpaint will replace what was under the mask with a strength of {inpaint_replace}') l_noise = self.get_noise(kwargs['width'],kwargs['height']) inverted_mask = 1.0-mask_image # there will be 1s where the mask is masked_region = (1.0-inpaint_replace) * inverted_mask * z_enc + inpaint_replace * inverted_mask * l_noise z_enc = z_enc * mask_image + masked_region # decode it samples = sampler.decode( z_enc, c, t_enc, img_callback = step_callback, unconditional_guidance_scale = cfg_scale, unconditional_conditioning = uc, mask = mask_image, init_latent = self.init_latent ) result = self.sample_to_image(samples) # Seam paint if this is our first pass (seam_size set to 0 during seam painting) if seam_size > 0: old_image = self.pil_image or init_image old_mask = self.pil_mask or mask_image result = self.seam_paint( result, seam_size, seam_blur, prompt, sampler, seam_steps, cfg_scale, ddim_eta, conditioning, seam_strength, x_T, step_callback) # Restore original settings self.get_make_image(prompt,sampler,steps,cfg_scale,ddim_eta, conditioning, old_image, old_mask, strength, mask_blur_radius, seam_size, seam_blur, seam_strength, seam_steps, tile_size, step_callback, inpaint_replace, enable_image_debugging, inpaint_width = inpaint_width, inpaint_height = inpaint_height, infill_method = infill_method, **kwargs) return result return make_image def sample_to_image(self, samples)->Image.Image: gen_result = super().sample_to_image(samples).convert('RGB') debug_image(gen_result, "gen_result", debug_status=self.enable_image_debugging) # Resize if necessary if self.inpaint_width and self.inpaint_height: gen_result = gen_result.resize(self.pil_image.size) if self.pil_image is None or self.pil_mask is None: return gen_result corrected_result = super().repaste_and_color_correct(gen_result, self.pil_image, self.pil_mask, self.mask_blur_radius) debug_image(corrected_result, "corrected_result", debug_status=self.enable_image_debugging) return corrected_result