mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Provide generator to all schedulers step function to make both ancestral and sde schedulers reproducible
This commit is contained in:
parent
f3d9797ebe
commit
9b32407744
@ -222,15 +222,6 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
||||||
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
||||||
|
|
||||||
custom_args = dict(
|
|
||||||
eta=0.0, #ddim_eta
|
|
||||||
)
|
|
||||||
|
|
||||||
if type(scheduler) is DPMSolverMultistepScheduler and scheduler.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
|
|
||||||
custom_args.update(
|
|
||||||
generator=torch.Generator(device=uc.device).manual_seed(0),
|
|
||||||
)
|
|
||||||
|
|
||||||
conditioning_data = ConditioningData(
|
conditioning_data = ConditioningData(
|
||||||
unconditioned_embeddings=uc,
|
unconditioned_embeddings=uc,
|
||||||
text_embeddings=c,
|
text_embeddings=c,
|
||||||
@ -242,7 +233,17 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
h_symmetry_time_pct=None,#h_symmetry_time_pct,
|
h_symmetry_time_pct=None,#h_symmetry_time_pct,
|
||||||
v_symmetry_time_pct=None#v_symmetry_time_pct,
|
v_symmetry_time_pct=None#v_symmetry_time_pct,
|
||||||
),
|
),
|
||||||
).add_scheduler_args_if_applicable(scheduler, **custom_args)
|
)
|
||||||
|
|
||||||
|
conditioning_data = conditioning_data.add_scheduler_args_if_applicable(
|
||||||
|
scheduler,
|
||||||
|
|
||||||
|
# for ddim scheduler
|
||||||
|
eta=0.0, #ddim_eta
|
||||||
|
|
||||||
|
# for ancestral and sde schedulers
|
||||||
|
generator=torch.Generator(device=uc.device).manual_seed(0),
|
||||||
|
)
|
||||||
return conditioning_data
|
return conditioning_data
|
||||||
|
|
||||||
def create_pipeline(self, unet, scheduler) -> StableDiffusionGeneratorPipeline:
|
def create_pipeline(self, unet, scheduler) -> StableDiffusionGeneratorPipeline:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user