From a5bfe2dccb06d3d55e8f17ed05292db6dca4ce47 Mon Sep 17 00:00:00 2001
From: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
Date: Wed, 17 Apr 2024 19:56:07 +1000
Subject: [PATCH] feat(ui): support negative regional prompt
---
.../util/graph/addRegionalPromptsToGraph.ts | 49 +++++++++----
.../features/nodes/util/graph/constants.ts | 3 +-
.../components/LayerListItem.tsx | 6 +-
.../RegionalPromptsNegativePrompt.tsx | 69 +++++++++++++++++++
....tsx => RegionalPromptsPositivePrompt.tsx} | 12 ++--
.../regionalPrompts/hooks/layerStateHooks.ts | 18 ++++-
.../store/regionalPromptsSlice.ts | 24 +++++--
.../regionalPrompts/util/getLayerBlobs.ts | 2 +-
8 files changed, 151 insertions(+), 32 deletions(-)
create mode 100644 invokeai/frontend/web/src/features/regionalPrompts/components/RegionalPromptsNegativePrompt.tsx
rename invokeai/frontend/web/src/features/regionalPrompts/components/{RegionalPromptsPrompt.tsx => RegionalPromptsPositivePrompt.tsx} (80%)
diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addRegionalPromptsToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addRegionalPromptsToGraph.ts
index ebe2e74da5..aede682866 100644
--- a/invokeai/frontend/web/src/features/nodes/util/graph/addRegionalPromptsToGraph.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/graph/addRegionalPromptsToGraph.ts
@@ -1,11 +1,13 @@
import { getStore } from 'app/store/nanostores/store';
import type { RootState } from 'app/store/store';
import {
+ NEGATIVE_CONDITIONING,
NEGATIVE_CONDITIONING_COLLECT,
POSITIVE_CONDITIONING,
POSITIVE_CONDITIONING_COLLECT,
- PROMPT_REGION_COND_PREFIX,
PROMPT_REGION_MASK_PREFIX,
+ PROMPT_REGION_NEGATIVE_COND_PREFIX,
+ PROMPT_REGION_POSITIVE_COND_PREFIX,
} from 'features/nodes/util/graph/constants';
import { getRegionalPromptLayerBlobs } from 'features/regionalPrompts/util/getLayerBlobs';
import { size } from 'lodash-es';
@@ -116,34 +118,55 @@ export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNull
// Create the conditioning nodes for each region - different handling for SDXL
// TODO: negative prompt
- const regionalCondNodeId = `${PROMPT_REGION_COND_PREFIX}_${layerId}`;
+ const regionalPositiveCondNodeId = `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${layerId}`;
+ const regionalNegativeCondNodeId = `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${layerId}`;
if (isSDXL) {
- graph.nodes[regionalCondNodeId] = {
+ graph.nodes[regionalPositiveCondNodeId] = {
type: 'sdxl_compel_prompt',
- id: regionalCondNodeId,
- prompt: layer.prompt,
+ id: regionalPositiveCondNodeId,
+ prompt: layer.positivePrompt,
+ };
+ graph.nodes[regionalNegativeCondNodeId] = {
+ type: 'sdxl_compel_prompt',
+ id: regionalNegativeCondNodeId,
+ prompt: layer.negativePrompt,
};
} else {
- graph.nodes[regionalCondNodeId] = {
- type: 'compel',
- id: regionalCondNodeId,
- prompt: layer.prompt,
- };
+ // TODO: non sdxl
+ // graph.nodes[regionalCondNodeId] = {
+ // type: 'compel',
+ // id: regionalCondNodeId,
+ // prompt: layer.prompt,
+ // };
}
graph.edges.push({
source: { node_id: id, field: 'mask' },
- destination: { node_id: regionalCondNodeId, field: 'mask' },
+ destination: { node_id: regionalPositiveCondNodeId, field: 'mask' },
});
graph.edges.push({
- source: { node_id: regionalCondNodeId, field: 'conditioning' },
+ source: { node_id: id, field: 'mask' },
+ destination: { node_id: regionalNegativeCondNodeId, field: 'mask' },
+ });
+ graph.edges.push({
+ source: { node_id: regionalPositiveCondNodeId, field: 'conditioning' },
destination: { node_id: posCondCollectNode.id, field: 'item' },
});
+ graph.edges.push({
+ source: { node_id: regionalNegativeCondNodeId, field: 'conditioning' },
+ destination: { node_id: negCondCollectNode.id, field: 'item' },
+ });
for (const edge of graph.edges) {
if (edge.destination.node_id === POSITIVE_CONDITIONING && edge.destination.field !== 'prompt') {
graph.edges.push({
source: edge.source,
- destination: { node_id: regionalCondNodeId, field: edge.destination.field },
+ destination: { node_id: regionalPositiveCondNodeId, field: edge.destination.field },
+ });
+ }
+ if (edge.destination.node_id === NEGATIVE_CONDITIONING && edge.destination.field !== 'prompt') {
+ graph.edges.push({
+ source: edge.source,
+ destination: { node_id: regionalNegativeCondNodeId, field: edge.destination.field },
});
}
}
diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/constants.ts b/invokeai/frontend/web/src/features/nodes/util/graph/constants.ts
index adde745b4a..81952658c8 100644
--- a/invokeai/frontend/web/src/features/nodes/util/graph/constants.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/graph/constants.ts
@@ -47,7 +47,8 @@ export const SDXL_REFINER_INPAINT_CREATE_MASK = 'refiner_inpaint_create_mask';
export const SEAMLESS = 'seamless';
export const SDXL_REFINER_SEAMLESS = 'refiner_seamless';
export const PROMPT_REGION_MASK_PREFIX = 'prompt_region_mask';
-export const PROMPT_REGION_COND_PREFIX = 'prompt_region_cond';
+export const PROMPT_REGION_POSITIVE_COND_PREFIX = 'prompt_region_positive_cond';
+export const PROMPT_REGION_NEGATIVE_COND_PREFIX = 'prompt_region_negative_cond';
export const POSITIVE_CONDITIONING_COLLECT = 'positive_conditioning_collect';
export const NEGATIVE_CONDITIONING_COLLECT = 'negative_conditioning_collect';
diff --git a/invokeai/frontend/web/src/features/regionalPrompts/components/LayerListItem.tsx b/invokeai/frontend/web/src/features/regionalPrompts/components/LayerListItem.tsx
index 39a7efd758..b41787ed9f 100644
--- a/invokeai/frontend/web/src/features/regionalPrompts/components/LayerListItem.tsx
+++ b/invokeai/frontend/web/src/features/regionalPrompts/components/LayerListItem.tsx
@@ -4,7 +4,8 @@ import { rgbColorToString } from 'features/canvas/util/colorToString';
import { LayerColorPicker } from 'features/regionalPrompts/components/LayerColorPicker';
import { LayerMenu } from 'features/regionalPrompts/components/LayerMenu';
import { LayerVisibilityToggle } from 'features/regionalPrompts/components/LayerVisibilityToggle';
-import { RegionalPromptsPrompt } from 'features/regionalPrompts/components/RegionalPromptsPrompt';
+import { RegionalPromptsNegativePrompt } from 'features/regionalPrompts/components/RegionalPromptsNegativePrompt';
+import { RegionalPromptsPositivePrompt } from 'features/regionalPrompts/components/RegionalPromptsPositivePrompt';
import { layerSelected } from 'features/regionalPrompts/store/regionalPromptsSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
@@ -42,7 +43,8 @@ export const LayerListItem = memo(({ id }: Props) => {
)}
-
+
+
);
diff --git a/invokeai/frontend/web/src/features/regionalPrompts/components/RegionalPromptsNegativePrompt.tsx b/invokeai/frontend/web/src/features/regionalPrompts/components/RegionalPromptsNegativePrompt.tsx
new file mode 100644
index 0000000000..8f5f9f484b
--- /dev/null
+++ b/invokeai/frontend/web/src/features/regionalPrompts/components/RegionalPromptsNegativePrompt.tsx
@@ -0,0 +1,69 @@
+import { Box, Textarea } from '@invoke-ai/ui-library';
+import { useAppDispatch } from 'app/store/storeHooks';
+import { PromptOverlayButtonWrapper } from 'features/parameters/components/Prompts/PromptOverlayButtonWrapper';
+import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton';
+import { PromptPopover } from 'features/prompt/PromptPopover';
+import { usePrompt } from 'features/prompt/usePrompt';
+import { useLayerNegativePrompt } from 'features/regionalPrompts/hooks/layerStateHooks';
+import { negativePromptChanged } from 'features/regionalPrompts/store/regionalPromptsSlice';
+import { memo, useCallback, useRef } from 'react';
+import type { HotkeyCallback } from 'react-hotkeys-hook';
+import { useHotkeys } from 'react-hotkeys-hook';
+import { useTranslation } from 'react-i18next';
+
+type Props = {
+ layerId: string;
+};
+
+export const RegionalPromptsNegativePrompt = memo((props: Props) => {
+ const prompt = useLayerNegativePrompt(props.layerId);
+ const dispatch = useAppDispatch();
+ const textareaRef = useRef(null);
+ const { t } = useTranslation();
+ const _onChange = useCallback(
+ (v: string) => {
+ dispatch(negativePromptChanged({ layerId: props.layerId, prompt: v }));
+ },
+ [dispatch, props.layerId]
+ );
+ const { onChange, isOpen, onClose, onOpen, onSelect, onKeyDown, onFocus } = usePrompt({
+ prompt,
+ textareaRef,
+ onChange: _onChange,
+ });
+ const focus: HotkeyCallback = useCallback(
+ (e) => {
+ onFocus();
+ e.preventDefault();
+ },
+ [onFocus]
+ );
+
+ useHotkeys('alt+a', focus, []);
+
+ return (
+
+
+
+
+
+
+
+
+ );
+});
+
+RegionalPromptsNegativePrompt.displayName = 'RegionalPromptsPrompt';
diff --git a/invokeai/frontend/web/src/features/regionalPrompts/components/RegionalPromptsPrompt.tsx b/invokeai/frontend/web/src/features/regionalPrompts/components/RegionalPromptsPositivePrompt.tsx
similarity index 80%
rename from invokeai/frontend/web/src/features/regionalPrompts/components/RegionalPromptsPrompt.tsx
rename to invokeai/frontend/web/src/features/regionalPrompts/components/RegionalPromptsPositivePrompt.tsx
index 54fa6df830..4f2e7ca306 100644
--- a/invokeai/frontend/web/src/features/regionalPrompts/components/RegionalPromptsPrompt.tsx
+++ b/invokeai/frontend/web/src/features/regionalPrompts/components/RegionalPromptsPositivePrompt.tsx
@@ -4,8 +4,8 @@ import { PromptOverlayButtonWrapper } from 'features/parameters/components/Promp
import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton';
import { PromptPopover } from 'features/prompt/PromptPopover';
import { usePrompt } from 'features/prompt/usePrompt';
-import { useLayerPrompt } from 'features/regionalPrompts/hooks/layerStateHooks';
-import { promptChanged } from 'features/regionalPrompts/store/regionalPromptsSlice';
+import { useLayerPositivePrompt } from 'features/regionalPrompts/hooks/layerStateHooks';
+import { positivePromptChanged } from 'features/regionalPrompts/store/regionalPromptsSlice';
import { memo, useCallback, useRef } from 'react';
import type { HotkeyCallback } from 'react-hotkeys-hook';
import { useHotkeys } from 'react-hotkeys-hook';
@@ -15,14 +15,14 @@ type Props = {
layerId: string;
};
-export const RegionalPromptsPrompt = memo((props: Props) => {
- const prompt = useLayerPrompt(props.layerId);
+export const RegionalPromptsPositivePrompt = memo((props: Props) => {
+ const prompt = useLayerPositivePrompt(props.layerId);
const dispatch = useAppDispatch();
const textareaRef = useRef(null);
const { t } = useTranslation();
const _onChange = useCallback(
(v: string) => {
- dispatch(promptChanged({ layerId: props.layerId, prompt: v }));
+ dispatch(positivePromptChanged({ layerId: props.layerId, prompt: v }));
},
[dispatch, props.layerId]
);
@@ -65,4 +65,4 @@ export const RegionalPromptsPrompt = memo((props: Props) => {
);
});
-RegionalPromptsPrompt.displayName = 'RegionalPromptsPrompt';
+RegionalPromptsPositivePrompt.displayName = 'RegionalPromptsPrompt';
diff --git a/invokeai/frontend/web/src/features/regionalPrompts/hooks/layerStateHooks.ts b/invokeai/frontend/web/src/features/regionalPrompts/hooks/layerStateHooks.ts
index 067c0b71d6..0b4729dcbb 100644
--- a/invokeai/frontend/web/src/features/regionalPrompts/hooks/layerStateHooks.ts
+++ b/invokeai/frontend/web/src/features/regionalPrompts/hooks/layerStateHooks.ts
@@ -17,12 +17,26 @@ export const useLayer = (layerId: string) => {
return layer;
};
-export const useLayerPrompt = (layerId: string) => {
+export const useLayerPositivePrompt = (layerId: string) => {
const selectLayer = useMemo(
() =>
createSelector(
selectRegionalPromptsSlice,
- (regionalPrompts) => regionalPrompts.layers.find((l) => l.id === layerId)?.prompt
+ (regionalPrompts) => regionalPrompts.layers.find((l) => l.id === layerId)?.positivePrompt
+ ),
+ [layerId]
+ );
+ const prompt = useAppSelector(selectLayer);
+ assert(prompt !== undefined, `Layer ${layerId} doesn't exist!`);
+ return prompt;
+};
+
+export const useLayerNegativePrompt = (layerId: string) => {
+ const selectLayer = useMemo(
+ () =>
+ createSelector(
+ selectRegionalPromptsSlice,
+ (regionalPrompts) => regionalPrompts.layers.find((l) => l.id === layerId)?.negativePrompt
),
[layerId]
);
diff --git a/invokeai/frontend/web/src/features/regionalPrompts/store/regionalPromptsSlice.ts b/invokeai/frontend/web/src/features/regionalPrompts/store/regionalPromptsSlice.ts
index f760ec9e91..28f941e642 100644
--- a/invokeai/frontend/web/src/features/regionalPrompts/store/regionalPromptsSlice.ts
+++ b/invokeai/frontend/web/src/features/regionalPrompts/store/regionalPromptsSlice.ts
@@ -52,7 +52,8 @@ type LayerBase = {
type PromptRegionLayer = LayerBase & {
kind: 'promptRegionLayer';
objects: LayerObject[];
- prompt: string;
+ positivePrompt: string;
+ negativePrompt: string;
color: RgbColor;
};
@@ -73,7 +74,7 @@ const initialRegionalPromptsState: RegionalPromptsState = {
selectedLayer: null,
brushSize: 40,
layers: [],
- promptLayerOpacity: 0.5,
+ promptLayerOpacity: 0.5, // This currently doesn't work
};
const isLine = (obj: LayerObject): obj is LineObject => obj.kind === 'line';
@@ -89,7 +90,8 @@ export const regionalPromptsSlice = createSlice({
isVisible: true,
bbox: null,
kind: action.payload,
- prompt: '',
+ positivePrompt: '',
+ negativePrompt: '',
objects: [],
color: action.meta.color,
x: 0,
@@ -118,7 +120,6 @@ export const regionalPromptsSlice = createSlice({
layer.objects = [];
layer.bbox = null;
layer.isVisible = true;
- layer.prompt = '';
},
layerDeleted: (state, action: PayloadAction) => {
state.layers = state.layers.filter((l) => l.id !== action.payload);
@@ -163,13 +164,21 @@ export const regionalPromptsSlice = createSlice({
state.layers = [];
state.selectedLayer = null;
},
- promptChanged: (state, action: PayloadAction<{ layerId: string; prompt: string }>) => {
+ positivePromptChanged: (state, action: PayloadAction<{ layerId: string; prompt: string }>) => {
const { layerId, prompt } = action.payload;
const layer = state.layers.find((l) => l.id === layerId);
if (!layer) {
return;
}
- layer.prompt = prompt;
+ layer.positivePrompt = prompt;
+ },
+ negativePromptChanged: (state, action: PayloadAction<{ layerId: string; prompt: string }>) => {
+ const { layerId, prompt } = action.payload;
+ const layer = state.layers.find((l) => l.id === layerId);
+ if (!layer) {
+ return;
+ }
+ layer.negativePrompt = prompt;
},
promptRegionLayerColorChanged: (state, action: PayloadAction<{ layerId: string; color: RgbColor }>) => {
const { layerId, color } = action.payload;
@@ -254,7 +263,8 @@ export const {
layerReset,
layerDeleted,
layerIsVisibleToggled,
- promptChanged,
+ positivePromptChanged,
+ negativePromptChanged,
lineAdded,
pointsAdded,
promptRegionLayerColorChanged,
diff --git a/invokeai/frontend/web/src/features/regionalPrompts/util/getLayerBlobs.ts b/invokeai/frontend/web/src/features/regionalPrompts/util/getLayerBlobs.ts
index ce6f751479..52bca435cf 100644
--- a/invokeai/frontend/web/src/features/regionalPrompts/util/getLayerBlobs.ts
+++ b/invokeai/frontend/web/src/features/regionalPrompts/util/getLayerBlobs.ts
@@ -48,7 +48,7 @@ export const getRegionalPromptLayerBlobs = async (
if (preview) {
const base64 = await blobToDataURL(blob);
- openBase64ImageInTab([{ base64, caption: `${reduxLayer.id}: ${reduxLayer.prompt}` }]);
+ openBase64ImageInTab([{ base64, caption: `${reduxLayer.id}: ${reduxLayer.positivePrompt} / ${reduxLayer.negativePrompt}` }]);
}
layer.remove();
blobs[layer.id()] = blob;