Merge branch 'outpaint' of https://github.com/Kyle0654/InvokeAI into Kyle0654-outpaint

This commit is contained in:
Lincoln Stein 2022-10-27 09:16:40 -04:00
commit 0eb07b7488
3 changed files with 200 additions and 30 deletions

View File

@ -295,6 +295,13 @@ class Generate:
catch_interrupts = False, catch_interrupts = False,
hires_fix = False, hires_fix = False,
use_mps_noise = False, use_mps_noise = False,
# Seam settings for outpainting
seam_size: int = 0,
seam_blur: int = 0,
seam_strength: float = 0.7,
seam_steps: int = 10,
tile_size: int = 32,
force_outpaint: bool = False,
**args, **args,
): # eat up additional cruft ): # eat up additional cruft
""" """
@ -459,7 +466,13 @@ class Generate:
embiggen_tiles=embiggen_tiles, embiggen_tiles=embiggen_tiles,
inpaint_replace=inpaint_replace, inpaint_replace=inpaint_replace,
mask_blur_radius=mask_blur_radius, mask_blur_radius=mask_blur_radius,
safety_checker=checker safety_checker=checker,
seam_size = seam_size,
seam_blur = seam_blur,
seam_strength = seam_strength,
seam_steps = seam_steps,
tile_size = tile_size,
force_outpaint = force_outpaint
) )
if init_color: if init_color:
@ -648,7 +661,7 @@ class Generate:
if inpainting_model_in_use: if inpainting_model_in_use:
return self._make_omnibus() return self._make_omnibus()
if (init_image is not None) and (mask_image is not None): if ((init_image is not None) and (mask_image is not None)) or force_outpaint:
return self._make_inpaint() return self._make_inpaint()
if init_image is not None: if init_image is not None:
@ -925,8 +938,9 @@ class Generate:
image = ImageOps.exif_transpose(image) image = ImageOps.exif_transpose(image)
return image return image
def _create_init_image(self, image, width, height, fit=True): def _create_init_image(self, image: Image.Image, width, height, fit=True):
image = image.convert('RGB') if image.mode != 'RGBA':
image = image.convert('RGB')
image = self._fit_image(image, (width, height)) if fit else self._squeeze_image(image) image = self._fit_image(image, (width, height)) if fit else self._squeeze_image(image)
return image return image

View File

@ -30,7 +30,7 @@ class Img2Img(Generator):
) )
if isinstance(init_image, PIL.Image.Image): if isinstance(init_image, PIL.Image.Image):
init_image = self._image_to_tensor(init_image) init_image = self._image_to_tensor(init_image.convert('RGB'))
scope = choose_autocast(self.precision) scope = choose_autocast(self.precision)
with scope(self.model.device.type): with scope(self.model.device.type):

View File

@ -2,12 +2,13 @@
ldm.invoke.generator.inpaint descends from ldm.invoke.generator ldm.invoke.generator.inpaint descends from ldm.invoke.generator
''' '''
import math
import torch import torch
import torchvision.transforms as T import torchvision.transforms as T
import numpy as np import numpy as np
import cv2 as cv import cv2 as cv
import PIL import PIL
from PIL import Image, ImageFilter from PIL import Image, ImageFilter, ImageOps
from skimage.exposure.histogram_matching import match_histograms from skimage.exposure.histogram_matching import match_histograms
from einops import rearrange, repeat from einops import rearrange, repeat
from ldm.invoke.devices import choose_autocast from ldm.invoke.devices import choose_autocast
@ -24,11 +25,128 @@ class Inpaint(Img2Img):
self.mask_blur_radius = 0 self.mask_blur_radius = 0
super().__init__(model, precision) super().__init__(model, precision)
# Outpaint support code
def get_tile_images(self, image: np.ndarray, width=8, height=8):
_nrows, _ncols, depth = image.shape
_strides = image.strides
nrows, _m = divmod(_nrows, height)
ncols, _n = divmod(_ncols, width)
if _m != 0 or _n != 0:
return None
return np.lib.stride_tricks.as_strided(
np.ravel(image),
shape=(nrows, ncols, height, width, depth),
strides=(height * _strides[0], width * _strides[1], *_strides),
writeable=False
)
def tile_fill_missing(self, im: Image.Image, tile_size: int = 16, seed: int = None) -> Image:
a = np.asarray(im, dtype=np.uint8)
tile_size = (tile_size, tile_size)
# Get the image as tiles of a specified size
tiles = self.get_tile_images(a,*tile_size).copy()
# Get the mask as tiles
tiles_mask = tiles[:,:,:,:,3]
# Find any mask tiles with any fully transparent pixels (we will be replacing these later)
tmask_shape = tiles_mask.shape
tiles_mask = tiles_mask.reshape(math.prod(tiles_mask.shape))
n,ny = (math.prod(tmask_shape[0:2])), math.prod(tmask_shape[2:])
tiles_mask = (tiles_mask > 0)
tiles_mask = tiles_mask.reshape((n,ny)).all(axis = 1)
# Get RGB tiles in single array and filter by the mask
tshape = tiles.shape
tiles_all = tiles.reshape((math.prod(tiles.shape[0:2]), * tiles.shape[2:]))
filtered_tiles = tiles_all[tiles_mask]
if len(filtered_tiles) == 0:
return im
# Find all invalid tiles and replace with a random valid tile
replace_count = (tiles_mask == False).sum()
rng = np.random.default_rng(seed = seed)
tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[rng.choice(filtered_tiles.shape[0], replace_count),:,:,:]
# Convert back to an image
tiles_all = tiles_all.reshape(tshape)
tiles_all = tiles_all.swapaxes(1,2)
st = tiles_all.reshape((math.prod(tiles_all.shape[0:2]), math.prod(tiles_all.shape[2:4]), tiles_all.shape[4]))
si = Image.fromarray(st, mode='RGBA')
return si
def mask_edge(self, mask: Image, edge_size: int, edge_blur: int) -> Image:
npimg = np.asarray(mask, dtype=np.uint8)
# Detect any partially transparent regions
npgradient = np.uint8(255 * (1.0 - np.floor(np.abs(0.5 - np.float32(npimg) / 255.0) * 2.0)))
# Detect hard edges
npedge = cv.Canny(npimg, threshold1=100, threshold2=200)
# Combine
npmask = npgradient + npedge
# Expand
npmask = cv.dilate(npmask, np.ones((3,3), np.uint8), iterations = int(edge_size / 2))
new_mask = Image.fromarray(npmask)
if edge_blur > 0:
new_mask = new_mask.filter(ImageFilter.BoxBlur(edge_blur))
return ImageOps.invert(new_mask)
def seam_paint(self,
im: Image.Image,
seam_size: int,
seam_blur: int,
prompt,sampler,steps,cfg_scale,ddim_eta,
conditioning,strength,
noise
) -> Image.Image:
hard_mask = self.pil_image.split()[-1].copy()
mask = self.mask_edge(hard_mask, seam_size, seam_blur)
make_image = self.get_make_image(
prompt,
sampler,
steps,
cfg_scale,
ddim_eta,
conditioning,
init_image = im.copy().convert('RGBA'),
mask_image = mask.convert('RGB'), # Code currently requires an RGB mask
strength = strength,
mask_blur_radius = 0,
seam_size = 0
)
result = make_image(noise)
return result
@torch.no_grad() @torch.no_grad()
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
conditioning,init_image,mask_image,strength, conditioning,init_image,mask_image,strength,
mask_blur_radius: int = 8, mask_blur_radius: int = 8,
step_callback=None,inpaint_replace=False, **kwargs): # Seam settings - when 0, doesn't fill seam
seam_size: int = 0,
seam_blur: int = 0,
seam_strength: float = 0.7,
seam_steps: int = 10,
tile_size: int = 32,
step_callback=None,
inpaint_replace=False, **kwargs):
""" """
Returns a function returning an image derived from the prompt and Returns a function returning an image derived from the prompt and
the initial image + mask. Return value depends on the seed at the initial image + mask. Return value depends on the seed at
@ -37,7 +155,17 @@ class Inpaint(Img2Img):
if isinstance(init_image, PIL.Image.Image): if isinstance(init_image, PIL.Image.Image):
self.pil_image = init_image self.pil_image = init_image
init_image = self._image_to_tensor(init_image)
# Fill missing areas of original image
init_filled = self.tile_fill_missing(
self.pil_image.copy(),
seed = self.seed,
tile_size = tile_size
)
init_filled.paste(init_image, (0,0), init_image.split()[-1])
# Create init tensor
init_image = self._image_to_tensor(init_filled.convert('RGB'))
if isinstance(mask_image, PIL.Image.Image): if isinstance(mask_image, PIL.Image.Image):
self.pil_mask = mask_image self.pil_mask = mask_image
@ -106,38 +234,56 @@ class Inpaint(Img2Img):
mask = mask_image, mask = mask_image,
init_latent = self.init_latent init_latent = self.init_latent
) )
return self.sample_to_image(samples)
result = self.sample_to_image(samples)
# Seam paint if this is our first pass (seam_size set to 0 during seam painting)
if seam_size > 0:
result = self.seam_paint(
result,
seam_size,
seam_blur,
prompt,
sampler,
seam_steps,
cfg_scale,
ddim_eta,
conditioning,
seam_strength,
x_T)
return result
return make_image return make_image
def sample_to_image(self, samples)->Image.Image:
gen_result = super().sample_to_image(samples).convert('RGB')
if self.pil_image is None or self.pil_mask is None:
return gen_result
pil_mask = self.pil_mask
pil_image = self.pil_image
mask_blur_radius = self.mask_blur_radius
def color_correct(self, image: Image.Image, base_image: Image.Image, mask: Image.Image, mask_blur_radius: int) -> Image.Image:
# Get the original alpha channel of the mask if there is one. # Get the original alpha channel of the mask if there is one.
# Otherwise it is some other black/white image format ('1', 'L' or 'RGB') # Otherwise it is some other black/white image format ('1', 'L' or 'RGB')
pil_init_mask = pil_mask.getchannel('A') if pil_mask.mode == 'RGBA' else pil_mask.convert('L') pil_init_mask = mask.getchannel('A') if mask.mode == 'RGBA' else mask.convert('L')
pil_init_image = pil_image.convert('RGBA') # Add an alpha channel if one doesn't exist pil_init_image = base_image.convert('RGBA') # Add an alpha channel if one doesn't exist
# Build an image with only visible pixels from source to use as reference for color-matching. # Build an image with only visible pixels from source to use as reference for color-matching.
# Note that this doesn't use the mask, which would exclude some source image pixels from the init_rgb_pixels = np.asarray(base_image.convert('RGB'), dtype=np.uint8)
# histogram and cause slight color changes. init_a_pixels = np.asarray(pil_init_image.getchannel('A'), dtype=np.uint8)
init_rgb_pixels = np.asarray(pil_image.convert('RGB'), dtype=np.uint8).reshape(pil_image.width * pil_image.height, 3) init_mask_pixels = np.asarray(pil_init_mask, dtype=np.uint8)
init_a_pixels = np.asarray(pil_init_image.getchannel('A'), dtype=np.uint8).reshape(pil_init_mask.width * pil_init_mask.height)
init_rgb_pixels = init_rgb_pixels[init_a_pixels > 0]
init_rgb_pixels = init_rgb_pixels.reshape(1, init_rgb_pixels.shape[0], init_rgb_pixels.shape[1]) # Filter to just pixels that have any alpha, this is now our histogram
# Get numpy version # Get numpy version of result
np_gen_result = np.asarray(gen_result, dtype=np.uint8) np_image = np.asarray(image, dtype=np.uint8)
# Mask and calculate mean and standard deviation
mask_pixels = init_a_pixels * init_mask_pixels > 0
np_init_rgb_pixels_masked = init_rgb_pixels[mask_pixels, :]
np_image_masked = np_image[mask_pixels, :]
init_means = np_init_rgb_pixels_masked.mean(axis=0)
init_std = np_init_rgb_pixels_masked.std(axis=0)
gen_means = np_image_masked.mean(axis=0)
gen_std = np_image_masked.std(axis=0)
# Color correct # Color correct
np_matched_result = match_histograms(np_gen_result, init_rgb_pixels, channel_axis=-1) np_matched_result = np_image.copy()
np_matched_result[:,:,:] = (((np_matched_result[:,:,:].astype(np.float32) - gen_means[None,None,:]) / gen_std[None,None,:]) * init_std[None,None,:] + init_means[None,None,:]).clip(0, 255).astype(np.uint8)
matched_result = Image.fromarray(np_matched_result, mode='RGB') matched_result = Image.fromarray(np_matched_result, mode='RGB')
# Blur the mask out (into init image) by specified amount # Blur the mask out (into init image) by specified amount
@ -150,6 +296,16 @@ class Inpaint(Img2Img):
blurred_init_mask = pil_init_mask blurred_init_mask = pil_init_mask
# Paste original on color-corrected generation (using blurred mask) # Paste original on color-corrected generation (using blurred mask)
matched_result.paste(pil_image, (0,0), mask = blurred_init_mask) matched_result.paste(base_image, (0,0), mask = blurred_init_mask)
return matched_result return matched_result
def sample_to_image(self, samples)->Image.Image:
gen_result = super().sample_to_image(samples).convert('RGB')
if self.pil_image is None or self.pil_mask is None:
return gen_result
corrected_result = self.color_correct(gen_result, self.pil_image, self.pil_mask, self.mask_blur_radius)
return corrected_result