feat(ui): support negative regional prompt

This commit is contained in:
psychedelicious 2024-04-17 19:56:07 +10:00 committed by Kent Keirsey
parent aa6bfc8645
commit a5bfe2dccb
8 changed files with 151 additions and 32 deletions

View File

@ -1,11 +1,13 @@
import { getStore } from 'app/store/nanostores/store'; import { getStore } from 'app/store/nanostores/store';
import type { RootState } from 'app/store/store'; import type { RootState } from 'app/store/store';
import { import {
NEGATIVE_CONDITIONING,
NEGATIVE_CONDITIONING_COLLECT, NEGATIVE_CONDITIONING_COLLECT,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
POSITIVE_CONDITIONING_COLLECT, POSITIVE_CONDITIONING_COLLECT,
PROMPT_REGION_COND_PREFIX,
PROMPT_REGION_MASK_PREFIX, PROMPT_REGION_MASK_PREFIX,
PROMPT_REGION_NEGATIVE_COND_PREFIX,
PROMPT_REGION_POSITIVE_COND_PREFIX,
} from 'features/nodes/util/graph/constants'; } from 'features/nodes/util/graph/constants';
import { getRegionalPromptLayerBlobs } from 'features/regionalPrompts/util/getLayerBlobs'; import { getRegionalPromptLayerBlobs } from 'features/regionalPrompts/util/getLayerBlobs';
import { size } from 'lodash-es'; import { size } from 'lodash-es';
@ -116,34 +118,55 @@ export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNull
// Create the conditioning nodes for each region - different handling for SDXL // Create the conditioning nodes for each region - different handling for SDXL
// TODO: negative prompt // TODO: negative prompt
const regionalCondNodeId = `${PROMPT_REGION_COND_PREFIX}_${layerId}`; const regionalPositiveCondNodeId = `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${layerId}`;
const regionalNegativeCondNodeId = `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${layerId}`;
if (isSDXL) { if (isSDXL) {
graph.nodes[regionalCondNodeId] = { graph.nodes[regionalPositiveCondNodeId] = {
type: 'sdxl_compel_prompt', type: 'sdxl_compel_prompt',
id: regionalCondNodeId, id: regionalPositiveCondNodeId,
prompt: layer.prompt, prompt: layer.positivePrompt,
};
graph.nodes[regionalNegativeCondNodeId] = {
type: 'sdxl_compel_prompt',
id: regionalNegativeCondNodeId,
prompt: layer.negativePrompt,
}; };
} else { } else {
graph.nodes[regionalCondNodeId] = { // TODO: non sdxl
type: 'compel', // graph.nodes[regionalCondNodeId] = {
id: regionalCondNodeId, // type: 'compel',
prompt: layer.prompt, // id: regionalCondNodeId,
}; // prompt: layer.prompt,
// };
} }
graph.edges.push({ graph.edges.push({
source: { node_id: id, field: 'mask' }, source: { node_id: id, field: 'mask' },
destination: { node_id: regionalCondNodeId, field: 'mask' }, destination: { node_id: regionalPositiveCondNodeId, field: 'mask' },
}); });
graph.edges.push({ graph.edges.push({
source: { node_id: regionalCondNodeId, field: 'conditioning' }, source: { node_id: id, field: 'mask' },
destination: { node_id: regionalNegativeCondNodeId, field: 'mask' },
});
graph.edges.push({
source: { node_id: regionalPositiveCondNodeId, field: 'conditioning' },
destination: { node_id: posCondCollectNode.id, field: 'item' }, destination: { node_id: posCondCollectNode.id, field: 'item' },
}); });
graph.edges.push({
source: { node_id: regionalNegativeCondNodeId, field: 'conditioning' },
destination: { node_id: negCondCollectNode.id, field: 'item' },
});
for (const edge of graph.edges) { for (const edge of graph.edges) {
if (edge.destination.node_id === POSITIVE_CONDITIONING && edge.destination.field !== 'prompt') { if (edge.destination.node_id === POSITIVE_CONDITIONING && edge.destination.field !== 'prompt') {
graph.edges.push({ graph.edges.push({
source: edge.source, source: edge.source,
destination: { node_id: regionalCondNodeId, field: edge.destination.field }, destination: { node_id: regionalPositiveCondNodeId, field: edge.destination.field },
});
}
if (edge.destination.node_id === NEGATIVE_CONDITIONING && edge.destination.field !== 'prompt') {
graph.edges.push({
source: edge.source,
destination: { node_id: regionalNegativeCondNodeId, field: edge.destination.field },
}); });
} }
} }

View File

@ -47,7 +47,8 @@ export const SDXL_REFINER_INPAINT_CREATE_MASK = 'refiner_inpaint_create_mask';
export const SEAMLESS = 'seamless'; export const SEAMLESS = 'seamless';
export const SDXL_REFINER_SEAMLESS = 'refiner_seamless'; export const SDXL_REFINER_SEAMLESS = 'refiner_seamless';
export const PROMPT_REGION_MASK_PREFIX = 'prompt_region_mask'; export const PROMPT_REGION_MASK_PREFIX = 'prompt_region_mask';
export const PROMPT_REGION_COND_PREFIX = 'prompt_region_cond'; export const PROMPT_REGION_POSITIVE_COND_PREFIX = 'prompt_region_positive_cond';
export const PROMPT_REGION_NEGATIVE_COND_PREFIX = 'prompt_region_negative_cond';
export const POSITIVE_CONDITIONING_COLLECT = 'positive_conditioning_collect'; export const POSITIVE_CONDITIONING_COLLECT = 'positive_conditioning_collect';
export const NEGATIVE_CONDITIONING_COLLECT = 'negative_conditioning_collect'; export const NEGATIVE_CONDITIONING_COLLECT = 'negative_conditioning_collect';

View File

@ -4,7 +4,8 @@ import { rgbColorToString } from 'features/canvas/util/colorToString';
import { LayerColorPicker } from 'features/regionalPrompts/components/LayerColorPicker'; import { LayerColorPicker } from 'features/regionalPrompts/components/LayerColorPicker';
import { LayerMenu } from 'features/regionalPrompts/components/LayerMenu'; import { LayerMenu } from 'features/regionalPrompts/components/LayerMenu';
import { LayerVisibilityToggle } from 'features/regionalPrompts/components/LayerVisibilityToggle'; import { LayerVisibilityToggle } from 'features/regionalPrompts/components/LayerVisibilityToggle';
import { RegionalPromptsPrompt } from 'features/regionalPrompts/components/RegionalPromptsPrompt'; import { RegionalPromptsNegativePrompt } from 'features/regionalPrompts/components/RegionalPromptsNegativePrompt';
import { RegionalPromptsPositivePrompt } from 'features/regionalPrompts/components/RegionalPromptsPositivePrompt';
import { layerSelected } from 'features/regionalPrompts/store/regionalPromptsSlice'; import { layerSelected } from 'features/regionalPrompts/store/regionalPromptsSlice';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
@ -42,7 +43,8 @@ export const LayerListItem = memo(({ id }: Props) => {
)} )}
<LayerMenu id={id} /> <LayerMenu id={id} />
</Flex> </Flex>
<RegionalPromptsPrompt layerId={id} /> <RegionalPromptsPositivePrompt layerId={id} />
<RegionalPromptsNegativePrompt layerId={id} />
</Flex> </Flex>
</Flex> </Flex>
); );

View File

@ -0,0 +1,69 @@
import { Box, Textarea } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { PromptOverlayButtonWrapper } from 'features/parameters/components/Prompts/PromptOverlayButtonWrapper';
import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton';
import { PromptPopover } from 'features/prompt/PromptPopover';
import { usePrompt } from 'features/prompt/usePrompt';
import { useLayerNegativePrompt } from 'features/regionalPrompts/hooks/layerStateHooks';
import { negativePromptChanged } from 'features/regionalPrompts/store/regionalPromptsSlice';
import { memo, useCallback, useRef } from 'react';
import type { HotkeyCallback } from 'react-hotkeys-hook';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
type Props = {
layerId: string;
};
export const RegionalPromptsNegativePrompt = memo((props: Props) => {
const prompt = useLayerNegativePrompt(props.layerId);
const dispatch = useAppDispatch();
const textareaRef = useRef<HTMLTextAreaElement>(null);
const { t } = useTranslation();
const _onChange = useCallback(
(v: string) => {
dispatch(negativePromptChanged({ layerId: props.layerId, prompt: v }));
},
[dispatch, props.layerId]
);
const { onChange, isOpen, onClose, onOpen, onSelect, onKeyDown, onFocus } = usePrompt({
prompt,
textareaRef,
onChange: _onChange,
});
const focus: HotkeyCallback = useCallback(
(e) => {
onFocus();
e.preventDefault();
},
[onFocus]
);
useHotkeys('alt+a', focus, []);
return (
<PromptPopover isOpen={isOpen} onClose={onClose} onSelect={onSelect} width={textareaRef.current?.clientWidth}>
<Box pos="relative" w="full">
<Textarea
id="prompt"
name="prompt"
ref={textareaRef}
value={prompt}
placeholder={t('parameters.negativePromptPlaceholder')}
onChange={onChange}
minH={28}
minW={64}
onKeyDown={onKeyDown}
variant="darkFilled"
paddingRight={30}
fontSize="sm"
/>
<PromptOverlayButtonWrapper>
<AddPromptTriggerButton isOpen={isOpen} onOpen={onOpen} />
</PromptOverlayButtonWrapper>
</Box>
</PromptPopover>
);
});
RegionalPromptsNegativePrompt.displayName = 'RegionalPromptsPrompt';

View File

@ -4,8 +4,8 @@ import { PromptOverlayButtonWrapper } from 'features/parameters/components/Promp
import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton'; import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton';
import { PromptPopover } from 'features/prompt/PromptPopover'; import { PromptPopover } from 'features/prompt/PromptPopover';
import { usePrompt } from 'features/prompt/usePrompt'; import { usePrompt } from 'features/prompt/usePrompt';
import { useLayerPrompt } from 'features/regionalPrompts/hooks/layerStateHooks'; import { useLayerPositivePrompt } from 'features/regionalPrompts/hooks/layerStateHooks';
import { promptChanged } from 'features/regionalPrompts/store/regionalPromptsSlice'; import { positivePromptChanged } from 'features/regionalPrompts/store/regionalPromptsSlice';
import { memo, useCallback, useRef } from 'react'; import { memo, useCallback, useRef } from 'react';
import type { HotkeyCallback } from 'react-hotkeys-hook'; import type { HotkeyCallback } from 'react-hotkeys-hook';
import { useHotkeys } from 'react-hotkeys-hook'; import { useHotkeys } from 'react-hotkeys-hook';
@ -15,14 +15,14 @@ type Props = {
layerId: string; layerId: string;
}; };
export const RegionalPromptsPrompt = memo((props: Props) => { export const RegionalPromptsPositivePrompt = memo((props: Props) => {
const prompt = useLayerPrompt(props.layerId); const prompt = useLayerPositivePrompt(props.layerId);
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const textareaRef = useRef<HTMLTextAreaElement>(null); const textareaRef = useRef<HTMLTextAreaElement>(null);
const { t } = useTranslation(); const { t } = useTranslation();
const _onChange = useCallback( const _onChange = useCallback(
(v: string) => { (v: string) => {
dispatch(promptChanged({ layerId: props.layerId, prompt: v })); dispatch(positivePromptChanged({ layerId: props.layerId, prompt: v }));
}, },
[dispatch, props.layerId] [dispatch, props.layerId]
); );
@ -65,4 +65,4 @@ export const RegionalPromptsPrompt = memo((props: Props) => {
); );
}); });
RegionalPromptsPrompt.displayName = 'RegionalPromptsPrompt'; RegionalPromptsPositivePrompt.displayName = 'RegionalPromptsPrompt';

View File

@ -17,12 +17,26 @@ export const useLayer = (layerId: string) => {
return layer; return layer;
}; };
export const useLayerPrompt = (layerId: string) => { export const useLayerPositivePrompt = (layerId: string) => {
const selectLayer = useMemo( const selectLayer = useMemo(
() => () =>
createSelector( createSelector(
selectRegionalPromptsSlice, selectRegionalPromptsSlice,
(regionalPrompts) => regionalPrompts.layers.find((l) => l.id === layerId)?.prompt (regionalPrompts) => regionalPrompts.layers.find((l) => l.id === layerId)?.positivePrompt
),
[layerId]
);
const prompt = useAppSelector(selectLayer);
assert(prompt !== undefined, `Layer ${layerId} doesn't exist!`);
return prompt;
};
export const useLayerNegativePrompt = (layerId: string) => {
const selectLayer = useMemo(
() =>
createSelector(
selectRegionalPromptsSlice,
(regionalPrompts) => regionalPrompts.layers.find((l) => l.id === layerId)?.negativePrompt
), ),
[layerId] [layerId]
); );

View File

@ -52,7 +52,8 @@ type LayerBase = {
type PromptRegionLayer = LayerBase & { type PromptRegionLayer = LayerBase & {
kind: 'promptRegionLayer'; kind: 'promptRegionLayer';
objects: LayerObject[]; objects: LayerObject[];
prompt: string; positivePrompt: string;
negativePrompt: string;
color: RgbColor; color: RgbColor;
}; };
@ -73,7 +74,7 @@ const initialRegionalPromptsState: RegionalPromptsState = {
selectedLayer: null, selectedLayer: null,
brushSize: 40, brushSize: 40,
layers: [], layers: [],
promptLayerOpacity: 0.5, promptLayerOpacity: 0.5, // This currently doesn't work
}; };
const isLine = (obj: LayerObject): obj is LineObject => obj.kind === 'line'; const isLine = (obj: LayerObject): obj is LineObject => obj.kind === 'line';
@ -89,7 +90,8 @@ export const regionalPromptsSlice = createSlice({
isVisible: true, isVisible: true,
bbox: null, bbox: null,
kind: action.payload, kind: action.payload,
prompt: '', positivePrompt: '',
negativePrompt: '',
objects: [], objects: [],
color: action.meta.color, color: action.meta.color,
x: 0, x: 0,
@ -118,7 +120,6 @@ export const regionalPromptsSlice = createSlice({
layer.objects = []; layer.objects = [];
layer.bbox = null; layer.bbox = null;
layer.isVisible = true; layer.isVisible = true;
layer.prompt = '';
}, },
layerDeleted: (state, action: PayloadAction<string>) => { layerDeleted: (state, action: PayloadAction<string>) => {
state.layers = state.layers.filter((l) => l.id !== action.payload); state.layers = state.layers.filter((l) => l.id !== action.payload);
@ -163,13 +164,21 @@ export const regionalPromptsSlice = createSlice({
state.layers = []; state.layers = [];
state.selectedLayer = null; state.selectedLayer = null;
}, },
promptChanged: (state, action: PayloadAction<{ layerId: string; prompt: string }>) => { positivePromptChanged: (state, action: PayloadAction<{ layerId: string; prompt: string }>) => {
const { layerId, prompt } = action.payload; const { layerId, prompt } = action.payload;
const layer = state.layers.find((l) => l.id === layerId); const layer = state.layers.find((l) => l.id === layerId);
if (!layer) { if (!layer) {
return; return;
} }
layer.prompt = prompt; layer.positivePrompt = prompt;
},
negativePromptChanged: (state, action: PayloadAction<{ layerId: string; prompt: string }>) => {
const { layerId, prompt } = action.payload;
const layer = state.layers.find((l) => l.id === layerId);
if (!layer) {
return;
}
layer.negativePrompt = prompt;
}, },
promptRegionLayerColorChanged: (state, action: PayloadAction<{ layerId: string; color: RgbColor }>) => { promptRegionLayerColorChanged: (state, action: PayloadAction<{ layerId: string; color: RgbColor }>) => {
const { layerId, color } = action.payload; const { layerId, color } = action.payload;
@ -254,7 +263,8 @@ export const {
layerReset, layerReset,
layerDeleted, layerDeleted,
layerIsVisibleToggled, layerIsVisibleToggled,
promptChanged, positivePromptChanged,
negativePromptChanged,
lineAdded, lineAdded,
pointsAdded, pointsAdded,
promptRegionLayerColorChanged, promptRegionLayerColorChanged,

View File

@ -48,7 +48,7 @@ export const getRegionalPromptLayerBlobs = async (
if (preview) { if (preview) {
const base64 = await blobToDataURL(blob); const base64 = await blobToDataURL(blob);
openBase64ImageInTab([{ base64, caption: `${reduxLayer.id}: ${reduxLayer.prompt}` }]); openBase64ImageInTab([{ base64, caption: `${reduxLayer.id}: ${reduxLayer.positivePrompt} / ${reduxLayer.negativePrompt}` }]);
} }
layer.remove(); layer.remove();
blobs[layer.id()] = blob; blobs[layer.id()] = blob;