Fixed up a few merge conflicts, looks good so far

This commit is contained in:
Lincoln Stein
2022-08-24 11:29:32 -04:00
19 changed files with 1318 additions and 35 deletions

View File

@ -7,7 +7,9 @@ https://github.com/CompVis/taming-transformers
"""
import torch
import torch.nn as nn
import os
import numpy as np
import pytorch_lightning as pl
from torch.optim.lr_scheduler import LambdaLR
@ -64,6 +66,7 @@ class DDPM(pl.LightningModule):
cosine_s=8e-3,
given_betas=None,
original_elbo_weight=0.,
embedding_reg_weight=0.,
v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
l_simple_weight=1.,
conditioning_key=None,
@ -98,6 +101,7 @@ class DDPM(pl.LightningModule):
self.v_posterior = v_posterior
self.original_elbo_weight = original_elbo_weight
self.l_simple_weight = l_simple_weight
self.embedding_reg_weight = embedding_reg_weight
if monitor is not None:
self.monitor = monitor
@ -427,6 +431,7 @@ class LatentDiffusion(DDPM):
def __init__(self,
first_stage_config,
cond_stage_config,
personalization_config,
num_timesteps_cond=None,
cond_stage_key="image",
cond_stage_trainable=False,
@ -436,6 +441,7 @@ class LatentDiffusion(DDPM):
scale_factor=1.0,
scale_by_std=False,
*args, **kwargs):
self.num_timesteps_cond = default(num_timesteps_cond, 1)
self.scale_by_std = scale_by_std
assert self.num_timesteps_cond <= kwargs['timesteps']
@ -450,6 +456,7 @@ class LatentDiffusion(DDPM):
self.concat_mode = concat_mode
self.cond_stage_trainable = cond_stage_trainable
self.cond_stage_key = cond_stage_key
try:
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
except:
@ -460,6 +467,7 @@ class LatentDiffusion(DDPM):
self.register_buffer('scale_factor', torch.tensor(scale_factor))
self.instantiate_first_stage(first_stage_config)
self.instantiate_cond_stage(cond_stage_config)
self.cond_stage_forward = cond_stage_forward
self.clip_denoised = False
self.bbox_tokenizer = None
@ -469,6 +477,25 @@ class LatentDiffusion(DDPM):
self.init_from_ckpt(ckpt_path, ignore_keys)
self.restarted_from_ckpt = True
self.cond_stage_model.train = disabled_train
for param in self.cond_stage_model.parameters():
param.requires_grad = False
self.model.eval()
self.model.train = disabled_train
for param in self.model.parameters():
param.requires_grad = False
self.embedding_manager = self.instantiate_embedding_manager(personalization_config, self.cond_stage_model)
self.emb_ckpt_counter = 0
# if self.embedding_manager.is_clip:
# self.cond_stage_model.update_embedding_func(self.embedding_manager)
for param in self.embedding_manager.embedding_parameters():
param.requires_grad = True
def make_cond_schedule(self, ):
self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
@ -530,6 +557,15 @@ class LatentDiffusion(DDPM):
except urllib.error.URLError:
raise SystemExit("* Couldn't load a dependency. Try running scripts/preload_models.py from an internet-conected machine.")
self.cond_stage_model = model
def instantiate_embedding_manager(self, config, embedder):
model = instantiate_from_config(config, embedder=embedder)
if config.params.get("embedding_manager_ckpt", None): # do not load if missing OR empty string
model.load(config.params.embedding_manager_ckpt)
return model
def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
denoise_row = []
@ -555,7 +591,7 @@ class LatentDiffusion(DDPM):
def get_learned_conditioning(self, c):
if self.cond_stage_forward is None:
if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
c = self.cond_stage_model.encode(c)
c = self.cond_stage_model.encode(c, embedding_manager=self.embedding_manager)
if isinstance(c, DiagonalGaussianDistribution):
c = c.mode()
else:
@ -880,6 +916,7 @@ class LatentDiffusion(DDPM):
if self.shorten_cond_schedule: # TODO: drop this option
tc = self.cond_ids[t].to(self.device)
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
return self.p_losses(x, c, t, *args, **kwargs)
def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
@ -1046,6 +1083,14 @@ class LatentDiffusion(DDPM):
loss += (self.original_elbo_weight * loss_vlb)
loss_dict.update({f'{prefix}/loss': loss})
if self.embedding_reg_weight > 0:
loss_embedding_reg = self.embedding_manager.embedding_to_coarse_loss().mean()
loss_dict.update({f'{prefix}/loss_emb_reg': loss_embedding_reg})
loss += (self.embedding_reg_weight * loss_embedding_reg)
loss_dict.update({f'{prefix}/loss': loss})
return loss, loss_dict
def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
@ -1250,11 +1295,10 @@ class LatentDiffusion(DDPM):
return samples, intermediates
@torch.no_grad()
def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
plot_diffusion_rows=True, **kwargs):
quantize_denoised=True, inpaint=False, plot_denoise_rows=False, plot_progressive_rows=False,
plot_diffusion_rows=False, **kwargs):
use_ddim = ddim_steps is not None
@ -1312,6 +1356,16 @@ class LatentDiffusion(DDPM):
if plot_denoise_rows:
denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
log["denoise_row"] = denoise_grid
uc = self.get_learned_conditioning(len(c) * [""])
sample_scaled, _ = self.sample_log(cond=c,
batch_size=N,
ddim=use_ddim,
ddim_steps=ddim_steps,
eta=ddim_eta,
unconditional_guidance_scale=5.0,
unconditional_conditioning=uc)
log["samples_scaled"] = self.decode_first_stage(sample_scaled)
if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
self.first_stage_model, IdentityFirstStage):
@ -1364,13 +1418,18 @@ class LatentDiffusion(DDPM):
def configure_optimizers(self):
lr = self.learning_rate
params = list(self.model.parameters())
if self.cond_stage_trainable:
print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
params = params + list(self.cond_stage_model.parameters())
if self.learn_logvar:
print('Diffusion model optimizing logvar')
params.append(self.logvar)
if self.embedding_manager is not None:
params = list(self.embedding_manager.embedding_parameters())
# params = list(self.cond_stage_model.transformer.text_model.embeddings.embedding_manager.embedding_parameters())
else:
params = list(self.model.parameters())
if self.cond_stage_trainable:
print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
params = params + list(self.cond_stage_model.parameters())
if self.learn_logvar:
print('Diffusion model optimizing logvar')
params.append(self.logvar)
opt = torch.optim.AdamW(params, lr=lr)
if self.use_scheduler:
assert 'target' in self.scheduler_config
@ -1395,6 +1454,18 @@ class LatentDiffusion(DDPM):
x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
return x
@rank_zero_only
def on_save_checkpoint(self, checkpoint):
checkpoint.clear()
if os.path.isdir(self.trainer.checkpoint_callback.dirpath):
self.embedding_manager.save(os.path.join(self.trainer.checkpoint_callback.dirpath, "embeddings.pt"))
if (self.global_step - self.emb_ckpt_counter) > 500:
self.embedding_manager.save(os.path.join(self.trainer.checkpoint_callback.dirpath, f"embeddings_gs-{self.global_step}.pt"))
self.emb_ckpt_counter += 500
class DiffusionWrapper(pl.LightningModule):
def __init__(self, diff_model_config, conditioning_key):