fix: Manually update eta(gamma) to 1.0 for TCDScheduler

seems to work best with invoke at 4 steps
This commit is contained in:
blessedcoolant 2024-05-01 01:20:53 +05:30
parent 38880cde5c
commit 2ddb82200c

View File

@ -3,7 +3,7 @@ import inspect
import math
from contextlib import ExitStack
from functools import singledispatchmethod
from typing import Any, Iterator, List, Literal, Optional, Tuple, Union
from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union
import einops
import numpy as np
@ -521,9 +521,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
)
if is_sdxl:
return SDXLConditioningInfo(
embeds=text_embedding, pooled_embeds=pooled_embedding, add_time_ids=add_time_ids
), regions
return (
SDXLConditioningInfo(embeds=text_embedding, pooled_embeds=pooled_embedding, add_time_ids=add_time_ids),
regions,
)
return BasicConditioningInfo(embeds=text_embedding), regions
def get_conditioning_data(
@ -825,7 +826,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
denoising_start: float,
denoising_end: float,
seed: int,
) -> Tuple[int, List[int], int]:
) -> Tuple[int, List[int], int, Dict[str, Union[torch.Generator, float]]]:
assert isinstance(scheduler, ConfigMixin)
if scheduler.config.get("cpu_only", False):
scheduler.set_timesteps(steps, device="cpu")
@ -853,13 +854,16 @@ class DenoiseLatentsInvocation(BaseInvocation):
timesteps = timesteps[t_start_idx : t_start_idx + t_end_idx]
num_inference_steps = len(timesteps) // scheduler.order
scheduler_step_kwargs = {}
scheduler_step_kwargs: Dict[str, Union[torch.Generator, float]] = {}
scheduler_step_signature = inspect.signature(scheduler.step)
print(scheduler_step_signature.parameters)
if "generator" in scheduler_step_signature.parameters:
# At some point, someone decided that schedulers that accept a generator should use the original seed with
# all bits flipped. I don't know the original rationale for this, but now we must keep it like this for
# reproducibility.
scheduler_step_kwargs = {"generator": torch.Generator(device=device).manual_seed(seed ^ 0xFFFFFFFF)}
if "eta" in scheduler_step_signature.parameters:
scheduler_step_kwargs = {"eta": 1.0}
return num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs