img2img works with all samplers, inpainting working with ddim & plms

- img2img confirmed working with all samplers
- inpainting working on ddim & plms. Changes to k-diffusion
  module seem to be needed for inpainting support.
- switched k-diffuser noise schedule to original karras schedule,
  which reduces the step number needed for good results
This commit is contained in:
Lincoln Stein
2022-09-25 04:03:28 -04:00
parent 72834ad16c
commit 958d7650dd
11 changed files with 597 additions and 685 deletions

View File

@ -3,6 +3,7 @@ import k_diffusion as K
import torch
import torch.nn as nn
from ldm.dream.devices import choose_torch_device
from ldm.models.diffusion.sampler import Sampler
class CFGDenoiser(nn.Module):
def __init__(self, model):
@ -17,12 +18,16 @@ class CFGDenoiser(nn.Module):
return uncond + (cond - uncond) * cond_scale
class KSampler(object):
class KSampler(Sampler):
def __init__(self, model, schedule='lms', device=None, **kwargs):
super().__init__()
self.model = K.external.CompVisDenoiser(model)
self.schedule = schedule
self.device = device or choose_torch_device()
denoiser = K.external.CompVisDenoiser(model)
super().__init__(
denoiser,
schedule,
steps=model.num_timesteps,
)
self.ds = None
self.s_in = None
def forward(self, x, sigma, uncond, cond, cond_scale):
x_in = torch.cat([x] * 2)
@ -33,7 +38,40 @@ class KSampler(object):
).chunk(2)
return uncond + (cond - uncond) * cond_scale
# most of these arguments are ignored and are only present for compatibility with
def make_schedule(
self,
ddim_num_steps,
ddim_discretize='uniform',
ddim_eta=0.0,
verbose=False,
):
outer_model = self.model
self.model = outer_model.inner_model
super().make_schedule(
ddim_num_steps,
ddim_discretize='uniform',
ddim_eta=0.0,
verbose=False,
)
self.model = outer_model
self.ddim_num_steps = ddim_num_steps
sigmas = K.sampling.get_sigmas_karras(
n=ddim_num_steps,
sigma_min=self.model.sigmas[0].item(),
sigma_max=self.model.sigmas[-1].item(),
rho=7.,
device=self.device,
# Birch-san recommends this, but it doesn't match the call signature in his branch of k-diffusion
# concat_zero=False
)
self.sigmas = sigmas
# ALERT: We are completely overriding the sample() method in the base class, which
# means that inpainting will (probably?) not work correctly. To get this to work
# we need to be able to modify the inner loop of k_heun, k_lms, etc, as is done
# in an ugly way in the lstein/k-diffusion branch.
# Most of these arguments are ignored and are only present for compatibility with
# other samples
@torch.no_grad()
def sample(
@ -63,9 +101,11 @@ class KSampler(object):
):
def route_callback(k_callback_values):
if img_callback is not None:
img_callback(k_callback_values['x'], k_callback_values['i'])
img_callback(k_callback_values['x'])
sigmas = self.model.get_sigmas(S)
# sigmas = self.model.get_sigmas(S)
# sigmas are now set up in make_schedule - we take the last steps items
sigmas = self.sigmas[-S:]
if x_T is not None:
x = x_T * sigmas[0]
else:
@ -86,3 +126,67 @@ class KSampler(object):
),
None,
)
@torch.no_grad()
def p_sample(
self,
img,
cond,
ts,
index,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
**kwargs,
):
if self.model_wrap is None:
self.model_wrap = CFGDenoiser(self.model)
extra_args = {
'cond': cond,
'uncond': unconditional_conditioning,
'cond_scale': unconditional_guidance_scale,
}
if self.s_in is None:
self.s_in = img.new_ones([img.shape[0]])
if self.ds is None:
self.ds = []
# terrible, confusing names here
steps = self.ddim_num_steps
t_enc = self.t_enc
# sigmas is a full steps in length, but t_enc might
# be less. We start in the middle of the sigma array
# and work our way to the end after t_enc steps.
# index starts at t_enc and works its way to zero,
# so the actual formula for indexing into sigmas:
# sigma_index = (steps-index)
s_index = t_enc - index - 1
img = K.sampling.__dict__[f'_{self.schedule}'](
self.model_wrap,
img,
self.sigmas,
s_index,
s_in = self.s_in,
ds = self.ds,
extra_args=extra_args,
)
return img, None, None
def get_initial_image(self,x_T,shape,steps):
if x_T is not None:
return x_T + x_T * self.sigmas[0]
else:
return (torch.randn(shape, device=self.device) * self.sigmas[0])
def prepare_to_sample(self,t_enc):
self.t_enc = t_enc
self.model_wrap = None
self.ds = None
self.s_in = None
def q_sample(self,x0,ts):
'''
Overrides parent method to return the q_sample of the inner model.
'''
return self.model.inner_model.q_sample(x0,ts)