Provide generator to all schedulers step function to make both ancestral and sde schedulers reproducible

This commit is contained in:
Sergey Borisov 2023-06-19 00:34:01 +03:00
parent f3d9797ebe
commit 9b32407744

View File

@ -222,15 +222,6 @@ class TextToLatentsInvocation(BaseInvocation):
c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name)
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
custom_args = dict(
eta=0.0, #ddim_eta
)
if type(scheduler) is DPMSolverMultistepScheduler and scheduler.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
custom_args.update(
generator=torch.Generator(device=uc.device).manual_seed(0),
)
conditioning_data = ConditioningData(
unconditioned_embeddings=uc,
text_embeddings=c,
@ -242,7 +233,17 @@ class TextToLatentsInvocation(BaseInvocation):
h_symmetry_time_pct=None,#h_symmetry_time_pct,
v_symmetry_time_pct=None#v_symmetry_time_pct,
),
).add_scheduler_args_if_applicable(scheduler, **custom_args)
)
conditioning_data = conditioning_data.add_scheduler_args_if_applicable(
scheduler,
# for ddim scheduler
eta=0.0, #ddim_eta
# for ancestral and sde schedulers
generator=torch.Generator(device=uc.device).manual_seed(0),
)
return conditioning_data
def create_pipeline(self, unet, scheduler) -> StableDiffusionGeneratorPipeline: