From 958d7650dd249b37df49036ae91a40739008a7f8 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sun, 25 Sep 2022 04:03:28 -0400 Subject: [PATCH 1/7] img2img works with all samplers, inpainting working with ddim & plms - img2img confirmed working with all samplers - inpainting working on ddim & plms. Changes to k-diffusion module seem to be needed for inpainting support. - switched k-diffuser noise schedule to original karras schedule, which reduces the step number needed for good results --- environment.yml | 2 +- ldm/dream/args.py | 2 +- ldm/dream/generator/img2img.py | 28 +-- ldm/dream/generator/inpaint.py | 19 +- ldm/dream/generator/txt2img.py | 2 + ldm/generate.py | 12 + ldm/models/diffusion/ddim.py | 380 ++--------------------------- ldm/models/diffusion/ksampler.py | 120 ++++++++- ldm/models/diffusion/plms.py | 314 ++---------------------- ldm/models/diffusion/sampler.py | 402 +++++++++++++++++++++++++++++++ scripts/dream.py | 1 + 11 files changed, 597 insertions(+), 685 deletions(-) create mode 100644 ldm/models/diffusion/sampler.py diff --git a/environment.yml b/environment.yml index eaf4d0e02a..14e97d1602 100644 --- a/environment.yml +++ b/environment.yml @@ -34,6 +34,6 @@ dependencies: - kornia==0.6.0 - -e git+https://github.com/openai/CLIP.git@main#egg=clip - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers - - -e git+https://github.com/lstein/k-diffusion.git@master#egg=k-diffusion + - -e git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k_diffusion - -e git+https://github.com/lstein/GFPGAN@fix-dark-cast-images#egg=gfpgan - -e . diff --git a/ldm/dream/args.py b/ldm/dream/args.py index acbe83a14c..8f04cf1ea4 100644 --- a/ldm/dream/args.py +++ b/ldm/dream/args.py @@ -193,7 +193,7 @@ class Args(object): # img2img generations have parameters relevant only to them and have special handling if a['init_img'] and len(a['init_img'])>0: switches.append(f'-I {a["init_img"]}') - switches.append(f'-A ddim') # TODO: FIX ME WHEN IMG2IMG SUPPORTS ALL SAMPLERS + switches.append(f'-A {a["sampler_name"]}') if a['fit']: switches.append(f'--fit') if a['init_mask'] and len(a['init_mask'])>0: diff --git a/ldm/dream/generator/img2img.py b/ldm/dream/generator/img2img.py index f354b59138..dbab188c4a 100644 --- a/ldm/dream/generator/img2img.py +++ b/ldm/dream/generator/img2img.py @@ -13,7 +13,6 @@ class Img2Img(Generator): super().__init__(model, precision) self.init_latent = None # by get_noise() - @torch.no_grad() def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, conditioning,init_image,strength,step_callback=None,**kwargs): """ @@ -21,13 +20,6 @@ class Img2Img(Generator): Return value depends on the seed at the time you call it. """ - # PLMS sampler not supported yet, so ignore previous sampler - if not isinstance(sampler,DDIMSampler): - print( - f">> sampler '{sampler.__class__.__name__}' is not yet supported. Using DDIM sampler" - ) - sampler = DDIMSampler(self.model, device=self.model.device) - sampler.make_schedule( ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False ) @@ -41,7 +33,6 @@ class Img2Img(Generator): t_enc = int(strength * steps) uc, c = conditioning - @torch.no_grad() def make_image(x_T): # encode (scaled latent) z_enc = sampler.stochastic_encode( @@ -49,14 +40,17 @@ class Img2Img(Generator): torch.tensor([t_enc]).to(self.model.device), noise=x_T ) - # decode it - samples = sampler.decode( - z_enc, - c, - t_enc, - img_callback = step_callback, - unconditional_guidance_scale=cfg_scale, - unconditional_conditioning=uc, + samples,_ = sampler.sample( + batch_size = 1, + S = t_enc, + shape = z_enc.shape[1:], + x_T = z_enc, + conditioning = c, + unconditional_guidance_scale = cfg_scale, + unconditional_conditioning = uc, + eta = ddim_eta, + img_callback = step_callback, + verbose = False, ) return self.sample_to_image(samples) diff --git a/ldm/dream/generator/inpaint.py b/ldm/dream/generator/inpaint.py index da5411ad64..620210b118 100644 --- a/ldm/dream/generator/inpaint.py +++ b/ldm/dream/generator/inpaint.py @@ -8,6 +8,7 @@ from einops import rearrange, repeat from ldm.dream.devices import choose_autocast from ldm.dream.generator.img2img import Img2Img from ldm.models.diffusion.ddim import DDIMSampler +from ldm.models.diffusion.ksampler import KSampler class Inpaint(Img2Img): def __init__(self, model, precision): @@ -23,21 +24,20 @@ class Inpaint(Img2Img): the initial image + mask. Return value depends on the seed at the time you call it. kwargs are 'init_latent' and 'strength' """ - - mask_image = mask_image[0][0].unsqueeze(0).repeat(4,1,1).unsqueeze(0) - mask_image = repeat(mask_image, '1 ... -> b ...', b=1) - - # PLMS sampler not supported yet, so ignore previous sampler - if not isinstance(sampler,DDIMSampler): + # klms samplers not supported yet, so ignore previous sampler + if isinstance(sampler,KSampler): print( - f">> sampler '{sampler.__class__.__name__}' is not yet supported. Using DDIM sampler" + f">> sampler '{sampler.__class__.__name__}' is not yet supported for inpainting, using DDIMSampler instead." ) sampler = DDIMSampler(self.model, device=self.model.device) - + sampler.make_schedule( ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False ) + mask_image = mask_image[0][0].unsqueeze(0).repeat(4,1,1).unsqueeze(0) + mask_image = repeat(mask_image, '1 ... -> b ...', b=1) + scope = choose_autocast(self.precision) with scope(self.model.device.type): self.init_latent = self.model.get_first_stage_encoding( @@ -57,7 +57,7 @@ class Inpaint(Img2Img): torch.tensor([t_enc]).to(self.model.device), noise=x_T ) - + # decode it samples = sampler.decode( z_enc, @@ -69,6 +69,7 @@ class Inpaint(Img2Img): mask = mask_image, init_latent = self.init_latent ) + return self.sample_to_image(samples) return make_image diff --git a/ldm/dream/generator/txt2img.py b/ldm/dream/generator/txt2img.py index 1ab15ba7cd..790f098f84 100644 --- a/ldm/dream/generator/txt2img.py +++ b/ldm/dream/generator/txt2img.py @@ -30,6 +30,8 @@ class Txt2Img(Generator): if self.free_gpu_mem and self.model.model.device != self.model.device: self.model.model.to(self.model.device) + + sampler.make_schedule(ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=True) samples, _ = sampler.sample( batch_size = 1, diff --git a/ldm/generate.py b/ldm/generate.py index c0936fdccf..f1b294c261 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -1065,3 +1065,15 @@ class Generate: f.write(hash) return hash + def write_intermediate_images(self,modulus,path): + counter = -1 + if not os.path.exists(path): + os.makedirs(path) + def callback(img): + nonlocal counter + counter += 1 + if counter % modulus != 0: + return; + image = self.sample_to_image(img) + image.save(os.path.join(path,f'{counter:03}.png'),'PNG') + return callback diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py index b875aac331..7aca381c74 100644 --- a/ldm/models/diffusion/ddim.py +++ b/ldm/models/diffusion/ddim.py @@ -5,289 +5,31 @@ import numpy as np from tqdm import tqdm from functools import partial from ldm.dream.devices import choose_torch_device +from ldm.models.diffusion.sampler import Sampler +from ldm.modules.diffusionmodules.util import noise_like -from ldm.modules.diffusionmodules.util import ( - make_ddim_sampling_parameters, - make_ddim_timesteps, - noise_like, - extract_into_tensor, -) - - -class DDIMSampler(object): +class DDIMSampler(Sampler): def __init__(self, model, schedule='linear', device=None, **kwargs): - super().__init__() - self.model = model - self.ddpm_num_timesteps = model.num_timesteps - self.schedule = schedule - self.device = device or choose_torch_device() - - def register_buffer(self, name, attr): - if type(attr) == torch.Tensor: - if attr.device != torch.device(self.device): - attr = attr.to(dtype=torch.float32, device=self.device) - setattr(self, name, attr) - - def make_schedule( - self, - ddim_num_steps, - ddim_discretize='uniform', - ddim_eta=0.0, - verbose=True, - ): - self.ddim_timesteps = make_ddim_timesteps( - ddim_discr_method=ddim_discretize, - num_ddim_timesteps=ddim_num_steps, - num_ddpm_timesteps=self.ddpm_num_timesteps, - verbose=verbose, - ) - alphas_cumprod = self.model.alphas_cumprod - assert ( - alphas_cumprod.shape[0] == self.ddpm_num_timesteps - ), 'alphas have to be defined for each timestep' - to_torch = ( - lambda x: x.clone() - .detach() - .to(torch.float32) - .to(self.model.device) - ) - - self.register_buffer('betas', to_torch(self.model.betas)) - self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) - self.register_buffer( - 'alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev) - ) - - # calculations for diffusion q(x_t | x_{t-1}) and others - self.register_buffer( - 'sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())) - ) - self.register_buffer( - 'sqrt_one_minus_alphas_cumprod', - to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())), - ) - self.register_buffer( - 'log_one_minus_alphas_cumprod', - to_torch(np.log(1.0 - alphas_cumprod.cpu())), - ) - self.register_buffer( - 'sqrt_recip_alphas_cumprod', - to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())), - ) - self.register_buffer( - 'sqrt_recipm1_alphas_cumprod', - to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)), - ) - - # ddim sampling parameters - ( - ddim_sigmas, - ddim_alphas, - ddim_alphas_prev, - ) = make_ddim_sampling_parameters( - alphacums=alphas_cumprod.cpu(), - ddim_timesteps=self.ddim_timesteps, - eta=ddim_eta, - verbose=verbose, - ) - self.register_buffer('ddim_sigmas', ddim_sigmas) - self.register_buffer('ddim_alphas', ddim_alphas) - self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) - self.register_buffer( - 'ddim_sqrt_one_minus_alphas', np.sqrt(1.0 - ddim_alphas) - ) - sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( - (1 - self.alphas_cumprod_prev) - / (1 - self.alphas_cumprod) - * (1 - self.alphas_cumprod / self.alphas_cumprod_prev) - ) - self.register_buffer( - 'ddim_sigmas_for_original_num_steps', - sigmas_for_original_sampling_steps, - ) + super().__init__(model,schedule,model.num_timesteps,device) + # This is the central routine @torch.no_grad() - def sample( - self, - S, - batch_size, - shape, - conditioning=None, - callback=None, - normals_sequence=None, - img_callback=None, - quantize_x0=False, - eta=0.0, - mask=None, - x0=None, - temperature=1.0, - noise_dropout=0.0, - score_corrector=None, - corrector_kwargs=None, - verbose=True, - x_T=None, - log_every_t=100, - unconditional_guidance_scale=1.0, - unconditional_conditioning=None, - # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... - **kwargs, - ): - if conditioning is not None: - if isinstance(conditioning, dict): - cbs = conditioning[list(conditioning.keys())[0]].shape[0] - if cbs != batch_size: - print( - f'Warning: Got {cbs} conditionings but batch-size is {batch_size}' - ) - else: - if conditioning.shape[0] != batch_size: - print( - f'Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}' - ) - - self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) - # sampling - C, H, W = shape - size = (batch_size, C, H, W) - print(f'Data shape for DDIM sampling is {size}, eta {eta}') - - samples, intermediates = self.ddim_sampling( - conditioning, - size, - callback=callback, - img_callback=img_callback, - quantize_denoised=quantize_x0, - mask=mask, - x0=x0, - ddim_use_original_steps=False, - noise_dropout=noise_dropout, - temperature=temperature, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - x_T=x_T, - log_every_t=log_every_t, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - ) - return samples, intermediates - - # This routine gets called from img2img - @torch.no_grad() - def ddim_sampling( - self, - cond, - shape, - x_T=None, - ddim_use_original_steps=False, - callback=None, - timesteps=None, - quantize_denoised=False, - mask=None, - x0=None, - img_callback=None, - log_every_t=100, - temperature=1.0, - noise_dropout=0.0, - score_corrector=None, - corrector_kwargs=None, - unconditional_guidance_scale=1.0, - unconditional_conditioning=None, - ): - device = self.model.betas.device - b = shape[0] - if x_T is None: - img = torch.randn(shape, device=device) - else: - img = x_T - - if timesteps is None: - timesteps = ( - self.ddpm_num_timesteps - if ddim_use_original_steps - else self.ddim_timesteps - ) - elif timesteps is not None and not ddim_use_original_steps: - subset_end = ( - int( - min(timesteps / self.ddim_timesteps.shape[0], 1) - * self.ddim_timesteps.shape[0] - ) - - 1 - ) - timesteps = self.ddim_timesteps[:subset_end] - - intermediates = {'x_inter': [img], 'pred_x0': [img]} - time_range = ( - reversed(range(0, timesteps)) - if ddim_use_original_steps - else np.flip(timesteps) - ) - total_steps = ( - timesteps if ddim_use_original_steps else timesteps.shape[0] - ) - print(f'\nRunning DDIM Sampling with {total_steps} timesteps') - - iterator = tqdm( - time_range, - desc='DDIM Sampler', - total=total_steps, - dynamic_ncols=True, - ) - - for i, step in enumerate(iterator): - index = total_steps - i - 1 - ts = torch.full((b,), step, device=device, dtype=torch.long) - - if mask is not None: - assert x0 is not None - img_orig = self.model.q_sample( - x0, ts - ) # TODO: deterministic forward pass? - img = img_orig * mask + (1.0 - mask) * img - - outs = self.p_sample_ddim( - img, - cond, - ts, - index=index, - use_original_steps=ddim_use_original_steps, - quantize_denoised=quantize_denoised, - temperature=temperature, - noise_dropout=noise_dropout, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - ) - img, pred_x0 = outs - if callback: - callback(i) - if img_callback: - img_callback(pred_x0, i) - - if index % log_every_t == 0 or index == total_steps - 1: - intermediates['x_inter'].append(img) - intermediates['pred_x0'].append(pred_x0) - - return img, intermediates - - # This routine gets called from ddim_sampling() and decode() - @torch.no_grad() - def p_sample_ddim( - self, - x, - c, - t, - index, - repeat_noise=False, - use_original_steps=False, - quantize_denoised=False, - temperature=1.0, - noise_dropout=0.0, - score_corrector=None, - corrector_kwargs=None, - unconditional_guidance_scale=1.0, - unconditional_conditioning=None, + def p_sample( + self, + x, + c, + t, + index, + repeat_noise=False, + use_original_steps=False, + quantize_denoised=False, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + **kwargs, ): b, *_, device = *x.shape, x.device @@ -351,83 +93,5 @@ class DDIMSampler(object): if noise_dropout > 0.0: noise = torch.nn.functional.dropout(noise, p=noise_dropout) x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise - return x_prev, pred_x0 + return x_prev, pred_x0, None - @torch.no_grad() - def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): - # fast, but does not allow for exact reconstruction - # t serves as an index to gather the correct alphas - if use_original_steps: - sqrt_alphas_cumprod = self.sqrt_alphas_cumprod - sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod - else: - sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) - sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas - - if noise is None: - noise = torch.randn_like(x0) - return ( - extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 - + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) - * noise - ) - - @torch.no_grad() - def decode( - self, - x_latent, - cond, - t_start, - img_callback=None, - unconditional_guidance_scale=1.0, - unconditional_conditioning=None, - use_original_steps=False, - init_latent = None, - mask = None, - ): - - timesteps = ( - np.arange(self.ddpm_num_timesteps) - if use_original_steps - else self.ddim_timesteps - ) - timesteps = timesteps[:t_start] - - time_range = np.flip(timesteps) - total_steps = timesteps.shape[0] - print(f'Running DDIM Sampling with {total_steps} timesteps') - - iterator = tqdm(time_range, desc='Decoding image', total=total_steps) - x_dec = x_latent - x0 = init_latent - - for i, step in enumerate(iterator): - index = total_steps - i - 1 - ts = torch.full( - (x_latent.shape[0],), - step, - device=x_latent.device, - dtype=torch.long, - ) - - if mask is not None: - assert x0 is not None - xdec_orig = self.model.q_sample( - x0, ts - ) # TODO: deterministic forward pass? - x_dec = xdec_orig * mask + (1.0 - mask) * x_dec - - x_dec, _ = self.p_sample_ddim( - x_dec, - cond, - ts, - index=index, - use_original_steps=use_original_steps, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - ) - - if img_callback: - img_callback(x_dec, i) - - return x_dec diff --git a/ldm/models/diffusion/ksampler.py b/ldm/models/diffusion/ksampler.py index 0f6814940e..963657f6ac 100644 --- a/ldm/models/diffusion/ksampler.py +++ b/ldm/models/diffusion/ksampler.py @@ -3,6 +3,7 @@ import k_diffusion as K import torch import torch.nn as nn from ldm.dream.devices import choose_torch_device +from ldm.models.diffusion.sampler import Sampler class CFGDenoiser(nn.Module): def __init__(self, model): @@ -17,12 +18,16 @@ class CFGDenoiser(nn.Module): return uncond + (cond - uncond) * cond_scale -class KSampler(object): +class KSampler(Sampler): def __init__(self, model, schedule='lms', device=None, **kwargs): - super().__init__() - self.model = K.external.CompVisDenoiser(model) - self.schedule = schedule - self.device = device or choose_torch_device() + denoiser = K.external.CompVisDenoiser(model) + super().__init__( + denoiser, + schedule, + steps=model.num_timesteps, + ) + self.ds = None + self.s_in = None def forward(self, x, sigma, uncond, cond, cond_scale): x_in = torch.cat([x] * 2) @@ -33,7 +38,40 @@ class KSampler(object): ).chunk(2) return uncond + (cond - uncond) * cond_scale - # most of these arguments are ignored and are only present for compatibility with + def make_schedule( + self, + ddim_num_steps, + ddim_discretize='uniform', + ddim_eta=0.0, + verbose=False, + ): + outer_model = self.model + self.model = outer_model.inner_model + super().make_schedule( + ddim_num_steps, + ddim_discretize='uniform', + ddim_eta=0.0, + verbose=False, + ) + self.model = outer_model + self.ddim_num_steps = ddim_num_steps + sigmas = K.sampling.get_sigmas_karras( + n=ddim_num_steps, + sigma_min=self.model.sigmas[0].item(), + sigma_max=self.model.sigmas[-1].item(), + rho=7., + device=self.device, + # Birch-san recommends this, but it doesn't match the call signature in his branch of k-diffusion + # concat_zero=False + ) + self.sigmas = sigmas + + # ALERT: We are completely overriding the sample() method in the base class, which + # means that inpainting will (probably?) not work correctly. To get this to work + # we need to be able to modify the inner loop of k_heun, k_lms, etc, as is done + # in an ugly way in the lstein/k-diffusion branch. + + # Most of these arguments are ignored and are only present for compatibility with # other samples @torch.no_grad() def sample( @@ -63,9 +101,11 @@ class KSampler(object): ): def route_callback(k_callback_values): if img_callback is not None: - img_callback(k_callback_values['x'], k_callback_values['i']) + img_callback(k_callback_values['x']) - sigmas = self.model.get_sigmas(S) + # sigmas = self.model.get_sigmas(S) + # sigmas are now set up in make_schedule - we take the last steps items + sigmas = self.sigmas[-S:] if x_T is not None: x = x_T * sigmas[0] else: @@ -86,3 +126,67 @@ class KSampler(object): ), None, ) + + @torch.no_grad() + def p_sample( + self, + img, + cond, + ts, + index, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + **kwargs, + ): + if self.model_wrap is None: + self.model_wrap = CFGDenoiser(self.model) + extra_args = { + 'cond': cond, + 'uncond': unconditional_conditioning, + 'cond_scale': unconditional_guidance_scale, + } + if self.s_in is None: + self.s_in = img.new_ones([img.shape[0]]) + if self.ds is None: + self.ds = [] + + # terrible, confusing names here + steps = self.ddim_num_steps + t_enc = self.t_enc + + # sigmas is a full steps in length, but t_enc might + # be less. We start in the middle of the sigma array + # and work our way to the end after t_enc steps. + # index starts at t_enc and works its way to zero, + # so the actual formula for indexing into sigmas: + # sigma_index = (steps-index) + s_index = t_enc - index - 1 + img = K.sampling.__dict__[f'_{self.schedule}']( + self.model_wrap, + img, + self.sigmas, + s_index, + s_in = self.s_in, + ds = self.ds, + extra_args=extra_args, + ) + + return img, None, None + + def get_initial_image(self,x_T,shape,steps): + if x_T is not None: + return x_T + x_T * self.sigmas[0] + else: + return (torch.randn(shape, device=self.device) * self.sigmas[0]) + + def prepare_to_sample(self,t_enc): + self.t_enc = t_enc + self.model_wrap = None + self.ds = None + self.s_in = None + + def q_sample(self,x0,ts): + ''' + Overrides parent method to return the q_sample of the inner model. + ''' + return self.model.inner_model.q_sample(x0,ts) diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py index 33bb7201c9..870d8739ec 100644 --- a/ldm/models/diffusion/plms.py +++ b/ldm/models/diffusion/plms.py @@ -5,302 +5,34 @@ import numpy as np from tqdm import tqdm from functools import partial from ldm.dream.devices import choose_torch_device - -from ldm.modules.diffusionmodules.util import ( - make_ddim_sampling_parameters, - make_ddim_timesteps, - noise_like, -) +from ldm.models.diffusion.sampler import Sampler +from ldm.modules.diffusionmodules.util import noise_like -class PLMSSampler(object): +class PLMSSampler(Sampler): def __init__(self, model, schedule='linear', device=None, **kwargs): - super().__init__() - self.model = model - self.ddpm_num_timesteps = model.num_timesteps - self.schedule = schedule - self.device = device if device else choose_torch_device() - - def register_buffer(self, name, attr): - if type(attr) == torch.Tensor: - if attr.device != torch.device(self.device): - attr = attr.to(torch.float32).to(torch.device(self.device)) - setattr(self, name, attr) - - def make_schedule( - self, - ddim_num_steps, - ddim_discretize='uniform', - ddim_eta=0.0, - verbose=True, - ): - if ddim_eta != 0: - raise ValueError('ddim_eta must be 0 for PLMS') - self.ddim_timesteps = make_ddim_timesteps( - ddim_discr_method=ddim_discretize, - num_ddim_timesteps=ddim_num_steps, - num_ddpm_timesteps=self.ddpm_num_timesteps, - verbose=verbose, - ) - alphas_cumprod = self.model.alphas_cumprod - assert ( - alphas_cumprod.shape[0] == self.ddpm_num_timesteps - ), 'alphas have to be defined for each timestep' - to_torch = ( - lambda x: x.clone() - .detach() - .to(torch.float32) - .to(self.model.device) - ) - - self.register_buffer('betas', to_torch(self.model.betas)) - self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) - self.register_buffer( - 'alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev) - ) - - # calculations for diffusion q(x_t | x_{t-1}) and others - self.register_buffer( - 'sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())) - ) - self.register_buffer( - 'sqrt_one_minus_alphas_cumprod', - to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())), - ) - self.register_buffer( - 'log_one_minus_alphas_cumprod', - to_torch(np.log(1.0 - alphas_cumprod.cpu())), - ) - self.register_buffer( - 'sqrt_recip_alphas_cumprod', - to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())), - ) - self.register_buffer( - 'sqrt_recipm1_alphas_cumprod', - to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)), - ) - - # ddim sampling parameters - ( - ddim_sigmas, - ddim_alphas, - ddim_alphas_prev, - ) = make_ddim_sampling_parameters( - alphacums=alphas_cumprod.cpu(), - ddim_timesteps=self.ddim_timesteps, - eta=ddim_eta, - verbose=verbose, - ) - self.register_buffer('ddim_sigmas', ddim_sigmas) - self.register_buffer('ddim_alphas', ddim_alphas) - self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) - self.register_buffer( - 'ddim_sqrt_one_minus_alphas', np.sqrt(1.0 - ddim_alphas) - ) - sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( - (1 - self.alphas_cumprod_prev) - / (1 - self.alphas_cumprod) - * (1 - self.alphas_cumprod / self.alphas_cumprod_prev) - ) - self.register_buffer( - 'ddim_sigmas_for_original_num_steps', - sigmas_for_original_sampling_steps, - ) + super().__init__(model,schedule,model.num_timesteps, device) + # this is the essential routine @torch.no_grad() - def sample( - self, - S, - batch_size, - shape, - conditioning=None, - callback=None, - normals_sequence=None, - img_callback=None, - quantize_x0=False, - eta=0.0, - mask=None, - x0=None, - temperature=1.0, - noise_dropout=0.0, - score_corrector=None, - corrector_kwargs=None, - verbose=True, - x_T=None, - log_every_t=100, - unconditional_guidance_scale=1.0, - unconditional_conditioning=None, - # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... - **kwargs, - ): - if conditioning is not None: - if isinstance(conditioning, dict): - cbs = conditioning[list(conditioning.keys())[0]].shape[0] - if cbs != batch_size: - print( - f'Warning: Got {cbs} conditionings but batch-size is {batch_size}' - ) - else: - if conditioning.shape[0] != batch_size: - print( - f'Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}' - ) - - self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) - # sampling - C, H, W = shape - size = (batch_size, C, H, W) - # print(f'Data shape for PLMS sampling is {size}') - - samples, intermediates = self.plms_sampling( - conditioning, - size, - callback=callback, - img_callback=img_callback, - quantize_denoised=quantize_x0, - mask=mask, - x0=x0, - ddim_use_original_steps=False, - noise_dropout=noise_dropout, - temperature=temperature, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - x_T=x_T, - log_every_t=log_every_t, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - ) - return samples, intermediates - - @torch.no_grad() - def plms_sampling( - self, - cond, - shape, - x_T=None, - ddim_use_original_steps=False, - callback=None, - timesteps=None, - quantize_denoised=False, - mask=None, - x0=None, - img_callback=None, - log_every_t=100, - temperature=1.0, - noise_dropout=0.0, - score_corrector=None, - corrector_kwargs=None, - unconditional_guidance_scale=1.0, - unconditional_conditioning=None, - ): - device = self.model.betas.device - b = shape[0] - if x_T is None: - img = torch.randn(shape, device=device) - else: - img = x_T - - if timesteps is None: - timesteps = ( - self.ddpm_num_timesteps - if ddim_use_original_steps - else self.ddim_timesteps - ) - elif timesteps is not None and not ddim_use_original_steps: - subset_end = ( - int( - min(timesteps / self.ddim_timesteps.shape[0], 1) - * self.ddim_timesteps.shape[0] - ) - - 1 - ) - timesteps = self.ddim_timesteps[:subset_end] - - intermediates = {'x_inter': [img], 'pred_x0': [img]} - time_range = ( - list(reversed(range(0, timesteps))) - if ddim_use_original_steps - else np.flip(timesteps) - ) - total_steps = ( - timesteps if ddim_use_original_steps else timesteps.shape[0] - ) - # print(f"Running PLMS Sampling with {total_steps} timesteps") - - iterator = tqdm( - time_range, - desc='PLMS Sampler', - total=total_steps, - dynamic_ncols=True, - ) - old_eps = [] - - for i, step in enumerate(iterator): - index = total_steps - i - 1 - ts = torch.full((b,), step, device=device, dtype=torch.long) - ts_next = torch.full( - (b,), - time_range[min(i + 1, len(time_range) - 1)], - device=device, - dtype=torch.long, - ) - - if mask is not None: - assert x0 is not None - img_orig = self.model.q_sample( - x0, ts - ) # TODO: deterministic forward pass? - img = img_orig * mask + (1.0 - mask) * img - - outs = self.p_sample_plms( - img, - cond, - ts, - index=index, - use_original_steps=ddim_use_original_steps, - quantize_denoised=quantize_denoised, - temperature=temperature, - noise_dropout=noise_dropout, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - old_eps=old_eps, - t_next=ts_next, - ) - img, pred_x0, e_t = outs - old_eps.append(e_t) - if len(old_eps) >= 4: - old_eps.pop(0) - if callback: - callback(i) - if img_callback: - img_callback(pred_x0, i) - - if index % log_every_t == 0 or index == total_steps - 1: - intermediates['x_inter'].append(img) - intermediates['pred_x0'].append(pred_x0) - - return img, intermediates - - @torch.no_grad() - def p_sample_plms( - self, - x, - c, - t, - index, - repeat_noise=False, - use_original_steps=False, - quantize_denoised=False, - temperature=1.0, - noise_dropout=0.0, - score_corrector=None, - corrector_kwargs=None, - unconditional_guidance_scale=1.0, - unconditional_conditioning=None, - old_eps=None, - t_next=None, + def p_sample( + self, + x, # image, called 'img' elsewhere + c, # conditioning, called 'cond' elsewhere + t, # timesteps, called 'ts' elsewhere + index, + repeat_noise=False, + use_original_steps=False, + quantize_denoised=False, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + old_eps=[], + t_next=None, + **kwargs, ): b, *_, device = *x.shape, x.device diff --git a/ldm/models/diffusion/sampler.py b/ldm/models/diffusion/sampler.py new file mode 100644 index 0000000000..dbde607023 --- /dev/null +++ b/ldm/models/diffusion/sampler.py @@ -0,0 +1,402 @@ +''' +ldm.models.diffusion.sampler + +Base class for ldm.models.diffusion.ddim, ldm.models.diffusion.ksampler, etc + +''' +import torch +import numpy as np +from tqdm import tqdm +from functools import partial +from ldm.dream.devices import choose_torch_device + +from ldm.modules.diffusionmodules.util import ( + make_ddim_sampling_parameters, + make_ddim_timesteps, + noise_like, + extract_into_tensor, +) + +class Sampler(object): + def __init__(self, model, schedule='linear', steps=None, device=None, **kwargs): + self.model = model + self.ddpm_num_timesteps = steps + self.schedule = schedule + self.device = device or choose_torch_device() + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device(self.device): + attr = attr.to(torch.float32).to(torch.device(self.device)) + setattr(self, name, attr) + + # This method was copied over from ddim.py and probably does stuff that is + # ddim-specific. Disentangle at some point. + def make_schedule( + self, + ddim_num_steps, + ddim_discretize='uniform', + ddim_eta=0.0, + verbose=False, + ): + self.ddim_timesteps = make_ddim_timesteps( + ddim_discr_method=ddim_discretize, + num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps, + verbose=verbose, + ) + alphas_cumprod = self.model.alphas_cumprod + assert ( + alphas_cumprod.shape[0] == self.ddpm_num_timesteps + ), 'alphas have to be defined for each timestep' + to_torch = ( + lambda x: x.clone() + .detach() + .to(torch.float32) + .to(self.model.device) + ) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer( + 'alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev) + ) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer( + 'sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())) + ) + self.register_buffer( + 'sqrt_one_minus_alphas_cumprod', + to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())), + ) + self.register_buffer( + 'log_one_minus_alphas_cumprod', + to_torch(np.log(1.0 - alphas_cumprod.cpu())), + ) + self.register_buffer( + 'sqrt_recip_alphas_cumprod', + to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())), + ) + self.register_buffer( + 'sqrt_recipm1_alphas_cumprod', + to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)), + ) + + # ddim sampling parameters + ( + ddim_sigmas, + ddim_alphas, + ddim_alphas_prev, + ) = make_ddim_sampling_parameters( + alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta, + verbose=verbose, + ) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer( + 'ddim_sqrt_one_minus_alphas', np.sqrt(1.0 - ddim_alphas) + ) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) + / (1 - self.alphas_cumprod) + * (1 - self.alphas_cumprod / self.alphas_cumprod_prev) + ) + self.register_buffer( + 'ddim_sigmas_for_original_num_steps', + sigmas_for_original_sampling_steps, + ) + + @torch.no_grad() + def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): + # fast, but does not allow for exact reconstruction + # t serves as an index to gather the correct alphas + if use_original_steps: + sqrt_alphas_cumprod = self.sqrt_alphas_cumprod + sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod + else: + sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) + sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas + + if noise is None: + noise = torch.randn_like(x0) + return ( + extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) + * noise + ) + + @torch.no_grad() + def sample( + self, + S, # S is steps + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0.0, + mask=None, + x0=None, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + verbose=False, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs, + ): + + ts = self.get_timesteps(S) + + # sampling + C, H, W = shape + shape = (batch_size, C, H, W) + samples, intermediates = self.do_sampling( + conditioning, + shape, + timesteps=ts, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, + x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + steps=S, + ) + return samples, intermediates + + #torch.no_grad() + def do_sampling( + self, + cond, + shape, + timesteps=None, + x_T=None, + ddim_use_original_steps=False, + callback=None, + quantize_denoised=False, + mask=None, + x0=None, + img_callback=None, + log_every_t=100, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + steps=None, + ): + b = shape[0] + time_range = ( + list(reversed(range(0, timesteps))) + if ddim_use_original_steps + else np.flip(timesteps) + ) + total_steps=steps + + iterator = tqdm( + time_range, + desc=f'{self.__class__.__name__}', + total=total_steps, + dynamic_ncols=True, + ) + old_eps = [] + self.prepare_to_sample(t_enc=total_steps) + img = self.get_initial_image(x_T,shape,total_steps) + + # probably don't need this at all + intermediates = {'x_inter': [img], 'pred_x0': [img]} + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full( + (b,), + step, + device=self.device, + dtype=torch.long + ) + ts_next = torch.full( + (b,), + time_range[min(i + 1, len(time_range) - 1)], + device=self.device, + dtype=torch.long, + ) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample( + x0, ts + ) # TODO: deterministic forward pass? + img = img_orig * mask + (1.0 - mask) * img + + outs = self.p_sample( + img, + cond, + ts, + index=index, + use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, + temperature=temperature, + noise_dropout=noise_dropout, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + old_eps=old_eps, + t_next=ts_next, + ) + img, pred_x0, e_t = outs + + old_eps.append(e_t) + if len(old_eps) >= 4: + old_eps.pop(0) + if callback: + callback(i) + if img_callback: + img_callback(img) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + # NOTE that decode() and sample() are almost the same code, and do the same thing. + # The variable names are changed in order to be confusing. + @torch.no_grad() + def decode( + self, + x_latent, + cond, + t_start, + img_callback=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + use_original_steps=False, + init_latent = None, + mask = None, + ): + + timesteps = ( + np.arange(self.ddpm_num_timesteps) + if use_original_steps + else self.ddim_timesteps + ) + timesteps = timesteps[:t_start] + + time_range = np.flip(timesteps) + total_steps = timesteps.shape[0] + print(f'>> Running {self.__class__.__name__} Sampling with {total_steps} timesteps') + + iterator = tqdm(time_range, desc='Decoding image', total=total_steps) + x_dec = x_latent + x0 = init_latent + self.prepare_to_sample(t_enc=total_steps) + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full( + (x_latent.shape[0],), + step, + device=x_latent.device, + dtype=torch.long, + ) + + ts_next = torch.full( + (x_latent.shape[0],), + time_range[min(i + 1, len(time_range) - 1)], + device=self.device, + dtype=torch.long, + ) + + if mask is not None: + assert x0 is not None + xdec_orig = self.q_sample(x0, ts) # TODO: deterministic forward pass? + x_dec = xdec_orig * mask + (1.0 - mask) * x_dec + + outs = self.p_sample( + x_dec, + cond, + ts, + index=index, + use_original_steps=use_original_steps, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + t_next = ts_next, + ) + + x_dec, pred_x0, e_t = outs + if img_callback: + img_callback(img) + + return x_dec + + def get_initial_image(self,x_T,shape,timesteps=None): + if x_T is None: + return torch.randn(shape, device=self.device) + else: + return x_T + + def p_sample( + self, + img, + cond, + ts, + index, + repeat_noise=False, + use_original_steps=False, + quantize_denoised=False, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + old_eps=None, + t_next=None, + steps=None, + ): + raise NotImplementedError("p_sample() must be implemented in a descendent class") + + def prepare_to_sample(self,t_enc,**kwargs): + ''' + Hook that will be called right before the very first invocation of p_sample() + to allow subclass to do additional initialization. t_enc corresponds to the actual + number of steps that will be run, and may be less than total steps if img2img is + active. + ''' + pass + + def get_timesteps(self,ddim_steps): + ''' + The ddim and plms samplers work on timesteps. This method is called after + ddim_timesteps are created in make_schedule(), and selects the portion of + timesteps that will be used for sampling, depending on the t_enc in img2img. + ''' + return self.ddim_timesteps[:ddim_steps] + + def q_sample(self,x0,ts): + ''' + Returns self.model.q_sample(x0,ts). Is overridden in the k* samplers to + return self.model.inner_model.q_sample(x0,ts) + ''' + return self.model.q_sample(x0,ts) diff --git a/scripts/dream.py b/scripts/dream.py index 84dca75e8f..7099b978aa 100644 --- a/scripts/dream.py +++ b/scripts/dream.py @@ -324,6 +324,7 @@ def main_loop(gen, opt, infile): opt.last_operation='generate' gen.prompt2image( image_callback=image_writer, +# step_callback=gen.write_intermediate_images(5,'./outputs/img-samples/intermediates'), #DEBUGGING ONLY - DELETE catch_interrupts=catch_ctrl_c, **vars(opt) ) From a0f4af087c894a66c08feb393c05724e33b32170 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sat, 1 Oct 2022 15:50:05 -0400 Subject: [PATCH 2/7] restore use of sampler.decode() in img2img --- ldm/dream/generator/img2img.py | 20 +++++++++----------- ldm/models/diffusion/ksampler.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 11 deletions(-) diff --git a/ldm/dream/generator/img2img.py b/ldm/dream/generator/img2img.py index dbab188c4a..8a2bccf055 100644 --- a/ldm/dream/generator/img2img.py +++ b/ldm/dream/generator/img2img.py @@ -40,18 +40,16 @@ class Img2Img(Generator): torch.tensor([t_enc]).to(self.model.device), noise=x_T ) - samples,_ = sampler.sample( - batch_size = 1, - S = t_enc, - shape = z_enc.shape[1:], - x_T = z_enc, - conditioning = c, - unconditional_guidance_scale = cfg_scale, - unconditional_conditioning = uc, - eta = ddim_eta, - img_callback = step_callback, - verbose = False, + # decode it + samples = sampler.decode( + z_enc, + c, + t_enc, + img_callback = step_callback, + unconditional_guidance_scale=cfg_scale, + unconditional_conditioning=uc, ) + return self.sample_to_image(samples) return make_image diff --git a/ldm/models/diffusion/ksampler.py b/ldm/models/diffusion/ksampler.py index 963657f6ac..97cd9cacde 100644 --- a/ldm/models/diffusion/ksampler.py +++ b/ldm/models/diffusion/ksampler.py @@ -119,6 +119,7 @@ class KSampler(Sampler): 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale, } + print(f'>> Sampling with k__{self.schedule}') return ( K.sampling.__dict__[f'sample_{self.schedule}']( model_wrap_cfg, x, sigmas, extra_args=extra_args, @@ -190,3 +191,30 @@ class KSampler(Sampler): Overrides parent method to return the q_sample of the inner model. ''' return self.model.inner_model.q_sample(x0,ts) + + @torch.no_grad() + def decode( + self, + z_enc, + cond, + t_enc, + img_callback=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + use_original_steps=False, + init_latent = None, + mask = None, + ): + samples,_ = self.sample( + batch_size = 1, + S = t_enc, + x_T = z_enc, + shape = z_enc.shape[1:], + conditioning = cond, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning = unconditional_conditioning, + img_callback = img_callback, + x0 = init_latent, + mask = mask + ) + return samples From 8ba5e385ec979f4728c0995660e9762a478a11cd Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 2 Oct 2022 15:56:57 +1100 Subject: [PATCH 3/7] Fixes #877 --- ldm/dream/args.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ldm/dream/args.py b/ldm/dream/args.py index 8f04cf1ea4..dad20c1450 100644 --- a/ldm/dream/args.py +++ b/ldm/dream/args.py @@ -227,8 +227,8 @@ class Args(object): # 2. However, they come out of the CLI (and probably web) with the keyword "with_variations" and # in broken-out form. Variation (1) should be changed to comply with (2) if a['with_variations']: - formatted_variations = ','.join(f'{seed}:{weight}' for seed, weight in (a["variations"])) - switches.append(f'-V {a["formatted_variations"]}') + formatted_variations = ','.join(f'{seed}:{weight}' for seed, weight in (a["with_variations"])) + switches.append(f'-V {formatted_variations}') if 'variations' in a: switches.append(f'-V {a["variations"]}') return ' '.join(switches) From 5a88be3744deb8079d6c63e78d1f5656fef7ae60 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sun, 2 Oct 2022 22:31:11 -0400 Subject: [PATCH 4/7] fix typo which caused crash in sampler.py --- ldm/models/diffusion/sampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ldm/models/diffusion/sampler.py b/ldm/models/diffusion/sampler.py index dbde607023..ad2e8f789f 100644 --- a/ldm/models/diffusion/sampler.py +++ b/ldm/models/diffusion/sampler.py @@ -346,7 +346,7 @@ class Sampler(object): x_dec, pred_x0, e_t = outs if img_callback: - img_callback(img) + img_callback(x_dec) return x_dec From 8e97bc24a4fa044f703aa1016eaf09db58b70551 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 3 Oct 2022 05:38:43 -0400 Subject: [PATCH 5/7] restore step argument to step_callback --- ldm/models/diffusion/ksampler.py | 2 +- ldm/models/diffusion/sampler.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ldm/models/diffusion/ksampler.py b/ldm/models/diffusion/ksampler.py index 97cd9cacde..d1064f06b8 100644 --- a/ldm/models/diffusion/ksampler.py +++ b/ldm/models/diffusion/ksampler.py @@ -101,7 +101,7 @@ class KSampler(Sampler): ): def route_callback(k_callback_values): if img_callback is not None: - img_callback(k_callback_values['x']) + img_callback(k_callback_values['x'],k_callback_values['i']) # sigmas = self.model.get_sigmas(S) # sigmas are now set up in make_schedule - we take the last steps items diff --git a/ldm/models/diffusion/sampler.py b/ldm/models/diffusion/sampler.py index ad2e8f789f..78196f8891 100644 --- a/ldm/models/diffusion/sampler.py +++ b/ldm/models/diffusion/sampler.py @@ -272,7 +272,7 @@ class Sampler(object): if callback: callback(i) if img_callback: - img_callback(img) + img_callback(img,step) if index % log_every_t == 0 or index == total_steps - 1: intermediates['x_inter'].append(img) From 7c6dbcb14a75fd6cfceddf92c0336cfd4609b000 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 3 Oct 2022 05:46:23 -0400 Subject: [PATCH 6/7] use right value for step arg in img_callback --- ldm/models/diffusion/sampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ldm/models/diffusion/sampler.py b/ldm/models/diffusion/sampler.py index 78196f8891..437b674518 100644 --- a/ldm/models/diffusion/sampler.py +++ b/ldm/models/diffusion/sampler.py @@ -272,7 +272,7 @@ class Sampler(object): if callback: callback(i) if img_callback: - img_callback(img,step) + img_callback(img,i) if index % log_every_t == 0 or index == total_steps - 1: intermediates['x_inter'].append(img) From 1ab09e7a066ef225c0c8acdca6f4214eb19992dc Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 3 Oct 2022 05:47:32 -0400 Subject: [PATCH 7/7] use right value for step arg in img_callback in decode() --- ldm/models/diffusion/sampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ldm/models/diffusion/sampler.py b/ldm/models/diffusion/sampler.py index 437b674518..b62278c719 100644 --- a/ldm/models/diffusion/sampler.py +++ b/ldm/models/diffusion/sampler.py @@ -346,7 +346,7 @@ class Sampler(object): x_dec, pred_x0, e_t = outs if img_callback: - img_callback(x_dec) + img_callback(x_dec,i) return x_dec