add configs for training unconditional/class-conditional ldms

This commit is contained in:
ablattmann
2021-12-22 15:57:23 +01:00
parent f8b4a07105
commit 171cf29fb5
13 changed files with 562 additions and 53 deletions

View File

@ -16,14 +16,14 @@ from contextlib import contextmanager
from functools import partial
from tqdm import tqdm
from torchvision.utils import make_grid
from PIL import Image
from pytorch_lightning.utilities.distributed import rank_zero_only
from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
from ldm.modules.ema import LitEma
from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL
from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor
from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
from ldm.models.diffusion.ddim import DDIMSampler
__conditioning_keys__ = {'concat': 'c_concat',
@ -37,12 +37,6 @@ def disabled_train(self, mode=True):
return self
def noise_like(shape, device, repeat=False):
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
noise = lambda: torch.randn(shape, device=device)
return repeat_noise() if repeat else noise()
def uniform_on_device(r1, r2, shape, device):
return (r1 - r2) * torch.rand(*shape, device=device) + r2
@ -119,6 +113,7 @@ class DDPM(pl.LightningModule):
if self.learn_logvar:
self.logvar = nn.Parameter(self.logvar, requires_grad=True)
def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
if exists(given_betas):
@ -1188,7 +1183,6 @@ class LatentDiffusion(DDPM):
if start_T is not None:
timesteps = min(timesteps, start_T)
print(timesteps, start_T)
iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
range(0, timesteps))
@ -1222,7 +1216,7 @@ class LatentDiffusion(DDPM):
@torch.no_grad()
def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
verbose=True, timesteps=None, quantize_denoised=False,
mask=None, x0=None, shape=None):
mask=None, x0=None, shape=None,**kwargs):
if shape is None:
shape = (batch_size, self.channels, self.image_size, self.image_size)
if cond is not None:
@ -1238,10 +1232,28 @@ class LatentDiffusion(DDPM):
mask=mask, x0=x0)
@torch.no_grad()
def log_images(self, batch, N=8, n_row=4, sample=True, sample_ddim=False, return_keys=None,
def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs):
if ddim:
ddim_sampler = DDIMSampler(self)
shape = (self.channels, self.image_size, self.image_size)
samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size,
shape,cond,verbose=False,**kwargs)
else:
samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
return_intermediates=True,**kwargs)
return samples, intermediates
@torch.no_grad()
def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
plot_diffusion_rows=True, **kwargs):
# TODO: maybe add option for ddim sampling via DDIMSampler class
use_ddim = ddim_steps is not None
log = dict()
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
return_first_stage_outputs=True,
@ -1288,7 +1300,9 @@ class LatentDiffusion(DDPM):
if sample:
# get denoise row
with self.ema_scope("Plotting"):
samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
ddim_steps=ddim_steps,eta=ddim_eta)
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
x_samples = self.decode_first_stage(samples)
log["samples"] = x_samples
if plot_denoise_rows:
@ -1299,8 +1313,11 @@ class LatentDiffusion(DDPM):
self.first_stage_model, IdentityFirstStage):
# also display when quantizing x0 while sampling
with self.ema_scope("Plotting Quantized Denoised"):
samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
quantize_denoised=True)
samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
ddim_steps=ddim_steps,eta=ddim_eta,
quantize_denoised=True)
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
# quantize_denoised=True)
x_samples = self.decode_first_stage(samples.to(self.device))
log["samples_x0_quantized"] = x_samples
@ -1312,19 +1329,17 @@ class LatentDiffusion(DDPM):
mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
mask = mask[:, None, ...]
with self.ema_scope("Plotting Inpaint"):
samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
quantize_denoised=False, x0=z[:N], mask=mask)
samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta,
ddim_steps=ddim_steps, x0=z[:N], mask=mask)
x_samples = self.decode_first_stage(samples.to(self.device))
log["samples_inpainting"] = x_samples
log["mask"] = mask
if plot_denoise_rows:
denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
log["denoise_row_inpainting"] = denoise_grid
# outpaint
with self.ema_scope("Plotting Outpaint"):
samples = self.sample(cond=c, batch_size=N, return_intermediates=False,
quantize_denoised=False, x0=z[:N], mask=1. - mask)
samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta,
ddim_steps=ddim_steps, x0=z[:N], mask=mask)
x_samples = self.decode_first_stage(samples.to(self.device))
log["samples_outpainting"] = x_samples