diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index a537972c0b..bc7069c75b 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -2,7 +2,7 @@ from contextlib import ExitStack from functools import singledispatchmethod -from typing import List, Literal, Optional, Union +from typing import Callable, List, Literal, Optional, Union import einops import numpy as np @@ -651,8 +651,20 @@ class DenoiseLatentsInvocation(BaseInvocation): return 1 - mask, masked_latents - @torch.no_grad() def invoke(self, context: InvocationContext) -> LatentsOutput: + # Get the source node id (we are invoking the prepared node) + graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) + source_node_id = graph_execution_state.prepared_source_mapping[self.id] + + def step_callback(state: PipelineIntermediateState): + self.dispatch_progress(context, source_node_id, state, self.unet.unet.base_model) + + return self.denoise(context, step_callback) + + @torch.no_grad() + def denoise( + self, context: InvocationContext, step_callback: Callable[[PipelineIntermediateState], None] + ) -> LatentsOutput: with SilenceWarnings(): # this quenches NSFW nag from diffusers seed = None noise = None @@ -687,13 +699,6 @@ class DenoiseLatentsInvocation(BaseInvocation): do_classifier_free_guidance=True, ) - # Get the source node id (we are invoking the prepared node) - graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) - source_node_id = graph_execution_state.prepared_source_mapping[self.id] - - def step_callback(state: PipelineIntermediateState): - self.dispatch_progress(context, source_node_id, state, self.unet.unet.base_model) - def _lora_loader(): for lora in self.unet.loras: lora_info = context.services.model_manager.get_model(