* Add threshold and perlin noise options for Karras samplers.

This commit is contained in:
Peter Baylies 2022-09-02 13:39:26 -04:00
parent 3ee82d8a3b
commit b6cf8b9052
4 changed files with 71 additions and 4 deletions

View File

@ -3,18 +3,34 @@ import k_diffusion as K
import torch
import torch.nn as nn
from ldm.dream.devices import choose_torch_device
from ldm.modules.diffusionmodules.util import rand_perlin_2d
def cfg_apply_threshold(result, threshold = 0.0, scale = 0.707):
if threshold <= 0.0:
return result
maxval = 0.0 + torch.max(result).cpu().numpy()
minval = 0.0 + torch.min(result).cpu().numpy()
if maxval < threshold and minval > -threshold:
return result
if maxval > threshold:
maxval = min(max(1, scale*maxval), threshold)
if minval < -threshold:
minval = max(min(-1, scale*minval), -threshold)
return torch.clamp(result, min=minval, max=maxval)
class CFGDenoiser(nn.Module):
def __init__(self, model):
def __init__(self, model, threshold = 0):
super().__init__()
self.inner_model = model
self.threshold = threshold
def forward(self, x, sigma, uncond, cond, cond_scale):
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
return uncond + (cond - uncond) * cond_scale
return cfg_apply_threshold(uncond + (cond - uncond) * cond_scale, self.threshold)
class KSampler(object):
@ -23,6 +39,7 @@ class KSampler(object):
self.model = K.external.CompVisDenoiser(model)
self.schedule = schedule
self.device = device or choose_torch_device()
#self.threshold = threshold or 0
def forward(self, x, sigma, uncond, cond, cond_scale):
x_in = torch.cat([x] * 2)
@ -32,6 +49,9 @@ class KSampler(object):
x_in, sigma_in, cond=cond_in
).chunk(2)
return uncond + (cond - uncond) * cond_scale
#return cfg_apply_threshold(uncond + (cond - uncond) * cond_scale, self.threshold)
# most of these arguments are ignored and are only present for compatibility with
# other samples
@ -58,6 +78,8 @@ class KSampler(object):
log_every_t=100,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
threshold = 0,
perlin = 0,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs,
):
@ -73,7 +95,12 @@ class KSampler(object):
torch.randn([batch_size, *shape], device=self.device)
* sigmas[0]
) # for GPU draw
model_wrap_cfg = CFGDenoiser(self.model)
if perlin > 0.0:
print(shape)
x = (1 - perlin / 2) * x + perlin * rand_perlin_2d((shape[1], shape[2]), (8, 8)).to(self.device)
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold)
extra_args = {
'cond': conditioning,
'uncond': unconditional_conditioning,

View File

@ -18,6 +18,25 @@ from einops import repeat
from ldm.util import instantiate_from_config
def rand_perlin_2d(shape, res, fade = lambda t: 6*t**5 - 15*t**4 + 10*t**3):
delta = (res[0] / shape[0], res[1] / shape[1])
d = (shape[0] // res[0], shape[1] // res[1])
grid = torch.stack(torch.meshgrid(torch.arange(0, res[0], delta[0]), torch.arange(0, res[1], delta[1])), dim = -1) % 1
angles = 2*math.pi*torch.rand(res[0]+1, res[1]+1)
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim = -1)
tile_grads = lambda slice1, slice2: gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]].repeat_interleave(d[0], 0).repeat_interleave(d[1], 1)
dot = lambda grad, shift: (torch.stack((grid[:shape[0],:shape[1],0] + shift[0], grid[:shape[0],:shape[1], 1] + shift[1] ), dim = -1) * grad[:shape[0], :shape[1]]).sum(dim = -1)
n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
n01 = dot(tile_grads([0, -1],[1, None]), [0, -1])
n11 = dot(tile_grads([1, None], [1, None]), [-1,-1])
t = fade(grid[:shape[0], :shape[1]])
return math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1])
def make_beta_schedule(
schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
):

View File

@ -226,6 +226,8 @@ class T2I:
upscale = None,
sampler_name = None,
log_tokenization= False,
threshold = 0,
perlin = 0,
**args,
): # eat up additional cruft
"""
@ -319,6 +321,8 @@ class T2I:
width=width,
height=height,
callback=step_callback,
threshold=threshold,
perlin=perlin,
)
device_type = choose_autocast_device(self.device)
@ -407,6 +411,8 @@ class T2I:
width,
height,
callback,
threshold,
perlin,
):
"""
An infinite iterator of images from the prompt.
@ -430,7 +436,9 @@ class T2I:
unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc,
eta=ddim_eta,
img_callback=callback
img_callback=callback,
threshold=threshold,
perlin=perlin,
)
yield self._sample_to_image(samples)

View File

@ -9,6 +9,7 @@ import sys
import copy
import warnings
import time
sys.path.insert(0, '.')
from ldm.dream.devices import choose_torch_device
import ldm.dream.readline
from ldm.dream.pngwriter import PngWriter, PromptFormatter
@ -546,6 +547,18 @@ def create_cmd_parser():
action='store_true',
help='shows how the prompt is split into tokens'
)
parser.add_argument(
'--threshold',
default=0.0,
type=float,
help='Add threshold value aka perform clipping.',
)
parser.add_argument(
'--perlin',
default=0.0,
type=float,
help='Add perlin noise.',
)
return parser