Compare commits

...

1 Commits

Author SHA1 Message Date
527c806f7b feat(nodes): extract denoise function 2023-10-20 16:31:11 +11:00

View File

@ -2,7 +2,7 @@
from contextlib import ExitStack
from functools import singledispatchmethod
from typing import List, Literal, Optional, Union
from typing import Callable, List, Literal, Optional, Union
import einops
import numpy as np
@ -651,8 +651,20 @@ class DenoiseLatentsInvocation(BaseInvocation):
return 1 - mask, masked_latents
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, source_node_id, state, self.unet.unet.base_model)
return self.denoise(context, step_callback)
@torch.no_grad()
def denoise(
self, context: InvocationContext, step_callback: Callable[[PipelineIntermediateState], None]
) -> LatentsOutput:
with SilenceWarnings(): # this quenches NSFW nag from diffusers
seed = None
noise = None
@ -687,13 +699,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
do_classifier_free_guidance=True,
)
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, source_node_id, state, self.unet.unet.base_model)
def _lora_loader():
for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model(