mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add configs for training unconditional/class-conditional ldms
This commit is contained in:
@ -16,14 +16,14 @@ from contextlib import contextmanager
|
||||
from functools import partial
|
||||
from tqdm import tqdm
|
||||
from torchvision.utils import make_grid
|
||||
from PIL import Image
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_only
|
||||
|
||||
from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
|
||||
from ldm.modules.ema import LitEma
|
||||
from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
|
||||
from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL
|
||||
from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor
|
||||
from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
|
||||
|
||||
__conditioning_keys__ = {'concat': 'c_concat',
|
||||
@ -37,12 +37,6 @@ def disabled_train(self, mode=True):
|
||||
return self
|
||||
|
||||
|
||||
def noise_like(shape, device, repeat=False):
|
||||
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
|
||||
noise = lambda: torch.randn(shape, device=device)
|
||||
return repeat_noise() if repeat else noise()
|
||||
|
||||
|
||||
def uniform_on_device(r1, r2, shape, device):
|
||||
return (r1 - r2) * torch.rand(*shape, device=device) + r2
|
||||
|
||||
@ -119,6 +113,7 @@ class DDPM(pl.LightningModule):
|
||||
if self.learn_logvar:
|
||||
self.logvar = nn.Parameter(self.logvar, requires_grad=True)
|
||||
|
||||
|
||||
def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
|
||||
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
||||
if exists(given_betas):
|
||||
@ -1188,7 +1183,6 @@ class LatentDiffusion(DDPM):
|
||||
|
||||
if start_T is not None:
|
||||
timesteps = min(timesteps, start_T)
|
||||
print(timesteps, start_T)
|
||||
iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
|
||||
range(0, timesteps))
|
||||
|
||||
@ -1222,7 +1216,7 @@ class LatentDiffusion(DDPM):
|
||||
@torch.no_grad()
|
||||
def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
|
||||
verbose=True, timesteps=None, quantize_denoised=False,
|
||||
mask=None, x0=None, shape=None):
|
||||
mask=None, x0=None, shape=None,**kwargs):
|
||||
if shape is None:
|
||||
shape = (batch_size, self.channels, self.image_size, self.image_size)
|
||||
if cond is not None:
|
||||
@ -1238,10 +1232,28 @@ class LatentDiffusion(DDPM):
|
||||
mask=mask, x0=x0)
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(self, batch, N=8, n_row=4, sample=True, sample_ddim=False, return_keys=None,
|
||||
def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs):
|
||||
|
||||
if ddim:
|
||||
ddim_sampler = DDIMSampler(self)
|
||||
shape = (self.channels, self.image_size, self.image_size)
|
||||
samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size,
|
||||
shape,cond,verbose=False,**kwargs)
|
||||
|
||||
else:
|
||||
samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
|
||||
return_intermediates=True,**kwargs)
|
||||
|
||||
return samples, intermediates
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
|
||||
quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
|
||||
plot_diffusion_rows=True, **kwargs):
|
||||
# TODO: maybe add option for ddim sampling via DDIMSampler class
|
||||
|
||||
use_ddim = ddim_steps is not None
|
||||
|
||||
log = dict()
|
||||
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
|
||||
return_first_stage_outputs=True,
|
||||
@ -1288,7 +1300,9 @@ class LatentDiffusion(DDPM):
|
||||
if sample:
|
||||
# get denoise row
|
||||
with self.ema_scope("Plotting"):
|
||||
samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
|
||||
samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
|
||||
ddim_steps=ddim_steps,eta=ddim_eta)
|
||||
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
|
||||
x_samples = self.decode_first_stage(samples)
|
||||
log["samples"] = x_samples
|
||||
if plot_denoise_rows:
|
||||
@ -1299,8 +1313,11 @@ class LatentDiffusion(DDPM):
|
||||
self.first_stage_model, IdentityFirstStage):
|
||||
# also display when quantizing x0 while sampling
|
||||
with self.ema_scope("Plotting Quantized Denoised"):
|
||||
samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
|
||||
quantize_denoised=True)
|
||||
samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
|
||||
ddim_steps=ddim_steps,eta=ddim_eta,
|
||||
quantize_denoised=True)
|
||||
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
|
||||
# quantize_denoised=True)
|
||||
x_samples = self.decode_first_stage(samples.to(self.device))
|
||||
log["samples_x0_quantized"] = x_samples
|
||||
|
||||
@ -1312,19 +1329,17 @@ class LatentDiffusion(DDPM):
|
||||
mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
|
||||
mask = mask[:, None, ...]
|
||||
with self.ema_scope("Plotting Inpaint"):
|
||||
samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
|
||||
quantize_denoised=False, x0=z[:N], mask=mask)
|
||||
|
||||
samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta,
|
||||
ddim_steps=ddim_steps, x0=z[:N], mask=mask)
|
||||
x_samples = self.decode_first_stage(samples.to(self.device))
|
||||
log["samples_inpainting"] = x_samples
|
||||
log["mask"] = mask
|
||||
if plot_denoise_rows:
|
||||
denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
|
||||
log["denoise_row_inpainting"] = denoise_grid
|
||||
|
||||
# outpaint
|
||||
with self.ema_scope("Plotting Outpaint"):
|
||||
samples = self.sample(cond=c, batch_size=N, return_intermediates=False,
|
||||
quantize_denoised=False, x0=z[:N], mask=1. - mask)
|
||||
samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta,
|
||||
ddim_steps=ddim_steps, x0=z[:N], mask=mask)
|
||||
x_samples = self.decode_first_stage(samples.to(self.device))
|
||||
log["samples_outpainting"] = x_samples
|
||||
|
||||
|
Reference in New Issue
Block a user