restore use of sampler.decode() in img2img

This commit is contained in:
Lincoln Stein 2022-10-01 15:50:05 -04:00
parent 958d7650dd
commit a0f4af087c
2 changed files with 37 additions and 11 deletions

View File

@ -40,18 +40,16 @@ class Img2Img(Generator):
torch.tensor([t_enc]).to(self.model.device), torch.tensor([t_enc]).to(self.model.device),
noise=x_T noise=x_T
) )
samples,_ = sampler.sample( # decode it
batch_size = 1, samples = sampler.decode(
S = t_enc, z_enc,
shape = z_enc.shape[1:], c,
x_T = z_enc, t_enc,
conditioning = c, img_callback = step_callback,
unconditional_guidance_scale = cfg_scale, unconditional_guidance_scale=cfg_scale,
unconditional_conditioning = uc, unconditional_conditioning=uc,
eta = ddim_eta,
img_callback = step_callback,
verbose = False,
) )
return self.sample_to_image(samples) return self.sample_to_image(samples)
return make_image return make_image

View File

@ -119,6 +119,7 @@ class KSampler(Sampler):
'uncond': unconditional_conditioning, 'uncond': unconditional_conditioning,
'cond_scale': unconditional_guidance_scale, 'cond_scale': unconditional_guidance_scale,
} }
print(f'>> Sampling with k__{self.schedule}')
return ( return (
K.sampling.__dict__[f'sample_{self.schedule}']( K.sampling.__dict__[f'sample_{self.schedule}'](
model_wrap_cfg, x, sigmas, extra_args=extra_args, model_wrap_cfg, x, sigmas, extra_args=extra_args,
@ -190,3 +191,30 @@ class KSampler(Sampler):
Overrides parent method to return the q_sample of the inner model. Overrides parent method to return the q_sample of the inner model.
''' '''
return self.model.inner_model.q_sample(x0,ts) return self.model.inner_model.q_sample(x0,ts)
@torch.no_grad()
def decode(
self,
z_enc,
cond,
t_enc,
img_callback=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
use_original_steps=False,
init_latent = None,
mask = None,
):
samples,_ = self.sample(
batch_size = 1,
S = t_enc,
x_T = z_enc,
shape = z_enc.shape[1:],
conditioning = cond,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning = unconditional_conditioning,
img_callback = img_callback,
x0 = init_latent,
mask = mask
)
return samples