InvokeAI/ldm/dream/generator/img2img.py
Lincoln Stein 958d7650dd img2img works with all samplers, inpainting working with ddim & plms
- img2img confirmed working with all samplers
- inpainting working on ddim & plms. Changes to k-diffusion
  module seem to be needed for inpainting support.
- switched k-diffuser noise schedule to original karras schedule,
  which reduces the step number needed for good results
2022-10-02 14:58:21 -04:00

67 lines
2.4 KiB
Python

'''
ldm.dream.generator.img2img descends from ldm.dream.generator
'''
import torch
import numpy as np
from ldm.dream.devices import choose_autocast
from ldm.dream.generator.base import Generator
from ldm.models.diffusion.ddim import DDIMSampler
class Img2Img(Generator):
def __init__(self, model, precision):
super().__init__(model, precision)
self.init_latent = None # by get_noise()
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
conditioning,init_image,strength,step_callback=None,**kwargs):
"""
Returns a function returning an image derived from the prompt and the initial image
Return value depends on the seed at the time you call it.
"""
sampler.make_schedule(
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
)
scope = choose_autocast(self.precision)
with scope(self.model.device.type):
self.init_latent = self.model.get_first_stage_encoding(
self.model.encode_first_stage(init_image)
) # move to latent space
t_enc = int(strength * steps)
uc, c = conditioning
def make_image(x_T):
# encode (scaled latent)
z_enc = sampler.stochastic_encode(
self.init_latent,
torch.tensor([t_enc]).to(self.model.device),
noise=x_T
)
samples,_ = sampler.sample(
batch_size = 1,
S = t_enc,
shape = z_enc.shape[1:],
x_T = z_enc,
conditioning = c,
unconditional_guidance_scale = cfg_scale,
unconditional_conditioning = uc,
eta = ddim_eta,
img_callback = step_callback,
verbose = False,
)
return self.sample_to_image(samples)
return make_image
def get_noise(self,width,height):
device = self.model.device
init_latent = self.init_latent
assert init_latent is not None,'call to get_noise() when init_latent not set'
if device.type == 'mps':
return torch.randn_like(init_latent, device='cpu').to(device)
else:
return torch.randn_like(init_latent, device=device)