mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
* Add threshold and perlin noise options for Karras samplers.
This commit is contained in:
parent
3ee82d8a3b
commit
b6cf8b9052
@ -3,18 +3,34 @@ import k_diffusion as K
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from ldm.dream.devices import choose_torch_device
|
from ldm.dream.devices import choose_torch_device
|
||||||
|
from ldm.modules.diffusionmodules.util import rand_perlin_2d
|
||||||
|
|
||||||
|
def cfg_apply_threshold(result, threshold = 0.0, scale = 0.707):
|
||||||
|
if threshold <= 0.0:
|
||||||
|
return result
|
||||||
|
maxval = 0.0 + torch.max(result).cpu().numpy()
|
||||||
|
minval = 0.0 + torch.min(result).cpu().numpy()
|
||||||
|
if maxval < threshold and minval > -threshold:
|
||||||
|
return result
|
||||||
|
if maxval > threshold:
|
||||||
|
maxval = min(max(1, scale*maxval), threshold)
|
||||||
|
if minval < -threshold:
|
||||||
|
minval = max(min(-1, scale*minval), -threshold)
|
||||||
|
return torch.clamp(result, min=minval, max=maxval)
|
||||||
|
|
||||||
|
|
||||||
class CFGDenoiser(nn.Module):
|
class CFGDenoiser(nn.Module):
|
||||||
def __init__(self, model):
|
def __init__(self, model, threshold = 0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.inner_model = model
|
self.inner_model = model
|
||||||
|
self.threshold = threshold
|
||||||
|
|
||||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||||
x_in = torch.cat([x] * 2)
|
x_in = torch.cat([x] * 2)
|
||||||
sigma_in = torch.cat([sigma] * 2)
|
sigma_in = torch.cat([sigma] * 2)
|
||||||
cond_in = torch.cat([uncond, cond])
|
cond_in = torch.cat([uncond, cond])
|
||||||
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
|
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
|
||||||
return uncond + (cond - uncond) * cond_scale
|
return cfg_apply_threshold(uncond + (cond - uncond) * cond_scale, self.threshold)
|
||||||
|
|
||||||
|
|
||||||
class KSampler(object):
|
class KSampler(object):
|
||||||
@ -23,6 +39,7 @@ class KSampler(object):
|
|||||||
self.model = K.external.CompVisDenoiser(model)
|
self.model = K.external.CompVisDenoiser(model)
|
||||||
self.schedule = schedule
|
self.schedule = schedule
|
||||||
self.device = device or choose_torch_device()
|
self.device = device or choose_torch_device()
|
||||||
|
#self.threshold = threshold or 0
|
||||||
|
|
||||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||||
x_in = torch.cat([x] * 2)
|
x_in = torch.cat([x] * 2)
|
||||||
@ -32,6 +49,9 @@ class KSampler(object):
|
|||||||
x_in, sigma_in, cond=cond_in
|
x_in, sigma_in, cond=cond_in
|
||||||
).chunk(2)
|
).chunk(2)
|
||||||
return uncond + (cond - uncond) * cond_scale
|
return uncond + (cond - uncond) * cond_scale
|
||||||
|
#return cfg_apply_threshold(uncond + (cond - uncond) * cond_scale, self.threshold)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# most of these arguments are ignored and are only present for compatibility with
|
# most of these arguments are ignored and are only present for compatibility with
|
||||||
# other samples
|
# other samples
|
||||||
@ -58,6 +78,8 @@ class KSampler(object):
|
|||||||
log_every_t=100,
|
log_every_t=100,
|
||||||
unconditional_guidance_scale=1.0,
|
unconditional_guidance_scale=1.0,
|
||||||
unconditional_conditioning=None,
|
unconditional_conditioning=None,
|
||||||
|
threshold = 0,
|
||||||
|
perlin = 0,
|
||||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
@ -73,7 +95,12 @@ class KSampler(object):
|
|||||||
torch.randn([batch_size, *shape], device=self.device)
|
torch.randn([batch_size, *shape], device=self.device)
|
||||||
* sigmas[0]
|
* sigmas[0]
|
||||||
) # for GPU draw
|
) # for GPU draw
|
||||||
model_wrap_cfg = CFGDenoiser(self.model)
|
|
||||||
|
if perlin > 0.0:
|
||||||
|
print(shape)
|
||||||
|
x = (1 - perlin / 2) * x + perlin * rand_perlin_2d((shape[1], shape[2]), (8, 8)).to(self.device)
|
||||||
|
|
||||||
|
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold)
|
||||||
extra_args = {
|
extra_args = {
|
||||||
'cond': conditioning,
|
'cond': conditioning,
|
||||||
'uncond': unconditional_conditioning,
|
'uncond': unconditional_conditioning,
|
||||||
|
@ -18,6 +18,25 @@ from einops import repeat
|
|||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
|
|
||||||
|
|
||||||
|
def rand_perlin_2d(shape, res, fade = lambda t: 6*t**5 - 15*t**4 + 10*t**3):
|
||||||
|
delta = (res[0] / shape[0], res[1] / shape[1])
|
||||||
|
d = (shape[0] // res[0], shape[1] // res[1])
|
||||||
|
|
||||||
|
grid = torch.stack(torch.meshgrid(torch.arange(0, res[0], delta[0]), torch.arange(0, res[1], delta[1])), dim = -1) % 1
|
||||||
|
angles = 2*math.pi*torch.rand(res[0]+1, res[1]+1)
|
||||||
|
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim = -1)
|
||||||
|
|
||||||
|
tile_grads = lambda slice1, slice2: gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]].repeat_interleave(d[0], 0).repeat_interleave(d[1], 1)
|
||||||
|
dot = lambda grad, shift: (torch.stack((grid[:shape[0],:shape[1],0] + shift[0], grid[:shape[0],:shape[1], 1] + shift[1] ), dim = -1) * grad[:shape[0], :shape[1]]).sum(dim = -1)
|
||||||
|
|
||||||
|
n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
|
||||||
|
n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
|
||||||
|
n01 = dot(tile_grads([0, -1],[1, None]), [0, -1])
|
||||||
|
n11 = dot(tile_grads([1, None], [1, None]), [-1,-1])
|
||||||
|
t = fade(grid[:shape[0], :shape[1]])
|
||||||
|
return math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1])
|
||||||
|
|
||||||
|
|
||||||
def make_beta_schedule(
|
def make_beta_schedule(
|
||||||
schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
|
schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
|
||||||
):
|
):
|
||||||
|
@ -226,6 +226,8 @@ class T2I:
|
|||||||
upscale = None,
|
upscale = None,
|
||||||
sampler_name = None,
|
sampler_name = None,
|
||||||
log_tokenization= False,
|
log_tokenization= False,
|
||||||
|
threshold = 0,
|
||||||
|
perlin = 0,
|
||||||
**args,
|
**args,
|
||||||
): # eat up additional cruft
|
): # eat up additional cruft
|
||||||
"""
|
"""
|
||||||
@ -319,6 +321,8 @@ class T2I:
|
|||||||
width=width,
|
width=width,
|
||||||
height=height,
|
height=height,
|
||||||
callback=step_callback,
|
callback=step_callback,
|
||||||
|
threshold=threshold,
|
||||||
|
perlin=perlin,
|
||||||
)
|
)
|
||||||
|
|
||||||
device_type = choose_autocast_device(self.device)
|
device_type = choose_autocast_device(self.device)
|
||||||
@ -407,6 +411,8 @@ class T2I:
|
|||||||
width,
|
width,
|
||||||
height,
|
height,
|
||||||
callback,
|
callback,
|
||||||
|
threshold,
|
||||||
|
perlin,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
An infinite iterator of images from the prompt.
|
An infinite iterator of images from the prompt.
|
||||||
@ -430,7 +436,9 @@ class T2I:
|
|||||||
unconditional_guidance_scale=cfg_scale,
|
unconditional_guidance_scale=cfg_scale,
|
||||||
unconditional_conditioning=uc,
|
unconditional_conditioning=uc,
|
||||||
eta=ddim_eta,
|
eta=ddim_eta,
|
||||||
img_callback=callback
|
img_callback=callback,
|
||||||
|
threshold=threshold,
|
||||||
|
perlin=perlin,
|
||||||
)
|
)
|
||||||
yield self._sample_to_image(samples)
|
yield self._sample_to_image(samples)
|
||||||
|
|
||||||
|
@ -9,6 +9,7 @@ import sys
|
|||||||
import copy
|
import copy
|
||||||
import warnings
|
import warnings
|
||||||
import time
|
import time
|
||||||
|
sys.path.insert(0, '.')
|
||||||
from ldm.dream.devices import choose_torch_device
|
from ldm.dream.devices import choose_torch_device
|
||||||
import ldm.dream.readline
|
import ldm.dream.readline
|
||||||
from ldm.dream.pngwriter import PngWriter, PromptFormatter
|
from ldm.dream.pngwriter import PngWriter, PromptFormatter
|
||||||
@ -546,6 +547,18 @@ def create_cmd_parser():
|
|||||||
action='store_true',
|
action='store_true',
|
||||||
help='shows how the prompt is split into tokens'
|
help='shows how the prompt is split into tokens'
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--threshold',
|
||||||
|
default=0.0,
|
||||||
|
type=float,
|
||||||
|
help='Add threshold value aka perform clipping.',
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--perlin',
|
||||||
|
default=0.0,
|
||||||
|
type=float,
|
||||||
|
help='Add perlin noise.',
|
||||||
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user