InvokeAI/ldm/invoke/generator/inpaint.py
2022-11-30 14:54:19 -08:00

361 lines
13 KiB
Python

'''
ldm.invoke.generator.inpaint descends from ldm.invoke.generator
'''
import math
import PIL
import cv2 as cv
import numpy as np
import torch
from PIL import Image, ImageFilter, ImageOps, ImageChops
from einops import repeat
from ldm.invoke.devices import choose_autocast
from ldm.invoke.generator.base import downsampling
from ldm.invoke.generator.img2img import Img2Img
from ldm.invoke.globals import Globals
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.ksampler import KSampler
from ldm.util import debug_image
infill_methods: list[str] = list()
if Globals.try_patchmatch:
from patchmatch import patch_match
if patch_match.patchmatch_available:
print('>> Patchmatch initialized')
infill_methods.append('patchmatch')
else:
print('>> Patchmatch not loaded, please see https://github.com/invoke-ai/InvokeAI/blob/patchmatch-install-docs/docs/installation/INSTALL_PATCHMATCH.md')
else:
print('>> Patchmatch loading disabled')
infill_methods.append('tile')
class Inpaint(Img2Img):
def __init__(self, model, precision):
self.init_latent = None
self.pil_image = None
self.pil_mask = None
self.mask_blur_radius = 0
super().__init__(model, precision)
# Outpaint support code
def get_tile_images(self, image: np.ndarray, width=8, height=8):
_nrows, _ncols, depth = image.shape
_strides = image.strides
nrows, _m = divmod(_nrows, height)
ncols, _n = divmod(_ncols, width)
if _m != 0 or _n != 0:
return None
return np.lib.stride_tricks.as_strided(
np.ravel(image),
shape=(nrows, ncols, height, width, depth),
strides=(height * _strides[0], width * _strides[1], *_strides),
writeable=False
)
def infill_patchmatch(self, im: Image.Image) -> Image:
if im.mode != 'RGBA':
return im
# Skip patchmatch if patchmatch isn't available
if not patch_match.patchmatch_available:
return im
# Patchmatch (note, we may want to expose patch_size? Increasing it significantly impacts performance though)
im_patched_np = patch_match.inpaint(im.convert('RGB'), ImageOps.invert(im.split()[-1]), patch_size = 3)
im_patched = Image.fromarray(im_patched_np, mode = 'RGB')
return im_patched
def tile_fill_missing(self, im: Image.Image, tile_size: int = 16, seed: int = None) -> Image:
# Only fill if there's an alpha layer
if im.mode != 'RGBA':
return im
a = np.asarray(im, dtype=np.uint8)
tile_size = (tile_size, tile_size)
# Get the image as tiles of a specified size
tiles = self.get_tile_images(a,*tile_size).copy()
# Get the mask as tiles
tiles_mask = tiles[:,:,:,:,3]
# Find any mask tiles with any fully transparent pixels (we will be replacing these later)
tmask_shape = tiles_mask.shape
tiles_mask = tiles_mask.reshape(math.prod(tiles_mask.shape))
n,ny = (math.prod(tmask_shape[0:2])), math.prod(tmask_shape[2:])
tiles_mask = (tiles_mask > 0)
tiles_mask = tiles_mask.reshape((n,ny)).all(axis = 1)
# Get RGB tiles in single array and filter by the mask
tshape = tiles.shape
tiles_all = tiles.reshape((math.prod(tiles.shape[0:2]), * tiles.shape[2:]))
filtered_tiles = tiles_all[tiles_mask]
if len(filtered_tiles) == 0:
return im
# Find all invalid tiles and replace with a random valid tile
replace_count = (tiles_mask == False).sum()
rng = np.random.default_rng(seed = seed)
tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[rng.choice(filtered_tiles.shape[0], replace_count),:,:,:]
# Convert back to an image
tiles_all = tiles_all.reshape(tshape)
tiles_all = tiles_all.swapaxes(1,2)
st = tiles_all.reshape((math.prod(tiles_all.shape[0:2]), math.prod(tiles_all.shape[2:4]), tiles_all.shape[4]))
si = Image.fromarray(st, mode='RGBA')
return si
def mask_edge(self, mask: Image, edge_size: int, edge_blur: int) -> Image:
npimg = np.asarray(mask, dtype=np.uint8)
# Detect any partially transparent regions
npgradient = np.uint8(255 * (1.0 - np.floor(np.abs(0.5 - np.float32(npimg) / 255.0) * 2.0)))
# Detect hard edges
npedge = cv.Canny(npimg, threshold1=100, threshold2=200)
# Combine
npmask = npgradient + npedge
# Expand
npmask = cv.dilate(npmask, np.ones((3,3), np.uint8), iterations = int(edge_size / 2))
new_mask = Image.fromarray(npmask)
if edge_blur > 0:
new_mask = new_mask.filter(ImageFilter.BoxBlur(edge_blur))
return ImageOps.invert(new_mask)
def seam_paint(self,
im: Image.Image,
seam_size: int,
seam_blur: int,
prompt,sampler,steps,cfg_scale,ddim_eta,
conditioning,strength,
noise,
step_callback
) -> Image.Image:
hard_mask = self.pil_image.split()[-1].copy()
mask = self.mask_edge(hard_mask, seam_size, seam_blur)
make_image = self.get_make_image(
prompt,
sampler,
steps,
cfg_scale,
ddim_eta,
conditioning,
init_image = im.copy().convert('RGBA'),
mask_image = mask.convert('RGB'), # Code currently requires an RGB mask
strength = strength,
mask_blur_radius = 0,
seam_size = 0,
step_callback = step_callback,
inpaint_width = im.width,
inpaint_height = im.height
)
seam_noise = self.get_noise(im.width, im.height)
result = make_image(seam_noise)
return result
@torch.no_grad()
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
conditioning,init_image,mask_image,strength,
mask_blur_radius: int = 8,
# Seam settings - when 0, doesn't fill seam
seam_size: int = 0,
seam_blur: int = 0,
seam_strength: float = 0.7,
seam_steps: int = 10,
tile_size: int = 32,
step_callback=None,
inpaint_replace=False, enable_image_debugging=False,
infill_method = infill_methods[0], # The infill method to use
inpaint_width=None,
inpaint_height=None,
**kwargs):
"""
Returns a function returning an image derived from the prompt and
the initial image + mask. Return value depends on the seed at
the time you call it. kwargs are 'init_latent' and 'strength'
"""
self.enable_image_debugging = enable_image_debugging
self.inpaint_width = inpaint_width
self.inpaint_height = inpaint_height
if isinstance(init_image, PIL.Image.Image):
self.pil_image = init_image.copy()
# Do infill
if infill_method == 'patchmatch' and patch_match.patchmatch_available:
init_filled = self.infill_patchmatch(self.pil_image.copy())
else: # if infill_method == 'tile': # Only two methods right now, so always use 'tile' if not patchmatch
init_filled = self.tile_fill_missing(
self.pil_image.copy(),
seed = self.seed,
tile_size = tile_size
)
init_filled.paste(init_image, (0,0), init_image.split()[-1])
# Resize if requested for inpainting
if inpaint_width and inpaint_height:
init_filled = init_filled.resize((inpaint_width, inpaint_height))
debug_image(init_filled, "init_filled", debug_status=self.enable_image_debugging)
# Create init tensor
init_image = self._image_to_tensor(init_filled.convert('RGB'))
if isinstance(mask_image, PIL.Image.Image):
self.pil_mask = mask_image.copy()
debug_image(mask_image, "mask_image BEFORE multiply with pil_image", debug_status=self.enable_image_debugging)
mask_image = ImageChops.multiply(mask_image, self.pil_image.split()[-1].convert('RGB'))
self.pil_mask = mask_image
# Resize if requested for inpainting
if inpaint_width and inpaint_height:
mask_image = mask_image.resize((inpaint_width, inpaint_height))
debug_image(mask_image, "mask_image AFTER multiply with pil_image", debug_status=self.enable_image_debugging)
mask_image = mask_image.resize(
(
mask_image.width // downsampling,
mask_image.height // downsampling
),
resample=Image.Resampling.NEAREST
)
mask_image = self._image_to_tensor(mask_image,normalize=False)
self.mask_blur_radius = mask_blur_radius
# klms samplers not supported yet, so ignore previous sampler
if isinstance(sampler,KSampler):
print(
">> Using recommended DDIM sampler for inpainting."
)
sampler = DDIMSampler(self.model, device=self.model.device)
sampler.make_schedule(
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
)
mask_image = mask_image[0][0].unsqueeze(0).repeat(4,1,1).unsqueeze(0)
mask_image = repeat(mask_image, '1 ... -> b ...', b=1)
scope = choose_autocast(self.precision)
with scope(self.model.device.type):
self.init_latent = self.model.get_first_stage_encoding(
self.model.encode_first_stage(init_image)
) # move to latent space
t_enc = int(strength * steps)
# todo: support cross-attention control
uc, c, _ = conditioning
print(f">> target t_enc is {t_enc} steps")
@torch.no_grad()
def make_image(x_T):
# encode (scaled latent)
z_enc = sampler.stochastic_encode(
self.init_latent,
torch.tensor([t_enc]).to(self.model.device),
noise=x_T
)
# to replace masked area with latent noise, weighted by inpaint_replace strength
if inpaint_replace > 0.0:
print(f'>> inpaint will replace what was under the mask with a strength of {inpaint_replace}')
l_noise = self.get_noise(kwargs['width'],kwargs['height'])
inverted_mask = 1.0-mask_image # there will be 1s where the mask is
masked_region = (1.0-inpaint_replace) * inverted_mask * z_enc + inpaint_replace * inverted_mask * l_noise
z_enc = z_enc * mask_image + masked_region
# decode it
samples = sampler.decode(
z_enc,
c,
t_enc,
img_callback = step_callback,
unconditional_guidance_scale = cfg_scale,
unconditional_conditioning = uc,
mask = mask_image,
init_latent = self.init_latent
)
result = self.sample_to_image(samples)
# Seam paint if this is our first pass (seam_size set to 0 during seam painting)
if seam_size > 0:
old_image = self.pil_image or init_image
old_mask = self.pil_mask or mask_image
result = self.seam_paint(
result,
seam_size,
seam_blur,
prompt,
sampler,
seam_steps,
cfg_scale,
ddim_eta,
conditioning,
seam_strength,
x_T,
step_callback)
# Restore original settings
self.get_make_image(prompt,sampler,steps,cfg_scale,ddim_eta,
conditioning,
old_image,
old_mask,
strength,
mask_blur_radius, seam_size, seam_blur, seam_strength,
seam_steps, tile_size, step_callback,
inpaint_replace, enable_image_debugging,
inpaint_width = inpaint_width,
inpaint_height = inpaint_height,
infill_method = infill_method,
**kwargs)
return result
return make_image
def sample_to_image(self, samples)->Image.Image:
gen_result = super().sample_to_image(samples).convert('RGB')
debug_image(gen_result, "gen_result", debug_status=self.enable_image_debugging)
# Resize if necessary
if self.inpaint_width and self.inpaint_height:
gen_result = gen_result.resize(self.pil_image.size)
if self.pil_image is None or self.pil_mask is None:
return gen_result
corrected_result = super().repaste_and_color_correct(gen_result, self.pil_image, self.pil_mask, self.mask_blur_radius)
debug_image(corrected_result, "corrected_result", debug_status=self.enable_image_debugging)
return corrected_result