mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): extract denoise function
This commit is contained in:
@ -2,7 +2,7 @@
|
||||
|
||||
from contextlib import ExitStack
|
||||
from functools import singledispatchmethod
|
||||
from typing import List, Literal, Optional, Union
|
||||
from typing import Callable, List, Literal, Optional, Union
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
@ -651,8 +651,20 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
return 1 - mask, masked_latents
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
def step_callback(state: PipelineIntermediateState):
|
||||
self.dispatch_progress(context, source_node_id, state, self.unet.unet.base_model)
|
||||
|
||||
return self.denoise(context, step_callback)
|
||||
|
||||
@torch.no_grad()
|
||||
def denoise(
|
||||
self, context: InvocationContext, step_callback: Callable[[PipelineIntermediateState], None]
|
||||
) -> LatentsOutput:
|
||||
with SilenceWarnings(): # this quenches NSFW nag from diffusers
|
||||
seed = None
|
||||
noise = None
|
||||
@ -687,13 +699,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
do_classifier_free_guidance=True,
|
||||
)
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
def step_callback(state: PipelineIntermediateState):
|
||||
self.dispatch_progress(context, source_node_id, state, self.unet.unet.base_model)
|
||||
|
||||
def _lora_loader():
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
|
Reference in New Issue
Block a user