InvokeAI/ldm/dream/generator/inpaint.py
Lincoln Stein 958d7650dd img2img works with all samplers, inpainting working with ddim & plms
- img2img confirmed working with all samplers
- inpainting working on ddim & plms. Changes to k-diffusion
  module seem to be needed for inpainting support.
- switched k-diffuser noise schedule to original karras schedule,
  which reduces the step number needed for good results
2022-10-02 14:58:21 -04:00

79 lines
2.6 KiB
Python

'''
ldm.dream.generator.inpaint descends from ldm.dream.generator
'''
import torch
import numpy as np
from einops import rearrange, repeat
from ldm.dream.devices import choose_autocast
from ldm.dream.generator.img2img import Img2Img
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.ksampler import KSampler
class Inpaint(Img2Img):
def __init__(self, model, precision):
self.init_latent = None
super().__init__(model, precision)
@torch.no_grad()
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
conditioning,init_image,mask_image,strength,
step_callback=None,**kwargs):
"""
Returns a function returning an image derived from the prompt and
the initial image + mask. Return value depends on the seed at
the time you call it. kwargs are 'init_latent' and 'strength'
"""
# klms samplers not supported yet, so ignore previous sampler
if isinstance(sampler,KSampler):
print(
f">> sampler '{sampler.__class__.__name__}' is not yet supported for inpainting, using DDIMSampler instead."
)
sampler = DDIMSampler(self.model, device=self.model.device)
sampler.make_schedule(
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
)
mask_image = mask_image[0][0].unsqueeze(0).repeat(4,1,1).unsqueeze(0)
mask_image = repeat(mask_image, '1 ... -> b ...', b=1)
scope = choose_autocast(self.precision)
with scope(self.model.device.type):
self.init_latent = self.model.get_first_stage_encoding(
self.model.encode_first_stage(init_image)
) # move to latent space
t_enc = int(strength * steps)
uc, c = conditioning
print(f">> target t_enc is {t_enc} steps")
@torch.no_grad()
def make_image(x_T):
# encode (scaled latent)
z_enc = sampler.stochastic_encode(
self.init_latent,
torch.tensor([t_enc]).to(self.model.device),
noise=x_T
)
# decode it
samples = sampler.decode(
z_enc,
c,
t_enc,
img_callback = step_callback,
unconditional_guidance_scale = cfg_scale,
unconditional_conditioning = uc,
mask = mask_image,
init_latent = self.init_latent
)
return self.sample_to_image(samples)
return make_image