Use array slicing to calc ddim timesteps

This commit is contained in:
wfng92 2022-11-03 17:13:52 +08:00 committed by Lincoln Stein
parent 8648da8111
commit 1f0c5b4cf1

View File

@ -65,8 +65,10 @@ def make_ddim_timesteps(
if ddim_discr_method == 'uniform': if ddim_discr_method == 'uniform':
c = num_ddpm_timesteps // num_ddim_timesteps c = num_ddpm_timesteps // num_ddim_timesteps
if c < 1: if c < 1:
c = 1 c = 1
ddim_timesteps = (np.arange(0, num_ddim_timesteps) * c).astype(int)
# remove 1 final step to prevent index out of bound error
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))[:-1]
elif ddim_discr_method == 'quad': elif ddim_discr_method == 'quad':
ddim_timesteps = ( ddim_timesteps = (
( (
@ -84,7 +86,6 @@ def make_ddim_timesteps(
# assert ddim_timesteps.shape[0] == num_ddim_timesteps # assert ddim_timesteps.shape[0] == num_ddim_timesteps
# add one to get the final alpha values right (the ones from first scale to data during sampling) # add one to get the final alpha values right (the ones from first scale to data during sampling)
steps_out = ddim_timesteps + 1 steps_out = ddim_timesteps + 1
# steps_out = ddim_timesteps
if verbose: if verbose:
print(f'Selected timesteps for ddim sampler: {steps_out}') print(f'Selected timesteps for ddim sampler: {steps_out}')