* Updates for thresholding and perlin noise options, added warmup for thresholding.

This commit is contained in:
Peter Baylies 2022-09-05 21:40:05 -04:00
parent 1a4bed2e55
commit 0891910cac
6 changed files with 55 additions and 37 deletions

View File

@ -79,6 +79,8 @@ class DreamServer(BaseHTTPRequestHandler):
upscale = [int(upscale_level),float(upscale_strength)] if upscale_level != '' else None upscale = [int(upscale_level),float(upscale_strength)] if upscale_level != '' else None
progress_images = 'progress_images' in post_data progress_images = 'progress_images' in post_data
seed = self.model.seed if int(post_data['seed']) == -1 else int(post_data['seed']) seed = self.model.seed if int(post_data['seed']) == -1 else int(post_data['seed'])
threshold = float(post_data['threshold'])
perlin = float(post_data['perlin'])
self.canceled.clear() self.canceled.clear()
print(f">> Request to generate with prompt: {prompt}") print(f">> Request to generate with prompt: {prompt}")
@ -165,7 +167,9 @@ class DreamServer(BaseHTTPRequestHandler):
upscale = upscale, upscale = upscale,
sampler_name = sampler_name, sampler_name = sampler_name,
step_callback=image_progress, step_callback=image_progress,
image_callback=image_done) image_callback=image_done,
threshold=threshold,
perlin=perlin)
else: else:
# Decode initimg as base64 to temp file # Decode initimg as base64 to temp file
with open("./img2img-tmp.png", "wb") as f: with open("./img2img-tmp.png", "wb") as f:
@ -188,7 +192,9 @@ class DreamServer(BaseHTTPRequestHandler):
gfpgan_strength=gfpgan_strength, gfpgan_strength=gfpgan_strength,
upscale = upscale, upscale = upscale,
step_callback=image_progress, step_callback=image_progress,
image_callback=image_done) image_callback=image_done,
threshold=threshold,
perlin=perlin)
finally: finally:
# Remove the temp file # Remove the temp file
os.remove("./img2img-tmp.png") os.remove("./img2img-tmp.png")

View File

@ -3,9 +3,9 @@ import k_diffusion as K
import torch import torch
import torch.nn as nn import torch.nn as nn
from ldm.dream.devices import choose_torch_device from ldm.dream.devices import choose_torch_device
from ldm.modules.diffusionmodules.util import rand_perlin_2d from ldm.util import rand_perlin_2d
def cfg_apply_threshold(result, threshold = 0.0, scale = 0.707): def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
if threshold <= 0.0: if threshold <= 0.0:
return result return result
maxval = 0.0 + torch.max(result).cpu().numpy() maxval = 0.0 + torch.max(result).cpu().numpy()
@ -20,17 +20,26 @@ def cfg_apply_threshold(result, threshold = 0.0, scale = 0.707):
class CFGDenoiser(nn.Module): class CFGDenoiser(nn.Module):
def __init__(self, model, threshold = 0): def __init__(self, model, threshold = 0, warmup = 0):
super().__init__() super().__init__()
self.inner_model = model self.inner_model = model
self.threshold = threshold self.threshold = threshold
self.warmup_max = warmup
self.warmup = 0
def forward(self, x, sigma, uncond, cond, cond_scale): def forward(self, x, sigma, uncond, cond, cond_scale):
x_in = torch.cat([x] * 2) x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2) sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond]) cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
return cfg_apply_threshold(uncond + (cond - uncond) * cond_scale, self.threshold) if self.warmup < self.warmup_max:
thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max))
self.warmup += 1
else:
thresh = self.threshold
if thresh > self.threshold:
thresh = self.threshold
return cfg_apply_threshold(uncond + (cond - uncond) * cond_scale, thresh)
class KSampler(object): class KSampler(object):
@ -39,7 +48,6 @@ class KSampler(object):
self.model = K.external.CompVisDenoiser(model) self.model = K.external.CompVisDenoiser(model)
self.schedule = schedule self.schedule = schedule
self.device = device or choose_torch_device() self.device = device or choose_torch_device()
#self.threshold = threshold or 0
def forward(self, x, sigma, uncond, cond, cond_scale): def forward(self, x, sigma, uncond, cond, cond_scale):
x_in = torch.cat([x] * 2) x_in = torch.cat([x] * 2)
@ -49,7 +57,6 @@ class KSampler(object):
x_in, sigma_in, cond=cond_in x_in, sigma_in, cond=cond_in
).chunk(2) ).chunk(2)
return uncond + (cond - uncond) * cond_scale return uncond + (cond - uncond) * cond_scale
#return cfg_apply_threshold(uncond + (cond - uncond) * cond_scale, self.threshold)
@ -95,12 +102,7 @@ class KSampler(object):
torch.randn([batch_size, *shape], device=self.device) torch.randn([batch_size, *shape], device=self.device)
* sigmas[0] * sigmas[0]
) # for GPU draw ) # for GPU draw
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10))
if perlin > 0.0:
print(shape)
x = (1 - perlin / 2) * x + perlin * rand_perlin_2d((shape[1], shape[2]), (8, 8)).to(self.device)
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold)
extra_args = { extra_args = {
'cond': conditioning, 'cond': conditioning,
'uncond': unconditional_conditioning, 'uncond': unconditional_conditioning,

View File

@ -18,25 +18,6 @@ from einops import repeat
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
def rand_perlin_2d(shape, res, fade = lambda t: 6*t**5 - 15*t**4 + 10*t**3):
delta = (res[0] / shape[0], res[1] / shape[1])
d = (shape[0] // res[0], shape[1] // res[1])
grid = torch.stack(torch.meshgrid(torch.arange(0, res[0], delta[0]), torch.arange(0, res[1], delta[1])), dim = -1) % 1
angles = 2*math.pi*torch.rand(res[0]+1, res[1]+1)
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim = -1)
tile_grads = lambda slice1, slice2: gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]].repeat_interleave(d[0], 0).repeat_interleave(d[1], 1)
dot = lambda grad, shift: (torch.stack((grid[:shape[0],:shape[1],0] + shift[0], grid[:shape[0],:shape[1], 1] + shift[1] ), dim = -1) * grad[:shape[0], :shape[1]]).sum(dim = -1)
n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
n01 = dot(tile_grads([0, -1],[1, None]), [0, -1])
n11 = dot(tile_grads([1, None], [1, None]), [-1,-1])
t = fade(grid[:shape[0], :shape[1]])
return math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1])
def make_beta_schedule( def make_beta_schedule(
schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3 schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
): ):

View File

@ -9,6 +9,7 @@ import numpy as np
import random import random
import os import os
import traceback import traceback
from ldm.modules.diffusionmodules.util import noise_like
from omegaconf import OmegaConf from omegaconf import OmegaConf
from PIL import Image from PIL import Image
from tqdm import tqdm, trange from tqdm import tqdm, trange
@ -23,7 +24,7 @@ import time
import re import re
import sys import sys
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config, rand_perlin_2d
from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler from ldm.models.diffusion.plms import PLMSSampler
from ldm.models.diffusion.ksampler import KSampler from ldm.models.diffusion.ksampler import KSampler
@ -352,13 +353,18 @@ class T2I:
def get_noise(): def get_noise():
if init_img: if init_img:
return torch.randn_like(init_latent, device=self.device) x = torch.randn_like(init_latent, device=self.device)
else: else:
return torch.randn([1, x = torch.randn([1,
self.latent_channels, self.latent_channels,
height // self.downsampling_factor, height // self.downsampling_factor,
width // self.downsampling_factor], width // self.downsampling_factor],
device=self.device) device=self.device)
if perlin > 0.0:
shape = x.shape
perlin_noise = torch.stack([rand_perlin_2d((shape[2], shape[3]), (8, 8)).to(self.device) for _ in range(shape[1])], dim=0)
x = (1 - perlin) * x + perlin * perlin_noise
return x
initial_noise = None initial_noise = None
if variation_amount > 0 or len(with_variations) > 0: if variation_amount > 0 or len(with_variations) > 0:
@ -387,7 +393,7 @@ class T2I:
x_T = initial_noise x_T = initial_noise
else: else:
seed_everything(seed) seed_everything(seed)
# make_image will do the equivalent of get_noise itself x_T = get_noise()
image = make_image(x_T) image = make_image(x_T)
results.append([image, seed]) results.append([image, seed])
if image_callback is not None: if image_callback is not None:

View File

@ -2,6 +2,7 @@ import importlib
import torch import torch
import numpy as np import numpy as np
import math
from collections import abc from collections import abc
from einops import rearrange from einops import rearrange
from functools import partial from functools import partial
@ -212,3 +213,21 @@ def parallel_data_prefetch(
return out return out
else: else:
return gather_res return gather_res
def rand_perlin_2d(shape, res, fade = lambda t: 6*t**5 - 15*t**4 + 10*t**3):
delta = (res[0] / shape[0], res[1] / shape[1])
d = (shape[0] // res[0], shape[1] // res[1])
grid = torch.stack(torch.meshgrid(torch.arange(0, res[0], delta[0]), torch.arange(0, res[1], delta[1])), dim = -1) % 1
angles = 2*math.pi*torch.rand(res[0]+1, res[1]+1)
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim = -1)
tile_grads = lambda slice1, slice2: gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]].repeat_interleave(d[0], 0).repeat_interleave(d[1], 1)
dot = lambda grad, shift: (torch.stack((grid[:shape[0],:shape[1],0] + shift[0], grid[:shape[0],:shape[1], 1] + shift[1] ), dim = -1) * grad[:shape[0], :shape[1]]).sum(dim = -1)
n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
n01 = dot(tile_grads([0, -1],[1, None]), [0, -1])
n11 = dot(tile_grads([1, None], [1, None]), [-1,-1])
t = fade(grid[:shape[0], :shape[1]])
return math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1])

View File

@ -63,6 +63,10 @@
<label title="Set to -1 for random seed" for="seed">Seed:</label> <label title="Set to -1 for random seed" for="seed">Seed:</label>
<input value="-1" type="number" id="seed" name="seed"> <input value="-1" type="number" id="seed" name="seed">
<button type="button" id="reset-seed">&olarr;</button> <button type="button" id="reset-seed">&olarr;</button>
<label title="Threshold" for="threshold">Threshold:</label>
<input value="0" type="number" id="threshold" name="threshold" step="any">
<label title="Perlin" for="perlin">Perlin:</label>
<input value="0" type="number" id="perlin" name="perlin" step="any">
<input type="checkbox" name="progress_images" id="progress_images"> <input type="checkbox" name="progress_images" id="progress_images">
<label for="progress_images">Display in-progress images (slows down generation):</label> <label for="progress_images">Display in-progress images (slows down generation):</label>
<button type="button" id="reset-all">Reset to Defaults</button> <button type="button" id="reset-all">Reset to Defaults</button>