mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add personalization
This commit is contained in:
@ -7,7 +7,9 @@ https://github.com/CompVis/taming-transformers
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
import torch.nn as nn
|
||||
import os
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
@ -64,6 +66,7 @@ class DDPM(pl.LightningModule):
|
||||
cosine_s=8e-3,
|
||||
given_betas=None,
|
||||
original_elbo_weight=0.,
|
||||
embedding_reg_weight=0.,
|
||||
v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
|
||||
l_simple_weight=1.,
|
||||
conditioning_key=None,
|
||||
@ -98,6 +101,7 @@ class DDPM(pl.LightningModule):
|
||||
self.v_posterior = v_posterior
|
||||
self.original_elbo_weight = original_elbo_weight
|
||||
self.l_simple_weight = l_simple_weight
|
||||
self.embedding_reg_weight = embedding_reg_weight
|
||||
|
||||
if monitor is not None:
|
||||
self.monitor = monitor
|
||||
@ -427,6 +431,7 @@ class LatentDiffusion(DDPM):
|
||||
def __init__(self,
|
||||
first_stage_config,
|
||||
cond_stage_config,
|
||||
personalization_config,
|
||||
num_timesteps_cond=None,
|
||||
cond_stage_key="image",
|
||||
cond_stage_trainable=False,
|
||||
@ -436,6 +441,7 @@ class LatentDiffusion(DDPM):
|
||||
scale_factor=1.0,
|
||||
scale_by_std=False,
|
||||
*args, **kwargs):
|
||||
|
||||
self.num_timesteps_cond = default(num_timesteps_cond, 1)
|
||||
self.scale_by_std = scale_by_std
|
||||
assert self.num_timesteps_cond <= kwargs['timesteps']
|
||||
@ -450,6 +456,7 @@ class LatentDiffusion(DDPM):
|
||||
self.concat_mode = concat_mode
|
||||
self.cond_stage_trainable = cond_stage_trainable
|
||||
self.cond_stage_key = cond_stage_key
|
||||
|
||||
try:
|
||||
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
|
||||
except:
|
||||
@ -460,6 +467,7 @@ class LatentDiffusion(DDPM):
|
||||
self.register_buffer('scale_factor', torch.tensor(scale_factor))
|
||||
self.instantiate_first_stage(first_stage_config)
|
||||
self.instantiate_cond_stage(cond_stage_config)
|
||||
|
||||
self.cond_stage_forward = cond_stage_forward
|
||||
self.clip_denoised = False
|
||||
self.bbox_tokenizer = None
|
||||
@ -469,6 +477,25 @@ class LatentDiffusion(DDPM):
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys)
|
||||
self.restarted_from_ckpt = True
|
||||
|
||||
self.cond_stage_model.train = disabled_train
|
||||
for param in self.cond_stage_model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
self.model.eval()
|
||||
self.model.train = disabled_train
|
||||
for param in self.model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
self.embedding_manager = self.instantiate_embedding_manager(personalization_config, self.cond_stage_model)
|
||||
|
||||
self.emb_ckpt_counter = 0
|
||||
|
||||
# if self.embedding_manager.is_clip:
|
||||
# self.cond_stage_model.update_embedding_func(self.embedding_manager)
|
||||
|
||||
for param in self.embedding_manager.embedding_parameters():
|
||||
param.requires_grad = True
|
||||
|
||||
def make_cond_schedule(self, ):
|
||||
self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
|
||||
ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
|
||||
@ -530,6 +557,15 @@ class LatentDiffusion(DDPM):
|
||||
except urllib.error.URLError:
|
||||
raise SystemExit("* Couldn't load a dependency. Try running scripts/preload_models.py from an internet-conected machine.")
|
||||
self.cond_stage_model = model
|
||||
|
||||
|
||||
def instantiate_embedding_manager(self, config, embedder):
|
||||
model = instantiate_from_config(config, embedder=embedder)
|
||||
|
||||
if config.params.get("embedding_manager_ckpt", None): # do not load if missing OR empty string
|
||||
model.load(config.params.embedding_manager_ckpt)
|
||||
|
||||
return model
|
||||
|
||||
def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
|
||||
denoise_row = []
|
||||
@ -555,7 +591,7 @@ class LatentDiffusion(DDPM):
|
||||
def get_learned_conditioning(self, c):
|
||||
if self.cond_stage_forward is None:
|
||||
if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
|
||||
c = self.cond_stage_model.encode(c)
|
||||
c = self.cond_stage_model.encode(c, embedding_manager=self.embedding_manager)
|
||||
if isinstance(c, DiagonalGaussianDistribution):
|
||||
c = c.mode()
|
||||
else:
|
||||
@ -880,6 +916,7 @@ class LatentDiffusion(DDPM):
|
||||
if self.shorten_cond_schedule: # TODO: drop this option
|
||||
tc = self.cond_ids[t].to(self.device)
|
||||
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
|
||||
|
||||
return self.p_losses(x, c, t, *args, **kwargs)
|
||||
|
||||
def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
|
||||
@ -1046,6 +1083,14 @@ class LatentDiffusion(DDPM):
|
||||
loss += (self.original_elbo_weight * loss_vlb)
|
||||
loss_dict.update({f'{prefix}/loss': loss})
|
||||
|
||||
if self.embedding_reg_weight > 0:
|
||||
loss_embedding_reg = self.embedding_manager.embedding_to_coarse_loss().mean()
|
||||
|
||||
loss_dict.update({f'{prefix}/loss_emb_reg': loss_embedding_reg})
|
||||
|
||||
loss += (self.embedding_reg_weight * loss_embedding_reg)
|
||||
loss_dict.update({f'{prefix}/loss': loss})
|
||||
|
||||
return loss, loss_dict
|
||||
|
||||
def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
|
||||
@ -1250,11 +1295,10 @@ class LatentDiffusion(DDPM):
|
||||
|
||||
return samples, intermediates
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
|
||||
quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
|
||||
plot_diffusion_rows=True, **kwargs):
|
||||
quantize_denoised=True, inpaint=False, plot_denoise_rows=False, plot_progressive_rows=False,
|
||||
plot_diffusion_rows=False, **kwargs):
|
||||
|
||||
use_ddim = ddim_steps is not None
|
||||
|
||||
@ -1312,6 +1356,16 @@ class LatentDiffusion(DDPM):
|
||||
if plot_denoise_rows:
|
||||
denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
|
||||
log["denoise_row"] = denoise_grid
|
||||
|
||||
uc = self.get_learned_conditioning(len(c) * [""])
|
||||
sample_scaled, _ = self.sample_log(cond=c,
|
||||
batch_size=N,
|
||||
ddim=use_ddim,
|
||||
ddim_steps=ddim_steps,
|
||||
eta=ddim_eta,
|
||||
unconditional_guidance_scale=5.0,
|
||||
unconditional_conditioning=uc)
|
||||
log["samples_scaled"] = self.decode_first_stage(sample_scaled)
|
||||
|
||||
if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
|
||||
self.first_stage_model, IdentityFirstStage):
|
||||
@ -1364,13 +1418,18 @@ class LatentDiffusion(DDPM):
|
||||
|
||||
def configure_optimizers(self):
|
||||
lr = self.learning_rate
|
||||
params = list(self.model.parameters())
|
||||
if self.cond_stage_trainable:
|
||||
print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
|
||||
params = params + list(self.cond_stage_model.parameters())
|
||||
if self.learn_logvar:
|
||||
print('Diffusion model optimizing logvar')
|
||||
params.append(self.logvar)
|
||||
|
||||
if self.embedding_manager is not None:
|
||||
params = list(self.embedding_manager.embedding_parameters())
|
||||
# params = list(self.cond_stage_model.transformer.text_model.embeddings.embedding_manager.embedding_parameters())
|
||||
else:
|
||||
params = list(self.model.parameters())
|
||||
if self.cond_stage_trainable:
|
||||
print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
|
||||
params = params + list(self.cond_stage_model.parameters())
|
||||
if self.learn_logvar:
|
||||
print('Diffusion model optimizing logvar')
|
||||
params.append(self.logvar)
|
||||
opt = torch.optim.AdamW(params, lr=lr)
|
||||
if self.use_scheduler:
|
||||
assert 'target' in self.scheduler_config
|
||||
@ -1395,6 +1454,18 @@ class LatentDiffusion(DDPM):
|
||||
x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
|
||||
return x
|
||||
|
||||
@rank_zero_only
|
||||
def on_save_checkpoint(self, checkpoint):
|
||||
checkpoint.clear()
|
||||
|
||||
if os.path.isdir(self.trainer.checkpoint_callback.dirpath):
|
||||
self.embedding_manager.save(os.path.join(self.trainer.checkpoint_callback.dirpath, "embeddings.pt"))
|
||||
|
||||
if (self.global_step - self.emb_ckpt_counter) > 500:
|
||||
self.embedding_manager.save(os.path.join(self.trainer.checkpoint_callback.dirpath, f"embeddings_gs-{self.global_step}.pt"))
|
||||
|
||||
self.emb_ckpt_counter += 500
|
||||
|
||||
|
||||
class DiffusionWrapper(pl.LightningModule):
|
||||
def __init__(self, diff_model_config, conditioning_key):
|
||||
|
Reference in New Issue
Block a user