2023-03-03 06:02:00 +00:00
|
|
|
"""
|
|
|
|
invokeai.backend.generator.inpaint descends from .generator
|
|
|
|
"""
|
2023-02-28 05:37:13 +00:00
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
import math
|
2023-07-03 16:17:45 +00:00
|
|
|
from typing import Tuple, Union, Optional
|
2023-02-28 05:37:13 +00:00
|
|
|
|
|
|
|
import cv2
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
2023-03-03 06:02:00 +00:00
|
|
|
from PIL import Image, ImageChops, ImageFilter, ImageOps
|
2023-02-28 05:37:13 +00:00
|
|
|
|
2023-03-02 18:28:17 +00:00
|
|
|
from ..image_util import PatchMatch, debug_image
|
2023-03-03 06:02:00 +00:00
|
|
|
from ..stable_diffusion.diffusers_pipeline import (
|
|
|
|
ConditioningData,
|
|
|
|
StableDiffusionGeneratorPipeline,
|
|
|
|
image_resized_to_grid_as_tensor,
|
|
|
|
)
|
|
|
|
from .img2img import Img2Img
|
2023-02-28 05:37:13 +00:00
|
|
|
|
|
|
|
|
2023-03-03 06:02:00 +00:00
|
|
|
def infill_methods() -> list[str]:
|
2023-02-28 05:37:13 +00:00
|
|
|
methods = [
|
|
|
|
"tile",
|
|
|
|
"solid",
|
|
|
|
]
|
|
|
|
if PatchMatch.patchmatch_available():
|
2023-03-03 06:02:00 +00:00
|
|
|
methods.insert(0, "patchmatch")
|
2023-02-28 05:37:13 +00:00
|
|
|
return methods
|
|
|
|
|
2023-03-03 06:02:00 +00:00
|
|
|
|
2023-02-28 05:37:13 +00:00
|
|
|
class Inpaint(Img2Img):
|
|
|
|
def __init__(self, model, precision):
|
|
|
|
self.inpaint_height = 0
|
|
|
|
self.inpaint_width = 0
|
|
|
|
self.enable_image_debugging = False
|
|
|
|
self.init_latent = None
|
|
|
|
self.pil_image = None
|
|
|
|
self.pil_mask = None
|
|
|
|
self.mask_blur_radius = 0
|
|
|
|
self.infill_method = None
|
|
|
|
super().__init__(model, precision)
|
|
|
|
|
|
|
|
# Outpaint support code
|
|
|
|
def get_tile_images(self, image: np.ndarray, width=8, height=8):
|
|
|
|
_nrows, _ncols, depth = image.shape
|
|
|
|
_strides = image.strides
|
|
|
|
|
|
|
|
nrows, _m = divmod(_nrows, height)
|
|
|
|
ncols, _n = divmod(_ncols, width)
|
|
|
|
if _m != 0 or _n != 0:
|
|
|
|
return None
|
|
|
|
|
|
|
|
return np.lib.stride_tricks.as_strided(
|
|
|
|
np.ravel(image),
|
|
|
|
shape=(nrows, ncols, height, width, depth),
|
|
|
|
strides=(height * _strides[0], width * _strides[1], *_strides),
|
2023-03-03 06:02:00 +00:00
|
|
|
writeable=False,
|
2023-02-28 05:37:13 +00:00
|
|
|
)
|
|
|
|
|
2023-05-05 01:38:17 +00:00
|
|
|
def infill_patchmatch(self, im: Image.Image) -> Image.Image:
|
2023-03-03 06:02:00 +00:00
|
|
|
if im.mode != "RGBA":
|
2023-02-28 05:37:13 +00:00
|
|
|
return im
|
|
|
|
|
|
|
|
# Skip patchmatch if patchmatch isn't available
|
|
|
|
if not PatchMatch.patchmatch_available():
|
|
|
|
return im
|
|
|
|
|
|
|
|
# Patchmatch (note, we may want to expose patch_size? Increasing it significantly impacts performance though)
|
2023-03-03 06:02:00 +00:00
|
|
|
im_patched_np = PatchMatch.inpaint(im.convert("RGB"), ImageOps.invert(im.split()[-1]), patch_size=3)
|
|
|
|
im_patched = Image.fromarray(im_patched_np, mode="RGB")
|
2023-02-28 05:37:13 +00:00
|
|
|
return im_patched
|
|
|
|
|
2023-07-03 16:17:45 +00:00
|
|
|
def tile_fill_missing(self, im: Image.Image, tile_size: int = 16, seed: Optional[int] = None) -> Image.Image:
|
2023-02-28 05:37:13 +00:00
|
|
|
# Only fill if there's an alpha layer
|
2023-03-03 06:02:00 +00:00
|
|
|
if im.mode != "RGBA":
|
2023-02-28 05:37:13 +00:00
|
|
|
return im
|
|
|
|
|
|
|
|
a = np.asarray(im, dtype=np.uint8)
|
|
|
|
|
2023-05-05 01:38:17 +00:00
|
|
|
tile_size_tuple = (tile_size, tile_size)
|
2023-02-28 05:37:13 +00:00
|
|
|
|
|
|
|
# Get the image as tiles of a specified size
|
2023-05-05 01:38:17 +00:00
|
|
|
tiles = self.get_tile_images(a, *tile_size_tuple).copy()
|
2023-02-28 05:37:13 +00:00
|
|
|
|
|
|
|
# Get the mask as tiles
|
2023-03-03 06:02:00 +00:00
|
|
|
tiles_mask = tiles[:, :, :, :, 3]
|
2023-02-28 05:37:13 +00:00
|
|
|
|
|
|
|
# Find any mask tiles with any fully transparent pixels (we will be replacing these later)
|
|
|
|
tmask_shape = tiles_mask.shape
|
|
|
|
tiles_mask = tiles_mask.reshape(math.prod(tiles_mask.shape))
|
2023-03-03 06:02:00 +00:00
|
|
|
n, ny = (math.prod(tmask_shape[0:2])), math.prod(tmask_shape[2:])
|
|
|
|
tiles_mask = tiles_mask > 0
|
|
|
|
tiles_mask = tiles_mask.reshape((n, ny)).all(axis=1)
|
2023-02-28 05:37:13 +00:00
|
|
|
|
|
|
|
# Get RGB tiles in single array and filter by the mask
|
|
|
|
tshape = tiles.shape
|
2023-03-03 06:02:00 +00:00
|
|
|
tiles_all = tiles.reshape((math.prod(tiles.shape[0:2]), *tiles.shape[2:]))
|
2023-02-28 05:37:13 +00:00
|
|
|
filtered_tiles = tiles_all[tiles_mask]
|
|
|
|
|
|
|
|
if len(filtered_tiles) == 0:
|
|
|
|
return im
|
|
|
|
|
|
|
|
# Find all invalid tiles and replace with a random valid tile
|
|
|
|
replace_count = (tiles_mask == False).sum()
|
2023-03-03 06:02:00 +00:00
|
|
|
rng = np.random.default_rng(seed=seed)
|
|
|
|
tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[
|
|
|
|
rng.choice(filtered_tiles.shape[0], replace_count), :, :, :
|
|
|
|
]
|
2023-02-28 05:37:13 +00:00
|
|
|
|
|
|
|
# Convert back to an image
|
|
|
|
tiles_all = tiles_all.reshape(tshape)
|
2023-03-03 06:02:00 +00:00
|
|
|
tiles_all = tiles_all.swapaxes(1, 2)
|
|
|
|
st = tiles_all.reshape(
|
|
|
|
(
|
|
|
|
math.prod(tiles_all.shape[0:2]),
|
|
|
|
math.prod(tiles_all.shape[2:4]),
|
|
|
|
tiles_all.shape[4],
|
|
|
|
)
|
|
|
|
)
|
|
|
|
si = Image.fromarray(st, mode="RGBA")
|
2023-02-28 05:37:13 +00:00
|
|
|
|
|
|
|
return si
|
|
|
|
|
2023-05-05 01:38:17 +00:00
|
|
|
def mask_edge(self, mask: Image.Image, edge_size: int, edge_blur: int) -> Image.Image:
|
2023-02-28 05:37:13 +00:00
|
|
|
npimg = np.asarray(mask, dtype=np.uint8)
|
|
|
|
|
|
|
|
# Detect any partially transparent regions
|
2023-03-03 06:02:00 +00:00
|
|
|
npgradient = np.uint8(255 * (1.0 - np.floor(np.abs(0.5 - np.float32(npimg) / 255.0) * 2.0)))
|
2023-02-28 05:37:13 +00:00
|
|
|
|
|
|
|
# Detect hard edges
|
|
|
|
npedge = cv2.Canny(npimg, threshold1=100, threshold2=200)
|
|
|
|
|
|
|
|
# Combine
|
|
|
|
npmask = npgradient + npedge
|
|
|
|
|
|
|
|
# Expand
|
2023-03-03 06:02:00 +00:00
|
|
|
npmask = cv2.dilate(npmask, np.ones((3, 3), np.uint8), iterations=int(edge_size / 2))
|
2023-02-28 05:37:13 +00:00
|
|
|
|
|
|
|
new_mask = Image.fromarray(npmask)
|
|
|
|
|
|
|
|
if edge_blur > 0:
|
|
|
|
new_mask = new_mask.filter(ImageFilter.BoxBlur(edge_blur))
|
|
|
|
|
|
|
|
return ImageOps.invert(new_mask)
|
|
|
|
|
2023-03-03 06:02:00 +00:00
|
|
|
def seam_paint(
|
|
|
|
self,
|
|
|
|
im: Image.Image,
|
|
|
|
seam_size: int,
|
|
|
|
seam_blur: int,
|
2023-03-14 11:00:08 +00:00
|
|
|
seed,
|
2023-03-03 06:02:00 +00:00
|
|
|
steps,
|
|
|
|
cfg_scale,
|
|
|
|
ddim_eta,
|
|
|
|
conditioning,
|
|
|
|
strength,
|
|
|
|
noise,
|
|
|
|
infill_method,
|
|
|
|
step_callback,
|
|
|
|
) -> Image.Image:
|
2023-02-28 05:37:13 +00:00
|
|
|
hard_mask = self.pil_image.split()[-1].copy()
|
|
|
|
mask = self.mask_edge(hard_mask, seam_size, seam_blur)
|
|
|
|
|
|
|
|
make_image = self.get_make_image(
|
|
|
|
steps,
|
|
|
|
cfg_scale,
|
|
|
|
ddim_eta,
|
|
|
|
conditioning,
|
2023-03-03 06:02:00 +00:00
|
|
|
init_image=im.copy().convert("RGBA"),
|
|
|
|
mask_image=mask,
|
|
|
|
strength=strength,
|
|
|
|
mask_blur_radius=0,
|
|
|
|
seam_size=0,
|
|
|
|
step_callback=step_callback,
|
|
|
|
inpaint_width=im.width,
|
|
|
|
inpaint_height=im.height,
|
|
|
|
infill_method=infill_method,
|
2023-02-28 05:37:13 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
seam_noise = self.get_noise(im.width, im.height)
|
|
|
|
|
2023-05-23 14:03:51 +00:00
|
|
|
result = make_image(seam_noise, seed=None)
|
2023-02-28 05:37:13 +00:00
|
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
@torch.no_grad()
|
2023-03-03 06:02:00 +00:00
|
|
|
def get_make_image(
|
|
|
|
self,
|
|
|
|
steps,
|
|
|
|
cfg_scale,
|
|
|
|
ddim_eta,
|
|
|
|
conditioning,
|
2023-07-03 14:55:04 +00:00
|
|
|
init_image: Union[Image.Image, torch.FloatTensor],
|
|
|
|
mask_image: Union[Image.Image, torch.FloatTensor],
|
2023-03-03 06:02:00 +00:00
|
|
|
strength: float,
|
|
|
|
mask_blur_radius: int = 8,
|
|
|
|
# Seam settings - when 0, doesn't fill seam
|
2023-05-04 14:06:34 +00:00
|
|
|
seam_size: int = 96,
|
|
|
|
seam_blur: int = 16,
|
2023-03-03 06:02:00 +00:00
|
|
|
seam_strength: float = 0.7,
|
2023-05-04 14:06:34 +00:00
|
|
|
seam_steps: int = 30,
|
2023-03-03 06:02:00 +00:00
|
|
|
tile_size: int = 32,
|
|
|
|
step_callback=None,
|
|
|
|
inpaint_replace=False,
|
|
|
|
enable_image_debugging=False,
|
|
|
|
infill_method=None,
|
|
|
|
inpaint_width=None,
|
|
|
|
inpaint_height=None,
|
2023-05-05 01:38:17 +00:00
|
|
|
inpaint_fill: Tuple[int, int, int, int] = (0x7F, 0x7F, 0x7F, 0xFF),
|
2023-03-03 06:02:00 +00:00
|
|
|
attention_maps_callback=None,
|
|
|
|
**kwargs,
|
|
|
|
):
|
2023-02-28 05:37:13 +00:00
|
|
|
"""
|
|
|
|
Returns a function returning an image derived from the prompt and
|
|
|
|
the initial image + mask. Return value depends on the seed at
|
|
|
|
the time you call it. kwargs are 'init_latent' and 'strength'
|
|
|
|
"""
|
|
|
|
|
|
|
|
self.enable_image_debugging = enable_image_debugging
|
|
|
|
infill_method = infill_method or infill_methods()[0]
|
|
|
|
self.infill_method = infill_method
|
|
|
|
|
|
|
|
self.inpaint_width = inpaint_width
|
|
|
|
self.inpaint_height = inpaint_height
|
|
|
|
|
2023-05-05 01:38:17 +00:00
|
|
|
if isinstance(init_image, Image.Image):
|
2023-02-28 05:37:13 +00:00
|
|
|
self.pil_image = init_image.copy()
|
|
|
|
|
|
|
|
# Do infill
|
2023-03-03 06:02:00 +00:00
|
|
|
if infill_method == "patchmatch" and PatchMatch.patchmatch_available():
|
2023-02-28 05:37:13 +00:00
|
|
|
init_filled = self.infill_patchmatch(self.pil_image.copy())
|
2023-03-03 06:02:00 +00:00
|
|
|
elif infill_method == "tile":
|
|
|
|
init_filled = self.tile_fill_missing(self.pil_image.copy(), seed=self.seed, tile_size=tile_size)
|
|
|
|
elif infill_method == "solid":
|
2023-05-05 01:38:17 +00:00
|
|
|
solid_bg = Image.new("RGBA", init_image.size, inpaint_fill)
|
|
|
|
init_filled = Image.alpha_composite(solid_bg, init_image)
|
2023-02-28 05:37:13 +00:00
|
|
|
else:
|
2023-03-03 06:02:00 +00:00
|
|
|
raise ValueError(f"Non-supported infill type {infill_method}", infill_method)
|
|
|
|
init_filled.paste(init_image, (0, 0), init_image.split()[-1])
|
2023-02-28 05:37:13 +00:00
|
|
|
|
|
|
|
# Resize if requested for inpainting
|
|
|
|
if inpaint_width and inpaint_height:
|
|
|
|
init_filled = init_filled.resize((inpaint_width, inpaint_height))
|
|
|
|
|
2023-03-03 06:02:00 +00:00
|
|
|
debug_image(init_filled, "init_filled", debug_status=self.enable_image_debugging)
|
2023-02-28 05:37:13 +00:00
|
|
|
|
|
|
|
# Create init tensor
|
2023-03-03 06:02:00 +00:00
|
|
|
init_image = image_resized_to_grid_as_tensor(init_filled.convert("RGB"))
|
2023-02-28 05:37:13 +00:00
|
|
|
|
2023-05-05 01:38:17 +00:00
|
|
|
if isinstance(mask_image, Image.Image):
|
2023-02-28 05:37:13 +00:00
|
|
|
self.pil_mask = mask_image.copy()
|
2023-03-03 06:02:00 +00:00
|
|
|
debug_image(
|
|
|
|
mask_image,
|
|
|
|
"mask_image BEFORE multiply with pil_image",
|
|
|
|
debug_status=self.enable_image_debugging,
|
|
|
|
)
|
2023-02-28 05:37:13 +00:00
|
|
|
|
|
|
|
init_alpha = self.pil_image.getchannel("A")
|
|
|
|
if mask_image.mode != "L":
|
|
|
|
# FIXME: why do we get passed an RGB image here? We can only use single-channel.
|
|
|
|
mask_image = mask_image.convert("L")
|
|
|
|
mask_image = ImageChops.multiply(mask_image, init_alpha)
|
|
|
|
self.pil_mask = mask_image
|
|
|
|
|
|
|
|
# Resize if requested for inpainting
|
|
|
|
if inpaint_width and inpaint_height:
|
|
|
|
mask_image = mask_image.resize((inpaint_width, inpaint_height))
|
|
|
|
|
2023-03-03 06:02:00 +00:00
|
|
|
debug_image(
|
|
|
|
mask_image,
|
|
|
|
"mask_image AFTER multiply with pil_image",
|
|
|
|
debug_status=self.enable_image_debugging,
|
|
|
|
)
|
|
|
|
mask: torch.FloatTensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
2023-02-28 05:37:13 +00:00
|
|
|
else:
|
|
|
|
mask: torch.FloatTensor = mask_image
|
|
|
|
|
|
|
|
self.mask_blur_radius = mask_blur_radius
|
|
|
|
|
|
|
|
# noinspection PyTypeChecker
|
|
|
|
pipeline: StableDiffusionGeneratorPipeline = self.model
|
|
|
|
|
|
|
|
# todo: support cross-attention control
|
|
|
|
uc, c, _ = conditioning
|
2023-03-03 06:02:00 +00:00
|
|
|
conditioning_data = ConditioningData(uc, c, cfg_scale).add_scheduler_args_if_applicable(
|
|
|
|
pipeline.scheduler, eta=ddim_eta
|
|
|
|
)
|
2023-02-28 05:37:13 +00:00
|
|
|
|
2023-03-13 13:11:09 +00:00
|
|
|
def make_image(x_T: torch.Tensor, seed: int):
|
2023-02-28 05:37:13 +00:00
|
|
|
pipeline_output = pipeline.inpaint_from_embeddings(
|
|
|
|
init_image=init_image,
|
|
|
|
mask=1 - mask, # expects white means "paint here."
|
|
|
|
strength=strength,
|
|
|
|
num_inference_steps=steps,
|
|
|
|
conditioning_data=conditioning_data,
|
|
|
|
noise_func=self.get_noise_like,
|
|
|
|
callback=step_callback,
|
2023-03-13 13:11:09 +00:00
|
|
|
seed=seed,
|
2023-02-28 05:37:13 +00:00
|
|
|
)
|
|
|
|
|
2023-03-03 06:02:00 +00:00
|
|
|
if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None:
|
2023-02-28 05:37:13 +00:00
|
|
|
attention_maps_callback(pipeline_output.attention_map_saver)
|
|
|
|
|
2023-03-03 06:02:00 +00:00
|
|
|
result = self.postprocess_size_and_mask(pipeline.numpy_to_pil(pipeline_output.images)[0])
|
2023-02-28 05:37:13 +00:00
|
|
|
|
|
|
|
# Seam paint if this is our first pass (seam_size set to 0 during seam painting)
|
|
|
|
if seam_size > 0:
|
|
|
|
old_image = self.pil_image or init_image
|
|
|
|
old_mask = self.pil_mask or mask_image
|
|
|
|
|
2023-03-03 06:02:00 +00:00
|
|
|
result = self.seam_paint(
|
|
|
|
result,
|
|
|
|
seam_size,
|
|
|
|
seam_blur,
|
2023-03-14 11:00:08 +00:00
|
|
|
seed,
|
2023-03-03 06:02:00 +00:00
|
|
|
seam_steps,
|
|
|
|
cfg_scale,
|
|
|
|
ddim_eta,
|
|
|
|
conditioning,
|
|
|
|
seam_strength,
|
|
|
|
x_T,
|
|
|
|
infill_method,
|
|
|
|
step_callback,
|
|
|
|
)
|
2023-02-28 05:37:13 +00:00
|
|
|
|
|
|
|
# Restore original settings
|
2023-03-03 06:02:00 +00:00
|
|
|
self.get_make_image(
|
|
|
|
steps,
|
|
|
|
cfg_scale,
|
|
|
|
ddim_eta,
|
|
|
|
conditioning,
|
|
|
|
old_image,
|
|
|
|
old_mask,
|
|
|
|
strength,
|
|
|
|
mask_blur_radius,
|
|
|
|
seam_size,
|
|
|
|
seam_blur,
|
|
|
|
seam_strength,
|
|
|
|
seam_steps,
|
|
|
|
tile_size,
|
|
|
|
step_callback,
|
|
|
|
inpaint_replace,
|
|
|
|
enable_image_debugging,
|
|
|
|
inpaint_width=inpaint_width,
|
|
|
|
inpaint_height=inpaint_height,
|
|
|
|
infill_method=infill_method,
|
|
|
|
**kwargs,
|
|
|
|
)
|
2023-02-28 05:37:13 +00:00
|
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
return make_image
|
|
|
|
|
2023-03-03 06:02:00 +00:00
|
|
|
def sample_to_image(self, samples) -> Image.Image:
|
|
|
|
gen_result = super().sample_to_image(samples).convert("RGB")
|
2023-02-28 05:37:13 +00:00
|
|
|
return self.postprocess_size_and_mask(gen_result)
|
|
|
|
|
|
|
|
def postprocess_size_and_mask(self, gen_result: Image.Image) -> Image.Image:
|
|
|
|
debug_image(gen_result, "gen_result", debug_status=self.enable_image_debugging)
|
|
|
|
|
|
|
|
# Resize if necessary
|
|
|
|
if self.inpaint_width and self.inpaint_height:
|
|
|
|
gen_result = gen_result.resize(self.pil_image.size)
|
|
|
|
|
|
|
|
if self.pil_image is None or self.pil_mask is None:
|
|
|
|
return gen_result
|
|
|
|
|
2023-03-03 06:02:00 +00:00
|
|
|
corrected_result = self.repaste_and_color_correct(
|
|
|
|
gen_result, self.pil_image, self.pil_mask, self.mask_blur_radius
|
|
|
|
)
|
|
|
|
debug_image(
|
|
|
|
corrected_result,
|
|
|
|
"corrected_result",
|
|
|
|
debug_status=self.enable_image_debugging,
|
|
|
|
)
|
2023-02-28 05:37:13 +00:00
|
|
|
|
|
|
|
return corrected_result
|