mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
cleanup: use dict update to actually update the scheduler keyword args
This commit is contained in:
@ -826,7 +826,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
denoising_start: float,
|
denoising_start: float,
|
||||||
denoising_end: float,
|
denoising_end: float,
|
||||||
seed: int,
|
seed: int,
|
||||||
) -> Tuple[int, List[int], int, Dict[str, Union[torch.Generator, float]]]:
|
) -> Tuple[int, List[int], int, Dict[str, Any]]:
|
||||||
assert isinstance(scheduler, ConfigMixin)
|
assert isinstance(scheduler, ConfigMixin)
|
||||||
if scheduler.config.get("cpu_only", False):
|
if scheduler.config.get("cpu_only", False):
|
||||||
scheduler.set_timesteps(steps, device="cpu")
|
scheduler.set_timesteps(steps, device="cpu")
|
||||||
@ -854,16 +854,15 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
timesteps = timesteps[t_start_idx : t_start_idx + t_end_idx]
|
timesteps = timesteps[t_start_idx : t_start_idx + t_end_idx]
|
||||||
num_inference_steps = len(timesteps) // scheduler.order
|
num_inference_steps = len(timesteps) // scheduler.order
|
||||||
|
|
||||||
scheduler_step_kwargs: Dict[str, Union[torch.Generator, float]] = {}
|
scheduler_step_kwargs: Dict[str, Any] = {}
|
||||||
scheduler_step_signature = inspect.signature(scheduler.step)
|
scheduler_step_signature = inspect.signature(scheduler.step)
|
||||||
print(scheduler_step_signature.parameters)
|
|
||||||
if "generator" in scheduler_step_signature.parameters:
|
if "generator" in scheduler_step_signature.parameters:
|
||||||
# At some point, someone decided that schedulers that accept a generator should use the original seed with
|
# At some point, someone decided that schedulers that accept a generator should use the original seed with
|
||||||
# all bits flipped. I don't know the original rationale for this, but now we must keep it like this for
|
# all bits flipped. I don't know the original rationale for this, but now we must keep it like this for
|
||||||
# reproducibility.
|
# reproducibility.
|
||||||
scheduler_step_kwargs = {"generator": torch.Generator(device=device).manual_seed(seed ^ 0xFFFFFFFF)}
|
scheduler_step_kwargs.update({"generator": torch.Generator(device=device).manual_seed(seed ^ 0xFFFFFFFF)})
|
||||||
if "eta" in scheduler_step_signature.parameters:
|
if "eta" in scheduler_step_signature.parameters:
|
||||||
scheduler_step_kwargs = {"eta": 1.0}
|
scheduler_step_kwargs.update({"eta": 1.0})
|
||||||
|
|
||||||
return num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs
|
return num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user