InvokeAI/invokeai/backend/stable_diffusion/diffusion/ksampler.py

340 lines
11 KiB
Python
Raw Normal View History

2023-02-28 05:31:15 +00:00
"""wrapper around part of Katherine Crowson's k-diffusion library, making it call compatible with other Samplers"""
import k_diffusion as K
import torch
from torch import nn
from .cross_attention_map_saving import AttentionMapSaver
from .sampler import Sampler
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent
# at this threshold, the scheduler will stop using the Karras
# noise schedule and start using the model's schedule
STEP_THRESHOLD = 30
2023-03-03 06:02:00 +00:00
def cfg_apply_threshold(result, threshold=0.0, scale=0.7):
2023-02-28 05:31:15 +00:00
if threshold <= 0.0:
return result
maxval = 0.0 + torch.max(result).cpu().numpy()
minval = 0.0 + torch.min(result).cpu().numpy()
if maxval < threshold and minval > -threshold:
return result
if maxval > threshold:
2023-03-03 06:02:00 +00:00
maxval = min(max(1, scale * maxval), threshold)
2023-02-28 05:31:15 +00:00
if minval < -threshold:
2023-03-03 06:02:00 +00:00
minval = max(min(-1, scale * minval), -threshold)
2023-02-28 05:31:15 +00:00
return torch.clamp(result, min=minval, max=maxval)
class CFGDenoiser(nn.Module):
2023-03-03 06:02:00 +00:00
def __init__(self, model, threshold=0, warmup=0):
2023-02-28 05:31:15 +00:00
super().__init__()
self.inner_model = model
self.threshold = threshold
self.warmup_max = warmup
self.warmup = max(warmup / 10, 1)
2023-03-03 06:02:00 +00:00
self.invokeai_diffuser = InvokeAIDiffuserComponent(
model,
model_forward_callback=lambda x, sigma, cond: self.inner_model(
x, sigma, cond=cond
),
)
2023-02-28 05:31:15 +00:00
def prepare_to_sample(self, t_enc, **kwargs):
2023-03-03 06:02:00 +00:00
extra_conditioning_info = kwargs.get("extra_conditioning_info", None)
if (
extra_conditioning_info is not None
and extra_conditioning_info.wants_cross_attention_control
):
self.invokeai_diffuser.override_cross_attention(
extra_conditioning_info, step_count=t_enc
)
2023-02-28 05:31:15 +00:00
else:
self.invokeai_diffuser.restore_default_cross_attention()
def forward(self, x, sigma, uncond, cond, cond_scale):
2023-03-03 06:02:00 +00:00
next_x = self.invokeai_diffuser.do_diffusion_step(
x, sigma, uncond, cond, cond_scale
)
2023-02-28 05:31:15 +00:00
if self.warmup < self.warmup_max:
thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max))
self.warmup += 1
else:
thresh = self.threshold
if thresh > self.threshold:
thresh = self.threshold
return cfg_apply_threshold(next_x, thresh)
2023-03-03 06:02:00 +00:00
2023-02-28 05:31:15 +00:00
class KSampler(Sampler):
2023-03-03 06:02:00 +00:00
def __init__(self, model, schedule="lms", device=None, **kwargs):
2023-02-28 05:31:15 +00:00
denoiser = K.external.CompVisDenoiser(model)
super().__init__(
denoiser,
schedule,
steps=model.num_timesteps,
)
self.sigmas = None
2023-03-03 06:02:00 +00:00
self.ds = None
self.s_in = None
self.karras_max = kwargs.get("karras_max", STEP_THRESHOLD)
2023-02-28 05:31:15 +00:00
if self.karras_max is None:
self.karras_max = STEP_THRESHOLD
def make_schedule(
2023-03-03 06:02:00 +00:00
self,
ddim_num_steps,
ddim_discretize="uniform",
ddim_eta=0.0,
verbose=False,
2023-02-28 05:31:15 +00:00
):
outer_model = self.model
2023-03-03 06:02:00 +00:00
self.model = outer_model.inner_model
2023-02-28 05:31:15 +00:00
super().make_schedule(
ddim_num_steps,
2023-03-03 06:02:00 +00:00
ddim_discretize="uniform",
2023-02-28 05:31:15 +00:00
ddim_eta=0.0,
verbose=False,
)
2023-03-03 06:02:00 +00:00
self.model = outer_model
2023-02-28 05:31:15 +00:00
self.ddim_num_steps = ddim_num_steps
# we don't need both of these sigmas, but storing them here to make
# comparison easier later on
2023-03-03 06:02:00 +00:00
self.model_sigmas = self.model.get_sigmas(ddim_num_steps)
2023-02-28 05:31:15 +00:00
self.karras_sigmas = K.sampling.get_sigmas_karras(
n=ddim_num_steps,
sigma_min=self.model.sigmas[0].item(),
sigma_max=self.model.sigmas[-1].item(),
2023-03-03 06:02:00 +00:00
rho=7.0,
2023-02-28 05:31:15 +00:00
device=self.device,
)
if ddim_num_steps >= self.karras_max:
2023-03-03 06:02:00 +00:00
print(
f">> Ksampler using model noise schedule (steps >= {self.karras_max})"
)
2023-02-28 05:31:15 +00:00
self.sigmas = self.model_sigmas
else:
2023-03-03 06:02:00 +00:00
print(
f">> Ksampler using karras noise schedule (steps < {self.karras_max})"
)
2023-02-28 05:31:15 +00:00
self.sigmas = self.karras_sigmas
# ALERT: We are completely overriding the sample() method in the base class, which
# means that inpainting will not work. To get this to work we need to be able to
# modify the inner loop of k_heun, k_lms, etc, as is done in an ugly way
# in the lstein/k-diffusion branch.
@torch.no_grad()
def decode(
2023-03-03 06:02:00 +00:00
self,
z_enc,
cond,
t_enc,
img_callback=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
use_original_steps=False,
init_latent=None,
mask=None,
**kwargs,
2023-02-28 05:31:15 +00:00
):
2023-03-03 06:02:00 +00:00
samples, _ = self.sample(
batch_size=1,
S=t_enc,
x_T=z_enc,
shape=z_enc.shape[1:],
conditioning=cond,
2023-02-28 05:31:15 +00:00
unconditional_guidance_scale=unconditional_guidance_scale,
2023-03-03 06:02:00 +00:00
unconditional_conditioning=unconditional_conditioning,
img_callback=img_callback,
x0=init_latent,
mask=mask,
**kwargs,
)
2023-02-28 05:31:15 +00:00
return samples
# this is a no-op, provided here for compatibility with ddim and plms samplers
@torch.no_grad()
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
return x0
# Most of these arguments are ignored and are only present for compatibility with
# other samples
@torch.no_grad()
def sample(
self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
attention_maps_callback=None,
quantize_x0=False,
eta=0.0,
mask=None,
x0=None,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
2023-03-03 06:02:00 +00:00
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None,
threshold=0,
perlin=0,
2023-02-28 05:31:15 +00:00
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs,
):
def route_callback(k_callback_values):
if img_callback is not None:
2023-03-03 06:02:00 +00:00
img_callback(k_callback_values["x"], k_callback_values["i"])
2023-02-28 05:31:15 +00:00
# if make_schedule() hasn't been called, we do it now
if self.sigmas is None:
self.make_schedule(
ddim_num_steps=S,
2023-03-03 06:02:00 +00:00
ddim_eta=eta,
verbose=False,
2023-02-28 05:31:15 +00:00
)
# sigmas are set up in make_schedule - we take the last steps items
2023-03-03 06:02:00 +00:00
sigmas = self.sigmas[-S - 1 :]
2023-02-28 05:31:15 +00:00
# x_T is variation noise. When an init image is provided (in x0) we need to add
# more randomness to the starting image.
if x_T is not None:
if x0 is not None:
x = x_T + torch.randn_like(x0, device=self.device) * sigmas[0]
else:
x = x_T * sigmas[0]
else:
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0]
2023-03-03 06:02:00 +00:00
model_wrap_cfg = CFGDenoiser(
self.model, threshold=threshold, warmup=max(0.8 * S, S - 10)
)
model_wrap_cfg.prepare_to_sample(
S, extra_conditioning_info=extra_conditioning_info
)
2023-02-28 05:31:15 +00:00
# setup attention maps saving. checks for None are because there are multiple code paths to get here.
attention_map_saver = None
if attention_maps_callback is not None and extra_conditioning_info is not None:
eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1
attention_map_token_ids = range(1, eos_token_index)
2023-03-03 06:02:00 +00:00
attention_map_saver = AttentionMapSaver(
token_ids=attention_map_token_ids, latents_shape=x.shape[-2:]
)
model_wrap_cfg.invokeai_diffuser.setup_attention_map_saving(
attention_map_saver
)
2023-02-28 05:31:15 +00:00
extra_args = {
2023-03-03 06:02:00 +00:00
"cond": conditioning,
"uncond": unconditional_conditioning,
"cond_scale": unconditional_guidance_scale,
2023-02-28 05:31:15 +00:00
}
2023-03-03 06:02:00 +00:00
print(
f">> Sampling with k_{self.schedule} starting at step {len(self.sigmas)-S-1} of {len(self.sigmas)-1} ({S} new sampling steps)"
)
2023-02-28 05:31:15 +00:00
sampling_result = (
2023-03-03 06:02:00 +00:00
K.sampling.__dict__[f"sample_{self.schedule}"](
model_wrap_cfg,
x,
sigmas,
extra_args=extra_args,
callback=route_callback,
2023-02-28 05:31:15 +00:00
),
None,
)
if attention_map_saver is not None:
attention_maps_callback(attention_map_saver)
return sampling_result
# this code will support inpainting if and when ksampler API modified or
# a workaround is found.
@torch.no_grad()
def p_sample(
2023-03-03 06:02:00 +00:00
self,
img,
cond,
ts,
index,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
extra_conditioning_info=None,
**kwargs,
2023-02-28 05:31:15 +00:00
):
if self.model_wrap is None:
self.model_wrap = CFGDenoiser(self.model)
extra_args = {
2023-03-03 06:02:00 +00:00
"cond": cond,
"uncond": unconditional_conditioning,
"cond_scale": unconditional_guidance_scale,
2023-02-28 05:31:15 +00:00
}
if self.s_in is None:
2023-03-03 06:02:00 +00:00
self.s_in = img.new_ones([img.shape[0]])
2023-02-28 05:31:15 +00:00
if self.ds is None:
self.ds = []
# terrible, confusing names here
steps = self.ddim_num_steps
t_enc = self.t_enc
# sigmas is a full steps in length, but t_enc might
# be less. We start in the middle of the sigma array
# and work our way to the end after t_enc steps.
# index starts at t_enc and works its way to zero,
# so the actual formula for indexing into sigmas:
# sigma_index = (steps-index)
s_index = t_enc - index - 1
2023-03-03 06:02:00 +00:00
self.model_wrap.prepare_to_sample(
s_index, extra_conditioning_info=extra_conditioning_info
)
img = K.sampling.__dict__[f"_{self.schedule}"](
2023-02-28 05:31:15 +00:00
self.model_wrap,
img,
self.sigmas,
s_index,
2023-03-03 06:02:00 +00:00
s_in=self.s_in,
ds=self.ds,
2023-02-28 05:31:15 +00:00
extra_args=extra_args,
)
return img, None, None
# REVIEW THIS METHOD: it has never been tested. In particular,
# we should not be multiplying by self.sigmas[0] if we
# are at an intermediate step in img2img. See similar in
# sample() which does work.
2023-03-03 06:02:00 +00:00
def get_initial_image(self, x_T, shape, steps):
print(f"WARNING: ksampler.get_initial_image(): get_initial_image needs testing")
x = torch.randn(shape, device=self.device) * self.sigmas[0]
2023-02-28 05:31:15 +00:00
if x_T is not None:
return x_T + x
else:
return x
2023-03-03 06:02:00 +00:00
def prepare_to_sample(self, t_enc, **kwargs):
self.t_enc = t_enc
2023-02-28 05:31:15 +00:00
self.model_wrap = None
2023-03-03 06:02:00 +00:00
self.ds = None
self.s_in = None
2023-02-28 05:31:15 +00:00
2023-03-03 06:02:00 +00:00
def q_sample(self, x0, ts):
"""
2023-02-28 05:31:15 +00:00
Overrides parent method to return the q_sample of the inner model.
2023-03-03 06:02:00 +00:00
"""
return self.model.inner_model.q_sample(x0, ts)
2023-02-28 05:31:15 +00:00
2023-03-03 06:02:00 +00:00
def conditioning_key(self) -> str:
2023-02-28 05:31:15 +00:00
return self.model.inner_model.model.conditioning_key