From ca15b8b33e8b5f64416a79a2e5ac0b459dbf1c0e Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Wed, 30 Aug 2023 03:40:59 +0300 Subject: [PATCH] Fix wrong timestep selection in some cases(dpmpp_sde) --- invokeai/app/invocations/latent.py | 37 +++++++++++++----------------- 1 file changed, 16 insertions(+), 21 deletions(-) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 80988f3c71..9cca0bd744 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -367,36 +367,31 @@ class DenoiseLatentsInvocation(BaseInvocation): # original idea by https://github.com/AmericanPresidentJimmyCarter # TODO: research more for second order schedulers timesteps def init_scheduler(self, scheduler, device, steps, denoising_start, denoising_end): - num_inference_steps = steps if scheduler.config.get("cpu_only", False): - scheduler.set_timesteps(num_inference_steps, device="cpu") + scheduler.set_timesteps(steps, device="cpu") timesteps = scheduler.timesteps.to(device=device) else: - scheduler.set_timesteps(num_inference_steps, device=device) + scheduler.set_timesteps(steps, device=device) timesteps = scheduler.timesteps - # apply denoising_start + # skip greater order timesteps + _timesteps = timesteps[:: scheduler.order] + + # get start timestep index t_start_val = int(round(scheduler.config.num_train_timesteps * (1 - denoising_start))) - t_start_idx = len(list(filter(lambda ts: ts >= t_start_val, timesteps))) - timesteps = timesteps[t_start_idx:] - if scheduler.order == 2 and t_start_idx > 0: - timesteps = timesteps[1:] + t_start_idx = len(list(filter(lambda ts: ts >= t_start_val, _timesteps))) - # save start timestep to apply noise - init_timestep = timesteps[:1] - - # apply denoising_end + # get end timestep index t_end_val = int(round(scheduler.config.num_train_timesteps * (1 - denoising_end))) - t_end_idx = len(list(filter(lambda ts: ts >= t_end_val, timesteps))) - if scheduler.order == 2 and t_end_idx > 0: - t_end_idx += 1 - timesteps = timesteps[:t_end_idx] + t_end_idx = len(list(filter(lambda ts: ts >= t_end_val, _timesteps[t_start_idx:]))) - # calculate step count based on scheduler order - num_inference_steps = len(timesteps) - if scheduler.order == 2: - num_inference_steps += num_inference_steps % 2 - num_inference_steps = num_inference_steps // 2 + # apply order to indexes + t_start_idx *= scheduler.order + t_end_idx *= scheduler.order + + init_timestep = timesteps[t_start_idx : t_start_idx + 1] + timesteps = timesteps[t_start_idx : t_start_idx + t_end_idx] + num_inference_steps = len(timesteps) // scheduler.order return num_inference_steps, timesteps, init_timestep