restore ability of ksamplers to process -v variation options

- supersedes #977
This commit is contained in:
Lincoln Stein 2022-10-07 14:43:59 -04:00
parent b296933ba0
commit 5157cbeda1
2 changed files with 11 additions and 4 deletions

View File

@ -49,6 +49,7 @@ class Img2Img(Generator):
img_callback = step_callback, img_callback = step_callback,
unconditional_guidance_scale=cfg_scale, unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc, unconditional_conditioning=uc,
init_latent = self.init_latent, # changes how noising is performed in ksampler
) )
return self.sample_to_image(samples) return self.sample_to_image(samples)

View File

@ -97,6 +97,7 @@ class KSampler(Sampler):
rho=7., rho=7.,
device=self.device, device=self.device,
) )
self.sigmas = self.karras_sigmas
# ALERT: We are completely overriding the sample() method in the base class, which # ALERT: We are completely overriding the sample() method in the base class, which
# means that inpainting will not work. To get this to work we need to be able to # means that inpainting will not work. To get this to work we need to be able to
@ -170,11 +171,16 @@ class KSampler(Sampler):
img_callback(k_callback_values['x'],k_callback_values['i']) img_callback(k_callback_values['x'],k_callback_values['i'])
# sigmas are set up in make_schedule - we take the last steps items # sigmas are set up in make_schedule - we take the last steps items
total_steps = len(self.karras_sigmas) total_steps = len(self.sigmas)
sigmas = self.karras_sigmas[-S-1:] sigmas = self.sigmas[-S-1:]
# x_T is variation noise. When an init image is provided (in x0) we need to add
# more randomness to the starting image.
if x_T is not None: if x_T is not None:
x = x_T + torch.randn([batch_size, *shape], device=self.device) * sigmas[0] if x0 is not None:
x = x_T + torch.randn_like(x0, device=self.device) * sigmas[0]
else:
x = x_T * sigmas[0]
else: else:
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0] x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0]