From a5bfe2dccb06d3d55e8f17ed05292db6dca4ce47 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 17 Apr 2024 19:56:07 +1000 Subject: [PATCH] feat(ui): support negative regional prompt --- .../util/graph/addRegionalPromptsToGraph.ts | 49 +++++++++---- .../features/nodes/util/graph/constants.ts | 3 +- .../components/LayerListItem.tsx | 6 +- .../RegionalPromptsNegativePrompt.tsx | 69 +++++++++++++++++++ ....tsx => RegionalPromptsPositivePrompt.tsx} | 12 ++-- .../regionalPrompts/hooks/layerStateHooks.ts | 18 ++++- .../store/regionalPromptsSlice.ts | 24 +++++-- .../regionalPrompts/util/getLayerBlobs.ts | 2 +- 8 files changed, 151 insertions(+), 32 deletions(-) create mode 100644 invokeai/frontend/web/src/features/regionalPrompts/components/RegionalPromptsNegativePrompt.tsx rename invokeai/frontend/web/src/features/regionalPrompts/components/{RegionalPromptsPrompt.tsx => RegionalPromptsPositivePrompt.tsx} (80%) diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addRegionalPromptsToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addRegionalPromptsToGraph.ts index ebe2e74da5..aede682866 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addRegionalPromptsToGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addRegionalPromptsToGraph.ts @@ -1,11 +1,13 @@ import { getStore } from 'app/store/nanostores/store'; import type { RootState } from 'app/store/store'; import { + NEGATIVE_CONDITIONING, NEGATIVE_CONDITIONING_COLLECT, POSITIVE_CONDITIONING, POSITIVE_CONDITIONING_COLLECT, - PROMPT_REGION_COND_PREFIX, PROMPT_REGION_MASK_PREFIX, + PROMPT_REGION_NEGATIVE_COND_PREFIX, + PROMPT_REGION_POSITIVE_COND_PREFIX, } from 'features/nodes/util/graph/constants'; import { getRegionalPromptLayerBlobs } from 'features/regionalPrompts/util/getLayerBlobs'; import { size } from 'lodash-es'; @@ -116,34 +118,55 @@ export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNull // Create the conditioning nodes for each region - different handling for SDXL // TODO: negative prompt - const regionalCondNodeId = `${PROMPT_REGION_COND_PREFIX}_${layerId}`; + const regionalPositiveCondNodeId = `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${layerId}`; + const regionalNegativeCondNodeId = `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${layerId}`; if (isSDXL) { - graph.nodes[regionalCondNodeId] = { + graph.nodes[regionalPositiveCondNodeId] = { type: 'sdxl_compel_prompt', - id: regionalCondNodeId, - prompt: layer.prompt, + id: regionalPositiveCondNodeId, + prompt: layer.positivePrompt, + }; + graph.nodes[regionalNegativeCondNodeId] = { + type: 'sdxl_compel_prompt', + id: regionalNegativeCondNodeId, + prompt: layer.negativePrompt, }; } else { - graph.nodes[regionalCondNodeId] = { - type: 'compel', - id: regionalCondNodeId, - prompt: layer.prompt, - }; + // TODO: non sdxl + // graph.nodes[regionalCondNodeId] = { + // type: 'compel', + // id: regionalCondNodeId, + // prompt: layer.prompt, + // }; } graph.edges.push({ source: { node_id: id, field: 'mask' }, - destination: { node_id: regionalCondNodeId, field: 'mask' }, + destination: { node_id: regionalPositiveCondNodeId, field: 'mask' }, }); graph.edges.push({ - source: { node_id: regionalCondNodeId, field: 'conditioning' }, + source: { node_id: id, field: 'mask' }, + destination: { node_id: regionalNegativeCondNodeId, field: 'mask' }, + }); + graph.edges.push({ + source: { node_id: regionalPositiveCondNodeId, field: 'conditioning' }, destination: { node_id: posCondCollectNode.id, field: 'item' }, }); + graph.edges.push({ + source: { node_id: regionalNegativeCondNodeId, field: 'conditioning' }, + destination: { node_id: negCondCollectNode.id, field: 'item' }, + }); for (const edge of graph.edges) { if (edge.destination.node_id === POSITIVE_CONDITIONING && edge.destination.field !== 'prompt') { graph.edges.push({ source: edge.source, - destination: { node_id: regionalCondNodeId, field: edge.destination.field }, + destination: { node_id: regionalPositiveCondNodeId, field: edge.destination.field }, + }); + } + if (edge.destination.node_id === NEGATIVE_CONDITIONING && edge.destination.field !== 'prompt') { + graph.edges.push({ + source: edge.source, + destination: { node_id: regionalNegativeCondNodeId, field: edge.destination.field }, }); } } diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/constants.ts b/invokeai/frontend/web/src/features/nodes/util/graph/constants.ts index adde745b4a..81952658c8 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/constants.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/constants.ts @@ -47,7 +47,8 @@ export const SDXL_REFINER_INPAINT_CREATE_MASK = 'refiner_inpaint_create_mask'; export const SEAMLESS = 'seamless'; export const SDXL_REFINER_SEAMLESS = 'refiner_seamless'; export const PROMPT_REGION_MASK_PREFIX = 'prompt_region_mask'; -export const PROMPT_REGION_COND_PREFIX = 'prompt_region_cond'; +export const PROMPT_REGION_POSITIVE_COND_PREFIX = 'prompt_region_positive_cond'; +export const PROMPT_REGION_NEGATIVE_COND_PREFIX = 'prompt_region_negative_cond'; export const POSITIVE_CONDITIONING_COLLECT = 'positive_conditioning_collect'; export const NEGATIVE_CONDITIONING_COLLECT = 'negative_conditioning_collect'; diff --git a/invokeai/frontend/web/src/features/regionalPrompts/components/LayerListItem.tsx b/invokeai/frontend/web/src/features/regionalPrompts/components/LayerListItem.tsx index 39a7efd758..b41787ed9f 100644 --- a/invokeai/frontend/web/src/features/regionalPrompts/components/LayerListItem.tsx +++ b/invokeai/frontend/web/src/features/regionalPrompts/components/LayerListItem.tsx @@ -4,7 +4,8 @@ import { rgbColorToString } from 'features/canvas/util/colorToString'; import { LayerColorPicker } from 'features/regionalPrompts/components/LayerColorPicker'; import { LayerMenu } from 'features/regionalPrompts/components/LayerMenu'; import { LayerVisibilityToggle } from 'features/regionalPrompts/components/LayerVisibilityToggle'; -import { RegionalPromptsPrompt } from 'features/regionalPrompts/components/RegionalPromptsPrompt'; +import { RegionalPromptsNegativePrompt } from 'features/regionalPrompts/components/RegionalPromptsNegativePrompt'; +import { RegionalPromptsPositivePrompt } from 'features/regionalPrompts/components/RegionalPromptsPositivePrompt'; import { layerSelected } from 'features/regionalPrompts/store/regionalPromptsSlice'; import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; @@ -42,7 +43,8 @@ export const LayerListItem = memo(({ id }: Props) => { )} - + + ); diff --git a/invokeai/frontend/web/src/features/regionalPrompts/components/RegionalPromptsNegativePrompt.tsx b/invokeai/frontend/web/src/features/regionalPrompts/components/RegionalPromptsNegativePrompt.tsx new file mode 100644 index 0000000000..8f5f9f484b --- /dev/null +++ b/invokeai/frontend/web/src/features/regionalPrompts/components/RegionalPromptsNegativePrompt.tsx @@ -0,0 +1,69 @@ +import { Box, Textarea } from '@invoke-ai/ui-library'; +import { useAppDispatch } from 'app/store/storeHooks'; +import { PromptOverlayButtonWrapper } from 'features/parameters/components/Prompts/PromptOverlayButtonWrapper'; +import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton'; +import { PromptPopover } from 'features/prompt/PromptPopover'; +import { usePrompt } from 'features/prompt/usePrompt'; +import { useLayerNegativePrompt } from 'features/regionalPrompts/hooks/layerStateHooks'; +import { negativePromptChanged } from 'features/regionalPrompts/store/regionalPromptsSlice'; +import { memo, useCallback, useRef } from 'react'; +import type { HotkeyCallback } from 'react-hotkeys-hook'; +import { useHotkeys } from 'react-hotkeys-hook'; +import { useTranslation } from 'react-i18next'; + +type Props = { + layerId: string; +}; + +export const RegionalPromptsNegativePrompt = memo((props: Props) => { + const prompt = useLayerNegativePrompt(props.layerId); + const dispatch = useAppDispatch(); + const textareaRef = useRef(null); + const { t } = useTranslation(); + const _onChange = useCallback( + (v: string) => { + dispatch(negativePromptChanged({ layerId: props.layerId, prompt: v })); + }, + [dispatch, props.layerId] + ); + const { onChange, isOpen, onClose, onOpen, onSelect, onKeyDown, onFocus } = usePrompt({ + prompt, + textareaRef, + onChange: _onChange, + }); + const focus: HotkeyCallback = useCallback( + (e) => { + onFocus(); + e.preventDefault(); + }, + [onFocus] + ); + + useHotkeys('alt+a', focus, []); + + return ( + + +