2022-08-26 07:15:42 +00:00
|
|
|
"""wrapper around part of Katherine Crowson's k-diffusion library, making it call compatible with other Samplers"""
|
2022-08-21 21:09:00 +00:00
|
|
|
import k_diffusion as K
|
2022-08-21 23:57:48 +00:00
|
|
|
import torch
|
2022-08-21 21:09:00 +00:00
|
|
|
import torch.nn as nn
|
2022-08-31 04:33:23 +00:00
|
|
|
from ldm.dream.devices import choose_torch_device
|
2022-08-26 07:15:42 +00:00
|
|
|
|
2022-08-21 21:09:00 +00:00
|
|
|
class CFGDenoiser(nn.Module):
|
|
|
|
def __init__(self, model):
|
|
|
|
super().__init__()
|
|
|
|
self.inner_model = model
|
|
|
|
|
2022-08-21 23:57:48 +00:00
|
|
|
def forward(self, x, sigma, uncond, cond, cond_scale):
|
|
|
|
x_in = torch.cat([x] * 2)
|
|
|
|
sigma_in = torch.cat([sigma] * 2)
|
|
|
|
cond_in = torch.cat([uncond, cond])
|
|
|
|
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
|
|
|
|
return uncond + (cond - uncond) * cond_scale
|
2022-08-21 21:09:00 +00:00
|
|
|
|
2022-08-26 07:15:42 +00:00
|
|
|
|
2022-08-21 21:09:00 +00:00
|
|
|
class KSampler(object):
|
2022-08-31 04:33:23 +00:00
|
|
|
def __init__(self, model, schedule='lms', device=None, **kwargs):
|
2022-08-21 21:09:00 +00:00
|
|
|
super().__init__()
|
2022-08-25 17:04:57 +00:00
|
|
|
self.model = K.external.CompVisDenoiser(model)
|
2022-08-21 21:09:00 +00:00
|
|
|
self.schedule = schedule
|
2022-08-31 04:33:23 +00:00
|
|
|
self.device = device or choose_torch_device()
|
2022-08-21 21:09:00 +00:00
|
|
|
|
2022-08-21 23:57:48 +00:00
|
|
|
def forward(self, x, sigma, uncond, cond, cond_scale):
|
|
|
|
x_in = torch.cat([x] * 2)
|
|
|
|
sigma_in = torch.cat([sigma] * 2)
|
|
|
|
cond_in = torch.cat([uncond, cond])
|
2022-08-26 07:15:42 +00:00
|
|
|
uncond, cond = self.inner_model(
|
|
|
|
x_in, sigma_in, cond=cond_in
|
|
|
|
).chunk(2)
|
2022-08-21 23:57:48 +00:00
|
|
|
return uncond + (cond - uncond) * cond_scale
|
|
|
|
|
2022-08-21 21:09:00 +00:00
|
|
|
# most of these arguments are ignored and are only present for compatibility with
|
|
|
|
# other samples
|
|
|
|
@torch.no_grad()
|
2022-08-26 07:15:42 +00:00
|
|
|
def sample(
|
|
|
|
self,
|
|
|
|
S,
|
|
|
|
batch_size,
|
|
|
|
shape,
|
|
|
|
conditioning=None,
|
|
|
|
callback=None,
|
|
|
|
normals_sequence=None,
|
|
|
|
img_callback=None,
|
|
|
|
quantize_x0=False,
|
|
|
|
eta=0.0,
|
|
|
|
mask=None,
|
|
|
|
x0=None,
|
|
|
|
temperature=1.0,
|
|
|
|
noise_dropout=0.0,
|
|
|
|
score_corrector=None,
|
|
|
|
corrector_kwargs=None,
|
|
|
|
verbose=True,
|
|
|
|
x_T=None,
|
|
|
|
log_every_t=100,
|
|
|
|
unconditional_guidance_scale=1.0,
|
|
|
|
unconditional_conditioning=None,
|
|
|
|
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
|
|
|
**kwargs,
|
|
|
|
):
|
2022-08-27 01:10:13 +00:00
|
|
|
def route_callback(k_callback_values):
|
|
|
|
if img_callback is not None:
|
|
|
|
img_callback(k_callback_values['x'], k_callback_values['i'])
|
2022-08-21 21:09:00 +00:00
|
|
|
|
|
|
|
sigmas = self.model.get_sigmas(S)
|
2022-09-01 02:31:52 +00:00
|
|
|
if x_T is not None:
|
|
|
|
x = x_T * sigmas[0]
|
2022-08-21 21:09:00 +00:00
|
|
|
else:
|
2022-08-26 07:15:42 +00:00
|
|
|
x = (
|
|
|
|
torch.randn([batch_size, *shape], device=self.device)
|
|
|
|
* sigmas[0]
|
|
|
|
) # for GPU draw
|
2022-08-21 21:09:00 +00:00
|
|
|
model_wrap_cfg = CFGDenoiser(self.model)
|
2022-08-26 07:15:42 +00:00
|
|
|
extra_args = {
|
|
|
|
'cond': conditioning,
|
|
|
|
'uncond': unconditional_conditioning,
|
|
|
|
'cond_scale': unconditional_guidance_scale,
|
|
|
|
}
|
|
|
|
return (
|
|
|
|
K.sampling.__dict__[f'sample_{self.schedule}'](
|
2022-08-27 01:10:13 +00:00
|
|
|
model_wrap_cfg, x, sigmas, extra_args=extra_args,
|
|
|
|
callback=route_callback
|
2022-08-26 07:15:42 +00:00
|
|
|
),
|
|
|
|
None,
|
|
|
|
)
|