mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
tidy(ui): tidy regional prompts graph helper, add comments
This commit is contained in:
parent
9e1a4a4a07
commit
170763899a
@ -34,16 +34,19 @@ export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNull
|
|||||||
console.log('blobs', blobs, 'layerIds', layerIds);
|
console.log('blobs', blobs, 'layerIds', layerIds);
|
||||||
assert(size(blobs) === size(layerIds), 'Mismatch between layer IDs and blobs');
|
assert(size(blobs) === size(layerIds), 'Mismatch between layer IDs and blobs');
|
||||||
|
|
||||||
// Set up the conditioning collectors
|
// TODO: We should probably just use conditioning collectors by default, and skip all this fanagling with re-routing
|
||||||
|
// the existing conditioning nodes.
|
||||||
|
|
||||||
|
// With regional prompts we have multiple conditioning nodes which much be routed into collectors. Set those up
|
||||||
const posCondCollectNode: CollectInvocation = {
|
const posCondCollectNode: CollectInvocation = {
|
||||||
id: POSITIVE_CONDITIONING_COLLECT,
|
id: POSITIVE_CONDITIONING_COLLECT,
|
||||||
type: 'collect',
|
type: 'collect',
|
||||||
};
|
};
|
||||||
|
graph.nodes[POSITIVE_CONDITIONING_COLLECT] = posCondCollectNode;
|
||||||
const negCondCollectNode: CollectInvocation = {
|
const negCondCollectNode: CollectInvocation = {
|
||||||
id: NEGATIVE_CONDITIONING_COLLECT,
|
id: NEGATIVE_CONDITIONING_COLLECT,
|
||||||
type: 'collect',
|
type: 'collect',
|
||||||
};
|
};
|
||||||
graph.nodes[POSITIVE_CONDITIONING_COLLECT] = posCondCollectNode;
|
|
||||||
graph.nodes[NEGATIVE_CONDITIONING_COLLECT] = negCondCollectNode;
|
graph.nodes[NEGATIVE_CONDITIONING_COLLECT] = negCondCollectNode;
|
||||||
|
|
||||||
// Re-route the denoise node's OG conditioning inputs to the collect nodes
|
// Re-route the denoise node's OG conditioning inputs to the collect nodes
|
||||||
@ -94,10 +97,13 @@ export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNull
|
|||||||
});
|
});
|
||||||
|
|
||||||
// Remove the global prompt
|
// Remove the global prompt
|
||||||
// TODO: Append regional prompts to CLIP2's prompt?
|
// TODO: Append regional prompts to CLIP2's prompt? Dunno...
|
||||||
(graph.nodes[POSITIVE_CONDITIONING] as S['SDXLCompelPromptInvocation'] | S['CompelInvocation']).prompt = '';
|
(graph.nodes[POSITIVE_CONDITIONING] as S['SDXLCompelPromptInvocation'] | S['CompelInvocation']).prompt = '';
|
||||||
|
|
||||||
// Upload the blobs to the backend, add each to graph
|
// Upload the blobs to the backend, add each to graph
|
||||||
|
// TODO: Store the uploaded image names in redux to reuse them, so long as the layer hasn't otherwise changed. This
|
||||||
|
// would be a great perf win - not only would we skip re-uploading the same image, but we'd be able to use the node
|
||||||
|
// cache (currently, when we re-use the same mask data, since it is a different image, the node cache is not used).
|
||||||
for (const [layerId, blob] of Object.entries(blobs)) {
|
for (const [layerId, blob] of Object.entries(blobs)) {
|
||||||
const layer = layers.find((l) => l.id === layerId);
|
const layer = layers.find((l) => l.id === layerId);
|
||||||
assert(layer, `Layer ${layerId} not found`);
|
assert(layer, `Layer ${layerId} not found`);
|
||||||
@ -108,9 +114,10 @@ export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNull
|
|||||||
);
|
);
|
||||||
req.reset();
|
req.reset();
|
||||||
|
|
||||||
// TODO: this will raise an error
|
// TODO: This will raise on network error
|
||||||
const { image_name } = await req.unwrap();
|
const { image_name } = await req.unwrap();
|
||||||
|
|
||||||
|
// This mask (image primitive) will be fed into at least once mask-to-tensor node - two if we use the "invert" mode
|
||||||
const maskImageNode: S['ImageInvocation'] = {
|
const maskImageNode: S['ImageInvocation'] = {
|
||||||
id: `${PROMPT_REGION_MASK_IMAGE_PRIMITIVE_PREFIX}_${layerId}`,
|
id: `${PROMPT_REGION_MASK_IMAGE_PRIMITIVE_PREFIX}_${layerId}`,
|
||||||
type: 'image',
|
type: 'image',
|
||||||
@ -120,12 +127,14 @@ export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNull
|
|||||||
};
|
};
|
||||||
graph.nodes[maskImageNode.id] = maskImageNode;
|
graph.nodes[maskImageNode.id] = maskImageNode;
|
||||||
|
|
||||||
|
// The main mask-to-tensor node
|
||||||
const maskToTensorNode: S['AlphaMaskToTensorInvocation'] = {
|
const maskToTensorNode: S['AlphaMaskToTensorInvocation'] = {
|
||||||
id: `${PROMPT_REGION_MASK_TO_TENSOR_PREFIX}_${layerId}`,
|
id: `${PROMPT_REGION_MASK_TO_TENSOR_PREFIX}_${layerId}`,
|
||||||
type: 'alpha_mask_to_tensor',
|
type: 'alpha_mask_to_tensor',
|
||||||
};
|
};
|
||||||
graph.nodes[maskToTensorNode.id] = maskToTensorNode;
|
graph.nodes[maskToTensorNode.id] = maskToTensorNode;
|
||||||
|
|
||||||
|
// Connect the mask image to the mask-to-tensor node
|
||||||
graph.edges.push({
|
graph.edges.push({
|
||||||
source: {
|
source: {
|
||||||
node_id: maskImageNode.id,
|
node_id: maskImageNode.id,
|
||||||
@ -137,38 +146,49 @@ export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNull
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
// Create the conditioning nodes for each region - different handling for SDXL
|
// The main positive conditioning node
|
||||||
|
|
||||||
const regionalPositiveCondNode: S['SDXLCompelPromptInvocation'] = {
|
const regionalPositiveCondNode: S['SDXLCompelPromptInvocation'] = {
|
||||||
type: 'sdxl_compel_prompt',
|
type: 'sdxl_compel_prompt',
|
||||||
id: `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${layerId}`,
|
id: `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${layerId}`,
|
||||||
prompt: layer.positivePrompt,
|
prompt: layer.positivePrompt,
|
||||||
style: layer.positivePrompt,
|
style: layer.positivePrompt, // TODO: Should we put the positive prompt in both fields?
|
||||||
};
|
};
|
||||||
|
graph.nodes[regionalPositiveCondNode.id] = regionalPositiveCondNode;
|
||||||
|
|
||||||
|
// Connect the mask to the conditioning
|
||||||
|
graph.edges.push({
|
||||||
|
source: { node_id: maskToTensorNode.id, field: 'mask' },
|
||||||
|
destination: { node_id: regionalPositiveCondNode.id, field: 'mask' },
|
||||||
|
});
|
||||||
|
|
||||||
|
// Connect the conditioning to the collector
|
||||||
|
graph.edges.push({
|
||||||
|
source: { node_id: regionalPositiveCondNode.id, field: 'conditioning' },
|
||||||
|
destination: { node_id: posCondCollectNode.id, field: 'item' },
|
||||||
|
});
|
||||||
|
|
||||||
|
// The main negative conditioning node
|
||||||
const regionalNegativeCondNode: S['SDXLCompelPromptInvocation'] = {
|
const regionalNegativeCondNode: S['SDXLCompelPromptInvocation'] = {
|
||||||
type: 'sdxl_compel_prompt',
|
type: 'sdxl_compel_prompt',
|
||||||
id: `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${layerId}`,
|
id: `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${layerId}`,
|
||||||
prompt: layer.negativePrompt,
|
prompt: layer.negativePrompt,
|
||||||
style: layer.negativePrompt,
|
style: layer.negativePrompt,
|
||||||
};
|
};
|
||||||
graph.nodes[regionalPositiveCondNode.id] = regionalPositiveCondNode;
|
|
||||||
graph.nodes[regionalNegativeCondNode.id] = regionalNegativeCondNode;
|
graph.nodes[regionalNegativeCondNode.id] = regionalNegativeCondNode;
|
||||||
graph.edges.push({
|
|
||||||
source: { node_id: maskToTensorNode.id, field: 'mask' },
|
// Connect the mask to the conditioning
|
||||||
destination: { node_id: regionalPositiveCondNode.id, field: 'mask' },
|
|
||||||
});
|
|
||||||
graph.edges.push({
|
graph.edges.push({
|
||||||
source: { node_id: maskToTensorNode.id, field: 'mask' },
|
source: { node_id: maskToTensorNode.id, field: 'mask' },
|
||||||
destination: { node_id: regionalNegativeCondNode.id, field: 'mask' },
|
destination: { node_id: regionalNegativeCondNode.id, field: 'mask' },
|
||||||
});
|
});
|
||||||
graph.edges.push({
|
|
||||||
source: { node_id: regionalPositiveCondNode.id, field: 'conditioning' },
|
// Connect the conditioning to the collector
|
||||||
destination: { node_id: posCondCollectNode.id, field: 'item' },
|
|
||||||
});
|
|
||||||
graph.edges.push({
|
graph.edges.push({
|
||||||
source: { node_id: regionalNegativeCondNode.id, field: 'conditioning' },
|
source: { node_id: regionalNegativeCondNode.id, field: 'conditioning' },
|
||||||
destination: { node_id: negCondCollectNode.id, field: 'item' },
|
destination: { node_id: negCondCollectNode.id, field: 'item' },
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Copy the connections to the "global" positive and negative conditioning nodes to our regional nodes
|
||||||
for (const edge of graph.edges) {
|
for (const edge of graph.edges) {
|
||||||
if (edge.destination.node_id === POSITIVE_CONDITIONING && edge.destination.field !== 'prompt') {
|
if (edge.destination.node_id === POSITIVE_CONDITIONING && edge.destination.field !== 'prompt') {
|
||||||
graph.edges.push({
|
graph.edges.push({
|
||||||
@ -184,14 +204,18 @@ export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNull
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If we are using the "invert" auto-negative setting, we need to add an additional negative conditioning node
|
||||||
if (autoNegative === 'invert') {
|
if (autoNegative === 'invert') {
|
||||||
// Add an additional negative conditioning node with the positive prompt & inverted region mask
|
// We re-use the mask image, but invert it when converting to tensor
|
||||||
|
// TODO: Probably faster to invert the tensor from the earlier mask rather than read the mask image and convert...
|
||||||
const invertedMaskToTensorNode: S['AlphaMaskToTensorInvocation'] = {
|
const invertedMaskToTensorNode: S['AlphaMaskToTensorInvocation'] = {
|
||||||
id: `${PROMPT_REGION_MASK_TO_TENSOR_INVERTED_PREFIX}_${layerId}`,
|
id: `${PROMPT_REGION_MASK_TO_TENSOR_INVERTED_PREFIX}_${layerId}`,
|
||||||
type: 'alpha_mask_to_tensor',
|
type: 'alpha_mask_to_tensor',
|
||||||
invert: true,
|
invert: true,
|
||||||
};
|
};
|
||||||
graph.nodes[invertedMaskToTensorNode.id] = invertedMaskToTensorNode;
|
graph.nodes[invertedMaskToTensorNode.id] = invertedMaskToTensorNode;
|
||||||
|
|
||||||
|
// Connect the OG mask image to the inverted mask-to-tensor node
|
||||||
graph.edges.push({
|
graph.edges.push({
|
||||||
source: {
|
source: {
|
||||||
node_id: maskImageNode.id,
|
node_id: maskImageNode.id,
|
||||||
@ -203,6 +227,8 @@ export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNull
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Create the conditioning node. It's going to be connected to the negative cond collector, but it uses the
|
||||||
|
// positive prompt
|
||||||
const regionalPositiveCondInvertedNode: S['SDXLCompelPromptInvocation'] = {
|
const regionalPositiveCondInvertedNode: S['SDXLCompelPromptInvocation'] = {
|
||||||
type: 'sdxl_compel_prompt',
|
type: 'sdxl_compel_prompt',
|
||||||
id: `${PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX}_${layerId}`,
|
id: `${PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX}_${layerId}`,
|
||||||
@ -210,14 +236,17 @@ export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNull
|
|||||||
style: layer.positivePrompt,
|
style: layer.positivePrompt,
|
||||||
};
|
};
|
||||||
graph.nodes[regionalPositiveCondInvertedNode.id] = regionalPositiveCondInvertedNode;
|
graph.nodes[regionalPositiveCondInvertedNode.id] = regionalPositiveCondInvertedNode;
|
||||||
|
// Connect the inverted mask to the conditioning
|
||||||
graph.edges.push({
|
graph.edges.push({
|
||||||
source: { node_id: invertedMaskToTensorNode.id, field: 'mask' },
|
source: { node_id: invertedMaskToTensorNode.id, field: 'mask' },
|
||||||
destination: { node_id: regionalPositiveCondInvertedNode.id, field: 'mask' },
|
destination: { node_id: regionalPositiveCondInvertedNode.id, field: 'mask' },
|
||||||
});
|
});
|
||||||
|
// Connect the conditioning to the negative collector
|
||||||
graph.edges.push({
|
graph.edges.push({
|
||||||
source: { node_id: regionalPositiveCondInvertedNode.id, field: 'conditioning' },
|
source: { node_id: regionalPositiveCondInvertedNode.id, field: 'conditioning' },
|
||||||
destination: { node_id: negCondCollectNode.id, field: 'item' },
|
destination: { node_id: negCondCollectNode.id, field: 'item' },
|
||||||
});
|
});
|
||||||
|
// Copy the connections to the "global" positive conditioning node to our regional node
|
||||||
for (const edge of graph.edges) {
|
for (const edge of graph.edges) {
|
||||||
if (edge.destination.node_id === POSITIVE_CONDITIONING && edge.destination.field !== 'prompt') {
|
if (edge.destination.node_id === POSITIVE_CONDITIONING && edge.destination.field !== 'prompt') {
|
||||||
graph.edges.push({
|
graph.edges.push({
|
||||||
|
Loading…
x
Reference in New Issue
Block a user