mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fixes for second order scheduler timesteps
This commit is contained in:
parent
94636ddb03
commit
6e0beb1ed4
@ -317,6 +317,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
return control_data
|
return control_data
|
||||||
|
|
||||||
# original idea by https://github.com/AmericanPresidentJimmyCarter
|
# original idea by https://github.com/AmericanPresidentJimmyCarter
|
||||||
|
# TODO: research more for second order schedulers timesteps
|
||||||
def init_scheduler(self, scheduler, device, steps, denoising_start, denoising_end):
|
def init_scheduler(self, scheduler, device, steps, denoising_start, denoising_end):
|
||||||
if scheduler.config.get("cpu_only", False):
|
if scheduler.config.get("cpu_only", False):
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
@ -329,8 +330,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
t_start_val = int(round(scheduler.config.num_train_timesteps * (1 - denoising_start)))
|
t_start_val = int(round(scheduler.config.num_train_timesteps * (1 - denoising_start)))
|
||||||
t_start_idx = len(list(filter(lambda ts: ts >= t_start_val, timesteps)))
|
t_start_idx = len(list(filter(lambda ts: ts >= t_start_val, timesteps)))
|
||||||
timesteps = timesteps[t_start_idx:]
|
timesteps = timesteps[t_start_idx:]
|
||||||
if scheduler.order == 2:
|
if scheduler.order == 2 and t_start_idx > 0:
|
||||||
# TODO: research for second order schedulers timesteps
|
|
||||||
timesteps = timesteps[1:]
|
timesteps = timesteps[1:]
|
||||||
|
|
||||||
# save start timestep to apply noise
|
# save start timestep to apply noise
|
||||||
@ -339,6 +339,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
# apply denoising_end
|
# apply denoising_end
|
||||||
t_end_val = int(round(scheduler.config.num_train_timesteps * (1 - denoising_end)))
|
t_end_val = int(round(scheduler.config.num_train_timesteps * (1 - denoising_end)))
|
||||||
t_end_idx = len(list(filter(lambda ts: ts >= t_end_val, timesteps)))
|
t_end_idx = len(list(filter(lambda ts: ts >= t_end_val, timesteps)))
|
||||||
|
if scheduler.order == 2 and t_end_idx > 0:
|
||||||
|
t_end_idx += 1
|
||||||
timesteps = timesteps[:t_end_idx]
|
timesteps = timesteps[:t_end_idx]
|
||||||
|
|
||||||
# calculate step count based on scheduler order
|
# calculate step count based on scheduler order
|
||||||
|
Loading…
Reference in New Issue
Block a user