Compare commits

...

1 Commits

Author SHA1 Message Date
527c806f7b feat(nodes): extract denoise function 2023-10-20 16:31:11 +11:00

View File

@ -2,7 +2,7 @@
from contextlib import ExitStack from contextlib import ExitStack
from functools import singledispatchmethod from functools import singledispatchmethod
from typing import List, Literal, Optional, Union from typing import Callable, List, Literal, Optional, Union
import einops import einops
import numpy as np import numpy as np
@ -651,8 +651,20 @@ class DenoiseLatentsInvocation(BaseInvocation):
return 1 - mask, masked_latents return 1 - mask, masked_latents
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, source_node_id, state, self.unet.unet.base_model)
return self.denoise(context, step_callback)
@torch.no_grad()
def denoise(
self, context: InvocationContext, step_callback: Callable[[PipelineIntermediateState], None]
) -> LatentsOutput:
with SilenceWarnings(): # this quenches NSFW nag from diffusers with SilenceWarnings(): # this quenches NSFW nag from diffusers
seed = None seed = None
noise = None noise = None
@ -687,13 +699,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
do_classifier_free_guidance=True, do_classifier_free_guidance=True,
) )
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, source_node_id, state, self.unet.unet.base_model)
def _lora_loader(): def _lora_loader():
for lora in self.unet.loras: for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model( lora_info = context.services.model_manager.get_model(