revert to original k* noise schedule

This commit is contained in:
Lincoln Stein 2022-10-04 17:48:16 -04:00 committed by Any-Winter-4079
parent feb405f19a
commit 7a3eae4572
2 changed files with 16 additions and 12 deletions

View File

@ -31,7 +31,7 @@ class Txt2Img(Generator):
if self.free_gpu_mem and self.model.model.device != self.model.device: if self.free_gpu_mem and self.model.model.device != self.model.device:
self.model.model.to(self.model.device) self.model.model.to(self.model.device)
sampler.make_schedule(ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=True) sampler.make_schedule(ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False)
samples, _ = sampler.sample( samples, _ = sampler.sample(
batch_size = 1, batch_size = 1,

View File

@ -45,6 +45,7 @@ class KSampler(Sampler):
ddim_eta=0.0, ddim_eta=0.0,
verbose=False, verbose=False,
): ):
ddim_num_steps += 1
outer_model = self.model outer_model = self.model
self.model = outer_model.inner_model self.model = outer_model.inner_model
super().make_schedule( super().make_schedule(
@ -53,17 +54,19 @@ class KSampler(Sampler):
ddim_eta=0.0, ddim_eta=0.0,
verbose=False, verbose=False,
) )
self.model = outer_model self.model = outer_model
self.ddim_num_steps = ddim_num_steps self.ddim_num_steps = ddim_num_steps
sigmas = K.sampling.get_sigmas_karras( # not working quite right
n=ddim_num_steps, # sigmas = K.sampling.get_sigmas_karras(
sigma_min=self.model.sigmas[0].item(), # n=ddim_num_steps,
sigma_max=self.model.sigmas[-1].item(), # sigma_min=self.model.sigmas[0].item(),
rho=7., # sigma_max=self.model.sigmas[-1].item(),
device=self.device, # rho=7.,
# Birch-san recommends this, but it doesn't match the call signature in his branch of k-diffusion # device=self.device,
# concat_zero=False # # Birch-san recommends this, but it doesn't match the call signature in his branch of k-diffusion
) # # concat_zero=False
# )
sigmas = self.model.get_sigmas(ddim_num_steps)
self.sigmas = sigmas self.sigmas = sigmas
# ALERT: We are completely overriding the sample() method in the base class, which # ALERT: We are completely overriding the sample() method in the base class, which
@ -99,6 +102,7 @@ class KSampler(Sampler):
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs, **kwargs,
): ):
S += 1
def route_callback(k_callback_values): def route_callback(k_callback_values):
if img_callback is not None: if img_callback is not None:
img_callback(k_callback_values['x'],k_callback_values['i']) img_callback(k_callback_values['x'],k_callback_values['i'])
@ -119,7 +123,7 @@ class KSampler(Sampler):
'uncond': unconditional_conditioning, 'uncond': unconditional_conditioning,
'cond_scale': unconditional_guidance_scale, 'cond_scale': unconditional_guidance_scale,
} }
print(f'>> Sampling with k__{self.schedule}') print(f'>> Sampling with k_{self.schedule}')
return ( return (
K.sampling.__dict__[f'sample_{self.schedule}']( K.sampling.__dict__[f'sample_{self.schedule}'](
model_wrap_cfg, x, sigmas, extra_args=extra_args, model_wrap_cfg, x, sigmas, extra_args=extra_args,