feat(ui): support negative regional prompt

This commit is contained in:
psychedelicious
2024-04-17 19:56:07 +10:00
committed by Kent Keirsey
parent aa6bfc8645
commit a5bfe2dccb
8 changed files with 151 additions and 32 deletions

View File

@ -1,11 +1,13 @@
import { getStore } from 'app/store/nanostores/store';
import type { RootState } from 'app/store/store';
import {
NEGATIVE_CONDITIONING,
NEGATIVE_CONDITIONING_COLLECT,
POSITIVE_CONDITIONING,
POSITIVE_CONDITIONING_COLLECT,
PROMPT_REGION_COND_PREFIX,
PROMPT_REGION_MASK_PREFIX,
PROMPT_REGION_NEGATIVE_COND_PREFIX,
PROMPT_REGION_POSITIVE_COND_PREFIX,
} from 'features/nodes/util/graph/constants';
import { getRegionalPromptLayerBlobs } from 'features/regionalPrompts/util/getLayerBlobs';
import { size } from 'lodash-es';
@ -116,34 +118,55 @@ export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNull
// Create the conditioning nodes for each region - different handling for SDXL
// TODO: negative prompt
const regionalCondNodeId = `${PROMPT_REGION_COND_PREFIX}_${layerId}`;
const regionalPositiveCondNodeId = `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${layerId}`;
const regionalNegativeCondNodeId = `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${layerId}`;
if (isSDXL) {
graph.nodes[regionalCondNodeId] = {
graph.nodes[regionalPositiveCondNodeId] = {
type: 'sdxl_compel_prompt',
id: regionalCondNodeId,
prompt: layer.prompt,
id: regionalPositiveCondNodeId,
prompt: layer.positivePrompt,
};
graph.nodes[regionalNegativeCondNodeId] = {
type: 'sdxl_compel_prompt',
id: regionalNegativeCondNodeId,
prompt: layer.negativePrompt,
};
} else {
graph.nodes[regionalCondNodeId] = {
type: 'compel',
id: regionalCondNodeId,
prompt: layer.prompt,
};
// TODO: non sdxl
// graph.nodes[regionalCondNodeId] = {
// type: 'compel',
// id: regionalCondNodeId,
// prompt: layer.prompt,
// };
}
graph.edges.push({
source: { node_id: id, field: 'mask' },
destination: { node_id: regionalCondNodeId, field: 'mask' },
destination: { node_id: regionalPositiveCondNodeId, field: 'mask' },
});
graph.edges.push({
source: { node_id: regionalCondNodeId, field: 'conditioning' },
source: { node_id: id, field: 'mask' },
destination: { node_id: regionalNegativeCondNodeId, field: 'mask' },
});
graph.edges.push({
source: { node_id: regionalPositiveCondNodeId, field: 'conditioning' },
destination: { node_id: posCondCollectNode.id, field: 'item' },
});
graph.edges.push({
source: { node_id: regionalNegativeCondNodeId, field: 'conditioning' },
destination: { node_id: negCondCollectNode.id, field: 'item' },
});
for (const edge of graph.edges) {
if (edge.destination.node_id === POSITIVE_CONDITIONING && edge.destination.field !== 'prompt') {
graph.edges.push({
source: edge.source,
destination: { node_id: regionalCondNodeId, field: edge.destination.field },
destination: { node_id: regionalPositiveCondNodeId, field: edge.destination.field },
});
}
if (edge.destination.node_id === NEGATIVE_CONDITIONING && edge.destination.field !== 'prompt') {
graph.edges.push({
source: edge.source,
destination: { node_id: regionalNegativeCondNodeId, field: edge.destination.field },
});
}
}

View File

@ -47,7 +47,8 @@ export const SDXL_REFINER_INPAINT_CREATE_MASK = 'refiner_inpaint_create_mask';
export const SEAMLESS = 'seamless';
export const SDXL_REFINER_SEAMLESS = 'refiner_seamless';
export const PROMPT_REGION_MASK_PREFIX = 'prompt_region_mask';
export const PROMPT_REGION_COND_PREFIX = 'prompt_region_cond';
export const PROMPT_REGION_POSITIVE_COND_PREFIX = 'prompt_region_positive_cond';
export const PROMPT_REGION_NEGATIVE_COND_PREFIX = 'prompt_region_negative_cond';
export const POSITIVE_CONDITIONING_COLLECT = 'positive_conditioning_collect';
export const NEGATIVE_CONDITIONING_COLLECT = 'negative_conditioning_collect';