mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
* Updates for thresholding and perlin noise options, added warmup for thresholding.
This commit is contained in:
@ -3,9 +3,9 @@ import k_diffusion as K
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from ldm.dream.devices import choose_torch_device
|
||||
from ldm.modules.diffusionmodules.util import rand_perlin_2d
|
||||
from ldm.util import rand_perlin_2d
|
||||
|
||||
def cfg_apply_threshold(result, threshold = 0.0, scale = 0.707):
|
||||
def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
|
||||
if threshold <= 0.0:
|
||||
return result
|
||||
maxval = 0.0 + torch.max(result).cpu().numpy()
|
||||
@ -20,17 +20,26 @@ def cfg_apply_threshold(result, threshold = 0.0, scale = 0.707):
|
||||
|
||||
|
||||
class CFGDenoiser(nn.Module):
|
||||
def __init__(self, model, threshold = 0):
|
||||
def __init__(self, model, threshold = 0, warmup = 0):
|
||||
super().__init__()
|
||||
self.inner_model = model
|
||||
self.threshold = threshold
|
||||
self.warmup_max = warmup
|
||||
self.warmup = 0
|
||||
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||
x_in = torch.cat([x] * 2)
|
||||
sigma_in = torch.cat([sigma] * 2)
|
||||
cond_in = torch.cat([uncond, cond])
|
||||
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
|
||||
return cfg_apply_threshold(uncond + (cond - uncond) * cond_scale, self.threshold)
|
||||
if self.warmup < self.warmup_max:
|
||||
thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max))
|
||||
self.warmup += 1
|
||||
else:
|
||||
thresh = self.threshold
|
||||
if thresh > self.threshold:
|
||||
thresh = self.threshold
|
||||
return cfg_apply_threshold(uncond + (cond - uncond) * cond_scale, thresh)
|
||||
|
||||
|
||||
class KSampler(object):
|
||||
@ -39,7 +48,6 @@ class KSampler(object):
|
||||
self.model = K.external.CompVisDenoiser(model)
|
||||
self.schedule = schedule
|
||||
self.device = device or choose_torch_device()
|
||||
#self.threshold = threshold or 0
|
||||
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||
x_in = torch.cat([x] * 2)
|
||||
@ -49,7 +57,6 @@ class KSampler(object):
|
||||
x_in, sigma_in, cond=cond_in
|
||||
).chunk(2)
|
||||
return uncond + (cond - uncond) * cond_scale
|
||||
#return cfg_apply_threshold(uncond + (cond - uncond) * cond_scale, self.threshold)
|
||||
|
||||
|
||||
|
||||
@ -95,12 +102,7 @@ class KSampler(object):
|
||||
torch.randn([batch_size, *shape], device=self.device)
|
||||
* sigmas[0]
|
||||
) # for GPU draw
|
||||
|
||||
if perlin > 0.0:
|
||||
print(shape)
|
||||
x = (1 - perlin / 2) * x + perlin * rand_perlin_2d((shape[1], shape[2]), (8, 8)).to(self.device)
|
||||
|
||||
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold)
|
||||
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10))
|
||||
extra_args = {
|
||||
'cond': conditioning,
|
||||
'uncond': unconditional_conditioning,
|
||||
|
Reference in New Issue
Block a user