fix broken image generation on plms and ddim samplers

This commit is contained in:
Lincoln Stein 2022-10-07 08:26:53 -04:00
parent c1230da3ab
commit 333219be35

View File

@ -39,6 +39,7 @@ class Sampler(object):
ddim_eta=0.0, ddim_eta=0.0,
verbose=False, verbose=False,
): ):
self.total_steps = ddim_num_steps
self.ddim_timesteps = make_ddim_timesteps( self.ddim_timesteps = make_ddim_timesteps(
ddim_discr_method=ddim_discretize, ddim_discr_method=ddim_discretize,
num_ddim_timesteps=ddim_num_steps, num_ddim_timesteps=ddim_num_steps,
@ -211,6 +212,7 @@ class Sampler(object):
if ddim_use_original_steps if ddim_use_original_steps
else np.flip(timesteps) else np.flip(timesteps)
) )
total_steps=steps total_steps=steps
iterator = tqdm( iterator = tqdm(
@ -305,7 +307,7 @@ class Sampler(object):
time_range = np.flip(timesteps) time_range = np.flip(timesteps)
total_steps = timesteps.shape[0] total_steps = timesteps.shape[0]
print(f'>> Running {self.__class__.__name__} Sampling with {total_steps} timesteps') print(f'>> Running {self.__class__.__name__} sampling starting at step {self.total_steps - t_start} of {self.total_steps} ({total_steps} new sampling steps)')
iterator = tqdm(time_range, desc='Decoding image', total=total_steps) iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
x_dec = x_latent x_dec = x_latent
@ -351,11 +353,10 @@ class Sampler(object):
return x_dec return x_dec
def get_initial_image(self,x_T,shape,timesteps=None): def get_initial_image(self,x_T,shape,timesteps=None):
x = torch.randn(shape, device=self.device)
if x_T is None: if x_T is None:
return x return torch.randn(shape, device=self.device)
else: else:
return x_T + x return x_T
def p_sample( def p_sample(
self, self,