From 7541c7cf5d4d3d325e293c873f0904c4d0048fd0 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Thu, 6 Oct 2022 18:31:04 -0400 Subject: [PATCH] fix k_samplers in img2img - probably correct now --- ldm/models/diffusion/ksampler.py | 93 +++++++++++++++++++------------- 1 file changed, 56 insertions(+), 37 deletions(-) diff --git a/ldm/models/diffusion/ksampler.py b/ldm/models/diffusion/ksampler.py index d5a6fb67bf..6692907d2d 100644 --- a/ldm/models/diffusion/ksampler.py +++ b/ldm/models/diffusion/ksampler.py @@ -5,6 +5,12 @@ import torch.nn as nn from ldm.dream.devices import choose_torch_device from ldm.models.diffusion.sampler import Sampler from ldm.util import rand_perlin_2d +from ldm.modules.diffusionmodules.util import ( + make_ddim_sampling_parameters, + make_ddim_timesteps, + noise_like, + extract_into_tensor, +) def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7): if threshold <= 0.0: @@ -81,13 +87,53 @@ class KSampler(Sampler): ) self.model = outer_model self.ddim_num_steps = ddim_num_steps - sigmas = self.model.get_sigmas(ddim_num_steps) - self.sigmas = sigmas + # we don't need both of these sigmas, but storing them here to make + # comparison easier later on + self.model_sigmas = self.model.get_sigmas(ddim_num_steps) + self.karras_sigmas = K.sampling.get_sigmas_karras( + n=ddim_num_steps, + sigma_min=self.model.sigmas[0].item(), + sigma_max=self.model.sigmas[-1].item(), + rho=7., + device=self.device, + ) # ALERT: We are completely overriding the sample() method in the base class, which - # means that inpainting will (probably?) not work correctly. To get this to work - # we need to be able to modify the inner loop of k_heun, k_lms, etc, as is done - # in an ugly way in the lstein/k-diffusion branch. + # means that inpainting will not work. To get this to work we need to be able to + # modify the inner loop of k_heun, k_lms, etc, as is done in an ugly way + # in the lstein/k-diffusion branch. + + @torch.no_grad() + def decode( + self, + z_enc, + cond, + t_enc, + img_callback=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + use_original_steps=False, + init_latent = None, + mask = None, + ): + samples,_ = self.sample( + batch_size = 1, + S = t_enc, + x_T = z_enc, + shape = z_enc.shape[1:], + conditioning = cond, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning = unconditional_conditioning, + img_callback = img_callback, + x0 = init_latent, + mask = mask + ) + return samples + + # this is a no-op, provided here for compatibility with ddim and plms samplers + @torch.no_grad() + def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): + return x0 # Most of these arguments are ignored and are only present for compatibility with # other samples @@ -124,15 +170,14 @@ class KSampler(Sampler): img_callback(k_callback_values['x'],k_callback_values['i']) # sigmas are set up in make_schedule - we take the last steps items - sigmas = self.sigmas[-S-1:] - + total_steps = len(self.karras_sigmas) + sigmas = self.karras_sigmas[-S-1:] + if x_T is not None: - x = x_T * sigmas[0] -# x = x_T + torch.randn([batch_size, *shape], device=self.device) * sigmas[0] + x = x_T + torch.randn([batch_size, *shape], device=self.device) * sigmas[0] else: x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0] - model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10)) extra_args = { 'cond': conditioning, @@ -199,12 +244,12 @@ class KSampler(Sampler): # are at an intermediate step in img2img. See similar in # sample() which does work. def get_initial_image(self,x_T,shape,steps): + print(f'WARNING: ksampler.get_initial_image(): get_initial_image needs testing') x = (torch.randn(shape, device=self.device) * self.sigmas[0]) if x_T is not None: return x_T + x else: return x - def prepare_to_sample(self,t_enc): self.t_enc = t_enc @@ -218,29 +263,3 @@ class KSampler(Sampler): ''' return self.model.inner_model.q_sample(x0,ts) - @torch.no_grad() - def decode( - self, - z_enc, - cond, - t_enc, - img_callback=None, - unconditional_guidance_scale=1.0, - unconditional_conditioning=None, - use_original_steps=False, - init_latent = None, - mask = None, - ): - samples,_ = self.sample( - batch_size = 1, - S = t_enc, - x_T = z_enc, - shape = z_enc.shape[1:], - conditioning = cond, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning = unconditional_conditioning, - img_callback = img_callback, - x0 = init_latent, - mask = mask - ) - return samples