mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix k_samplers in img2img - probably correct now
This commit is contained in:
parent
82481a6f9c
commit
d60df54f69
@ -5,6 +5,12 @@ import torch.nn as nn
|
|||||||
from ldm.dream.devices import choose_torch_device
|
from ldm.dream.devices import choose_torch_device
|
||||||
from ldm.models.diffusion.sampler import Sampler
|
from ldm.models.diffusion.sampler import Sampler
|
||||||
from ldm.util import rand_perlin_2d
|
from ldm.util import rand_perlin_2d
|
||||||
|
from ldm.modules.diffusionmodules.util import (
|
||||||
|
make_ddim_sampling_parameters,
|
||||||
|
make_ddim_timesteps,
|
||||||
|
noise_like,
|
||||||
|
extract_into_tensor,
|
||||||
|
)
|
||||||
|
|
||||||
def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
|
def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
|
||||||
if threshold <= 0.0:
|
if threshold <= 0.0:
|
||||||
@ -82,13 +88,53 @@ class KSampler(Sampler):
|
|||||||
)
|
)
|
||||||
self.model = outer_model
|
self.model = outer_model
|
||||||
self.ddim_num_steps = ddim_num_steps
|
self.ddim_num_steps = ddim_num_steps
|
||||||
sigmas = self.model.get_sigmas(ddim_num_steps)
|
# we don't need both of these sigmas, but storing them here to make
|
||||||
self.sigmas = sigmas
|
# comparison easier later on
|
||||||
|
self.model_sigmas = self.model.get_sigmas(ddim_num_steps)
|
||||||
|
self.karras_sigmas = K.sampling.get_sigmas_karras(
|
||||||
|
n=ddim_num_steps,
|
||||||
|
sigma_min=self.model.sigmas[0].item(),
|
||||||
|
sigma_max=self.model.sigmas[-1].item(),
|
||||||
|
rho=7.,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
# ALERT: We are completely overriding the sample() method in the base class, which
|
# ALERT: We are completely overriding the sample() method in the base class, which
|
||||||
# means that inpainting will (probably?) not work correctly. To get this to work
|
# means that inpainting will not work. To get this to work we need to be able to
|
||||||
# we need to be able to modify the inner loop of k_heun, k_lms, etc, as is done
|
# modify the inner loop of k_heun, k_lms, etc, as is done in an ugly way
|
||||||
# in an ugly way in the lstein/k-diffusion branch.
|
# in the lstein/k-diffusion branch.
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def decode(
|
||||||
|
self,
|
||||||
|
z_enc,
|
||||||
|
cond,
|
||||||
|
t_enc,
|
||||||
|
img_callback=None,
|
||||||
|
unconditional_guidance_scale=1.0,
|
||||||
|
unconditional_conditioning=None,
|
||||||
|
use_original_steps=False,
|
||||||
|
init_latent = None,
|
||||||
|
mask = None,
|
||||||
|
):
|
||||||
|
samples,_ = self.sample(
|
||||||
|
batch_size = 1,
|
||||||
|
S = t_enc,
|
||||||
|
x_T = z_enc,
|
||||||
|
shape = z_enc.shape[1:],
|
||||||
|
conditioning = cond,
|
||||||
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
|
unconditional_conditioning = unconditional_conditioning,
|
||||||
|
img_callback = img_callback,
|
||||||
|
x0 = init_latent,
|
||||||
|
mask = mask
|
||||||
|
)
|
||||||
|
return samples
|
||||||
|
|
||||||
|
# this is a no-op, provided here for compatibility with ddim and plms samplers
|
||||||
|
@torch.no_grad()
|
||||||
|
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
|
||||||
|
return x0
|
||||||
|
|
||||||
# Most of these arguments are ignored and are only present for compatibility with
|
# Most of these arguments are ignored and are only present for compatibility with
|
||||||
# other samples
|
# other samples
|
||||||
@ -124,17 +170,15 @@ class KSampler(Sampler):
|
|||||||
if img_callback is not None:
|
if img_callback is not None:
|
||||||
img_callback(k_callback_values['x'],k_callback_values['i'])
|
img_callback(k_callback_values['x'],k_callback_values['i'])
|
||||||
|
|
||||||
# sigmas = self.model.get_sigmas(S)
|
# sigmas are set up in make_schedule - we take the last steps items
|
||||||
# sigmas are now set up in make_schedule - we take the last steps items
|
total_steps = len(self.karras_sigmas)
|
||||||
sigmas = self.sigmas[-S-1:]
|
sigmas = self.karras_sigmas[-S-1:]
|
||||||
|
|
||||||
if x_T is not None:
|
if x_T is not None:
|
||||||
x = x_T * sigmas[0]
|
x = x_T + torch.randn([batch_size, *shape], device=self.device) * sigmas[0]
|
||||||
else:
|
else:
|
||||||
x = (
|
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0]
|
||||||
torch.randn([batch_size, *shape], device=self.device)
|
|
||||||
* sigmas[0]
|
|
||||||
) # for GPU draw
|
|
||||||
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10))
|
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10))
|
||||||
extra_args = {
|
extra_args = {
|
||||||
'cond': conditioning,
|
'cond': conditioning,
|
||||||
@ -199,10 +243,12 @@ class KSampler(Sampler):
|
|||||||
return img, None, None
|
return img, None, None
|
||||||
|
|
||||||
def get_initial_image(self,x_T,shape,steps):
|
def get_initial_image(self,x_T,shape,steps):
|
||||||
|
print(f'WARNING: ksampler.get_initial_image(): get_initial_image needs testing')
|
||||||
|
x = (torch.randn(shape, device=self.device) * self.sigmas[0])
|
||||||
if x_T is not None:
|
if x_T is not None:
|
||||||
return x_T + x_T * self.sigmas[0]
|
return x_T + x_T * self.sigmas[0]
|
||||||
else:
|
else:
|
||||||
return (torch.randn(shape, device=self.device) * self.sigmas[0])
|
return x
|
||||||
|
|
||||||
def prepare_to_sample(self,t_enc):
|
def prepare_to_sample(self,t_enc):
|
||||||
self.t_enc = t_enc
|
self.t_enc = t_enc
|
||||||
@ -216,29 +262,3 @@ class KSampler(Sampler):
|
|||||||
'''
|
'''
|
||||||
return self.model.inner_model.q_sample(x0,ts)
|
return self.model.inner_model.q_sample(x0,ts)
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def decode(
|
|
||||||
self,
|
|
||||||
z_enc,
|
|
||||||
cond,
|
|
||||||
t_enc,
|
|
||||||
img_callback=None,
|
|
||||||
unconditional_guidance_scale=1.0,
|
|
||||||
unconditional_conditioning=None,
|
|
||||||
use_original_steps=False,
|
|
||||||
init_latent = None,
|
|
||||||
mask = None,
|
|
||||||
):
|
|
||||||
samples,_ = self.sample(
|
|
||||||
batch_size = 1,
|
|
||||||
S = t_enc,
|
|
||||||
x_T = z_enc,
|
|
||||||
shape = z_enc.shape[1:],
|
|
||||||
conditioning = cond,
|
|
||||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
|
||||||
unconditional_conditioning = unconditional_conditioning,
|
|
||||||
img_callback = img_callback,
|
|
||||||
x0 = init_latent,
|
|
||||||
mask = mask
|
|
||||||
)
|
|
||||||
return samples
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user