Apply denoising_start/end according on timestep value

This commit is contained in:
Sergey Borisov 2023-08-12 03:19:49 +03:00
parent 8acd7eeca5
commit ce3675fc14

View File

@ -316,26 +316,36 @@ class DenoiseLatentsInvocation(BaseInvocation):
# MultiControlNetModel has been refactored out, just need list[ControlNetData] # MultiControlNetModel has been refactored out, just need list[ControlNetData]
return control_data return control_data
# original idea by https://github.com/AmericanPresidentJimmyCarter
def init_scheduler(self, scheduler, device, steps, denoising_start, denoising_end): def init_scheduler(self, scheduler, device, steps, denoising_start, denoising_end):
if scheduler.config.get("cpu_only", False): if scheduler.config.get("cpu_only", False):
device = torch.device("cpu") device = torch.device("cpu")
# apply denoising_start
num_inference_steps = steps num_inference_steps = steps
scheduler.set_timesteps(num_inference_steps, device=device) scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = scheduler.timesteps
t_start = int(round(denoising_start * num_inference_steps)) # apply denoising_start
timesteps = scheduler.timesteps[t_start * scheduler.order :] t_start_val = int(round(scheduler.config.num_train_timesteps * (1 - denoising_start)))
num_inference_steps = num_inference_steps - t_start t_start_idx = len(list(filter(lambda ts: ts >= t_start_val, timesteps)))
timesteps = timesteps[t_start_idx:]
if scheduler.order == 2:
# TODO: research for second order schedulers timesteps
timesteps = timesteps[1:]
# save start timestep to apply noise
init_timestep = timesteps[:1] init_timestep = timesteps[:1]
# apply denoising_end # apply denoising_end
num_warmup_steps = max(len(timesteps) - num_inference_steps * scheduler.order, 0) t_end_val = int(round(scheduler.config.num_train_timesteps * (1 - denoising_end)))
t_end_idx = len(list(filter(lambda ts: ts >= t_end_val, timesteps)))
timesteps = timesteps[:t_end_idx]
skipped_final_steps = int(round((1 - denoising_end) * steps)) # calculate step count based on scheduler order
num_inference_steps = num_inference_steps - skipped_final_steps num_inference_steps = len(timesteps)
timesteps = timesteps[: num_warmup_steps + scheduler.order * num_inference_steps] if scheduler.order == 2:
num_inference_steps += (num_inference_steps % 2)
num_inference_steps = num_inference_steps // 2
return num_inference_steps, timesteps, init_timestep return num_inference_steps, timesteps, init_timestep