From 0891910cac7f646fc089ccc138fd1f8d9c4b6483 Mon Sep 17 00:00:00 2001 From: Peter Baylies Date: Mon, 5 Sep 2022 21:40:05 -0400 Subject: [PATCH] * Updates for thresholding and perlin noise options, added warmup for thresholding. --- ldm/dream/server.py | 10 ++++++++-- ldm/models/diffusion/ksampler.py | 26 ++++++++++++++------------ ldm/modules/diffusionmodules/util.py | 19 ------------------- ldm/simplet2i.py | 14 ++++++++++---- ldm/util.py | 19 +++++++++++++++++++ static/dream_web/index.html | 4 ++++ 6 files changed, 55 insertions(+), 37 deletions(-) diff --git a/ldm/dream/server.py b/ldm/dream/server.py index f592457e4c..c28c38b7dc 100644 --- a/ldm/dream/server.py +++ b/ldm/dream/server.py @@ -79,6 +79,8 @@ class DreamServer(BaseHTTPRequestHandler): upscale = [int(upscale_level),float(upscale_strength)] if upscale_level != '' else None progress_images = 'progress_images' in post_data seed = self.model.seed if int(post_data['seed']) == -1 else int(post_data['seed']) + threshold = float(post_data['threshold']) + perlin = float(post_data['perlin']) self.canceled.clear() print(f">> Request to generate with prompt: {prompt}") @@ -165,7 +167,9 @@ class DreamServer(BaseHTTPRequestHandler): upscale = upscale, sampler_name = sampler_name, step_callback=image_progress, - image_callback=image_done) + image_callback=image_done, + threshold=threshold, + perlin=perlin) else: # Decode initimg as base64 to temp file with open("./img2img-tmp.png", "wb") as f: @@ -188,7 +192,9 @@ class DreamServer(BaseHTTPRequestHandler): gfpgan_strength=gfpgan_strength, upscale = upscale, step_callback=image_progress, - image_callback=image_done) + image_callback=image_done, + threshold=threshold, + perlin=perlin) finally: # Remove the temp file os.remove("./img2img-tmp.png") diff --git a/ldm/models/diffusion/ksampler.py b/ldm/models/diffusion/ksampler.py index d8f255f5a6..45e1c47ba3 100644 --- a/ldm/models/diffusion/ksampler.py +++ b/ldm/models/diffusion/ksampler.py @@ -3,9 +3,9 @@ import k_diffusion as K import torch import torch.nn as nn from ldm.dream.devices import choose_torch_device -from ldm.modules.diffusionmodules.util import rand_perlin_2d +from ldm.util import rand_perlin_2d -def cfg_apply_threshold(result, threshold = 0.0, scale = 0.707): +def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7): if threshold <= 0.0: return result maxval = 0.0 + torch.max(result).cpu().numpy() @@ -20,17 +20,26 @@ def cfg_apply_threshold(result, threshold = 0.0, scale = 0.707): class CFGDenoiser(nn.Module): - def __init__(self, model, threshold = 0): + def __init__(self, model, threshold = 0, warmup = 0): super().__init__() self.inner_model = model self.threshold = threshold + self.warmup_max = warmup + self.warmup = 0 def forward(self, x, sigma, uncond, cond, cond_scale): x_in = torch.cat([x] * 2) sigma_in = torch.cat([sigma] * 2) cond_in = torch.cat([uncond, cond]) uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) - return cfg_apply_threshold(uncond + (cond - uncond) * cond_scale, self.threshold) + if self.warmup < self.warmup_max: + thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max)) + self.warmup += 1 + else: + thresh = self.threshold + if thresh > self.threshold: + thresh = self.threshold + return cfg_apply_threshold(uncond + (cond - uncond) * cond_scale, thresh) class KSampler(object): @@ -39,7 +48,6 @@ class KSampler(object): self.model = K.external.CompVisDenoiser(model) self.schedule = schedule self.device = device or choose_torch_device() - #self.threshold = threshold or 0 def forward(self, x, sigma, uncond, cond, cond_scale): x_in = torch.cat([x] * 2) @@ -49,7 +57,6 @@ class KSampler(object): x_in, sigma_in, cond=cond_in ).chunk(2) return uncond + (cond - uncond) * cond_scale - #return cfg_apply_threshold(uncond + (cond - uncond) * cond_scale, self.threshold) @@ -95,12 +102,7 @@ class KSampler(object): torch.randn([batch_size, *shape], device=self.device) * sigmas[0] ) # for GPU draw - - if perlin > 0.0: - print(shape) - x = (1 - perlin / 2) * x + perlin * rand_perlin_2d((shape[1], shape[2]), (8, 8)).to(self.device) - - model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold) + model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10)) extra_args = { 'cond': conditioning, 'uncond': unconditional_conditioning, diff --git a/ldm/modules/diffusionmodules/util.py b/ldm/modules/diffusionmodules/util.py index 2a92152145..197b42b2bc 100644 --- a/ldm/modules/diffusionmodules/util.py +++ b/ldm/modules/diffusionmodules/util.py @@ -18,25 +18,6 @@ from einops import repeat from ldm.util import instantiate_from_config -def rand_perlin_2d(shape, res, fade = lambda t: 6*t**5 - 15*t**4 + 10*t**3): - delta = (res[0] / shape[0], res[1] / shape[1]) - d = (shape[0] // res[0], shape[1] // res[1]) - - grid = torch.stack(torch.meshgrid(torch.arange(0, res[0], delta[0]), torch.arange(0, res[1], delta[1])), dim = -1) % 1 - angles = 2*math.pi*torch.rand(res[0]+1, res[1]+1) - gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim = -1) - - tile_grads = lambda slice1, slice2: gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]].repeat_interleave(d[0], 0).repeat_interleave(d[1], 1) - dot = lambda grad, shift: (torch.stack((grid[:shape[0],:shape[1],0] + shift[0], grid[:shape[0],:shape[1], 1] + shift[1] ), dim = -1) * grad[:shape[0], :shape[1]]).sum(dim = -1) - - n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]) - n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]) - n01 = dot(tile_grads([0, -1],[1, None]), [0, -1]) - n11 = dot(tile_grads([1, None], [1, None]), [-1,-1]) - t = fade(grid[:shape[0], :shape[1]]) - return math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]) - - def make_beta_schedule( schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3 ): diff --git a/ldm/simplet2i.py b/ldm/simplet2i.py index 5306aa4785..0eee80918d 100644 --- a/ldm/simplet2i.py +++ b/ldm/simplet2i.py @@ -9,6 +9,7 @@ import numpy as np import random import os import traceback +from ldm.modules.diffusionmodules.util import noise_like from omegaconf import OmegaConf from PIL import Image from tqdm import tqdm, trange @@ -23,7 +24,7 @@ import time import re import sys -from ldm.util import instantiate_from_config +from ldm.util import instantiate_from_config, rand_perlin_2d from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.plms import PLMSSampler from ldm.models.diffusion.ksampler import KSampler @@ -352,13 +353,18 @@ class T2I: def get_noise(): if init_img: - return torch.randn_like(init_latent, device=self.device) + x = torch.randn_like(init_latent, device=self.device) else: - return torch.randn([1, + x = torch.randn([1, self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor], device=self.device) + if perlin > 0.0: + shape = x.shape + perlin_noise = torch.stack([rand_perlin_2d((shape[2], shape[3]), (8, 8)).to(self.device) for _ in range(shape[1])], dim=0) + x = (1 - perlin) * x + perlin * perlin_noise + return x initial_noise = None if variation_amount > 0 or len(with_variations) > 0: @@ -387,7 +393,7 @@ class T2I: x_T = initial_noise else: seed_everything(seed) - # make_image will do the equivalent of get_noise itself + x_T = get_noise() image = make_image(x_T) results.append([image, seed]) if image_callback is not None: diff --git a/ldm/util.py b/ldm/util.py index d1379cae2b..459816cf30 100644 --- a/ldm/util.py +++ b/ldm/util.py @@ -2,6 +2,7 @@ import importlib import torch import numpy as np +import math from collections import abc from einops import rearrange from functools import partial @@ -212,3 +213,21 @@ def parallel_data_prefetch( return out else: return gather_res + +def rand_perlin_2d(shape, res, fade = lambda t: 6*t**5 - 15*t**4 + 10*t**3): + delta = (res[0] / shape[0], res[1] / shape[1]) + d = (shape[0] // res[0], shape[1] // res[1]) + + grid = torch.stack(torch.meshgrid(torch.arange(0, res[0], delta[0]), torch.arange(0, res[1], delta[1])), dim = -1) % 1 + angles = 2*math.pi*torch.rand(res[0]+1, res[1]+1) + gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim = -1) + + tile_grads = lambda slice1, slice2: gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]].repeat_interleave(d[0], 0).repeat_interleave(d[1], 1) + dot = lambda grad, shift: (torch.stack((grid[:shape[0],:shape[1],0] + shift[0], grid[:shape[0],:shape[1], 1] + shift[1] ), dim = -1) * grad[:shape[0], :shape[1]]).sum(dim = -1) + + n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]) + n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]) + n01 = dot(tile_grads([0, -1],[1, None]), [0, -1]) + n11 = dot(tile_grads([1, None], [1, None]), [-1,-1]) + t = fade(grid[:shape[0], :shape[1]]) + return math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]) \ No newline at end of file diff --git a/static/dream_web/index.html b/static/dream_web/index.html index bf57afae3f..6e81fd9c83 100644 --- a/static/dream_web/index.html +++ b/static/dream_web/index.html @@ -63,6 +63,10 @@ + + + +