2022-09-06 00:40:10 +00:00
'''
2022-10-08 15:37:23 +00:00
ldm . invoke . generator . inpaint descends from ldm . invoke . generator
2022-09-06 00:40:10 +00:00
'''
import torch
2022-10-22 21:56:33 +00:00
import torchvision . transforms as T
2022-09-06 00:40:10 +00:00
import numpy as np
2022-10-22 21:56:33 +00:00
import cv2 as cv
2022-10-23 03:09:38 +00:00
import PIL
2022-10-22 21:56:33 +00:00
from PIL import Image , ImageFilter
from skimage . exposure . histogram_matching import match_histograms
2022-09-06 00:40:10 +00:00
from einops import rearrange , repeat
2022-10-08 15:37:23 +00:00
from ldm . invoke . devices import choose_autocast
from ldm . invoke . generator . img2img import Img2Img
2022-09-06 00:40:10 +00:00
from ldm . models . diffusion . ddim import DDIMSampler
2022-09-25 08:03:28 +00:00
from ldm . models . diffusion . ksampler import KSampler
2022-10-23 03:09:38 +00:00
from ldm . invoke . generator . base import downsampling
2022-09-06 00:40:10 +00:00
class Inpaint ( Img2Img ) :
2022-09-17 17:56:25 +00:00
def __init__ ( self , model , precision ) :
2022-09-08 00:24:35 +00:00
self . init_latent = None
2022-10-23 03:09:38 +00:00
self . pil_image = None
self . pil_mask = None
self . mask_blur_radius = 0
2022-09-17 17:56:25 +00:00
super ( ) . __init__ ( model , precision )
2022-09-06 00:40:10 +00:00
@torch.no_grad ( )
def get_make_image ( self , prompt , sampler , steps , cfg_scale , ddim_eta ,
2022-09-08 11:34:03 +00:00
conditioning , init_image , mask_image , strength ,
2022-10-22 21:56:33 +00:00
mask_blur_radius : int = 8 ,
step_callback = None , inpaint_replace = False , * * kwargs ) :
2022-09-06 00:40:10 +00:00
"""
Returns a function returning an image derived from the prompt and
the initial image + mask . Return value depends on the seed at
the time you call it . kwargs are ' init_latent ' and ' strength '
"""
2022-10-22 21:56:33 +00:00
2022-10-23 03:09:38 +00:00
if isinstance ( init_image , PIL . Image . Image ) :
self . pil_image = init_image
init_image = self . _image_to_tensor ( init_image )
if isinstance ( mask_image , PIL . Image . Image ) :
self . pil_mask = mask_image
mask_image = mask_image . resize (
(
mask_image . width / / downsampling ,
mask_image . height / / downsampling
) ,
resample = Image . Resampling . NEAREST
)
mask_image = self . _image_to_tensor ( mask_image , normalize = False )
2022-10-22 21:56:33 +00:00
2022-10-23 03:09:38 +00:00
self . mask_blur_radius = mask_blur_radius
2022-10-22 21:56:33 +00:00
2022-09-25 08:03:28 +00:00
# klms samplers not supported yet, so ignore previous sampler
if isinstance ( sampler , KSampler ) :
2022-09-06 00:40:10 +00:00
print (
2022-10-06 14:39:08 +00:00
f " >> Using recommended DDIM sampler for inpainting. "
2022-09-06 00:40:10 +00:00
)
sampler = DDIMSampler ( self . model , device = self . model . device )
2022-09-25 08:03:28 +00:00
2022-09-23 09:02:30 +00:00
sampler . make_schedule (
ddim_num_steps = steps , ddim_eta = ddim_eta , verbose = False
)
2022-09-06 00:40:10 +00:00
2022-09-25 08:03:28 +00:00
mask_image = mask_image [ 0 ] [ 0 ] . unsqueeze ( 0 ) . repeat ( 4 , 1 , 1 ) . unsqueeze ( 0 )
mask_image = repeat ( mask_image , ' 1 ... -> b ... ' , b = 1 )
2022-09-17 17:56:25 +00:00
scope = choose_autocast ( self . precision )
with scope ( self . model . device . type ) :
2022-09-06 00:40:10 +00:00
self . init_latent = self . model . get_first_stage_encoding (
self . model . encode_first_stage ( init_image )
) # move to latent space
t_enc = int ( strength * steps )
uc , c = conditioning
print ( f " >> target t_enc is { t_enc } steps " )
@torch.no_grad ( )
def make_image ( x_T ) :
# encode (scaled latent)
z_enc = sampler . stochastic_encode (
self . init_latent ,
torch . tensor ( [ t_enc ] ) . to ( self . model . device ) ,
noise = x_T
)
2022-09-25 08:03:28 +00:00
2022-10-02 20:37:36 +00:00
# to replace masked area with latent noise, weighted by inpaint_replace strength
if inpaint_replace > 0.0 :
print ( f ' >> inpaint will replace what was under the mask with a strength of { inpaint_replace } ' )
l_noise = self . get_noise ( kwargs [ ' width ' ] , kwargs [ ' height ' ] )
inverted_mask = 1.0 - mask_image # there will be 1s where the mask is
masked_region = ( 1.0 - inpaint_replace ) * inverted_mask * z_enc + inpaint_replace * inverted_mask * l_noise
z_enc = z_enc * mask_image + masked_region
2022-09-06 00:40:10 +00:00
# decode it
samples = sampler . decode (
z_enc ,
c ,
t_enc ,
img_callback = step_callback ,
unconditional_guidance_scale = cfg_scale ,
unconditional_conditioning = uc ,
2022-09-08 11:34:03 +00:00
mask = mask_image ,
2022-09-06 00:40:10 +00:00
init_latent = self . init_latent
)
2022-10-23 03:09:38 +00:00
return self . sample_to_image ( samples )
return make_image
2022-09-25 08:03:28 +00:00
2022-10-24 02:52:32 +00:00
def sample_to_image ( self , samples ) - > Image . Image :
2022-10-23 03:09:38 +00:00
gen_result = super ( ) . sample_to_image ( samples ) . convert ( ' RGB ' )
2022-09-06 00:40:10 +00:00
2022-10-23 03:09:38 +00:00
if self . pil_image is None or self . pil_mask is None :
return gen_result
pil_mask = self . pil_mask
pil_image = self . pil_image
mask_blur_radius = self . mask_blur_radius
2022-09-06 00:40:10 +00:00
2022-10-23 03:09:38 +00:00
# Get the original alpha channel of the mask if there is one.
# Otherwise it is some other black/white image format ('1', 'L' or 'RGB')
pil_init_mask = pil_mask . getchannel ( ' A ' ) if pil_mask . mode == ' RGBA ' else pil_mask . convert ( ' L ' )
pil_init_image = pil_image . convert ( ' RGBA ' ) # Add an alpha channel if one doesn't exist
2022-09-06 00:40:10 +00:00
2022-10-23 03:09:38 +00:00
# Build an image with only visible pixels from source to use as reference for color-matching.
# Note that this doesn't use the mask, which would exclude some source image pixels from the
# histogram and cause slight color changes.
init_rgb_pixels = np . asarray ( pil_image . convert ( ' RGB ' ) , dtype = np . uint8 ) . reshape ( pil_image . width * pil_image . height , 3 )
init_a_pixels = np . asarray ( pil_init_image . getchannel ( ' A ' ) , dtype = np . uint8 ) . reshape ( pil_init_mask . width * pil_init_mask . height )
init_rgb_pixels = init_rgb_pixels [ init_a_pixels > 0 ]
init_rgb_pixels = init_rgb_pixels . reshape ( 1 , init_rgb_pixels . shape [ 0 ] , init_rgb_pixels . shape [ 1 ] ) # Filter to just pixels that have any alpha, this is now our histogram
2022-09-06 00:40:10 +00:00
2022-10-23 03:09:38 +00:00
# Get numpy version
np_gen_result = np . asarray ( gen_result , dtype = np . uint8 )
2022-10-22 21:56:33 +00:00
2022-10-23 03:09:38 +00:00
# Color correct
np_matched_result = match_histograms ( np_gen_result , init_rgb_pixels , channel_axis = - 1 )
matched_result = Image . fromarray ( np_matched_result , mode = ' RGB ' )
2022-10-22 21:56:33 +00:00
2022-10-23 03:09:38 +00:00
# Blur the mask out (into init image) by specified amount
if mask_blur_radius > 0 :
nm = np . asarray ( pil_init_mask , dtype = np . uint8 )
2022-10-23 21:02:52 +00:00
nmd = cv . erode ( nm , kernel = np . ones ( ( 3 , 3 ) , dtype = np . uint8 ) , iterations = int ( mask_blur_radius / 2 ) )
2022-10-23 03:09:38 +00:00
pmd = Image . fromarray ( nmd , mode = ' L ' )
blurred_init_mask = pmd . filter ( ImageFilter . BoxBlur ( mask_blur_radius ) )
else :
blurred_init_mask = pil_init_mask
# Paste original on color-corrected generation (using blurred mask)
matched_result . paste ( pil_image , ( 0 , 0 ) , mask = blurred_init_mask )
return matched_result
2022-10-22 21:56:33 +00:00