remove legacy ldm code

This commit is contained in:
Kevin Turner
2023-03-04 18:16:59 -08:00
parent c4e6d4b348
commit c703b60986
33 changed files with 24 additions and 9957 deletions

View File

@ -1,330 +0,0 @@
import os
from copy import deepcopy
from glob import glob
import pytorch_lightning as pl
import torch
from einops import rearrange
from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
from ldm.util import default, instantiate_from_config, ismap, log_txt_as_img
from natsort import natsorted
from omegaconf import OmegaConf
from torch.nn import functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
__models__ = {"class_label": EncoderUNetModel, "segmentation": UNetModel}
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
class NoisyLatentImageClassifier(pl.LightningModule):
def __init__(
self,
diffusion_path,
num_classes,
ckpt_path=None,
pool="attention",
label_key=None,
diffusion_ckpt_path=None,
scheduler_config=None,
weight_decay=1.0e-2,
log_steps=10,
monitor="val/loss",
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.num_classes = num_classes
# get latest config of diffusion model
diffusion_config = natsorted(
glob(os.path.join(diffusion_path, "configs", "*-project.yaml"))
)[-1]
self.diffusion_config = OmegaConf.load(diffusion_config).model
self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
self.load_diffusion()
self.monitor = monitor
self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
self.log_time_interval = self.diffusion_model.num_timesteps // log_steps
self.log_steps = log_steps
self.label_key = (
label_key
if not hasattr(self.diffusion_model, "cond_stage_key")
else self.diffusion_model.cond_stage_key
)
assert (
self.label_key is not None
), "label_key neither in diffusion model nor in model.params"
if self.label_key not in __models__:
raise NotImplementedError()
self.load_classifier(ckpt_path, pool)
self.scheduler_config = scheduler_config
self.use_scheduler = self.scheduler_config is not None
self.weight_decay = weight_decay
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
sd = torch.load(path, map_location="cpu")
if "state_dict" in list(sd.keys()):
sd = sd["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
missing, unexpected = (
self.load_state_dict(sd, strict=False)
if not only_model
else self.model.load_state_dict(sd, strict=False)
)
print(
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
)
if len(missing) > 0:
print(f"Missing Keys: {missing}")
if len(unexpected) > 0:
print(f"Unexpected Keys: {unexpected}")
def load_diffusion(self):
model = instantiate_from_config(self.diffusion_config)
self.diffusion_model = model.eval()
self.diffusion_model.train = disabled_train
for param in self.diffusion_model.parameters():
param.requires_grad = False
def load_classifier(self, ckpt_path, pool):
model_config = deepcopy(self.diffusion_config.params.unet_config.params)
model_config.in_channels = (
self.diffusion_config.params.unet_config.params.out_channels
)
model_config.out_channels = self.num_classes
if self.label_key == "class_label":
model_config.pool = pool
self.model = __models__[self.label_key](**model_config)
if ckpt_path is not None:
print(
"#####################################################################"
)
print(f'load from ckpt "{ckpt_path}"')
print(
"#####################################################################"
)
self.init_from_ckpt(ckpt_path)
@torch.no_grad()
def get_x_noisy(self, x, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x))
continuous_sqrt_alpha_cumprod = None
if self.diffusion_model.use_continuous_noise:
continuous_sqrt_alpha_cumprod = (
self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1)
)
# todo: make sure t+1 is correct here
return self.diffusion_model.q_sample(
x_start=x,
t=t,
noise=noise,
continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod,
)
def forward(self, x_noisy, t, *args, **kwargs):
return self.model(x_noisy, t)
@torch.no_grad()
def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = rearrange(x, "b h w c -> b c h w")
x = x.to(memory_format=torch.contiguous_format).float()
return x
@torch.no_grad()
def get_conditioning(self, batch, k=None):
if k is None:
k = self.label_key
assert k is not None, "Needs to provide label key"
targets = batch[k].to(self.device)
if self.label_key == "segmentation":
targets = rearrange(targets, "b h w c -> b c h w")
for down in range(self.numd):
h, w = targets.shape[-2:]
targets = F.interpolate(targets, size=(h // 2, w // 2), mode="nearest")
# targets = rearrange(targets,'b c h w -> b h w c')
return targets
def compute_top_k(self, logits, labels, k, reduction="mean"):
_, top_ks = torch.topk(logits, k, dim=1)
if reduction == "mean":
return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
elif reduction == "none":
return (top_ks == labels[:, None]).float().sum(dim=-1)
def on_train_epoch_start(self):
# save some memory
self.diffusion_model.model.to("cpu")
@torch.no_grad()
def write_logs(self, loss, logits, targets):
log_prefix = "train" if self.training else "val"
log = {}
log[f"{log_prefix}/loss"] = loss.mean()
log[f"{log_prefix}/acc@1"] = self.compute_top_k(
logits, targets, k=1, reduction="mean"
)
log[f"{log_prefix}/acc@5"] = self.compute_top_k(
logits, targets, k=5, reduction="mean"
)
self.log_dict(
log,
prog_bar=False,
logger=True,
on_step=self.training,
on_epoch=True,
)
self.log("loss", log[f"{log_prefix}/loss"], prog_bar=True, logger=False)
self.log(
"global_step",
self.global_step,
logger=False,
on_epoch=False,
prog_bar=True,
)
lr = self.optimizers().param_groups[0]["lr"]
self.log(
"lr_abs",
lr,
on_step=True,
logger=True,
on_epoch=False,
prog_bar=True,
)
def shared_step(self, batch, t=None):
x, *_ = self.diffusion_model.get_input(
batch, k=self.diffusion_model.first_stage_key
)
targets = self.get_conditioning(batch)
if targets.dim() == 4:
targets = targets.argmax(dim=1)
if t is None:
t = torch.randint(
0,
self.diffusion_model.num_timesteps,
(x.shape[0],),
device=self.device,
).long()
else:
t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long()
x_noisy = self.get_x_noisy(x, t)
logits = self(x_noisy, t)
loss = F.cross_entropy(logits, targets, reduction="none")
self.write_logs(loss.detach(), logits.detach(), targets.detach())
loss = loss.mean()
return loss, logits, x_noisy, targets
def training_step(self, batch, batch_idx):
loss, *_ = self.shared_step(batch)
return loss
def reset_noise_accs(self):
self.noisy_acc = {
t: {"acc@1": [], "acc@5": []}
for t in range(
0,
self.diffusion_model.num_timesteps,
self.diffusion_model.log_every_t,
)
}
def on_validation_start(self):
self.reset_noise_accs()
@torch.no_grad()
def validation_step(self, batch, batch_idx):
loss, *_ = self.shared_step(batch)
for t in self.noisy_acc:
_, logits, _, targets = self.shared_step(batch, t)
self.noisy_acc[t]["acc@1"].append(
self.compute_top_k(logits, targets, k=1, reduction="mean")
)
self.noisy_acc[t]["acc@5"].append(
self.compute_top_k(logits, targets, k=5, reduction="mean")
)
return loss
def configure_optimizers(self):
optimizer = AdamW(
self.model.parameters(),
lr=self.learning_rate,
weight_decay=self.weight_decay,
)
if self.use_scheduler:
scheduler = instantiate_from_config(self.scheduler_config)
print("Setting up LambdaLR scheduler...")
scheduler = [
{
"scheduler": LambdaLR(optimizer, lr_lambda=scheduler.schedule),
"interval": "step",
"frequency": 1,
}
]
return [optimizer], scheduler
return optimizer
@torch.no_grad()
def log_images(self, batch, N=8, *args, **kwargs):
log = dict()
x = self.get_input(batch, self.diffusion_model.first_stage_key)
log["inputs"] = x
y = self.get_conditioning(batch)
if self.label_key == "class_label":
y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
log["labels"] = y
if ismap(y):
log["labels"] = self.diffusion_model.to_rgb(y)
for step in range(self.log_steps):
current_time = step * self.log_time_interval
_, logits, x_noisy, _ = self.shared_step(batch, t=current_time)
log[f"inputs@t{current_time}"] = x_noisy
pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes)
pred = rearrange(pred, "b h w c -> b c h w")
log[f"pred@t{current_time}"] = self.diffusion_model.to_rgb(pred)
for key in log:
log[key] = log[key][:N]
return log

View File

@ -1,113 +0,0 @@
"""SAMPLING ONLY."""
import torch
from ..diffusionmodules.util import noise_like
from .sampler import Sampler
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent
class DDIMSampler(Sampler):
def __init__(self, model, schedule="linear", device=None, **kwargs):
super().__init__(model, schedule, model.num_timesteps, device)
self.invokeai_diffuser = InvokeAIDiffuserComponent(
self.model,
model_forward_callback=lambda x, sigma, cond: self.model.apply_model(
x, sigma, cond
),
)
def prepare_to_sample(self, t_enc, **kwargs):
super().prepare_to_sample(t_enc, **kwargs)
extra_conditioning_info = kwargs.get("extra_conditioning_info", None)
all_timesteps_count = kwargs.get("all_timesteps_count", t_enc)
if (
extra_conditioning_info is not None
and extra_conditioning_info.wants_cross_attention_control
):
self.invokeai_diffuser.override_cross_attention(
extra_conditioning_info, step_count=all_timesteps_count
)
else:
self.invokeai_diffuser.restore_default_cross_attention()
# This is the central routine
@torch.no_grad()
def p_sample(
self,
x,
c,
t,
index,
repeat_noise=False,
use_original_steps=False,
quantize_denoised=False,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
step_count: int = 1000, # total number of steps
**kwargs,
):
b, *_, device = *x.shape, x.device
if unconditional_conditioning is None or unconditional_guidance_scale == 1.0:
# damian0815 would like to know when/if this code path is used
e_t = self.model.apply_model(x, t, c)
else:
# step_index counts in the opposite direction to index
step_index = step_count - (index + 1)
e_t = self.invokeai_diffuser.do_diffusion_step(
x,
t,
unconditional_conditioning,
c,
unconditional_guidance_scale,
step_index=step_index,
)
if score_corrector is not None:
assert self.model.parameterization == "eps"
e_t = score_corrector.modify_score(
self.model, e_t, x, t, c, **corrector_kwargs
)
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = (
self.model.alphas_cumprod_prev
if use_original_steps
else self.ddim_alphas_prev
)
sqrt_one_minus_alphas = (
self.model.sqrt_one_minus_alphas_cumprod
if use_original_steps
else self.ddim_sqrt_one_minus_alphas
)
sigmas = (
self.model.ddim_sigmas_for_original_num_steps
if use_original_steps
else self.ddim_sigmas
)
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full(
(b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
)
# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
# direction pointing to x_t
dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.0:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0, None

File diff suppressed because it is too large Load Diff

View File

@ -1,339 +0,0 @@
"""wrapper around part of Katherine Crowson's k-diffusion library, making it call compatible with other Samplers"""
import k_diffusion as K
import torch
from torch import nn
from .cross_attention_map_saving import AttentionMapSaver
from .sampler import Sampler
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent
# at this threshold, the scheduler will stop using the Karras
# noise schedule and start using the model's schedule
STEP_THRESHOLD = 30
def cfg_apply_threshold(result, threshold=0.0, scale=0.7):
if threshold <= 0.0:
return result
maxval = 0.0 + torch.max(result).cpu().numpy()
minval = 0.0 + torch.min(result).cpu().numpy()
if maxval < threshold and minval > -threshold:
return result
if maxval > threshold:
maxval = min(max(1, scale * maxval), threshold)
if minval < -threshold:
minval = max(min(-1, scale * minval), -threshold)
return torch.clamp(result, min=minval, max=maxval)
class CFGDenoiser(nn.Module):
def __init__(self, model, threshold=0, warmup=0):
super().__init__()
self.inner_model = model
self.threshold = threshold
self.warmup_max = warmup
self.warmup = max(warmup / 10, 1)
self.invokeai_diffuser = InvokeAIDiffuserComponent(
model,
model_forward_callback=lambda x, sigma, cond: self.inner_model(
x, sigma, cond=cond
),
)
def prepare_to_sample(self, t_enc, **kwargs):
extra_conditioning_info = kwargs.get("extra_conditioning_info", None)
if (
extra_conditioning_info is not None
and extra_conditioning_info.wants_cross_attention_control
):
self.invokeai_diffuser.override_cross_attention(
extra_conditioning_info, step_count=t_enc
)
else:
self.invokeai_diffuser.restore_default_cross_attention()
def forward(self, x, sigma, uncond, cond, cond_scale):
next_x = self.invokeai_diffuser.do_diffusion_step(
x, sigma, uncond, cond, cond_scale
)
if self.warmup < self.warmup_max:
thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max))
self.warmup += 1
else:
thresh = self.threshold
if thresh > self.threshold:
thresh = self.threshold
return cfg_apply_threshold(next_x, thresh)
class KSampler(Sampler):
def __init__(self, model, schedule="lms", device=None, **kwargs):
denoiser = K.external.CompVisDenoiser(model)
super().__init__(
denoiser,
schedule,
steps=model.num_timesteps,
)
self.sigmas = None
self.ds = None
self.s_in = None
self.karras_max = kwargs.get("karras_max", STEP_THRESHOLD)
if self.karras_max is None:
self.karras_max = STEP_THRESHOLD
def make_schedule(
self,
ddim_num_steps,
ddim_discretize="uniform",
ddim_eta=0.0,
verbose=False,
):
outer_model = self.model
self.model = outer_model.inner_model
super().make_schedule(
ddim_num_steps,
ddim_discretize="uniform",
ddim_eta=0.0,
verbose=False,
)
self.model = outer_model
self.ddim_num_steps = ddim_num_steps
# we don't need both of these sigmas, but storing them here to make
# comparison easier later on
self.model_sigmas = self.model.get_sigmas(ddim_num_steps)
self.karras_sigmas = K.sampling.get_sigmas_karras(
n=ddim_num_steps,
sigma_min=self.model.sigmas[0].item(),
sigma_max=self.model.sigmas[-1].item(),
rho=7.0,
device=self.device,
)
if ddim_num_steps >= self.karras_max:
print(
f">> Ksampler using model noise schedule (steps >= {self.karras_max})"
)
self.sigmas = self.model_sigmas
else:
print(
f">> Ksampler using karras noise schedule (steps < {self.karras_max})"
)
self.sigmas = self.karras_sigmas
# ALERT: We are completely overriding the sample() method in the base class, which
# means that inpainting will not work. To get this to work we need to be able to
# modify the inner loop of k_heun, k_lms, etc, as is done in an ugly way
# in the lstein/k-diffusion branch.
@torch.no_grad()
def decode(
self,
z_enc,
cond,
t_enc,
img_callback=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
use_original_steps=False,
init_latent=None,
mask=None,
**kwargs,
):
samples, _ = self.sample(
batch_size=1,
S=t_enc,
x_T=z_enc,
shape=z_enc.shape[1:],
conditioning=cond,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
img_callback=img_callback,
x0=init_latent,
mask=mask,
**kwargs,
)
return samples
# this is a no-op, provided here for compatibility with ddim and plms samplers
@torch.no_grad()
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
return x0
# Most of these arguments are ignored and are only present for compatibility with
# other samples
@torch.no_grad()
def sample(
self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
attention_maps_callback=None,
quantize_x0=False,
eta=0.0,
mask=None,
x0=None,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None,
threshold=0,
perlin=0,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs,
):
def route_callback(k_callback_values):
if img_callback is not None:
img_callback(k_callback_values["x"], k_callback_values["i"])
# if make_schedule() hasn't been called, we do it now
if self.sigmas is None:
self.make_schedule(
ddim_num_steps=S,
ddim_eta=eta,
verbose=False,
)
# sigmas are set up in make_schedule - we take the last steps items
sigmas = self.sigmas[-S - 1 :]
# x_T is variation noise. When an init image is provided (in x0) we need to add
# more randomness to the starting image.
if x_T is not None:
if x0 is not None:
x = x_T + torch.randn_like(x0, device=self.device) * sigmas[0]
else:
x = x_T * sigmas[0]
else:
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0]
model_wrap_cfg = CFGDenoiser(
self.model, threshold=threshold, warmup=max(0.8 * S, S - 10)
)
model_wrap_cfg.prepare_to_sample(
S, extra_conditioning_info=extra_conditioning_info
)
# setup attention maps saving. checks for None are because there are multiple code paths to get here.
attention_map_saver = None
if attention_maps_callback is not None and extra_conditioning_info is not None:
eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1
attention_map_token_ids = range(1, eos_token_index)
attention_map_saver = AttentionMapSaver(
token_ids=attention_map_token_ids, latents_shape=x.shape[-2:]
)
model_wrap_cfg.invokeai_diffuser.setup_attention_map_saving(
attention_map_saver
)
extra_args = {
"cond": conditioning,
"uncond": unconditional_conditioning,
"cond_scale": unconditional_guidance_scale,
}
print(
f">> Sampling with k_{self.schedule} starting at step {len(self.sigmas)-S-1} of {len(self.sigmas)-1} ({S} new sampling steps)"
)
sampling_result = (
K.sampling.__dict__[f"sample_{self.schedule}"](
model_wrap_cfg,
x,
sigmas,
extra_args=extra_args,
callback=route_callback,
),
None,
)
if attention_map_saver is not None:
attention_maps_callback(attention_map_saver)
return sampling_result
# this code will support inpainting if and when ksampler API modified or
# a workaround is found.
@torch.no_grad()
def p_sample(
self,
img,
cond,
ts,
index,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
extra_conditioning_info=None,
**kwargs,
):
if self.model_wrap is None:
self.model_wrap = CFGDenoiser(self.model)
extra_args = {
"cond": cond,
"uncond": unconditional_conditioning,
"cond_scale": unconditional_guidance_scale,
}
if self.s_in is None:
self.s_in = img.new_ones([img.shape[0]])
if self.ds is None:
self.ds = []
# terrible, confusing names here
steps = self.ddim_num_steps
t_enc = self.t_enc
# sigmas is a full steps in length, but t_enc might
# be less. We start in the middle of the sigma array
# and work our way to the end after t_enc steps.
# index starts at t_enc and works its way to zero,
# so the actual formula for indexing into sigmas:
# sigma_index = (steps-index)
s_index = t_enc - index - 1
self.model_wrap.prepare_to_sample(
s_index, extra_conditioning_info=extra_conditioning_info
)
img = K.sampling.__dict__[f"_{self.schedule}"](
self.model_wrap,
img,
self.sigmas,
s_index,
s_in=self.s_in,
ds=self.ds,
extra_args=extra_args,
)
return img, None, None
# REVIEW THIS METHOD: it has never been tested. In particular,
# we should not be multiplying by self.sigmas[0] if we
# are at an intermediate step in img2img. See similar in
# sample() which does work.
def get_initial_image(self, x_T, shape, steps):
print(f"WARNING: ksampler.get_initial_image(): get_initial_image needs testing")
x = torch.randn(shape, device=self.device) * self.sigmas[0]
if x_T is not None:
return x_T + x
else:
return x
def prepare_to_sample(self, t_enc, **kwargs):
self.t_enc = t_enc
self.model_wrap = None
self.ds = None
self.s_in = None
def q_sample(self, x0, ts):
"""
Overrides parent method to return the q_sample of the inner model.
"""
return self.model.inner_model.q_sample(x0, ts)
def conditioning_key(self) -> str:
return self.model.inner_model.model.conditioning_key

View File

@ -1,143 +0,0 @@
"""SAMPLING ONLY."""
from functools import partial
import numpy as np
import torch
from tqdm import tqdm
from ...util import choose_torch_device
from ..diffusionmodules.util import noise_like
from .sampler import Sampler
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent
class PLMSSampler(Sampler):
def __init__(self, model, schedule="linear", device=None, **kwargs):
super().__init__(model, schedule, model.num_timesteps, device)
def prepare_to_sample(self, t_enc, **kwargs):
super().prepare_to_sample(t_enc, **kwargs)
extra_conditioning_info = kwargs.get("extra_conditioning_info", None)
all_timesteps_count = kwargs.get("all_timesteps_count", t_enc)
if (
extra_conditioning_info is not None
and extra_conditioning_info.wants_cross_attention_control
):
self.invokeai_diffuser.override_cross_attention(
extra_conditioning_info, step_count=all_timesteps_count
)
else:
self.invokeai_diffuser.restore_default_cross_attention()
# this is the essential routine
@torch.no_grad()
def p_sample(
self,
x, # image, called 'img' elsewhere
c, # conditioning, called 'cond' elsewhere
t, # timesteps, called 'ts' elsewhere
index,
repeat_noise=False,
use_original_steps=False,
quantize_denoised=False,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
old_eps=[],
t_next=None,
step_count: int = 1000, # total number of steps
**kwargs,
):
b, *_, device = *x.shape, x.device
def get_model_output(x, t):
if (
unconditional_conditioning is None
or unconditional_guidance_scale == 1.0
):
# damian0815 would like to know when/if this code path is used
e_t = self.model.apply_model(x, t, c)
else:
# step_index counts in the opposite direction to index
step_index = step_count - (index + 1)
e_t = self.invokeai_diffuser.do_diffusion_step(
x,
t,
unconditional_conditioning,
c,
unconditional_guidance_scale,
step_index=step_index,
)
if score_corrector is not None:
assert self.model.parameterization == "eps"
e_t = score_corrector.modify_score(
self.model, e_t, x, t, c, **corrector_kwargs
)
return e_t
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = (
self.model.alphas_cumprod_prev
if use_original_steps
else self.ddim_alphas_prev
)
sqrt_one_minus_alphas = (
self.model.sqrt_one_minus_alphas_cumprod
if use_original_steps
else self.ddim_sqrt_one_minus_alphas
)
sigmas = (
self.model.ddim_sigmas_for_original_num_steps
if use_original_steps
else self.ddim_sigmas
)
def get_x_prev_and_pred_x0(e_t, index):
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full(
(b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
)
# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
# direction pointing to x_t
dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.0:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
e_t = get_model_output(x, t)
if len(old_eps) == 0:
# Pseudo Improved Euler (2nd order)
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
e_t_next = get_model_output(x_prev, t_next)
e_t_prime = (e_t + e_t_next) / 2
elif len(old_eps) == 1:
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (3 * e_t - old_eps[-1]) / 2
elif len(old_eps) == 2:
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
elif len(old_eps) >= 3:
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (
55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]
) / 24
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
return x_prev, pred_x0, e_t

View File

@ -1,454 +0,0 @@
"""
invokeai.models.diffusion.sampler
Base class for invokeai.models.diffusion.ddim, invokeai.models.diffusion.ksampler, etc
"""
from functools import partial
import numpy as np
import torch
from tqdm import tqdm
from ...util import choose_torch_device
from ..diffusionmodules.util import (
extract_into_tensor,
make_ddim_sampling_parameters,
make_ddim_timesteps,
noise_like,
)
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent
class Sampler(object):
def __init__(self, model, schedule="linear", steps=None, device=None, **kwargs):
self.model = model
self.ddim_timesteps = None
self.ddpm_num_timesteps = steps
self.schedule = schedule
self.device = device or choose_torch_device()
self.invokeai_diffuser = InvokeAIDiffuserComponent(
self.model,
model_forward_callback=lambda x, sigma, cond: self.model.apply_model(
x, sigma, cond
),
)
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device(self.device):
attr = attr.to(torch.float32).to(torch.device(self.device))
setattr(self, name, attr)
# This method was copied over from ddim.py and probably does stuff that is
# ddim-specific. Disentangle at some point.
def make_schedule(
self,
ddim_num_steps,
ddim_discretize="uniform",
ddim_eta=0.0,
verbose=False,
):
self.total_steps = ddim_num_steps
self.ddim_timesteps = make_ddim_timesteps(
ddim_discr_method=ddim_discretize,
num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,
verbose=verbose,
)
alphas_cumprod = self.model.alphas_cumprod
assert (
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
), "alphas have to be defined for each timestep"
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
self.register_buffer("betas", to_torch(self.model.betas))
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
self.register_buffer(
"alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer(
"sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
)
self.register_buffer(
"sqrt_one_minus_alphas_cumprod",
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
)
self.register_buffer(
"log_one_minus_alphas_cumprod",
to_torch(np.log(1.0 - alphas_cumprod.cpu())),
)
self.register_buffer(
"sqrt_recip_alphas_cumprod",
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())),
)
self.register_buffer(
"sqrt_recipm1_alphas_cumprod",
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
)
# ddim sampling parameters
(
ddim_sigmas,
ddim_alphas,
ddim_alphas_prev,
) = make_ddim_sampling_parameters(
alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,
verbose=verbose,
)
self.register_buffer("ddim_sigmas", ddim_sigmas)
self.register_buffer("ddim_alphas", ddim_alphas)
self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev)
/ (1 - self.alphas_cumprod)
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
)
self.register_buffer(
"ddim_sigmas_for_original_num_steps",
sigmas_for_original_sampling_steps,
)
@torch.no_grad()
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
# fast, but does not allow for exact reconstruction
# t serves as an index to gather the correct alphas
if use_original_steps:
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
else:
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
if noise is None:
noise = torch.randn_like(x0)
return (
extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise
)
@torch.no_grad()
def sample(
self,
S, # S is steps
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None, # TODO: this is very confusing because it is called "step_callback" elsewhere. Change.
quantize_x0=False,
eta=0.0,
mask=None,
x0=None,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
verbose=False,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs,
):
if conditioning is not None:
if isinstance(conditioning, dict):
ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list):
ctmp = ctmp[0]
cbs = ctmp.shape[0]
if cbs != batch_size:
print(
f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
)
else:
if conditioning.shape[0] != batch_size:
print(
f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
)
# check to see if make_schedule() has run, and if not, run it
if self.ddim_timesteps is None:
self.make_schedule(
ddim_num_steps=S,
ddim_eta=eta,
verbose=False,
)
ts = self.get_timesteps(S)
# sampling
C, H, W = shape
shape = (batch_size, C, H, W)
samples, intermediates = self.do_sampling(
conditioning,
shape,
timesteps=ts,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask,
x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
steps=S,
**kwargs,
)
return samples, intermediates
@torch.no_grad()
def do_sampling(
self,
cond,
shape,
timesteps=None,
x_T=None,
ddim_use_original_steps=False,
callback=None,
quantize_denoised=False,
mask=None,
x0=None,
img_callback=None,
log_every_t=100,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
steps=None,
**kwargs,
):
b = shape[0]
time_range = (
list(reversed(range(0, timesteps)))
if ddim_use_original_steps
else np.flip(timesteps)
)
total_steps = steps
iterator = tqdm(
time_range,
desc=f"{self.__class__.__name__}",
total=total_steps,
dynamic_ncols=True,
)
old_eps = []
self.prepare_to_sample(t_enc=total_steps, all_timesteps_count=steps, **kwargs)
img = self.get_initial_image(x_T, shape, total_steps)
# probably don't need this at all
intermediates = {"x_inter": [img], "pred_x0": [img]}
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=self.device, dtype=torch.long)
ts_next = torch.full(
(b,),
time_range[min(i + 1, len(time_range) - 1)],
device=self.device,
dtype=torch.long,
)
if mask is not None:
assert x0 is not None
img_orig = self.model.q_sample(
x0, ts
) # TODO: deterministic forward pass?
img = img_orig * mask + (1.0 - mask) * img
outs = self.p_sample(
img,
cond,
ts,
index=index,
use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised,
temperature=temperature,
noise_dropout=noise_dropout,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
old_eps=old_eps,
t_next=ts_next,
step_count=steps,
)
img, pred_x0, e_t = outs
old_eps.append(e_t)
if len(old_eps) >= 4:
old_eps.pop(0)
if callback:
callback(i)
if img_callback:
img_callback(img, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates["x_inter"].append(img)
intermediates["pred_x0"].append(pred_x0)
return img, intermediates
# NOTE that decode() and sample() are almost the same code, and do the same thing.
# The variable names are changed in order to be confusing.
@torch.no_grad()
def decode(
self,
x_latent,
cond,
t_start,
img_callback=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
use_original_steps=False,
init_latent=None,
mask=None,
all_timesteps_count=None,
**kwargs,
):
timesteps = (
np.arange(self.ddpm_num_timesteps)
if use_original_steps
else self.ddim_timesteps
)
timesteps = timesteps[:t_start]
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
print(
f">> Running {self.__class__.__name__} sampling starting at step {self.total_steps - t_start} of {self.total_steps} ({total_steps} new sampling steps)"
)
iterator = tqdm(time_range, desc="Decoding image", total=total_steps)
x_dec = x_latent
x0 = init_latent
self.prepare_to_sample(
t_enc=total_steps, all_timesteps_count=all_timesteps_count, **kwargs
)
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full(
(x_latent.shape[0],),
step,
device=x_latent.device,
dtype=torch.long,
)
ts_next = torch.full(
(x_latent.shape[0],),
time_range[min(i + 1, len(time_range) - 1)],
device=self.device,
dtype=torch.long,
)
if mask is not None:
assert x0 is not None
xdec_orig = self.q_sample(x0, ts) # TODO: deterministic forward pass?
x_dec = xdec_orig * mask + (1.0 - mask) * x_dec
outs = self.p_sample(
x_dec,
cond,
ts,
index=index,
use_original_steps=use_original_steps,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
t_next=ts_next,
step_count=len(self.ddim_timesteps),
)
x_dec, pred_x0, e_t = outs
if img_callback:
img_callback(x_dec, i)
return x_dec
def get_initial_image(self, x_T, shape, timesteps=None):
if x_T is None:
return torch.randn(shape, device=self.device)
else:
return x_T
def p_sample(
self,
img,
cond,
ts,
index,
repeat_noise=False,
use_original_steps=False,
quantize_denoised=False,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
old_eps=None,
t_next=None,
steps=None,
):
raise NotImplementedError(
"p_sample() must be implemented in a descendent class"
)
def prepare_to_sample(self, t_enc, **kwargs):
"""
Hook that will be called right before the very first invocation of p_sample()
to allow subclass to do additional initialization. t_enc corresponds to the actual
number of steps that will be run, and may be less than total steps if img2img is
active.
"""
pass
def get_timesteps(self, ddim_steps):
"""
The ddim and plms samplers work on timesteps. This method is called after
ddim_timesteps are created in make_schedule(), and selects the portion of
timesteps that will be used for sampling, depending on the t_enc in img2img.
"""
return self.ddim_timesteps[:ddim_steps]
def q_sample(self, x0, ts):
"""
Returns self.model.q_sample(x0,ts). Is overridden in the k* samplers to
return self.model.inner_model.q_sample(x0,ts)
"""
return self.model.q_sample(x0, ts)
def conditioning_key(self) -> str:
return self.model.model.conditioning_key
def uses_inpainting_model(self) -> bool:
return self.conditioning_key() in ("hybrid", "concat")
def adjust_settings(self, **kwargs):
"""
This is a catch-all method for adjusting any instance variables
after the sampler is instantiated. No type-checking performed
here, so use with care!
"""
for k in kwargs.keys():
try:
setattr(self, k, kwargs[k])
except AttributeError:
print(
f"** Warning: attempt to set unknown attribute {k} in sampler of type {type(self)}"
)