mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
c3b992db96
commit 9bb0b5d0036c4dffbb72ce11e097fae4ab63defd Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Sat Oct 15 23:43:41 2022 +0200 undo local_files_only stuff commit eed93f5d30c34cfccaf7497618ae9af17a5ecfbb Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Sat Oct 15 23:40:37 2022 +0200 Revert "Merge branch 'development-invoke' into fix-prompts" This reverts commit 7c40892a9f184f7e216f14d14feb0411c5a90e24, reversing changes made to e3f2dd62b0548ca6988818ef058093a4f5b022f2. commit f06d6024e345c69e6d5a91ab5423925a68ee95a7 Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Thu Oct 13 23:30:16 2022 +0200 more efficiently handle multiple conditioning commit 5efdfcbcd980ce6202ab74e7f90e7415ce7260da Merge: b9c0dc5 ac08bb6 Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Thu Oct 13 14:51:01 2022 +0200 Merge branch 'optional-disable-karras-schedule' into fix-prompts commit ac08bb6fd25e19a9d35cf6c199e66500fb604af1 Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Thu Oct 13 14:50:43 2022 +0200 append '*use_model_sigmas*' to prompt string to use model sigmas commit 70d8c05a3ff329409f76204f4af94e55d468ab8b Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Thu Oct 13 12:12:17 2022 +0200 make karras scheduling switchable commitd60df54f69
replaced the model's own scheduling with karras scheduling. this has changed image generation (seems worse now?) this commit wraps the change in a bool. commit b9c0dc5f1a658a0e6c3936000e9ae559e1c7a1db Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Wed Oct 12 20:16:00 2022 +0200 add test of more complex conjunction commit 9ac0c15cc0d7b5f6df3289d3ad474260972a17be Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Wed Oct 12 17:18:25 2022 +0200 improve comments commit ad33bce60590b87b2a93e90f16dc9d3e935d04a5 Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Wed Oct 12 17:04:46 2022 +0200 put back thresholding stuff commit 4852c698a325049834ba0d4b358f07210bc7171a Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Wed Oct 12 14:25:02 2022 +0200 notes on improving conjunction efficiency commit a53bb1e5b68025d09642b935ae6a9a015cfaf2d6 Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Wed Oct 12 14:14:33 2022 +0200 optional weights support for Conjunction commit fec79ab15e4f0c84dd61cb1b45a5e6a72ae4aaeb Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Wed Oct 12 12:07:27 2022 +0200 fix blend error and log parsing output commit 1f751c2a039f9c97af57b18e0f019512631d5a25 Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Wed Oct 12 10:33:33 2022 +0200 fix broken euler sampler commit 02f8148d17efe4b6bde8d29b827092a0626363ee Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Wed Oct 12 10:24:20 2022 +0200 cleanup prompt parser commit 8028d49ae6c16c0d6ec9c9de9c12d56c32201421 Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Wed Oct 12 10:14:18 2022 +0200 explicit conjunction, improve flattening logic commit 8a1710892185f07eb77483f7edae0fc4d6bbb250 Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Tue Oct 11 22:59:30 2022 +0200 adapt multi-conditioning to also work with ddim commit 53802a839850d0d1ff017c6bafe457c4bed750b0 Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Tue Oct 11 22:31:42 2022 +0200 unconditioning is also fancy-prompt-syntaxable commit 7c40892a9f184f7e216f14d14feb0411c5a90e24 Merge: e3f2dd6 dbe0da4 Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Tue Oct 11 21:39:54 2022 +0200 Merge branch 'development-invoke' into fix-prompts commit e3f2dd62b0548ca6988818ef058093a4f5b022f2 Merge: eef0e4806f542e
Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Tue Oct 11 21:38:09 2022 +0200 Merge remote-tracking branch 'upstream/development' into fix-prompts commit eef0e484c2eaa1bd4e0e0b1d3f8d7bba38478144 Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Tue Oct 11 21:26:25 2022 +0200 fix run-on paren-less attention, add some comments commit fd29afdf0e9f5e0cdc60239e22480c36ca0aaeca Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Tue Oct 11 21:03:02 2022 +0200 python 3.9 compatibility commit 26f7646eef7f39bc8f7ce805e747df0f723464da Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Tue Oct 11 20:58:42 2022 +0200 first pass connecting PromptParser to conditioning commit ae53dff3796d7b9a5e7ed30fa1edb0374af6cd8d Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Tue Oct 11 20:51:15 2022 +0200 update frontend dist commit 9be4a59a2d76f49e635474b5984bfca826a5dab4 Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Tue Oct 11 19:01:39 2022 +0200 fix issues with correctness checking FlattenedPrompt commit 3be212323eab68e72a363a654124edd9809e4cf0 Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Tue Oct 11 18:43:16 2022 +0200 parsing nested seems to work pretty ok commit acd73eb08cf67c27cac8a22934754321256f56a9 Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Tue Oct 11 18:26:17 2022 +0200 wip introducing FlattenedPrompt class commit 71698d5c7c2ac855b690d8ef67e8830148c59eda Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Tue Oct 11 15:59:42 2022 +0200 recursive attention weighting seems to actually work commit a4e1ec6b20deb7cc0cd12737bdbd266e56144709 Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Tue Oct 11 15:06:24 2022 +0200 now apparently almost supported nested attention commit da76fd1ddf22a3888cdc08fd4fed38d8b178e524 Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Tue Oct 11 13:23:37 2022 +0200 wip prompt parsing commit dbe0da4572c2ac22f26a7afd722349a5680a9e47 Author: Kyle Schouviller <kyle0654@hotmail.com> Date: Mon Oct 10 22:32:35 2022 -0700 Adding node-based invocation apps commit 8f2a2ffc083366de74d7dae471b50b6f98a7c5f8 Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Mon Oct 10 19:03:18 2022 +0200 fix merge issues commit 73118dee2a8f4891700756e014caf1c9ca629267 Merge: fd0084412413b0
Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Mon Oct 10 12:42:48 2022 +0200 Merge remote-tracking branch 'upstream/development' into fix-prompts commit fd0084413541013c2cf71e006af0392719bef53d Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Mon Oct 10 12:39:38 2022 +0200 wip prompt parsing commit 0be9363db9307859d2b65cffc6af01f57d7873a4 Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Mon Oct 10 03:20:06 2022 +0200 better +/- attention parsing commit 5383f691874a58ab01cda1e4fac6cf330146526a Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Mon Oct 10 02:27:47 2022 +0200 prompt parser seems to work commit 591d098a33ce35462428d8c169501d8ed73615ab Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Sun Oct 9 20:25:37 2022 +0200 supports weighting unconditioning, cross-attention with | commit 7a7220563aa05a2980235b5b908362f66b728309 Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Sun Oct 9 18:15:56 2022 +0200 i think cross attention might be working? commit 951ed391e7126bff228c18b2db304ad28d59644a Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Sun Oct 9 16:04:54 2022 +0200 weighted CFG denoiser working with a single item commit ee532a0c2827368c9e45a6a5f3975666402873da Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Sun Oct 9 06:33:40 2022 +0200 wip probably doesn't work or compile commit 14654bcbd207b9ca28a6cbd37dbd967d699b062d Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Fri Oct 7 18:11:48 2022 +0200 use tan() to calculate embedding weight for <1 attentions commit 1a8e76b31aa5abf5150419ebf3b29d4658d07f2b Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Fri Oct 7 16:14:54 2022 +0200 fix bad math.max reference commit f697ff896875876ccaa1e5527405bdaa7ed27cde Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Fri Oct 7 15:55:57 2022 +0200 respect http[s]x protocol when making socket.io middleware commit 41d3dd4eeae8d4efb05dfb44fc6d8aac5dc468ab Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Fri Oct 7 13:29:54 2022 +0200 fractional weighting works, by blending with prompts excluding the word commit 087fb6dfb3e8f5e84de8c911f75faa3e3fa3553c Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Fri Oct 7 10:52:03 2022 +0200 wip doing weights <1 by averaging with conditioning absent the lower-weighted fragment commit 3c49e3f3ec7c18dc60f3e18ed2f7f0d97aad3a47 Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Fri Oct 7 10:36:15 2022 +0200 notate CFGDenoiser, perhaps commit d2bcf1bb522026ebf209ad0103f6b370383e5070 Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Thu Oct 6 05:04:47 2022 +0200 hack blending syntax to test attention weighting more extensively commit 94904ef2cf917f74ec23ef7a570e12ff8255b048 Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Thu Oct 6 04:56:37 2022 +0200 conditioning works, apparently commit 7c6663ddd70f665fd1308b6dd74f92ca393a8df5 Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Thu Oct 6 02:20:24 2022 +0200 attention weighting, definitely works in positive direction commit 5856d453a9b020bc1a28ff643ae1f58c12c9be73 Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Tue Oct 4 19:02:14 2022 +0200 wip bubbling weights down commit a2ed14fd9b7d3cb36b6c5348018b364c76d1e892 Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Tue Oct 4 17:35:39 2022 +0200 bring in changes from PC
305 lines
10 KiB
Python
305 lines
10 KiB
Python
"""wrapper around part of Katherine Crowson's k-diffusion library, making it call compatible with other Samplers"""
|
|
import k_diffusion as K
|
|
import torch
|
|
import torch.nn as nn
|
|
from ldm.invoke.devices import choose_torch_device
|
|
from ldm.models.diffusion.sampler import Sampler
|
|
from ldm.util import rand_perlin_2d
|
|
from ldm.modules.diffusionmodules.util import (
|
|
make_ddim_sampling_parameters,
|
|
make_ddim_timesteps,
|
|
noise_like,
|
|
extract_into_tensor,
|
|
)
|
|
|
|
def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
|
|
if threshold <= 0.0:
|
|
return result
|
|
maxval = 0.0 + torch.max(result).cpu().numpy()
|
|
minval = 0.0 + torch.min(result).cpu().numpy()
|
|
if maxval < threshold and minval > -threshold:
|
|
return result
|
|
if maxval > threshold:
|
|
maxval = min(max(1, scale*maxval), threshold)
|
|
if minval < -threshold:
|
|
minval = max(min(-1, scale*minval), -threshold)
|
|
return torch.clamp(result, min=minval, max=maxval)
|
|
|
|
|
|
class CFGDenoiser(nn.Module):
|
|
def __init__(self, model, threshold = 0, warmup = 0):
|
|
super().__init__()
|
|
self.inner_model = model
|
|
self.threshold = threshold
|
|
self.warmup_max = warmup
|
|
self.warmup = max(warmup / 10, 1)
|
|
|
|
def forward(self, x, sigma, uncond, cond, cond_scale):
|
|
x_in = torch.cat([x] * 2)
|
|
sigma_in = torch.cat([sigma] * 2)
|
|
cond_in = torch.cat([uncond, cond])
|
|
unconditioned_x, conditioned_x = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
|
|
|
|
if self.warmup < self.warmup_max:
|
|
thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max))
|
|
self.warmup += 1
|
|
else:
|
|
thresh = self.threshold
|
|
if thresh > self.threshold:
|
|
thresh = self.threshold
|
|
|
|
# damian0815 thinking out loud notes:
|
|
# b + (a - b)*scale
|
|
# starting at the output that emerges applying the negative prompt (by default ''),
|
|
# (-> this is why the unconditioning feels like hammer)
|
|
# move toward the positive prompt by an amount controlled by cond_scale.
|
|
return cfg_apply_threshold(unconditioned_x + (conditioned_x - unconditioned_x) * cond_scale, thresh)
|
|
|
|
|
|
class ProgrammableCFGDenoiser(CFGDenoiser):
|
|
def forward(self, x, sigma, uncond, cond, cond_scale):
|
|
forward_lambda = lambda x, t, c: self.inner_model(x, t, cond=c)
|
|
x_new = Sampler.apply_weighted_conditioning_list(x, sigma, forward_lambda, uncond, cond, cond_scale)
|
|
|
|
if self.warmup < self.warmup_max:
|
|
thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max))
|
|
self.warmup += 1
|
|
else:
|
|
thresh = self.threshold
|
|
if thresh > self.threshold:
|
|
thresh = self.threshold
|
|
return cfg_apply_threshold(x_new, threshold=thresh)
|
|
|
|
|
|
class KSampler(Sampler):
|
|
def __init__(self, model, schedule='lms', device=None, **kwargs):
|
|
denoiser = K.external.CompVisDenoiser(model)
|
|
super().__init__(
|
|
denoiser,
|
|
schedule,
|
|
steps=model.num_timesteps,
|
|
)
|
|
self.sigmas = None
|
|
self.ds = None
|
|
self.s_in = None
|
|
|
|
def forward(self, x, sigma, uncond, cond, cond_scale):
|
|
x_in = torch.cat([x] * 2)
|
|
sigma_in = torch.cat([sigma] * 2)
|
|
cond_in = torch.cat([uncond, cond])
|
|
uncond, cond = self.inner_model(
|
|
x_in, sigma_in, cond=cond_in
|
|
).chunk(2)
|
|
return uncond + (cond - uncond) * cond_scale
|
|
|
|
|
|
def make_schedule(
|
|
self,
|
|
ddim_num_steps,
|
|
ddim_discretize='uniform',
|
|
ddim_eta=0.0,
|
|
verbose=False,
|
|
):
|
|
outer_model = self.model
|
|
self.model = outer_model.inner_model
|
|
super().make_schedule(
|
|
ddim_num_steps,
|
|
ddim_discretize='uniform',
|
|
ddim_eta=0.0,
|
|
verbose=False,
|
|
)
|
|
self.model = outer_model
|
|
self.ddim_num_steps = ddim_num_steps
|
|
# we don't need both of these sigmas, but storing them here to make
|
|
# comparison easier later on
|
|
self.model_sigmas = self.model.get_sigmas(ddim_num_steps)
|
|
self.karras_sigmas = K.sampling.get_sigmas_karras(
|
|
n=ddim_num_steps,
|
|
sigma_min=self.model.sigmas[0].item(),
|
|
sigma_max=self.model.sigmas[-1].item(),
|
|
rho=7.,
|
|
device=self.device,
|
|
)
|
|
self.sigmas = self.model_sigmas
|
|
#self.sigmas = self.karras_sigmas
|
|
|
|
# ALERT: We are completely overriding the sample() method in the base class, which
|
|
# means that inpainting will not work. To get this to work we need to be able to
|
|
# modify the inner loop of k_heun, k_lms, etc, as is done in an ugly way
|
|
# in the lstein/k-diffusion branch.
|
|
|
|
@torch.no_grad()
|
|
def decode(
|
|
self,
|
|
z_enc,
|
|
cond,
|
|
t_enc,
|
|
img_callback=None,
|
|
unconditional_guidance_scale=1.0,
|
|
unconditional_conditioning=None,
|
|
use_original_steps=False,
|
|
init_latent = None,
|
|
mask = None,
|
|
):
|
|
samples,_ = self.sample(
|
|
batch_size = 1,
|
|
S = t_enc,
|
|
x_T = z_enc,
|
|
shape = z_enc.shape[1:],
|
|
conditioning = cond,
|
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
|
unconditional_conditioning = unconditional_conditioning,
|
|
img_callback = img_callback,
|
|
x0 = init_latent,
|
|
mask = mask
|
|
)
|
|
return samples
|
|
|
|
# this is a no-op, provided here for compatibility with ddim and plms samplers
|
|
@torch.no_grad()
|
|
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
|
|
return x0
|
|
|
|
# Most of these arguments are ignored and are only present for compatibility with
|
|
# other samples
|
|
@torch.no_grad()
|
|
def sample(
|
|
self,
|
|
S,
|
|
batch_size,
|
|
shape,
|
|
conditioning=None,
|
|
callback=None,
|
|
normals_sequence=None,
|
|
img_callback=None,
|
|
quantize_x0=False,
|
|
eta=0.0,
|
|
mask=None,
|
|
x0=None,
|
|
temperature=1.0,
|
|
noise_dropout=0.0,
|
|
score_corrector=None,
|
|
corrector_kwargs=None,
|
|
verbose=True,
|
|
x_T=None,
|
|
log_every_t=100,
|
|
unconditional_guidance_scale=1.0,
|
|
unconditional_conditioning=None,
|
|
threshold = 0,
|
|
perlin = 0,
|
|
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
|
**kwargs,
|
|
):
|
|
def route_callback(k_callback_values):
|
|
if img_callback is not None:
|
|
img_callback(k_callback_values['x'],k_callback_values['i'])
|
|
|
|
# if make_schedule() hasn't been called, we do it now
|
|
if self.sigmas is None:
|
|
self.make_schedule(
|
|
ddim_num_steps=S,
|
|
ddim_eta = eta,
|
|
verbose = False,
|
|
)
|
|
|
|
# sigmas are set up in make_schedule - we take the last steps items
|
|
sigmas = self.sigmas[-S-1:]
|
|
|
|
# x_T is variation noise. When an init image is provided (in x0) we need to add
|
|
# more randomness to the starting image.
|
|
if x_T is not None:
|
|
if x0 is not None:
|
|
x = x_T + torch.randn_like(x0, device=self.device) * sigmas[0]
|
|
else:
|
|
x = x_T * sigmas[0]
|
|
else:
|
|
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0]
|
|
|
|
model_wrap_cfg = ProgrammableCFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10))
|
|
extra_args = {
|
|
'cond': conditioning,
|
|
'uncond': unconditional_conditioning,
|
|
'cond_scale': unconditional_guidance_scale,
|
|
}
|
|
print(f'>> Sampling with k_{self.schedule} starting at step {len(self.sigmas)-S-1} of {len(self.sigmas)-1} ({S} new sampling steps)')
|
|
return (
|
|
K.sampling.__dict__[f'sample_{self.schedule}'](
|
|
model_wrap_cfg, x, sigmas, extra_args=extra_args,
|
|
callback=route_callback
|
|
),
|
|
None,
|
|
)
|
|
|
|
# this code will support inpainting if and when ksampler API modified or
|
|
# a workaround is found.
|
|
@torch.no_grad()
|
|
def p_sample(
|
|
self,
|
|
img,
|
|
cond,
|
|
ts,
|
|
index,
|
|
unconditional_guidance_scale=1.0,
|
|
unconditional_conditioning=None,
|
|
**kwargs,
|
|
):
|
|
if self.model_wrap is None:
|
|
self.model_wrap = CFGDenoiser(self.model)
|
|
extra_args = {
|
|
'cond': cond,
|
|
'uncond': unconditional_conditioning,
|
|
'cond_scale': unconditional_guidance_scale,
|
|
}
|
|
if self.s_in is None:
|
|
self.s_in = img.new_ones([img.shape[0]])
|
|
if self.ds is None:
|
|
self.ds = []
|
|
|
|
# terrible, confusing names here
|
|
steps = self.ddim_num_steps
|
|
t_enc = self.t_enc
|
|
|
|
# sigmas is a full steps in length, but t_enc might
|
|
# be less. We start in the middle of the sigma array
|
|
# and work our way to the end after t_enc steps.
|
|
# index starts at t_enc and works its way to zero,
|
|
# so the actual formula for indexing into sigmas:
|
|
# sigma_index = (steps-index)
|
|
s_index = t_enc - index - 1
|
|
img = K.sampling.__dict__[f'_{self.schedule}'](
|
|
self.model_wrap,
|
|
img,
|
|
self.sigmas,
|
|
s_index,
|
|
s_in = self.s_in,
|
|
ds = self.ds,
|
|
extra_args=extra_args,
|
|
)
|
|
|
|
return img, None, None
|
|
|
|
# REVIEW THIS METHOD: it has never been tested. In particular,
|
|
# we should not be multiplying by self.sigmas[0] if we
|
|
# are at an intermediate step in img2img. See similar in
|
|
# sample() which does work.
|
|
def get_initial_image(self,x_T,shape,steps):
|
|
print(f'WARNING: ksampler.get_initial_image(): get_initial_image needs testing')
|
|
x = (torch.randn(shape, device=self.device) * self.sigmas[0])
|
|
if x_T is not None:
|
|
return x_T + x
|
|
else:
|
|
return x
|
|
|
|
def prepare_to_sample(self,t_enc):
|
|
self.t_enc = t_enc
|
|
self.model_wrap = None
|
|
self.ds = None
|
|
self.s_in = None
|
|
|
|
def q_sample(self,x0,ts):
|
|
'''
|
|
Overrides parent method to return the q_sample of the inner model.
|
|
'''
|
|
return self.model.inner_model.q_sample(x0,ts)
|
|
|