2021-12-21 02:23:41 +00:00
|
|
|
import os
|
|
|
|
import torch
|
|
|
|
import pytorch_lightning as pl
|
|
|
|
from omegaconf import OmegaConf
|
|
|
|
from torch.nn import functional as F
|
|
|
|
from torch.optim import AdamW
|
|
|
|
from torch.optim.lr_scheduler import LambdaLR
|
|
|
|
from copy import deepcopy
|
|
|
|
from einops import rearrange
|
|
|
|
from glob import glob
|
|
|
|
from natsort import natsorted
|
|
|
|
|
2022-08-26 07:15:42 +00:00
|
|
|
from ldm.modules.diffusionmodules.openaimodel import (
|
|
|
|
EncoderUNetModel,
|
|
|
|
UNetModel,
|
|
|
|
)
|
2021-12-21 02:23:41 +00:00
|
|
|
from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config
|
|
|
|
|
2022-08-26 07:15:42 +00:00
|
|
|
__models__ = {'class_label': EncoderUNetModel, 'segmentation': UNetModel}
|
2021-12-21 02:23:41 +00:00
|
|
|
|
|
|
|
|
|
|
|
def disabled_train(self, mode=True):
|
|
|
|
"""Overwrite model.train with this function to make sure train/eval mode
|
|
|
|
does not change anymore."""
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
|
|
class NoisyLatentImageClassifier(pl.LightningModule):
|
2022-08-26 07:15:42 +00:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
diffusion_path,
|
|
|
|
num_classes,
|
|
|
|
ckpt_path=None,
|
|
|
|
pool='attention',
|
|
|
|
label_key=None,
|
|
|
|
diffusion_ckpt_path=None,
|
|
|
|
scheduler_config=None,
|
|
|
|
weight_decay=1.0e-2,
|
|
|
|
log_steps=10,
|
|
|
|
monitor='val/loss',
|
|
|
|
*args,
|
|
|
|
**kwargs,
|
|
|
|
):
|
2021-12-21 02:23:41 +00:00
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
self.num_classes = num_classes
|
|
|
|
# get latest config of diffusion model
|
2022-08-26 07:15:42 +00:00
|
|
|
diffusion_config = natsorted(
|
|
|
|
glob(os.path.join(diffusion_path, 'configs', '*-project.yaml'))
|
|
|
|
)[-1]
|
2021-12-21 02:23:41 +00:00
|
|
|
self.diffusion_config = OmegaConf.load(diffusion_config).model
|
|
|
|
self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
|
|
|
|
self.load_diffusion()
|
|
|
|
|
|
|
|
self.monitor = monitor
|
2022-08-26 07:15:42 +00:00
|
|
|
self.numd = (
|
|
|
|
self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
|
|
|
|
)
|
|
|
|
self.log_time_interval = (
|
|
|
|
self.diffusion_model.num_timesteps // log_steps
|
|
|
|
)
|
2021-12-21 02:23:41 +00:00
|
|
|
self.log_steps = log_steps
|
|
|
|
|
2022-08-26 07:15:42 +00:00
|
|
|
self.label_key = (
|
|
|
|
label_key
|
|
|
|
if not hasattr(self.diffusion_model, 'cond_stage_key')
|
2021-12-21 02:23:41 +00:00
|
|
|
else self.diffusion_model.cond_stage_key
|
2022-08-26 07:15:42 +00:00
|
|
|
)
|
2021-12-21 02:23:41 +00:00
|
|
|
|
2022-08-26 07:15:42 +00:00
|
|
|
assert (
|
|
|
|
self.label_key is not None
|
|
|
|
), 'label_key neither in diffusion model nor in model.params'
|
2021-12-21 02:23:41 +00:00
|
|
|
|
|
|
|
if self.label_key not in __models__:
|
|
|
|
raise NotImplementedError()
|
|
|
|
|
|
|
|
self.load_classifier(ckpt_path, pool)
|
|
|
|
|
|
|
|
self.scheduler_config = scheduler_config
|
|
|
|
self.use_scheduler = self.scheduler_config is not None
|
|
|
|
self.weight_decay = weight_decay
|
|
|
|
|
|
|
|
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
|
2022-08-26 07:15:42 +00:00
|
|
|
sd = torch.load(path, map_location='cpu')
|
|
|
|
if 'state_dict' in list(sd.keys()):
|
|
|
|
sd = sd['state_dict']
|
2021-12-21 02:23:41 +00:00
|
|
|
keys = list(sd.keys())
|
|
|
|
for k in keys:
|
|
|
|
for ik in ignore_keys:
|
|
|
|
if k.startswith(ik):
|
2022-08-26 07:15:42 +00:00
|
|
|
print('Deleting key {} from state_dict.'.format(k))
|
2021-12-21 02:23:41 +00:00
|
|
|
del sd[k]
|
2022-08-26 07:15:42 +00:00
|
|
|
missing, unexpected = (
|
|
|
|
self.load_state_dict(sd, strict=False)
|
|
|
|
if not only_model
|
|
|
|
else self.model.load_state_dict(sd, strict=False)
|
|
|
|
)
|
|
|
|
print(
|
|
|
|
f'Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys'
|
|
|
|
)
|
2021-12-21 02:23:41 +00:00
|
|
|
if len(missing) > 0:
|
2022-08-26 07:15:42 +00:00
|
|
|
print(f'Missing Keys: {missing}')
|
2021-12-21 02:23:41 +00:00
|
|
|
if len(unexpected) > 0:
|
2022-08-26 07:15:42 +00:00
|
|
|
print(f'Unexpected Keys: {unexpected}')
|
2021-12-21 02:23:41 +00:00
|
|
|
|
|
|
|
def load_diffusion(self):
|
|
|
|
model = instantiate_from_config(self.diffusion_config)
|
|
|
|
self.diffusion_model = model.eval()
|
|
|
|
self.diffusion_model.train = disabled_train
|
|
|
|
for param in self.diffusion_model.parameters():
|
|
|
|
param.requires_grad = False
|
|
|
|
|
|
|
|
def load_classifier(self, ckpt_path, pool):
|
2022-08-26 07:15:42 +00:00
|
|
|
model_config = deepcopy(
|
|
|
|
self.diffusion_config.params.unet_config.params
|
|
|
|
)
|
|
|
|
model_config.in_channels = (
|
|
|
|
self.diffusion_config.params.unet_config.params.out_channels
|
|
|
|
)
|
2021-12-21 02:23:41 +00:00
|
|
|
model_config.out_channels = self.num_classes
|
|
|
|
if self.label_key == 'class_label':
|
|
|
|
model_config.pool = pool
|
|
|
|
|
|
|
|
self.model = __models__[self.label_key](**model_config)
|
|
|
|
if ckpt_path is not None:
|
2022-08-26 07:15:42 +00:00
|
|
|
print(
|
|
|
|
'#####################################################################'
|
|
|
|
)
|
2021-12-21 02:23:41 +00:00
|
|
|
print(f'load from ckpt "{ckpt_path}"')
|
2022-08-26 07:15:42 +00:00
|
|
|
print(
|
|
|
|
'#####################################################################'
|
|
|
|
)
|
2021-12-21 02:23:41 +00:00
|
|
|
self.init_from_ckpt(ckpt_path)
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
def get_x_noisy(self, x, t, noise=None):
|
|
|
|
noise = default(noise, lambda: torch.randn_like(x))
|
|
|
|
continuous_sqrt_alpha_cumprod = None
|
|
|
|
if self.diffusion_model.use_continuous_noise:
|
2022-08-26 07:15:42 +00:00
|
|
|
continuous_sqrt_alpha_cumprod = (
|
|
|
|
self.diffusion_model.sample_continuous_noise_level(
|
|
|
|
x.shape[0], t + 1
|
|
|
|
)
|
|
|
|
)
|
2021-12-21 02:23:41 +00:00
|
|
|
# todo: make sure t+1 is correct here
|
|
|
|
|
2022-08-26 07:15:42 +00:00
|
|
|
return self.diffusion_model.q_sample(
|
|
|
|
x_start=x,
|
|
|
|
t=t,
|
|
|
|
noise=noise,
|
|
|
|
continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod,
|
|
|
|
)
|
2021-12-21 02:23:41 +00:00
|
|
|
|
|
|
|
def forward(self, x_noisy, t, *args, **kwargs):
|
|
|
|
return self.model(x_noisy, t)
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
def get_input(self, batch, k):
|
|
|
|
x = batch[k]
|
|
|
|
if len(x.shape) == 3:
|
|
|
|
x = x[..., None]
|
|
|
|
x = rearrange(x, 'b h w c -> b c h w')
|
|
|
|
x = x.to(memory_format=torch.contiguous_format).float()
|
|
|
|
return x
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
def get_conditioning(self, batch, k=None):
|
|
|
|
if k is None:
|
|
|
|
k = self.label_key
|
|
|
|
assert k is not None, 'Needs to provide label key'
|
|
|
|
|
|
|
|
targets = batch[k].to(self.device)
|
|
|
|
|
|
|
|
if self.label_key == 'segmentation':
|
|
|
|
targets = rearrange(targets, 'b h w c -> b c h w')
|
|
|
|
for down in range(self.numd):
|
|
|
|
h, w = targets.shape[-2:]
|
2022-08-26 07:15:42 +00:00
|
|
|
targets = F.interpolate(
|
|
|
|
targets, size=(h // 2, w // 2), mode='nearest'
|
|
|
|
)
|
2021-12-21 02:23:41 +00:00
|
|
|
|
|
|
|
# targets = rearrange(targets,'b c h w -> b h w c')
|
|
|
|
|
|
|
|
return targets
|
|
|
|
|
2022-08-26 07:15:42 +00:00
|
|
|
def compute_top_k(self, logits, labels, k, reduction='mean'):
|
2021-12-21 02:23:41 +00:00
|
|
|
_, top_ks = torch.topk(logits, k, dim=1)
|
2022-08-26 07:15:42 +00:00
|
|
|
if reduction == 'mean':
|
|
|
|
return (
|
|
|
|
(top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
|
|
|
|
)
|
|
|
|
elif reduction == 'none':
|
2021-12-21 02:23:41 +00:00
|
|
|
return (top_ks == labels[:, None]).float().sum(dim=-1)
|
|
|
|
|
|
|
|
def on_train_epoch_start(self):
|
|
|
|
# save some memory
|
|
|
|
self.diffusion_model.model.to('cpu')
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
def write_logs(self, loss, logits, targets):
|
|
|
|
log_prefix = 'train' if self.training else 'val'
|
|
|
|
log = {}
|
2022-08-26 07:15:42 +00:00
|
|
|
log[f'{log_prefix}/loss'] = loss.mean()
|
|
|
|
log[f'{log_prefix}/acc@1'] = self.compute_top_k(
|
|
|
|
logits, targets, k=1, reduction='mean'
|
2021-12-21 02:23:41 +00:00
|
|
|
)
|
2022-08-26 07:15:42 +00:00
|
|
|
log[f'{log_prefix}/acc@5'] = self.compute_top_k(
|
|
|
|
logits, targets, k=5, reduction='mean'
|
2021-12-21 02:23:41 +00:00
|
|
|
)
|
|
|
|
|
2022-08-26 07:15:42 +00:00
|
|
|
self.log_dict(
|
|
|
|
log,
|
|
|
|
prog_bar=False,
|
|
|
|
logger=True,
|
|
|
|
on_step=self.training,
|
|
|
|
on_epoch=True,
|
|
|
|
)
|
|
|
|
self.log(
|
|
|
|
'loss', log[f'{log_prefix}/loss'], prog_bar=True, logger=False
|
|
|
|
)
|
|
|
|
self.log(
|
|
|
|
'global_step',
|
|
|
|
self.global_step,
|
|
|
|
logger=False,
|
|
|
|
on_epoch=False,
|
|
|
|
prog_bar=True,
|
|
|
|
)
|
2021-12-21 02:23:41 +00:00
|
|
|
lr = self.optimizers().param_groups[0]['lr']
|
2022-08-26 07:15:42 +00:00
|
|
|
self.log(
|
|
|
|
'lr_abs',
|
|
|
|
lr,
|
|
|
|
on_step=True,
|
|
|
|
logger=True,
|
|
|
|
on_epoch=False,
|
|
|
|
prog_bar=True,
|
|
|
|
)
|
2021-12-21 02:23:41 +00:00
|
|
|
|
|
|
|
def shared_step(self, batch, t=None):
|
2022-08-26 07:15:42 +00:00
|
|
|
x, *_ = self.diffusion_model.get_input(
|
|
|
|
batch, k=self.diffusion_model.first_stage_key
|
|
|
|
)
|
2021-12-21 02:23:41 +00:00
|
|
|
targets = self.get_conditioning(batch)
|
|
|
|
if targets.dim() == 4:
|
|
|
|
targets = targets.argmax(dim=1)
|
|
|
|
if t is None:
|
2022-08-26 07:15:42 +00:00
|
|
|
t = torch.randint(
|
|
|
|
0,
|
|
|
|
self.diffusion_model.num_timesteps,
|
|
|
|
(x.shape[0],),
|
|
|
|
device=self.device,
|
|
|
|
).long()
|
2021-12-21 02:23:41 +00:00
|
|
|
else:
|
2022-08-26 07:15:42 +00:00
|
|
|
t = torch.full(
|
|
|
|
size=(x.shape[0],), fill_value=t, device=self.device
|
|
|
|
).long()
|
2021-12-21 02:23:41 +00:00
|
|
|
x_noisy = self.get_x_noisy(x, t)
|
|
|
|
logits = self(x_noisy, t)
|
|
|
|
|
|
|
|
loss = F.cross_entropy(logits, targets, reduction='none')
|
|
|
|
|
|
|
|
self.write_logs(loss.detach(), logits.detach(), targets.detach())
|
|
|
|
|
|
|
|
loss = loss.mean()
|
|
|
|
return loss, logits, x_noisy, targets
|
|
|
|
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
|
|
loss, *_ = self.shared_step(batch)
|
|
|
|
return loss
|
|
|
|
|
|
|
|
def reset_noise_accs(self):
|
2022-08-26 07:15:42 +00:00
|
|
|
self.noisy_acc = {
|
|
|
|
t: {'acc@1': [], 'acc@5': []}
|
|
|
|
for t in range(
|
|
|
|
0,
|
|
|
|
self.diffusion_model.num_timesteps,
|
|
|
|
self.diffusion_model.log_every_t,
|
|
|
|
)
|
|
|
|
}
|
2021-12-21 02:23:41 +00:00
|
|
|
|
|
|
|
def on_validation_start(self):
|
|
|
|
self.reset_noise_accs()
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
def validation_step(self, batch, batch_idx):
|
|
|
|
loss, *_ = self.shared_step(batch)
|
|
|
|
|
|
|
|
for t in self.noisy_acc:
|
|
|
|
_, logits, _, targets = self.shared_step(batch, t)
|
2022-08-26 07:15:42 +00:00
|
|
|
self.noisy_acc[t]['acc@1'].append(
|
|
|
|
self.compute_top_k(logits, targets, k=1, reduction='mean')
|
|
|
|
)
|
|
|
|
self.noisy_acc[t]['acc@5'].append(
|
|
|
|
self.compute_top_k(logits, targets, k=5, reduction='mean')
|
|
|
|
)
|
2021-12-21 02:23:41 +00:00
|
|
|
|
|
|
|
return loss
|
|
|
|
|
|
|
|
def configure_optimizers(self):
|
2022-08-26 07:15:42 +00:00
|
|
|
optimizer = AdamW(
|
|
|
|
self.model.parameters(),
|
|
|
|
lr=self.learning_rate,
|
|
|
|
weight_decay=self.weight_decay,
|
|
|
|
)
|
2021-12-21 02:23:41 +00:00
|
|
|
|
|
|
|
if self.use_scheduler:
|
|
|
|
scheduler = instantiate_from_config(self.scheduler_config)
|
|
|
|
|
2022-08-26 07:15:42 +00:00
|
|
|
print('Setting up LambdaLR scheduler...')
|
2021-12-21 02:23:41 +00:00
|
|
|
scheduler = [
|
|
|
|
{
|
2022-08-26 07:15:42 +00:00
|
|
|
'scheduler': LambdaLR(
|
|
|
|
optimizer, lr_lambda=scheduler.schedule
|
|
|
|
),
|
2021-12-21 02:23:41 +00:00
|
|
|
'interval': 'step',
|
2022-08-26 07:15:42 +00:00
|
|
|
'frequency': 1,
|
|
|
|
}
|
|
|
|
]
|
2021-12-21 02:23:41 +00:00
|
|
|
return [optimizer], scheduler
|
|
|
|
|
|
|
|
return optimizer
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
def log_images(self, batch, N=8, *args, **kwargs):
|
|
|
|
log = dict()
|
|
|
|
x = self.get_input(batch, self.diffusion_model.first_stage_key)
|
|
|
|
log['inputs'] = x
|
|
|
|
|
|
|
|
y = self.get_conditioning(batch)
|
|
|
|
|
|
|
|
if self.label_key == 'class_label':
|
2022-08-26 07:15:42 +00:00
|
|
|
y = log_txt_as_img((x.shape[2], x.shape[3]), batch['human_label'])
|
2021-12-21 02:23:41 +00:00
|
|
|
log['labels'] = y
|
|
|
|
|
|
|
|
if ismap(y):
|
|
|
|
log['labels'] = self.diffusion_model.to_rgb(y)
|
|
|
|
|
|
|
|
for step in range(self.log_steps):
|
|
|
|
current_time = step * self.log_time_interval
|
|
|
|
|
|
|
|
_, logits, x_noisy, _ = self.shared_step(batch, t=current_time)
|
|
|
|
|
|
|
|
log[f'inputs@t{current_time}'] = x_noisy
|
|
|
|
|
2022-08-26 07:15:42 +00:00
|
|
|
pred = F.one_hot(
|
|
|
|
logits.argmax(dim=1), num_classes=self.num_classes
|
|
|
|
)
|
2021-12-21 02:23:41 +00:00
|
|
|
pred = rearrange(pred, 'b h w c -> b c h w')
|
|
|
|
|
2022-08-26 07:15:42 +00:00
|
|
|
log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(
|
|
|
|
pred
|
|
|
|
)
|
2021-12-21 02:23:41 +00:00
|
|
|
|
|
|
|
for key in log:
|
|
|
|
log[key] = log[key][:N]
|
|
|
|
|
|
|
|
return log
|