mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Split out the prepare_noise_and_latents(...) logic in DenoiseLatentsInvocation so that it can be called from other invocations.
This commit is contained in:
parent
43108eec13
commit
610a1fd611
@ -659,19 +659,21 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
return 1 - mask, masked_latents, self.denoise_mask.gradient
|
return 1 - mask, masked_latents, self.denoise_mask.gradient
|
||||||
|
|
||||||
@torch.no_grad()
|
@staticmethod
|
||||||
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
|
def prepare_noise_and_latents(
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
context: InvocationContext, noise_field: LatentsField | None, latents_field: LatentsField | None
|
||||||
|
) -> Tuple[float, torch.Tensor | None, torch.Tensor]:
|
||||||
seed = None
|
seed = None
|
||||||
noise = None
|
noise = None
|
||||||
if self.noise is not None:
|
|
||||||
noise = context.tensors.load(self.noise.latents_name)
|
|
||||||
seed = self.noise.seed
|
|
||||||
|
|
||||||
if self.latents is not None:
|
if noise_field is not None:
|
||||||
latents = context.tensors.load(self.latents.latents_name)
|
noise = context.tensors.load(noise_field.latents_name)
|
||||||
|
seed = noise_field.seed
|
||||||
|
|
||||||
|
if latents_field is not None:
|
||||||
|
latents = context.tensors.load(latents_field.latents_name)
|
||||||
if seed is None:
|
if seed is None:
|
||||||
seed = self.latents.seed
|
seed = latents_field.seed
|
||||||
|
|
||||||
if noise is not None and noise.shape[1:] != latents.shape[1:]:
|
if noise is not None and noise.shape[1:] != latents.shape[1:]:
|
||||||
raise Exception(f"Incompatable 'noise' and 'latents' shapes: {latents.shape=} {noise.shape=}")
|
raise Exception(f"Incompatable 'noise' and 'latents' shapes: {latents.shape=} {noise.shape=}")
|
||||||
@ -684,6 +686,13 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
if seed is None:
|
if seed is None:
|
||||||
seed = 0
|
seed = 0
|
||||||
|
|
||||||
|
return seed, noise, latents
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
|
||||||
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
|
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
|
||||||
|
|
||||||
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
|
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
|
||||||
|
|
||||||
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
|
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
|
||||||
|
Loading…
Reference in New Issue
Block a user