From 170763899a3788cee6beb24e0469e6bbad7bcde8 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 18 Apr 2024 18:49:56 +1000 Subject: [PATCH] tidy(ui): tidy regional prompts graph helper, add comments --- .../util/graph/addRegionalPromptsToGraph.ts | 63 ++++++++++++++----- 1 file changed, 46 insertions(+), 17 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addRegionalPromptsToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addRegionalPromptsToGraph.ts index 3b73bf5f33..f07bd44917 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addRegionalPromptsToGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addRegionalPromptsToGraph.ts @@ -34,16 +34,19 @@ export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNull console.log('blobs', blobs, 'layerIds', layerIds); assert(size(blobs) === size(layerIds), 'Mismatch between layer IDs and blobs'); - // Set up the conditioning collectors + // TODO: We should probably just use conditioning collectors by default, and skip all this fanagling with re-routing + // the existing conditioning nodes. + + // With regional prompts we have multiple conditioning nodes which much be routed into collectors. Set those up const posCondCollectNode: CollectInvocation = { id: POSITIVE_CONDITIONING_COLLECT, type: 'collect', }; + graph.nodes[POSITIVE_CONDITIONING_COLLECT] = posCondCollectNode; const negCondCollectNode: CollectInvocation = { id: NEGATIVE_CONDITIONING_COLLECT, type: 'collect', }; - graph.nodes[POSITIVE_CONDITIONING_COLLECT] = posCondCollectNode; graph.nodes[NEGATIVE_CONDITIONING_COLLECT] = negCondCollectNode; // Re-route the denoise node's OG conditioning inputs to the collect nodes @@ -94,10 +97,13 @@ export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNull }); // Remove the global prompt - // TODO: Append regional prompts to CLIP2's prompt? + // TODO: Append regional prompts to CLIP2's prompt? Dunno... (graph.nodes[POSITIVE_CONDITIONING] as S['SDXLCompelPromptInvocation'] | S['CompelInvocation']).prompt = ''; // Upload the blobs to the backend, add each to graph + // TODO: Store the uploaded image names in redux to reuse them, so long as the layer hasn't otherwise changed. This + // would be a great perf win - not only would we skip re-uploading the same image, but we'd be able to use the node + // cache (currently, when we re-use the same mask data, since it is a different image, the node cache is not used). for (const [layerId, blob] of Object.entries(blobs)) { const layer = layers.find((l) => l.id === layerId); assert(layer, `Layer ${layerId} not found`); @@ -108,9 +114,10 @@ export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNull ); req.reset(); - // TODO: this will raise an error + // TODO: This will raise on network error const { image_name } = await req.unwrap(); + // This mask (image primitive) will be fed into at least once mask-to-tensor node - two if we use the "invert" mode const maskImageNode: S['ImageInvocation'] = { id: `${PROMPT_REGION_MASK_IMAGE_PRIMITIVE_PREFIX}_${layerId}`, type: 'image', @@ -120,12 +127,14 @@ export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNull }; graph.nodes[maskImageNode.id] = maskImageNode; + // The main mask-to-tensor node const maskToTensorNode: S['AlphaMaskToTensorInvocation'] = { id: `${PROMPT_REGION_MASK_TO_TENSOR_PREFIX}_${layerId}`, type: 'alpha_mask_to_tensor', }; graph.nodes[maskToTensorNode.id] = maskToTensorNode; + // Connect the mask image to the mask-to-tensor node graph.edges.push({ source: { node_id: maskImageNode.id, @@ -137,38 +146,49 @@ export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNull }, }); - // Create the conditioning nodes for each region - different handling for SDXL - + // The main positive conditioning node const regionalPositiveCondNode: S['SDXLCompelPromptInvocation'] = { type: 'sdxl_compel_prompt', id: `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${layerId}`, prompt: layer.positivePrompt, - style: layer.positivePrompt, + style: layer.positivePrompt, // TODO: Should we put the positive prompt in both fields? }; + graph.nodes[regionalPositiveCondNode.id] = regionalPositiveCondNode; + + // Connect the mask to the conditioning + graph.edges.push({ + source: { node_id: maskToTensorNode.id, field: 'mask' }, + destination: { node_id: regionalPositiveCondNode.id, field: 'mask' }, + }); + + // Connect the conditioning to the collector + graph.edges.push({ + source: { node_id: regionalPositiveCondNode.id, field: 'conditioning' }, + destination: { node_id: posCondCollectNode.id, field: 'item' }, + }); + + // The main negative conditioning node const regionalNegativeCondNode: S['SDXLCompelPromptInvocation'] = { type: 'sdxl_compel_prompt', id: `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${layerId}`, prompt: layer.negativePrompt, style: layer.negativePrompt, }; - graph.nodes[regionalPositiveCondNode.id] = regionalPositiveCondNode; graph.nodes[regionalNegativeCondNode.id] = regionalNegativeCondNode; - graph.edges.push({ - source: { node_id: maskToTensorNode.id, field: 'mask' }, - destination: { node_id: regionalPositiveCondNode.id, field: 'mask' }, - }); + + // Connect the mask to the conditioning graph.edges.push({ source: { node_id: maskToTensorNode.id, field: 'mask' }, destination: { node_id: regionalNegativeCondNode.id, field: 'mask' }, }); - graph.edges.push({ - source: { node_id: regionalPositiveCondNode.id, field: 'conditioning' }, - destination: { node_id: posCondCollectNode.id, field: 'item' }, - }); + + // Connect the conditioning to the collector graph.edges.push({ source: { node_id: regionalNegativeCondNode.id, field: 'conditioning' }, destination: { node_id: negCondCollectNode.id, field: 'item' }, }); + + // Copy the connections to the "global" positive and negative conditioning nodes to our regional nodes for (const edge of graph.edges) { if (edge.destination.node_id === POSITIVE_CONDITIONING && edge.destination.field !== 'prompt') { graph.edges.push({ @@ -184,14 +204,18 @@ export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNull } } + // If we are using the "invert" auto-negative setting, we need to add an additional negative conditioning node if (autoNegative === 'invert') { - // Add an additional negative conditioning node with the positive prompt & inverted region mask + // We re-use the mask image, but invert it when converting to tensor + // TODO: Probably faster to invert the tensor from the earlier mask rather than read the mask image and convert... const invertedMaskToTensorNode: S['AlphaMaskToTensorInvocation'] = { id: `${PROMPT_REGION_MASK_TO_TENSOR_INVERTED_PREFIX}_${layerId}`, type: 'alpha_mask_to_tensor', invert: true, }; graph.nodes[invertedMaskToTensorNode.id] = invertedMaskToTensorNode; + + // Connect the OG mask image to the inverted mask-to-tensor node graph.edges.push({ source: { node_id: maskImageNode.id, @@ -203,6 +227,8 @@ export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNull }, }); + // Create the conditioning node. It's going to be connected to the negative cond collector, but it uses the + // positive prompt const regionalPositiveCondInvertedNode: S['SDXLCompelPromptInvocation'] = { type: 'sdxl_compel_prompt', id: `${PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX}_${layerId}`, @@ -210,14 +236,17 @@ export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNull style: layer.positivePrompt, }; graph.nodes[regionalPositiveCondInvertedNode.id] = regionalPositiveCondInvertedNode; + // Connect the inverted mask to the conditioning graph.edges.push({ source: { node_id: invertedMaskToTensorNode.id, field: 'mask' }, destination: { node_id: regionalPositiveCondInvertedNode.id, field: 'mask' }, }); + // Connect the conditioning to the negative collector graph.edges.push({ source: { node_id: regionalPositiveCondInvertedNode.id, field: 'conditioning' }, destination: { node_id: negCondCollectNode.id, field: 'item' }, }); + // Copy the connections to the "global" positive conditioning node to our regional node for (const edge of graph.edges) { if (edge.destination.node_id === POSITIVE_CONDITIONING && edge.destination.field !== 'prompt') { graph.edges.push({