mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
* Updates for thresholding and perlin noise options, added warmup for thresholding.
This commit is contained in:
parent
1a4bed2e55
commit
0891910cac
@ -79,6 +79,8 @@ class DreamServer(BaseHTTPRequestHandler):
|
|||||||
upscale = [int(upscale_level),float(upscale_strength)] if upscale_level != '' else None
|
upscale = [int(upscale_level),float(upscale_strength)] if upscale_level != '' else None
|
||||||
progress_images = 'progress_images' in post_data
|
progress_images = 'progress_images' in post_data
|
||||||
seed = self.model.seed if int(post_data['seed']) == -1 else int(post_data['seed'])
|
seed = self.model.seed if int(post_data['seed']) == -1 else int(post_data['seed'])
|
||||||
|
threshold = float(post_data['threshold'])
|
||||||
|
perlin = float(post_data['perlin'])
|
||||||
|
|
||||||
self.canceled.clear()
|
self.canceled.clear()
|
||||||
print(f">> Request to generate with prompt: {prompt}")
|
print(f">> Request to generate with prompt: {prompt}")
|
||||||
@ -165,7 +167,9 @@ class DreamServer(BaseHTTPRequestHandler):
|
|||||||
upscale = upscale,
|
upscale = upscale,
|
||||||
sampler_name = sampler_name,
|
sampler_name = sampler_name,
|
||||||
step_callback=image_progress,
|
step_callback=image_progress,
|
||||||
image_callback=image_done)
|
image_callback=image_done,
|
||||||
|
threshold=threshold,
|
||||||
|
perlin=perlin)
|
||||||
else:
|
else:
|
||||||
# Decode initimg as base64 to temp file
|
# Decode initimg as base64 to temp file
|
||||||
with open("./img2img-tmp.png", "wb") as f:
|
with open("./img2img-tmp.png", "wb") as f:
|
||||||
@ -188,7 +192,9 @@ class DreamServer(BaseHTTPRequestHandler):
|
|||||||
gfpgan_strength=gfpgan_strength,
|
gfpgan_strength=gfpgan_strength,
|
||||||
upscale = upscale,
|
upscale = upscale,
|
||||||
step_callback=image_progress,
|
step_callback=image_progress,
|
||||||
image_callback=image_done)
|
image_callback=image_done,
|
||||||
|
threshold=threshold,
|
||||||
|
perlin=perlin)
|
||||||
finally:
|
finally:
|
||||||
# Remove the temp file
|
# Remove the temp file
|
||||||
os.remove("./img2img-tmp.png")
|
os.remove("./img2img-tmp.png")
|
||||||
|
@ -3,9 +3,9 @@ import k_diffusion as K
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from ldm.dream.devices import choose_torch_device
|
from ldm.dream.devices import choose_torch_device
|
||||||
from ldm.modules.diffusionmodules.util import rand_perlin_2d
|
from ldm.util import rand_perlin_2d
|
||||||
|
|
||||||
def cfg_apply_threshold(result, threshold = 0.0, scale = 0.707):
|
def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
|
||||||
if threshold <= 0.0:
|
if threshold <= 0.0:
|
||||||
return result
|
return result
|
||||||
maxval = 0.0 + torch.max(result).cpu().numpy()
|
maxval = 0.0 + torch.max(result).cpu().numpy()
|
||||||
@ -20,17 +20,26 @@ def cfg_apply_threshold(result, threshold = 0.0, scale = 0.707):
|
|||||||
|
|
||||||
|
|
||||||
class CFGDenoiser(nn.Module):
|
class CFGDenoiser(nn.Module):
|
||||||
def __init__(self, model, threshold = 0):
|
def __init__(self, model, threshold = 0, warmup = 0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.inner_model = model
|
self.inner_model = model
|
||||||
self.threshold = threshold
|
self.threshold = threshold
|
||||||
|
self.warmup_max = warmup
|
||||||
|
self.warmup = 0
|
||||||
|
|
||||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||||
x_in = torch.cat([x] * 2)
|
x_in = torch.cat([x] * 2)
|
||||||
sigma_in = torch.cat([sigma] * 2)
|
sigma_in = torch.cat([sigma] * 2)
|
||||||
cond_in = torch.cat([uncond, cond])
|
cond_in = torch.cat([uncond, cond])
|
||||||
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
|
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
|
||||||
return cfg_apply_threshold(uncond + (cond - uncond) * cond_scale, self.threshold)
|
if self.warmup < self.warmup_max:
|
||||||
|
thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max))
|
||||||
|
self.warmup += 1
|
||||||
|
else:
|
||||||
|
thresh = self.threshold
|
||||||
|
if thresh > self.threshold:
|
||||||
|
thresh = self.threshold
|
||||||
|
return cfg_apply_threshold(uncond + (cond - uncond) * cond_scale, thresh)
|
||||||
|
|
||||||
|
|
||||||
class KSampler(object):
|
class KSampler(object):
|
||||||
@ -39,7 +48,6 @@ class KSampler(object):
|
|||||||
self.model = K.external.CompVisDenoiser(model)
|
self.model = K.external.CompVisDenoiser(model)
|
||||||
self.schedule = schedule
|
self.schedule = schedule
|
||||||
self.device = device or choose_torch_device()
|
self.device = device or choose_torch_device()
|
||||||
#self.threshold = threshold or 0
|
|
||||||
|
|
||||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||||
x_in = torch.cat([x] * 2)
|
x_in = torch.cat([x] * 2)
|
||||||
@ -49,7 +57,6 @@ class KSampler(object):
|
|||||||
x_in, sigma_in, cond=cond_in
|
x_in, sigma_in, cond=cond_in
|
||||||
).chunk(2)
|
).chunk(2)
|
||||||
return uncond + (cond - uncond) * cond_scale
|
return uncond + (cond - uncond) * cond_scale
|
||||||
#return cfg_apply_threshold(uncond + (cond - uncond) * cond_scale, self.threshold)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -95,12 +102,7 @@ class KSampler(object):
|
|||||||
torch.randn([batch_size, *shape], device=self.device)
|
torch.randn([batch_size, *shape], device=self.device)
|
||||||
* sigmas[0]
|
* sigmas[0]
|
||||||
) # for GPU draw
|
) # for GPU draw
|
||||||
|
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10))
|
||||||
if perlin > 0.0:
|
|
||||||
print(shape)
|
|
||||||
x = (1 - perlin / 2) * x + perlin * rand_perlin_2d((shape[1], shape[2]), (8, 8)).to(self.device)
|
|
||||||
|
|
||||||
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold)
|
|
||||||
extra_args = {
|
extra_args = {
|
||||||
'cond': conditioning,
|
'cond': conditioning,
|
||||||
'uncond': unconditional_conditioning,
|
'uncond': unconditional_conditioning,
|
||||||
|
@ -18,25 +18,6 @@ from einops import repeat
|
|||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
|
|
||||||
|
|
||||||
def rand_perlin_2d(shape, res, fade = lambda t: 6*t**5 - 15*t**4 + 10*t**3):
|
|
||||||
delta = (res[0] / shape[0], res[1] / shape[1])
|
|
||||||
d = (shape[0] // res[0], shape[1] // res[1])
|
|
||||||
|
|
||||||
grid = torch.stack(torch.meshgrid(torch.arange(0, res[0], delta[0]), torch.arange(0, res[1], delta[1])), dim = -1) % 1
|
|
||||||
angles = 2*math.pi*torch.rand(res[0]+1, res[1]+1)
|
|
||||||
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim = -1)
|
|
||||||
|
|
||||||
tile_grads = lambda slice1, slice2: gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]].repeat_interleave(d[0], 0).repeat_interleave(d[1], 1)
|
|
||||||
dot = lambda grad, shift: (torch.stack((grid[:shape[0],:shape[1],0] + shift[0], grid[:shape[0],:shape[1], 1] + shift[1] ), dim = -1) * grad[:shape[0], :shape[1]]).sum(dim = -1)
|
|
||||||
|
|
||||||
n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
|
|
||||||
n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
|
|
||||||
n01 = dot(tile_grads([0, -1],[1, None]), [0, -1])
|
|
||||||
n11 = dot(tile_grads([1, None], [1, None]), [-1,-1])
|
|
||||||
t = fade(grid[:shape[0], :shape[1]])
|
|
||||||
return math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1])
|
|
||||||
|
|
||||||
|
|
||||||
def make_beta_schedule(
|
def make_beta_schedule(
|
||||||
schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
|
schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
|
||||||
):
|
):
|
||||||
|
@ -9,6 +9,7 @@ import numpy as np
|
|||||||
import random
|
import random
|
||||||
import os
|
import os
|
||||||
import traceback
|
import traceback
|
||||||
|
from ldm.modules.diffusionmodules.util import noise_like
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
@ -23,7 +24,7 @@ import time
|
|||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config, rand_perlin_2d
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
from ldm.models.diffusion.plms import PLMSSampler
|
from ldm.models.diffusion.plms import PLMSSampler
|
||||||
from ldm.models.diffusion.ksampler import KSampler
|
from ldm.models.diffusion.ksampler import KSampler
|
||||||
@ -352,13 +353,18 @@ class T2I:
|
|||||||
|
|
||||||
def get_noise():
|
def get_noise():
|
||||||
if init_img:
|
if init_img:
|
||||||
return torch.randn_like(init_latent, device=self.device)
|
x = torch.randn_like(init_latent, device=self.device)
|
||||||
else:
|
else:
|
||||||
return torch.randn([1,
|
x = torch.randn([1,
|
||||||
self.latent_channels,
|
self.latent_channels,
|
||||||
height // self.downsampling_factor,
|
height // self.downsampling_factor,
|
||||||
width // self.downsampling_factor],
|
width // self.downsampling_factor],
|
||||||
device=self.device)
|
device=self.device)
|
||||||
|
if perlin > 0.0:
|
||||||
|
shape = x.shape
|
||||||
|
perlin_noise = torch.stack([rand_perlin_2d((shape[2], shape[3]), (8, 8)).to(self.device) for _ in range(shape[1])], dim=0)
|
||||||
|
x = (1 - perlin) * x + perlin * perlin_noise
|
||||||
|
return x
|
||||||
|
|
||||||
initial_noise = None
|
initial_noise = None
|
||||||
if variation_amount > 0 or len(with_variations) > 0:
|
if variation_amount > 0 or len(with_variations) > 0:
|
||||||
@ -387,7 +393,7 @@ class T2I:
|
|||||||
x_T = initial_noise
|
x_T = initial_noise
|
||||||
else:
|
else:
|
||||||
seed_everything(seed)
|
seed_everything(seed)
|
||||||
# make_image will do the equivalent of get_noise itself
|
x_T = get_noise()
|
||||||
image = make_image(x_T)
|
image = make_image(x_T)
|
||||||
results.append([image, seed])
|
results.append([image, seed])
|
||||||
if image_callback is not None:
|
if image_callback is not None:
|
||||||
|
19
ldm/util.py
19
ldm/util.py
@ -2,6 +2,7 @@ import importlib
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import math
|
||||||
from collections import abc
|
from collections import abc
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from functools import partial
|
from functools import partial
|
||||||
@ -212,3 +213,21 @@ def parallel_data_prefetch(
|
|||||||
return out
|
return out
|
||||||
else:
|
else:
|
||||||
return gather_res
|
return gather_res
|
||||||
|
|
||||||
|
def rand_perlin_2d(shape, res, fade = lambda t: 6*t**5 - 15*t**4 + 10*t**3):
|
||||||
|
delta = (res[0] / shape[0], res[1] / shape[1])
|
||||||
|
d = (shape[0] // res[0], shape[1] // res[1])
|
||||||
|
|
||||||
|
grid = torch.stack(torch.meshgrid(torch.arange(0, res[0], delta[0]), torch.arange(0, res[1], delta[1])), dim = -1) % 1
|
||||||
|
angles = 2*math.pi*torch.rand(res[0]+1, res[1]+1)
|
||||||
|
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim = -1)
|
||||||
|
|
||||||
|
tile_grads = lambda slice1, slice2: gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]].repeat_interleave(d[0], 0).repeat_interleave(d[1], 1)
|
||||||
|
dot = lambda grad, shift: (torch.stack((grid[:shape[0],:shape[1],0] + shift[0], grid[:shape[0],:shape[1], 1] + shift[1] ), dim = -1) * grad[:shape[0], :shape[1]]).sum(dim = -1)
|
||||||
|
|
||||||
|
n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
|
||||||
|
n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
|
||||||
|
n01 = dot(tile_grads([0, -1],[1, None]), [0, -1])
|
||||||
|
n11 = dot(tile_grads([1, None], [1, None]), [-1,-1])
|
||||||
|
t = fade(grid[:shape[0], :shape[1]])
|
||||||
|
return math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1])
|
@ -63,6 +63,10 @@
|
|||||||
<label title="Set to -1 for random seed" for="seed">Seed:</label>
|
<label title="Set to -1 for random seed" for="seed">Seed:</label>
|
||||||
<input value="-1" type="number" id="seed" name="seed">
|
<input value="-1" type="number" id="seed" name="seed">
|
||||||
<button type="button" id="reset-seed">↺</button>
|
<button type="button" id="reset-seed">↺</button>
|
||||||
|
<label title="Threshold" for="threshold">Threshold:</label>
|
||||||
|
<input value="0" type="number" id="threshold" name="threshold" step="any">
|
||||||
|
<label title="Perlin" for="perlin">Perlin:</label>
|
||||||
|
<input value="0" type="number" id="perlin" name="perlin" step="any">
|
||||||
<input type="checkbox" name="progress_images" id="progress_images">
|
<input type="checkbox" name="progress_images" id="progress_images">
|
||||||
<label for="progress_images">Display in-progress images (slows down generation):</label>
|
<label for="progress_images">Display in-progress images (slows down generation):</label>
|
||||||
<button type="button" id="reset-all">Reset to Defaults</button>
|
<button type="button" id="reset-all">Reset to Defaults</button>
|
||||||
|
Loading…
Reference in New Issue
Block a user