fix k_samplers in img2img - probably correct now

This commit is contained in:
Lincoln Stein 2022-10-06 18:31:04 -04:00
parent 440065f7f8
commit 7541c7cf5d

View File

@ -5,6 +5,12 @@ import torch.nn as nn
from ldm.dream.devices import choose_torch_device
from ldm.models.diffusion.sampler import Sampler
from ldm.util import rand_perlin_2d
from ldm.modules.diffusionmodules.util import (
make_ddim_sampling_parameters,
make_ddim_timesteps,
noise_like,
extract_into_tensor,
)
def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
if threshold <= 0.0:
@ -81,13 +87,53 @@ class KSampler(Sampler):
)
self.model = outer_model
self.ddim_num_steps = ddim_num_steps
sigmas = self.model.get_sigmas(ddim_num_steps)
self.sigmas = sigmas
# we don't need both of these sigmas, but storing them here to make
# comparison easier later on
self.model_sigmas = self.model.get_sigmas(ddim_num_steps)
self.karras_sigmas = K.sampling.get_sigmas_karras(
n=ddim_num_steps,
sigma_min=self.model.sigmas[0].item(),
sigma_max=self.model.sigmas[-1].item(),
rho=7.,
device=self.device,
)
# ALERT: We are completely overriding the sample() method in the base class, which
# means that inpainting will (probably?) not work correctly. To get this to work
# we need to be able to modify the inner loop of k_heun, k_lms, etc, as is done
# in an ugly way in the lstein/k-diffusion branch.
# means that inpainting will not work. To get this to work we need to be able to
# modify the inner loop of k_heun, k_lms, etc, as is done in an ugly way
# in the lstein/k-diffusion branch.
@torch.no_grad()
def decode(
self,
z_enc,
cond,
t_enc,
img_callback=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
use_original_steps=False,
init_latent = None,
mask = None,
):
samples,_ = self.sample(
batch_size = 1,
S = t_enc,
x_T = z_enc,
shape = z_enc.shape[1:],
conditioning = cond,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning = unconditional_conditioning,
img_callback = img_callback,
x0 = init_latent,
mask = mask
)
return samples
# this is a no-op, provided here for compatibility with ddim and plms samplers
@torch.no_grad()
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
return x0
# Most of these arguments are ignored and are only present for compatibility with
# other samples
@ -124,15 +170,14 @@ class KSampler(Sampler):
img_callback(k_callback_values['x'],k_callback_values['i'])
# sigmas are set up in make_schedule - we take the last steps items
sigmas = self.sigmas[-S-1:]
total_steps = len(self.karras_sigmas)
sigmas = self.karras_sigmas[-S-1:]
if x_T is not None:
x = x_T * sigmas[0]
# x = x_T + torch.randn([batch_size, *shape], device=self.device) * sigmas[0]
x = x_T + torch.randn([batch_size, *shape], device=self.device) * sigmas[0]
else:
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0]
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10))
extra_args = {
'cond': conditioning,
@ -199,12 +244,12 @@ class KSampler(Sampler):
# are at an intermediate step in img2img. See similar in
# sample() which does work.
def get_initial_image(self,x_T,shape,steps):
print(f'WARNING: ksampler.get_initial_image(): get_initial_image needs testing')
x = (torch.randn(shape, device=self.device) * self.sigmas[0])
if x_T is not None:
return x_T + x
else:
return x
def prepare_to_sample(self,t_enc):
self.t_enc = t_enc
@ -218,29 +263,3 @@ class KSampler(Sampler):
'''
return self.model.inner_model.q_sample(x0,ts)
@torch.no_grad()
def decode(
self,
z_enc,
cond,
t_enc,
img_callback=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
use_original_steps=False,
init_latent = None,
mask = None,
):
samples,_ = self.sample(
batch_size = 1,
S = t_enc,
x_T = z_enc,
shape = z_enc.shape[1:],
conditioning = cond,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning = unconditional_conditioning,
img_callback = img_callback,
x0 = init_latent,
mask = mask
)
return samples