2022-09-06 00:40:10 +00:00
|
|
|
'''
|
2022-10-08 15:37:23 +00:00
|
|
|
ldm.invoke.generator.inpaint descends from ldm.invoke.generator
|
2022-09-06 00:40:10 +00:00
|
|
|
'''
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import numpy as np
|
|
|
|
from einops import rearrange, repeat
|
2022-10-08 15:37:23 +00:00
|
|
|
from ldm.invoke.devices import choose_autocast
|
|
|
|
from ldm.invoke.generator.img2img import Img2Img
|
2022-09-06 00:40:10 +00:00
|
|
|
from ldm.models.diffusion.ddim import DDIMSampler
|
2022-09-25 08:03:28 +00:00
|
|
|
from ldm.models.diffusion.ksampler import KSampler
|
2022-09-06 00:40:10 +00:00
|
|
|
|
|
|
|
class Inpaint(Img2Img):
|
2022-09-17 17:56:25 +00:00
|
|
|
def __init__(self, model, precision):
|
2022-09-08 00:24:35 +00:00
|
|
|
self.init_latent = None
|
2022-09-17 17:56:25 +00:00
|
|
|
super().__init__(model, precision)
|
|
|
|
|
2022-09-06 00:40:10 +00:00
|
|
|
@torch.no_grad()
|
|
|
|
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
2022-09-08 11:34:03 +00:00
|
|
|
conditioning,init_image,mask_image,strength,
|
2022-10-02 20:37:36 +00:00
|
|
|
step_callback=None,inpaint_replace=False,**kwargs):
|
2022-09-06 00:40:10 +00:00
|
|
|
"""
|
|
|
|
Returns a function returning an image derived from the prompt and
|
|
|
|
the initial image + mask. Return value depends on the seed at
|
|
|
|
the time you call it. kwargs are 'init_latent' and 'strength'
|
|
|
|
"""
|
2022-09-25 08:03:28 +00:00
|
|
|
# klms samplers not supported yet, so ignore previous sampler
|
|
|
|
if isinstance(sampler,KSampler):
|
2022-09-06 00:40:10 +00:00
|
|
|
print(
|
2022-10-06 14:39:08 +00:00
|
|
|
f">> Using recommended DDIM sampler for inpainting."
|
2022-09-06 00:40:10 +00:00
|
|
|
)
|
|
|
|
sampler = DDIMSampler(self.model, device=self.model.device)
|
2022-09-25 08:03:28 +00:00
|
|
|
|
2022-09-23 09:02:30 +00:00
|
|
|
sampler.make_schedule(
|
|
|
|
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
|
|
|
|
)
|
2022-09-06 00:40:10 +00:00
|
|
|
|
2022-09-25 08:03:28 +00:00
|
|
|
mask_image = mask_image[0][0].unsqueeze(0).repeat(4,1,1).unsqueeze(0)
|
|
|
|
mask_image = repeat(mask_image, '1 ... -> b ...', b=1)
|
|
|
|
|
2022-09-17 17:56:25 +00:00
|
|
|
scope = choose_autocast(self.precision)
|
|
|
|
with scope(self.model.device.type):
|
2022-09-06 00:40:10 +00:00
|
|
|
self.init_latent = self.model.get_first_stage_encoding(
|
|
|
|
self.model.encode_first_stage(init_image)
|
|
|
|
) # move to latent space
|
|
|
|
|
|
|
|
t_enc = int(strength * steps)
|
2022-10-18 21:23:38 +00:00
|
|
|
# todo: support cross-attention control
|
2022-10-23 12:58:25 +00:00
|
|
|
uc, c, _ = conditioning
|
2022-09-06 00:40:10 +00:00
|
|
|
|
|
|
|
print(f">> target t_enc is {t_enc} steps")
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
def make_image(x_T):
|
|
|
|
# encode (scaled latent)
|
|
|
|
z_enc = sampler.stochastic_encode(
|
|
|
|
self.init_latent,
|
|
|
|
torch.tensor([t_enc]).to(self.model.device),
|
|
|
|
noise=x_T
|
|
|
|
)
|
2022-09-25 08:03:28 +00:00
|
|
|
|
2022-10-02 20:37:36 +00:00
|
|
|
# to replace masked area with latent noise, weighted by inpaint_replace strength
|
|
|
|
if inpaint_replace > 0.0:
|
|
|
|
print(f'>> inpaint will replace what was under the mask with a strength of {inpaint_replace}')
|
|
|
|
l_noise = self.get_noise(kwargs['width'],kwargs['height'])
|
|
|
|
inverted_mask = 1.0-mask_image # there will be 1s where the mask is
|
|
|
|
masked_region = (1.0-inpaint_replace) * inverted_mask * z_enc + inpaint_replace * inverted_mask * l_noise
|
|
|
|
z_enc = z_enc * mask_image + masked_region
|
|
|
|
|
2022-09-06 00:40:10 +00:00
|
|
|
# decode it
|
|
|
|
samples = sampler.decode(
|
|
|
|
z_enc,
|
|
|
|
c,
|
|
|
|
t_enc,
|
|
|
|
img_callback = step_callback,
|
|
|
|
unconditional_guidance_scale = cfg_scale,
|
|
|
|
unconditional_conditioning = uc,
|
2022-09-08 11:34:03 +00:00
|
|
|
mask = mask_image,
|
2022-09-06 00:40:10 +00:00
|
|
|
init_latent = self.init_latent
|
|
|
|
)
|
2022-09-25 08:03:28 +00:00
|
|
|
|
2022-09-06 00:40:10 +00:00
|
|
|
return self.sample_to_image(samples)
|
|
|
|
|
|
|
|
return make_image
|
|
|
|
|
|
|
|
|
|
|
|
|