2022-08-26 07:15:42 +00:00
|
|
|
"""wrapper around part of Katherine Crowson's k-diffusion library, making it call compatible with other Samplers"""
|
2022-10-12 21:29:48 +00:00
|
|
|
|
2022-08-21 21:09:00 +00:00
|
|
|
import k_diffusion as K
|
2022-08-21 23:57:48 +00:00
|
|
|
import torch
|
2022-10-19 16:19:55 +00:00
|
|
|
from torch import nn
|
|
|
|
|
2022-12-10 14:57:41 +00:00
|
|
|
from .cross_attention_map_saving import AttentionMapSaver
|
2022-10-19 16:19:55 +00:00
|
|
|
from .sampler import Sampler
|
|
|
|
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
2022-10-18 20:09:06 +00:00
|
|
|
|
2022-09-02 17:39:26 +00:00
|
|
|
|
2022-10-23 03:02:50 +00:00
|
|
|
# at this threshold, the scheduler will stop using the Karras
|
|
|
|
# noise schedule and start using the model's schedule
|
2022-10-30 14:35:55 +00:00
|
|
|
STEP_THRESHOLD = 30
|
2022-10-23 03:02:50 +00:00
|
|
|
|
2022-09-06 01:40:05 +00:00
|
|
|
def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
|
2022-09-02 17:39:26 +00:00
|
|
|
if threshold <= 0.0:
|
|
|
|
return result
|
|
|
|
maxval = 0.0 + torch.max(result).cpu().numpy()
|
|
|
|
minval = 0.0 + torch.min(result).cpu().numpy()
|
|
|
|
if maxval < threshold and minval > -threshold:
|
|
|
|
return result
|
|
|
|
if maxval > threshold:
|
|
|
|
maxval = min(max(1, scale*maxval), threshold)
|
|
|
|
if minval < -threshold:
|
|
|
|
minval = max(min(-1, scale*minval), -threshold)
|
|
|
|
return torch.clamp(result, min=minval, max=maxval)
|
|
|
|
|
2022-08-26 07:15:42 +00:00
|
|
|
|
2022-10-19 16:19:55 +00:00
|
|
|
class CFGDenoiser(nn.Module):
|
2022-10-26 22:25:48 +00:00
|
|
|
def __init__(self, model, threshold = 0, warmup = 0):
|
2022-08-21 21:09:00 +00:00
|
|
|
super().__init__()
|
2022-10-26 22:25:48 +00:00
|
|
|
self.inner_model = model
|
2022-09-02 17:39:26 +00:00
|
|
|
self.threshold = threshold
|
2022-09-06 01:40:05 +00:00
|
|
|
self.warmup_max = warmup
|
2022-09-07 08:50:53 +00:00
|
|
|
self.warmup = max(warmup / 10, 1)
|
2022-10-19 16:19:55 +00:00
|
|
|
self.invokeai_diffuser = InvokeAIDiffuserComponent(model,
|
|
|
|
model_forward_callback=lambda x, sigma, cond: self.inner_model(x, sigma, cond=cond))
|
2022-08-21 21:09:00 +00:00
|
|
|
|
2022-12-10 14:57:41 +00:00
|
|
|
|
2022-10-18 20:09:06 +00:00
|
|
|
def prepare_to_sample(self, t_enc, **kwargs):
|
2022-10-16 18:39:47 +00:00
|
|
|
|
2022-10-20 10:01:48 +00:00
|
|
|
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
|
2022-10-12 21:29:48 +00:00
|
|
|
|
2022-10-20 10:01:48 +00:00
|
|
|
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
2022-10-23 12:58:25 +00:00
|
|
|
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = t_enc)
|
2022-10-19 16:19:55 +00:00
|
|
|
else:
|
2022-10-19 17:57:20 +00:00
|
|
|
self.invokeai_diffuser.remove_cross_attention_control()
|
2022-10-12 21:29:48 +00:00
|
|
|
|
2022-10-18 09:48:33 +00:00
|
|
|
|
2022-10-23 14:26:50 +00:00
|
|
|
def forward(self, x, sigma, uncond, cond, cond_scale):
|
2022-10-27 02:40:01 +00:00
|
|
|
next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale)
|
2022-09-06 01:40:05 +00:00
|
|
|
if self.warmup < self.warmup_max:
|
|
|
|
thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max))
|
|
|
|
self.warmup += 1
|
|
|
|
else:
|
|
|
|
thresh = self.threshold
|
|
|
|
if thresh > self.threshold:
|
|
|
|
thresh = self.threshold
|
2022-10-19 17:57:20 +00:00
|
|
|
return cfg_apply_threshold(next_x, thresh)
|
2022-10-19 16:19:55 +00:00
|
|
|
|
2022-09-25 08:03:28 +00:00
|
|
|
class KSampler(Sampler):
|
2022-08-31 04:33:23 +00:00
|
|
|
def __init__(self, model, schedule='lms', device=None, **kwargs):
|
2022-09-25 08:03:28 +00:00
|
|
|
denoiser = K.external.CompVisDenoiser(model)
|
|
|
|
super().__init__(
|
|
|
|
denoiser,
|
|
|
|
schedule,
|
|
|
|
steps=model.num_timesteps,
|
|
|
|
)
|
2022-10-06 14:39:08 +00:00
|
|
|
self.sigmas = None
|
|
|
|
self.ds = None
|
|
|
|
self.s_in = None
|
2022-10-27 19:50:32 +00:00
|
|
|
self.karras_max = kwargs.get('karras_max',STEP_THRESHOLD)
|
|
|
|
if self.karras_max is None:
|
|
|
|
self.karras_max = STEP_THRESHOLD
|
2022-08-21 21:09:00 +00:00
|
|
|
|
2022-09-25 08:03:28 +00:00
|
|
|
def make_schedule(
|
|
|
|
self,
|
|
|
|
ddim_num_steps,
|
|
|
|
ddim_discretize='uniform',
|
|
|
|
ddim_eta=0.0,
|
|
|
|
verbose=False,
|
|
|
|
):
|
|
|
|
outer_model = self.model
|
|
|
|
self.model = outer_model.inner_model
|
|
|
|
super().make_schedule(
|
|
|
|
ddim_num_steps,
|
|
|
|
ddim_discretize='uniform',
|
|
|
|
ddim_eta=0.0,
|
|
|
|
verbose=False,
|
|
|
|
)
|
2022-10-04 21:48:16 +00:00
|
|
|
self.model = outer_model
|
2022-09-25 08:03:28 +00:00
|
|
|
self.ddim_num_steps = ddim_num_steps
|
2022-10-06 22:31:04 +00:00
|
|
|
# we don't need both of these sigmas, but storing them here to make
|
|
|
|
# comparison easier later on
|
|
|
|
self.model_sigmas = self.model.get_sigmas(ddim_num_steps)
|
|
|
|
self.karras_sigmas = K.sampling.get_sigmas_karras(
|
|
|
|
n=ddim_num_steps,
|
|
|
|
sigma_min=self.model.sigmas[0].item(),
|
|
|
|
sigma_max=self.model.sigmas[-1].item(),
|
|
|
|
rho=7.,
|
|
|
|
device=self.device,
|
|
|
|
)
|
2022-10-23 03:02:50 +00:00
|
|
|
|
2022-10-27 19:50:32 +00:00
|
|
|
if ddim_num_steps >= self.karras_max:
|
2022-10-30 14:35:55 +00:00
|
|
|
print(f'>> Ksampler using model noise schedule (steps >= {self.karras_max})')
|
2022-10-23 03:02:50 +00:00
|
|
|
self.sigmas = self.model_sigmas
|
|
|
|
else:
|
2022-10-30 14:35:55 +00:00
|
|
|
print(f'>> Ksampler using karras noise schedule (steps < {self.karras_max})')
|
2022-10-23 03:02:50 +00:00
|
|
|
self.sigmas = self.karras_sigmas
|
2022-12-10 14:57:41 +00:00
|
|
|
|
2022-09-25 08:03:28 +00:00
|
|
|
# ALERT: We are completely overriding the sample() method in the base class, which
|
2022-10-06 22:31:04 +00:00
|
|
|
# means that inpainting will not work. To get this to work we need to be able to
|
|
|
|
# modify the inner loop of k_heun, k_lms, etc, as is done in an ugly way
|
|
|
|
# in the lstein/k-diffusion branch.
|
2022-12-10 14:57:41 +00:00
|
|
|
|
2022-10-06 22:31:04 +00:00
|
|
|
@torch.no_grad()
|
|
|
|
def decode(
|
|
|
|
self,
|
|
|
|
z_enc,
|
|
|
|
cond,
|
|
|
|
t_enc,
|
|
|
|
img_callback=None,
|
|
|
|
unconditional_guidance_scale=1.0,
|
|
|
|
unconditional_conditioning=None,
|
|
|
|
use_original_steps=False,
|
|
|
|
init_latent = None,
|
|
|
|
mask = None,
|
2022-10-18 21:23:38 +00:00
|
|
|
**kwargs
|
2022-10-06 22:31:04 +00:00
|
|
|
):
|
|
|
|
samples,_ = self.sample(
|
|
|
|
batch_size = 1,
|
|
|
|
S = t_enc,
|
|
|
|
x_T = z_enc,
|
|
|
|
shape = z_enc.shape[1:],
|
|
|
|
conditioning = cond,
|
|
|
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
|
|
|
unconditional_conditioning = unconditional_conditioning,
|
|
|
|
img_callback = img_callback,
|
|
|
|
x0 = init_latent,
|
2022-10-18 21:23:38 +00:00
|
|
|
mask = mask,
|
|
|
|
**kwargs
|
2022-10-06 22:31:04 +00:00
|
|
|
)
|
|
|
|
return samples
|
|
|
|
|
|
|
|
# this is a no-op, provided here for compatibility with ddim and plms samplers
|
|
|
|
@torch.no_grad()
|
|
|
|
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
|
|
|
|
return x0
|
2022-12-10 14:57:41 +00:00
|
|
|
|
2022-09-25 08:03:28 +00:00
|
|
|
# Most of these arguments are ignored and are only present for compatibility with
|
2022-08-21 21:09:00 +00:00
|
|
|
# other samples
|
|
|
|
@torch.no_grad()
|
2022-08-26 07:15:42 +00:00
|
|
|
def sample(
|
|
|
|
self,
|
|
|
|
S,
|
|
|
|
batch_size,
|
|
|
|
shape,
|
|
|
|
conditioning=None,
|
|
|
|
callback=None,
|
|
|
|
normals_sequence=None,
|
|
|
|
img_callback=None,
|
2022-12-10 14:57:41 +00:00
|
|
|
attention_maps_callback=None,
|
2022-08-26 07:15:42 +00:00
|
|
|
quantize_x0=False,
|
|
|
|
eta=0.0,
|
|
|
|
mask=None,
|
|
|
|
x0=None,
|
|
|
|
temperature=1.0,
|
|
|
|
noise_dropout=0.0,
|
|
|
|
score_corrector=None,
|
|
|
|
corrector_kwargs=None,
|
|
|
|
verbose=True,
|
|
|
|
x_T=None,
|
|
|
|
log_every_t=100,
|
|
|
|
unconditional_guidance_scale=1.0,
|
|
|
|
unconditional_conditioning=None,
|
2022-12-10 14:57:41 +00:00
|
|
|
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo=None,
|
2022-09-02 17:39:26 +00:00
|
|
|
threshold = 0,
|
|
|
|
perlin = 0,
|
2022-08-26 07:15:42 +00:00
|
|
|
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
|
|
|
**kwargs,
|
|
|
|
):
|
2022-08-27 01:10:13 +00:00
|
|
|
def route_callback(k_callback_values):
|
|
|
|
if img_callback is not None:
|
2022-10-03 09:38:43 +00:00
|
|
|
img_callback(k_callback_values['x'],k_callback_values['i'])
|
2022-08-21 21:09:00 +00:00
|
|
|
|
2022-10-09 19:29:04 +00:00
|
|
|
# if make_schedule() hasn't been called, we do it now
|
|
|
|
if self.sigmas is None:
|
|
|
|
self.make_schedule(
|
|
|
|
ddim_num_steps=S,
|
|
|
|
ddim_eta = eta,
|
|
|
|
verbose = False,
|
|
|
|
)
|
|
|
|
|
2022-10-06 18:57:06 +00:00
|
|
|
# sigmas are set up in make_schedule - we take the last steps items
|
2022-10-07 18:43:59 +00:00
|
|
|
sigmas = self.sigmas[-S-1:]
|
|
|
|
|
|
|
|
# x_T is variation noise. When an init image is provided (in x0) we need to add
|
|
|
|
# more randomness to the starting image.
|
2022-09-01 02:31:52 +00:00
|
|
|
if x_T is not None:
|
2022-10-07 18:43:59 +00:00
|
|
|
if x0 is not None:
|
|
|
|
x = x_T + torch.randn_like(x0, device=self.device) * sigmas[0]
|
|
|
|
else:
|
|
|
|
x = x_T * sigmas[0]
|
2022-10-06 18:57:06 +00:00
|
|
|
else:
|
|
|
|
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0]
|
|
|
|
|
2022-10-18 20:09:06 +00:00
|
|
|
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10))
|
2022-10-20 10:01:48 +00:00
|
|
|
model_wrap_cfg.prepare_to_sample(S, extra_conditioning_info=extra_conditioning_info)
|
2022-12-10 14:57:41 +00:00
|
|
|
|
2022-12-11 18:48:12 +00:00
|
|
|
# setup attention maps saving. checks for None are because there are multiple code paths to get here.
|
|
|
|
attention_maps_saver = None
|
|
|
|
if attention_maps_callback is not None and extra_conditioning_info is not None:
|
|
|
|
eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1
|
|
|
|
attention_map_token_ids = range(1, eos_token_index)
|
|
|
|
attention_maps_saver = AttentionMapSaver(token_ids = attention_map_token_ids, latents_shape=x.shape[-2:])
|
2022-12-10 14:57:41 +00:00
|
|
|
model_wrap_cfg.invokeai_diffuser.setup_attention_map_saving(attention_maps_saver)
|
|
|
|
|
2022-08-26 07:15:42 +00:00
|
|
|
extra_args = {
|
|
|
|
'cond': conditioning,
|
|
|
|
'uncond': unconditional_conditioning,
|
|
|
|
'cond_scale': unconditional_guidance_scale,
|
|
|
|
}
|
2022-10-06 14:39:08 +00:00
|
|
|
print(f'>> Sampling with k_{self.schedule} starting at step {len(self.sigmas)-S-1} of {len(self.sigmas)-1} ({S} new sampling steps)')
|
2022-10-18 20:09:06 +00:00
|
|
|
sampling_result = (
|
2022-08-26 07:15:42 +00:00
|
|
|
K.sampling.__dict__[f'sample_{self.schedule}'](
|
2022-08-27 01:10:13 +00:00
|
|
|
model_wrap_cfg, x, sigmas, extra_args=extra_args,
|
|
|
|
callback=route_callback
|
2022-08-26 07:15:42 +00:00
|
|
|
),
|
|
|
|
None,
|
|
|
|
)
|
2022-12-11 18:48:12 +00:00
|
|
|
if attention_maps_saver is not None:
|
2022-12-10 14:57:41 +00:00
|
|
|
attention_maps_callback(attention_maps_saver)
|
2022-10-18 20:09:06 +00:00
|
|
|
return sampling_result
|
2022-09-25 08:03:28 +00:00
|
|
|
|
2022-10-06 14:39:08 +00:00
|
|
|
# this code will support inpainting if and when ksampler API modified or
|
|
|
|
# a workaround is found.
|
2022-09-25 08:03:28 +00:00
|
|
|
@torch.no_grad()
|
|
|
|
def p_sample(
|
|
|
|
self,
|
|
|
|
img,
|
|
|
|
cond,
|
|
|
|
ts,
|
|
|
|
index,
|
|
|
|
unconditional_guidance_scale=1.0,
|
|
|
|
unconditional_conditioning=None,
|
2022-10-20 10:01:48 +00:00
|
|
|
extra_conditioning_info=None,
|
2022-09-25 08:03:28 +00:00
|
|
|
**kwargs,
|
|
|
|
):
|
|
|
|
if self.model_wrap is None:
|
|
|
|
self.model_wrap = CFGDenoiser(self.model)
|
|
|
|
extra_args = {
|
|
|
|
'cond': cond,
|
|
|
|
'uncond': unconditional_conditioning,
|
|
|
|
'cond_scale': unconditional_guidance_scale,
|
|
|
|
}
|
|
|
|
if self.s_in is None:
|
|
|
|
self.s_in = img.new_ones([img.shape[0]])
|
|
|
|
if self.ds is None:
|
|
|
|
self.ds = []
|
|
|
|
|
|
|
|
# terrible, confusing names here
|
|
|
|
steps = self.ddim_num_steps
|
|
|
|
t_enc = self.t_enc
|
2022-12-10 14:57:41 +00:00
|
|
|
|
2022-09-25 08:03:28 +00:00
|
|
|
# sigmas is a full steps in length, but t_enc might
|
|
|
|
# be less. We start in the middle of the sigma array
|
|
|
|
# and work our way to the end after t_enc steps.
|
|
|
|
# index starts at t_enc and works its way to zero,
|
|
|
|
# so the actual formula for indexing into sigmas:
|
|
|
|
# sigma_index = (steps-index)
|
|
|
|
s_index = t_enc - index - 1
|
2022-10-20 10:01:48 +00:00
|
|
|
self.model_wrap.prepare_to_sample(s_index, extra_conditioning_info=extra_conditioning_info)
|
2022-09-25 08:03:28 +00:00
|
|
|
img = K.sampling.__dict__[f'_{self.schedule}'](
|
|
|
|
self.model_wrap,
|
|
|
|
img,
|
|
|
|
self.sigmas,
|
|
|
|
s_index,
|
|
|
|
s_in = self.s_in,
|
|
|
|
ds = self.ds,
|
|
|
|
extra_args=extra_args,
|
|
|
|
)
|
|
|
|
|
|
|
|
return img, None, None
|
|
|
|
|
2022-10-06 18:57:06 +00:00
|
|
|
# REVIEW THIS METHOD: it has never been tested. In particular,
|
|
|
|
# we should not be multiplying by self.sigmas[0] if we
|
|
|
|
# are at an intermediate step in img2img. See similar in
|
|
|
|
# sample() which does work.
|
2022-09-25 08:03:28 +00:00
|
|
|
def get_initial_image(self,x_T,shape,steps):
|
2022-10-06 22:31:04 +00:00
|
|
|
print(f'WARNING: ksampler.get_initial_image(): get_initial_image needs testing')
|
2022-10-06 18:57:06 +00:00
|
|
|
x = (torch.randn(shape, device=self.device) * self.sigmas[0])
|
2022-09-25 08:03:28 +00:00
|
|
|
if x_T is not None:
|
2022-10-06 18:57:06 +00:00
|
|
|
return x_T + x
|
2022-09-25 08:03:28 +00:00
|
|
|
else:
|
2022-10-06 18:57:06 +00:00
|
|
|
return x
|
2022-12-10 14:57:41 +00:00
|
|
|
|
2022-10-18 20:09:06 +00:00
|
|
|
def prepare_to_sample(self,t_enc,**kwargs):
|
2022-09-25 08:03:28 +00:00
|
|
|
self.t_enc = t_enc
|
|
|
|
self.model_wrap = None
|
|
|
|
self.ds = None
|
|
|
|
self.s_in = None
|
|
|
|
|
|
|
|
def q_sample(self,x0,ts):
|
|
|
|
'''
|
|
|
|
Overrides parent method to return the q_sample of the inner model.
|
|
|
|
'''
|
|
|
|
return self.model.inner_model.q_sample(x0,ts)
|
2022-10-19 16:19:55 +00:00
|
|
|
|
2022-10-25 04:30:48 +00:00
|
|
|
def conditioning_key(self)->str:
|
|
|
|
return self.model.inner_model.model.conditioning_key
|
2022-10-25 14:45:12 +00:00
|
|
|
|