fix: Manually update eta(gamma) to 1.0 for TCDScheduler

seems to work best with invoke at 4 steps
This commit is contained in:
blessedcoolant 2024-05-01 01:20:53 +05:30
parent 38880cde5c
commit 2ddb82200c

View File

@ -3,7 +3,7 @@ import inspect
import math import math
from contextlib import ExitStack from contextlib import ExitStack
from functools import singledispatchmethod from functools import singledispatchmethod
from typing import Any, Iterator, List, Literal, Optional, Tuple, Union from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union
import einops import einops
import numpy as np import numpy as np
@ -521,9 +521,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
) )
if is_sdxl: if is_sdxl:
return SDXLConditioningInfo( return (
embeds=text_embedding, pooled_embeds=pooled_embedding, add_time_ids=add_time_ids SDXLConditioningInfo(embeds=text_embedding, pooled_embeds=pooled_embedding, add_time_ids=add_time_ids),
), regions regions,
)
return BasicConditioningInfo(embeds=text_embedding), regions return BasicConditioningInfo(embeds=text_embedding), regions
def get_conditioning_data( def get_conditioning_data(
@ -825,7 +826,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
denoising_start: float, denoising_start: float,
denoising_end: float, denoising_end: float,
seed: int, seed: int,
) -> Tuple[int, List[int], int]: ) -> Tuple[int, List[int], int, Dict[str, Union[torch.Generator, float]]]:
assert isinstance(scheduler, ConfigMixin) assert isinstance(scheduler, ConfigMixin)
if scheduler.config.get("cpu_only", False): if scheduler.config.get("cpu_only", False):
scheduler.set_timesteps(steps, device="cpu") scheduler.set_timesteps(steps, device="cpu")
@ -853,13 +854,16 @@ class DenoiseLatentsInvocation(BaseInvocation):
timesteps = timesteps[t_start_idx : t_start_idx + t_end_idx] timesteps = timesteps[t_start_idx : t_start_idx + t_end_idx]
num_inference_steps = len(timesteps) // scheduler.order num_inference_steps = len(timesteps) // scheduler.order
scheduler_step_kwargs = {} scheduler_step_kwargs: Dict[str, Union[torch.Generator, float]] = {}
scheduler_step_signature = inspect.signature(scheduler.step) scheduler_step_signature = inspect.signature(scheduler.step)
print(scheduler_step_signature.parameters)
if "generator" in scheduler_step_signature.parameters: if "generator" in scheduler_step_signature.parameters:
# At some point, someone decided that schedulers that accept a generator should use the original seed with # At some point, someone decided that schedulers that accept a generator should use the original seed with
# all bits flipped. I don't know the original rationale for this, but now we must keep it like this for # all bits flipped. I don't know the original rationale for this, but now we must keep it like this for
# reproducibility. # reproducibility.
scheduler_step_kwargs = {"generator": torch.Generator(device=device).manual_seed(seed ^ 0xFFFFFFFF)} scheduler_step_kwargs = {"generator": torch.Generator(device=device).manual_seed(seed ^ 0xFFFFFFFF)}
if "eta" in scheduler_step_signature.parameters:
scheduler_step_kwargs = {"eta": 1.0}
return num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs return num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs