InvokeAI/invokeai/backend/stable_diffusion/diffusion/classifier.py

331 lines
11 KiB
Python
Raw Normal View History

2023-02-28 05:31:15 +00:00
import os
2023-03-03 06:02:00 +00:00
from copy import deepcopy
from glob import glob
2023-02-28 05:31:15 +00:00
import pytorch_lightning as pl
2023-03-03 06:02:00 +00:00
import torch
from einops import rearrange
from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
from ldm.util import default, instantiate_from_config, ismap, log_txt_as_img
from natsort import natsorted
2023-02-28 05:31:15 +00:00
from omegaconf import OmegaConf
from torch.nn import functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
2023-03-03 06:02:00 +00:00
__models__ = {"class_label": EncoderUNetModel, "segmentation": UNetModel}
2023-02-28 05:31:15 +00:00
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
class NoisyLatentImageClassifier(pl.LightningModule):
def __init__(
self,
diffusion_path,
num_classes,
ckpt_path=None,
2023-03-03 06:02:00 +00:00
pool="attention",
2023-02-28 05:31:15 +00:00
label_key=None,
diffusion_ckpt_path=None,
scheduler_config=None,
weight_decay=1.0e-2,
log_steps=10,
2023-03-03 06:02:00 +00:00
monitor="val/loss",
2023-02-28 05:31:15 +00:00
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.num_classes = num_classes
# get latest config of diffusion model
diffusion_config = natsorted(
2023-03-03 06:02:00 +00:00
glob(os.path.join(diffusion_path, "configs", "*-project.yaml"))
2023-02-28 05:31:15 +00:00
)[-1]
self.diffusion_config = OmegaConf.load(diffusion_config).model
self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
self.load_diffusion()
self.monitor = monitor
2023-03-03 06:02:00 +00:00
self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
self.log_time_interval = self.diffusion_model.num_timesteps // log_steps
2023-02-28 05:31:15 +00:00
self.log_steps = log_steps
self.label_key = (
label_key
2023-03-03 06:02:00 +00:00
if not hasattr(self.diffusion_model, "cond_stage_key")
2023-02-28 05:31:15 +00:00
else self.diffusion_model.cond_stage_key
)
assert (
self.label_key is not None
2023-03-03 06:02:00 +00:00
), "label_key neither in diffusion model nor in model.params"
2023-02-28 05:31:15 +00:00
if self.label_key not in __models__:
raise NotImplementedError()
self.load_classifier(ckpt_path, pool)
self.scheduler_config = scheduler_config
self.use_scheduler = self.scheduler_config is not None
self.weight_decay = weight_decay
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
2023-03-03 06:02:00 +00:00
sd = torch.load(path, map_location="cpu")
if "state_dict" in list(sd.keys()):
sd = sd["state_dict"]
2023-02-28 05:31:15 +00:00
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
2023-03-03 06:02:00 +00:00
print("Deleting key {} from state_dict.".format(k))
2023-02-28 05:31:15 +00:00
del sd[k]
missing, unexpected = (
self.load_state_dict(sd, strict=False)
if not only_model
else self.model.load_state_dict(sd, strict=False)
)
print(
2023-03-03 06:02:00 +00:00
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
2023-02-28 05:31:15 +00:00
)
if len(missing) > 0:
2023-03-03 06:02:00 +00:00
print(f"Missing Keys: {missing}")
2023-02-28 05:31:15 +00:00
if len(unexpected) > 0:
2023-03-03 06:02:00 +00:00
print(f"Unexpected Keys: {unexpected}")
2023-02-28 05:31:15 +00:00
def load_diffusion(self):
model = instantiate_from_config(self.diffusion_config)
self.diffusion_model = model.eval()
self.diffusion_model.train = disabled_train
for param in self.diffusion_model.parameters():
param.requires_grad = False
def load_classifier(self, ckpt_path, pool):
2023-03-03 06:02:00 +00:00
model_config = deepcopy(self.diffusion_config.params.unet_config.params)
2023-02-28 05:31:15 +00:00
model_config.in_channels = (
self.diffusion_config.params.unet_config.params.out_channels
)
model_config.out_channels = self.num_classes
2023-03-03 06:02:00 +00:00
if self.label_key == "class_label":
2023-02-28 05:31:15 +00:00
model_config.pool = pool
self.model = __models__[self.label_key](**model_config)
if ckpt_path is not None:
print(
2023-03-03 06:02:00 +00:00
"#####################################################################"
2023-02-28 05:31:15 +00:00
)
print(f'load from ckpt "{ckpt_path}"')
print(
2023-03-03 06:02:00 +00:00
"#####################################################################"
2023-02-28 05:31:15 +00:00
)
self.init_from_ckpt(ckpt_path)
@torch.no_grad()
def get_x_noisy(self, x, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x))
continuous_sqrt_alpha_cumprod = None
if self.diffusion_model.use_continuous_noise:
continuous_sqrt_alpha_cumprod = (
2023-03-03 06:02:00 +00:00
self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1)
2023-02-28 05:31:15 +00:00
)
# todo: make sure t+1 is correct here
return self.diffusion_model.q_sample(
x_start=x,
t=t,
noise=noise,
continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod,
)
def forward(self, x_noisy, t, *args, **kwargs):
return self.model(x_noisy, t)
@torch.no_grad()
def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
2023-03-03 06:02:00 +00:00
x = rearrange(x, "b h w c -> b c h w")
2023-02-28 05:31:15 +00:00
x = x.to(memory_format=torch.contiguous_format).float()
return x
@torch.no_grad()
def get_conditioning(self, batch, k=None):
if k is None:
k = self.label_key
2023-03-03 06:02:00 +00:00
assert k is not None, "Needs to provide label key"
2023-02-28 05:31:15 +00:00
targets = batch[k].to(self.device)
2023-03-03 06:02:00 +00:00
if self.label_key == "segmentation":
targets = rearrange(targets, "b h w c -> b c h w")
2023-02-28 05:31:15 +00:00
for down in range(self.numd):
h, w = targets.shape[-2:]
2023-03-03 06:02:00 +00:00
targets = F.interpolate(targets, size=(h // 2, w // 2), mode="nearest")
2023-02-28 05:31:15 +00:00
# targets = rearrange(targets,'b c h w -> b h w c')
return targets
2023-03-03 06:02:00 +00:00
def compute_top_k(self, logits, labels, k, reduction="mean"):
2023-02-28 05:31:15 +00:00
_, top_ks = torch.topk(logits, k, dim=1)
2023-03-03 06:02:00 +00:00
if reduction == "mean":
return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
elif reduction == "none":
2023-02-28 05:31:15 +00:00
return (top_ks == labels[:, None]).float().sum(dim=-1)
def on_train_epoch_start(self):
# save some memory
2023-03-03 06:02:00 +00:00
self.diffusion_model.model.to("cpu")
2023-02-28 05:31:15 +00:00
@torch.no_grad()
def write_logs(self, loss, logits, targets):
2023-03-03 06:02:00 +00:00
log_prefix = "train" if self.training else "val"
2023-02-28 05:31:15 +00:00
log = {}
2023-03-03 06:02:00 +00:00
log[f"{log_prefix}/loss"] = loss.mean()
log[f"{log_prefix}/acc@1"] = self.compute_top_k(
logits, targets, k=1, reduction="mean"
2023-02-28 05:31:15 +00:00
)
2023-03-03 06:02:00 +00:00
log[f"{log_prefix}/acc@5"] = self.compute_top_k(
logits, targets, k=5, reduction="mean"
2023-02-28 05:31:15 +00:00
)
self.log_dict(
log,
prog_bar=False,
logger=True,
on_step=self.training,
on_epoch=True,
)
2023-03-03 06:02:00 +00:00
self.log("loss", log[f"{log_prefix}/loss"], prog_bar=True, logger=False)
2023-02-28 05:31:15 +00:00
self.log(
2023-03-03 06:02:00 +00:00
"global_step",
2023-02-28 05:31:15 +00:00
self.global_step,
logger=False,
on_epoch=False,
prog_bar=True,
)
2023-03-03 06:02:00 +00:00
lr = self.optimizers().param_groups[0]["lr"]
2023-02-28 05:31:15 +00:00
self.log(
2023-03-03 06:02:00 +00:00
"lr_abs",
2023-02-28 05:31:15 +00:00
lr,
on_step=True,
logger=True,
on_epoch=False,
prog_bar=True,
)
def shared_step(self, batch, t=None):
x, *_ = self.diffusion_model.get_input(
batch, k=self.diffusion_model.first_stage_key
)
targets = self.get_conditioning(batch)
if targets.dim() == 4:
targets = targets.argmax(dim=1)
if t is None:
t = torch.randint(
0,
self.diffusion_model.num_timesteps,
(x.shape[0],),
device=self.device,
).long()
else:
2023-03-03 06:02:00 +00:00
t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long()
2023-02-28 05:31:15 +00:00
x_noisy = self.get_x_noisy(x, t)
logits = self(x_noisy, t)
2023-03-03 06:02:00 +00:00
loss = F.cross_entropy(logits, targets, reduction="none")
2023-02-28 05:31:15 +00:00
self.write_logs(loss.detach(), logits.detach(), targets.detach())
loss = loss.mean()
return loss, logits, x_noisy, targets
def training_step(self, batch, batch_idx):
loss, *_ = self.shared_step(batch)
return loss
def reset_noise_accs(self):
self.noisy_acc = {
2023-03-03 06:02:00 +00:00
t: {"acc@1": [], "acc@5": []}
2023-02-28 05:31:15 +00:00
for t in range(
0,
self.diffusion_model.num_timesteps,
self.diffusion_model.log_every_t,
)
}
def on_validation_start(self):
self.reset_noise_accs()
@torch.no_grad()
def validation_step(self, batch, batch_idx):
loss, *_ = self.shared_step(batch)
for t in self.noisy_acc:
_, logits, _, targets = self.shared_step(batch, t)
2023-03-03 06:02:00 +00:00
self.noisy_acc[t]["acc@1"].append(
self.compute_top_k(logits, targets, k=1, reduction="mean")
2023-02-28 05:31:15 +00:00
)
2023-03-03 06:02:00 +00:00
self.noisy_acc[t]["acc@5"].append(
self.compute_top_k(logits, targets, k=5, reduction="mean")
2023-02-28 05:31:15 +00:00
)
return loss
def configure_optimizers(self):
optimizer = AdamW(
self.model.parameters(),
lr=self.learning_rate,
weight_decay=self.weight_decay,
)
if self.use_scheduler:
scheduler = instantiate_from_config(self.scheduler_config)
2023-03-03 06:02:00 +00:00
print("Setting up LambdaLR scheduler...")
2023-02-28 05:31:15 +00:00
scheduler = [
{
2023-03-03 06:02:00 +00:00
"scheduler": LambdaLR(optimizer, lr_lambda=scheduler.schedule),
"interval": "step",
"frequency": 1,
2023-02-28 05:31:15 +00:00
}
]
return [optimizer], scheduler
return optimizer
@torch.no_grad()
def log_images(self, batch, N=8, *args, **kwargs):
log = dict()
x = self.get_input(batch, self.diffusion_model.first_stage_key)
2023-03-03 06:02:00 +00:00
log["inputs"] = x
2023-02-28 05:31:15 +00:00
y = self.get_conditioning(batch)
2023-03-03 06:02:00 +00:00
if self.label_key == "class_label":
y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
log["labels"] = y
2023-02-28 05:31:15 +00:00
if ismap(y):
2023-03-03 06:02:00 +00:00
log["labels"] = self.diffusion_model.to_rgb(y)
2023-02-28 05:31:15 +00:00
for step in range(self.log_steps):
current_time = step * self.log_time_interval
_, logits, x_noisy, _ = self.shared_step(batch, t=current_time)
2023-03-03 06:02:00 +00:00
log[f"inputs@t{current_time}"] = x_noisy
2023-02-28 05:31:15 +00:00
2023-03-03 06:02:00 +00:00
pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes)
pred = rearrange(pred, "b h w c -> b c h w")
2023-02-28 05:31:15 +00:00
2023-03-03 06:02:00 +00:00
log[f"pred@t{current_time}"] = self.diffusion_model.to_rgb(pred)
2023-02-28 05:31:15 +00:00
for key in log:
log[key] = log[key][:N]
return log