From 610a1fd61115429711e797258add50bf0b3d3977 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 6 Jun 2024 15:10:04 -0400 Subject: [PATCH] Split out the prepare_noise_and_latents(...) logic in DenoiseLatentsInvocation so that it can be called from other invocations. --- invokeai/app/invocations/denoise_latents.py | 27 ++++++++++++++------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py index 3825f57fea..2bdf7a55c6 100644 --- a/invokeai/app/invocations/denoise_latents.py +++ b/invokeai/app/invocations/denoise_latents.py @@ -659,19 +659,21 @@ class DenoiseLatentsInvocation(BaseInvocation): return 1 - mask, masked_latents, self.denoise_mask.gradient - @torch.no_grad() - @SilenceWarnings() # This quenches the NSFW nag from diffusers. - def invoke(self, context: InvocationContext) -> LatentsOutput: + @staticmethod + def prepare_noise_and_latents( + context: InvocationContext, noise_field: LatentsField | None, latents_field: LatentsField | None + ) -> Tuple[float, torch.Tensor | None, torch.Tensor]: seed = None noise = None - if self.noise is not None: - noise = context.tensors.load(self.noise.latents_name) - seed = self.noise.seed - if self.latents is not None: - latents = context.tensors.load(self.latents.latents_name) + if noise_field is not None: + noise = context.tensors.load(noise_field.latents_name) + seed = noise_field.seed + + if latents_field is not None: + latents = context.tensors.load(latents_field.latents_name) if seed is None: - seed = self.latents.seed + seed = latents_field.seed if noise is not None and noise.shape[1:] != latents.shape[1:]: raise Exception(f"Incompatable 'noise' and 'latents' shapes: {latents.shape=} {noise.shape=}") @@ -684,6 +686,13 @@ class DenoiseLatentsInvocation(BaseInvocation): if seed is None: seed = 0 + return seed, noise, latents + + @torch.no_grad() + @SilenceWarnings() # This quenches the NSFW nag from diffusers. + def invoke(self, context: InvocationContext) -> LatentsOutput: + seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents) + mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents) # TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,