cleanup: use dict update to actually update the scheduler keyword args

This commit is contained in:
blessedcoolant 2024-05-01 12:08:39 +05:30
parent 88ac3bc7f0
commit 1bdcbe3284

View File

@ -826,7 +826,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
denoising_start: float,
denoising_end: float,
seed: int,
) -> Tuple[int, List[int], int, Dict[str, Union[torch.Generator, float]]]:
) -> Tuple[int, List[int], int, Dict[str, Any]]:
assert isinstance(scheduler, ConfigMixin)
if scheduler.config.get("cpu_only", False):
scheduler.set_timesteps(steps, device="cpu")
@ -854,16 +854,15 @@ class DenoiseLatentsInvocation(BaseInvocation):
timesteps = timesteps[t_start_idx : t_start_idx + t_end_idx]
num_inference_steps = len(timesteps) // scheduler.order
scheduler_step_kwargs: Dict[str, Union[torch.Generator, float]] = {}
scheduler_step_kwargs: Dict[str, Any] = {}
scheduler_step_signature = inspect.signature(scheduler.step)
print(scheduler_step_signature.parameters)
if "generator" in scheduler_step_signature.parameters:
# At some point, someone decided that schedulers that accept a generator should use the original seed with
# all bits flipped. I don't know the original rationale for this, but now we must keep it like this for
# reproducibility.
scheduler_step_kwargs = {"generator": torch.Generator(device=device).manual_seed(seed ^ 0xFFFFFFFF)}
scheduler_step_kwargs.update({"generator": torch.Generator(device=device).manual_seed(seed ^ 0xFFFFFFFF)})
if "eta" in scheduler_step_signature.parameters:
scheduler_step_kwargs = {"eta": 1.0}
scheduler_step_kwargs.update({"eta": 1.0})
return num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs