Fix wrong timestep selection in some cases(dpmpp_sde)

This commit is contained in:
Sergey Borisov 2023-08-30 03:40:59 +03:00
parent 8562dbaaa8
commit ca15b8b33e

View File

@ -367,36 +367,31 @@ class DenoiseLatentsInvocation(BaseInvocation):
# original idea by https://github.com/AmericanPresidentJimmyCarter # original idea by https://github.com/AmericanPresidentJimmyCarter
# TODO: research more for second order schedulers timesteps # TODO: research more for second order schedulers timesteps
def init_scheduler(self, scheduler, device, steps, denoising_start, denoising_end): def init_scheduler(self, scheduler, device, steps, denoising_start, denoising_end):
num_inference_steps = steps
if scheduler.config.get("cpu_only", False): if scheduler.config.get("cpu_only", False):
scheduler.set_timesteps(num_inference_steps, device="cpu") scheduler.set_timesteps(steps, device="cpu")
timesteps = scheduler.timesteps.to(device=device) timesteps = scheduler.timesteps.to(device=device)
else: else:
scheduler.set_timesteps(num_inference_steps, device=device) scheduler.set_timesteps(steps, device=device)
timesteps = scheduler.timesteps timesteps = scheduler.timesteps
# apply denoising_start # skip greater order timesteps
_timesteps = timesteps[:: scheduler.order]
# get start timestep index
t_start_val = int(round(scheduler.config.num_train_timesteps * (1 - denoising_start))) t_start_val = int(round(scheduler.config.num_train_timesteps * (1 - denoising_start)))
t_start_idx = len(list(filter(lambda ts: ts >= t_start_val, timesteps))) t_start_idx = len(list(filter(lambda ts: ts >= t_start_val, _timesteps)))
timesteps = timesteps[t_start_idx:]
if scheduler.order == 2 and t_start_idx > 0:
timesteps = timesteps[1:]
# save start timestep to apply noise # get end timestep index
init_timestep = timesteps[:1]
# apply denoising_end
t_end_val = int(round(scheduler.config.num_train_timesteps * (1 - denoising_end))) t_end_val = int(round(scheduler.config.num_train_timesteps * (1 - denoising_end)))
t_end_idx = len(list(filter(lambda ts: ts >= t_end_val, timesteps))) t_end_idx = len(list(filter(lambda ts: ts >= t_end_val, _timesteps[t_start_idx:])))
if scheduler.order == 2 and t_end_idx > 0:
t_end_idx += 1
timesteps = timesteps[:t_end_idx]
# calculate step count based on scheduler order # apply order to indexes
num_inference_steps = len(timesteps) t_start_idx *= scheduler.order
if scheduler.order == 2: t_end_idx *= scheduler.order
num_inference_steps += num_inference_steps % 2
num_inference_steps = num_inference_steps // 2 init_timestep = timesteps[t_start_idx : t_start_idx + 1]
timesteps = timesteps[t_start_idx : t_start_idx + t_end_idx]
num_inference_steps = len(timesteps) // scheduler.order
return num_inference_steps, timesteps, init_timestep return num_inference_steps, timesteps, init_timestep