feat(ui): support negative regional prompt

This commit is contained in:
psychedelicious 2024-04-17 19:56:07 +10:00 committed by Kent Keirsey
parent aa6bfc8645
commit a5bfe2dccb
8 changed files with 151 additions and 32 deletions

View File

@ -1,11 +1,13 @@
import { getStore } from 'app/store/nanostores/store';
import type { RootState } from 'app/store/store';
import {
NEGATIVE_CONDITIONING,
NEGATIVE_CONDITIONING_COLLECT,
POSITIVE_CONDITIONING,
POSITIVE_CONDITIONING_COLLECT,
PROMPT_REGION_COND_PREFIX,
PROMPT_REGION_MASK_PREFIX,
PROMPT_REGION_NEGATIVE_COND_PREFIX,
PROMPT_REGION_POSITIVE_COND_PREFIX,
} from 'features/nodes/util/graph/constants';
import { getRegionalPromptLayerBlobs } from 'features/regionalPrompts/util/getLayerBlobs';
import { size } from 'lodash-es';
@ -116,34 +118,55 @@ export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNull
// Create the conditioning nodes for each region - different handling for SDXL
// TODO: negative prompt
const regionalCondNodeId = `${PROMPT_REGION_COND_PREFIX}_${layerId}`;
const regionalPositiveCondNodeId = `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${layerId}`;
const regionalNegativeCondNodeId = `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${layerId}`;
if (isSDXL) {
graph.nodes[regionalCondNodeId] = {
graph.nodes[regionalPositiveCondNodeId] = {
type: 'sdxl_compel_prompt',
id: regionalCondNodeId,
prompt: layer.prompt,
id: regionalPositiveCondNodeId,
prompt: layer.positivePrompt,
};
graph.nodes[regionalNegativeCondNodeId] = {
type: 'sdxl_compel_prompt',
id: regionalNegativeCondNodeId,
prompt: layer.negativePrompt,
};
} else {
graph.nodes[regionalCondNodeId] = {
type: 'compel',
id: regionalCondNodeId,
prompt: layer.prompt,
};
// TODO: non sdxl
// graph.nodes[regionalCondNodeId] = {
// type: 'compel',
// id: regionalCondNodeId,
// prompt: layer.prompt,
// };
}
graph.edges.push({
source: { node_id: id, field: 'mask' },
destination: { node_id: regionalCondNodeId, field: 'mask' },
destination: { node_id: regionalPositiveCondNodeId, field: 'mask' },
});
graph.edges.push({
source: { node_id: regionalCondNodeId, field: 'conditioning' },
source: { node_id: id, field: 'mask' },
destination: { node_id: regionalNegativeCondNodeId, field: 'mask' },
});
graph.edges.push({
source: { node_id: regionalPositiveCondNodeId, field: 'conditioning' },
destination: { node_id: posCondCollectNode.id, field: 'item' },
});
graph.edges.push({
source: { node_id: regionalNegativeCondNodeId, field: 'conditioning' },
destination: { node_id: negCondCollectNode.id, field: 'item' },
});
for (const edge of graph.edges) {
if (edge.destination.node_id === POSITIVE_CONDITIONING && edge.destination.field !== 'prompt') {
graph.edges.push({
source: edge.source,
destination: { node_id: regionalCondNodeId, field: edge.destination.field },
destination: { node_id: regionalPositiveCondNodeId, field: edge.destination.field },
});
}
if (edge.destination.node_id === NEGATIVE_CONDITIONING && edge.destination.field !== 'prompt') {
graph.edges.push({
source: edge.source,
destination: { node_id: regionalNegativeCondNodeId, field: edge.destination.field },
});
}
}

View File

@ -47,7 +47,8 @@ export const SDXL_REFINER_INPAINT_CREATE_MASK = 'refiner_inpaint_create_mask';
export const SEAMLESS = 'seamless';
export const SDXL_REFINER_SEAMLESS = 'refiner_seamless';
export const PROMPT_REGION_MASK_PREFIX = 'prompt_region_mask';
export const PROMPT_REGION_COND_PREFIX = 'prompt_region_cond';
export const PROMPT_REGION_POSITIVE_COND_PREFIX = 'prompt_region_positive_cond';
export const PROMPT_REGION_NEGATIVE_COND_PREFIX = 'prompt_region_negative_cond';
export const POSITIVE_CONDITIONING_COLLECT = 'positive_conditioning_collect';
export const NEGATIVE_CONDITIONING_COLLECT = 'negative_conditioning_collect';

View File

@ -4,7 +4,8 @@ import { rgbColorToString } from 'features/canvas/util/colorToString';
import { LayerColorPicker } from 'features/regionalPrompts/components/LayerColorPicker';
import { LayerMenu } from 'features/regionalPrompts/components/LayerMenu';
import { LayerVisibilityToggle } from 'features/regionalPrompts/components/LayerVisibilityToggle';
import { RegionalPromptsPrompt } from 'features/regionalPrompts/components/RegionalPromptsPrompt';
import { RegionalPromptsNegativePrompt } from 'features/regionalPrompts/components/RegionalPromptsNegativePrompt';
import { RegionalPromptsPositivePrompt } from 'features/regionalPrompts/components/RegionalPromptsPositivePrompt';
import { layerSelected } from 'features/regionalPrompts/store/regionalPromptsSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
@ -42,7 +43,8 @@ export const LayerListItem = memo(({ id }: Props) => {
)}
<LayerMenu id={id} />
</Flex>
<RegionalPromptsPrompt layerId={id} />
<RegionalPromptsPositivePrompt layerId={id} />
<RegionalPromptsNegativePrompt layerId={id} />
</Flex>
</Flex>
);

View File

@ -0,0 +1,69 @@
import { Box, Textarea } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { PromptOverlayButtonWrapper } from 'features/parameters/components/Prompts/PromptOverlayButtonWrapper';
import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton';
import { PromptPopover } from 'features/prompt/PromptPopover';
import { usePrompt } from 'features/prompt/usePrompt';
import { useLayerNegativePrompt } from 'features/regionalPrompts/hooks/layerStateHooks';
import { negativePromptChanged } from 'features/regionalPrompts/store/regionalPromptsSlice';
import { memo, useCallback, useRef } from 'react';
import type { HotkeyCallback } from 'react-hotkeys-hook';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
type Props = {
layerId: string;
};
export const RegionalPromptsNegativePrompt = memo((props: Props) => {
const prompt = useLayerNegativePrompt(props.layerId);
const dispatch = useAppDispatch();
const textareaRef = useRef<HTMLTextAreaElement>(null);
const { t } = useTranslation();
const _onChange = useCallback(
(v: string) => {
dispatch(negativePromptChanged({ layerId: props.layerId, prompt: v }));
},
[dispatch, props.layerId]
);
const { onChange, isOpen, onClose, onOpen, onSelect, onKeyDown, onFocus } = usePrompt({
prompt,
textareaRef,
onChange: _onChange,
});
const focus: HotkeyCallback = useCallback(
(e) => {
onFocus();
e.preventDefault();
},
[onFocus]
);
useHotkeys('alt+a', focus, []);
return (
<PromptPopover isOpen={isOpen} onClose={onClose} onSelect={onSelect} width={textareaRef.current?.clientWidth}>
<Box pos="relative" w="full">
<Textarea
id="prompt"
name="prompt"
ref={textareaRef}
value={prompt}
placeholder={t('parameters.negativePromptPlaceholder')}
onChange={onChange}
minH={28}
minW={64}
onKeyDown={onKeyDown}
variant="darkFilled"
paddingRight={30}
fontSize="sm"
/>
<PromptOverlayButtonWrapper>
<AddPromptTriggerButton isOpen={isOpen} onOpen={onOpen} />
</PromptOverlayButtonWrapper>
</Box>
</PromptPopover>
);
});
RegionalPromptsNegativePrompt.displayName = 'RegionalPromptsPrompt';

View File

@ -4,8 +4,8 @@ import { PromptOverlayButtonWrapper } from 'features/parameters/components/Promp
import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton';
import { PromptPopover } from 'features/prompt/PromptPopover';
import { usePrompt } from 'features/prompt/usePrompt';
import { useLayerPrompt } from 'features/regionalPrompts/hooks/layerStateHooks';
import { promptChanged } from 'features/regionalPrompts/store/regionalPromptsSlice';
import { useLayerPositivePrompt } from 'features/regionalPrompts/hooks/layerStateHooks';
import { positivePromptChanged } from 'features/regionalPrompts/store/regionalPromptsSlice';
import { memo, useCallback, useRef } from 'react';
import type { HotkeyCallback } from 'react-hotkeys-hook';
import { useHotkeys } from 'react-hotkeys-hook';
@ -15,14 +15,14 @@ type Props = {
layerId: string;
};
export const RegionalPromptsPrompt = memo((props: Props) => {
const prompt = useLayerPrompt(props.layerId);
export const RegionalPromptsPositivePrompt = memo((props: Props) => {
const prompt = useLayerPositivePrompt(props.layerId);
const dispatch = useAppDispatch();
const textareaRef = useRef<HTMLTextAreaElement>(null);
const { t } = useTranslation();
const _onChange = useCallback(
(v: string) => {
dispatch(promptChanged({ layerId: props.layerId, prompt: v }));
dispatch(positivePromptChanged({ layerId: props.layerId, prompt: v }));
},
[dispatch, props.layerId]
);
@ -65,4 +65,4 @@ export const RegionalPromptsPrompt = memo((props: Props) => {
);
});
RegionalPromptsPrompt.displayName = 'RegionalPromptsPrompt';
RegionalPromptsPositivePrompt.displayName = 'RegionalPromptsPrompt';

View File

@ -17,12 +17,26 @@ export const useLayer = (layerId: string) => {
return layer;
};
export const useLayerPrompt = (layerId: string) => {
export const useLayerPositivePrompt = (layerId: string) => {
const selectLayer = useMemo(
() =>
createSelector(
selectRegionalPromptsSlice,
(regionalPrompts) => regionalPrompts.layers.find((l) => l.id === layerId)?.prompt
(regionalPrompts) => regionalPrompts.layers.find((l) => l.id === layerId)?.positivePrompt
),
[layerId]
);
const prompt = useAppSelector(selectLayer);
assert(prompt !== undefined, `Layer ${layerId} doesn't exist!`);
return prompt;
};
export const useLayerNegativePrompt = (layerId: string) => {
const selectLayer = useMemo(
() =>
createSelector(
selectRegionalPromptsSlice,
(regionalPrompts) => regionalPrompts.layers.find((l) => l.id === layerId)?.negativePrompt
),
[layerId]
);

View File

@ -52,7 +52,8 @@ type LayerBase = {
type PromptRegionLayer = LayerBase & {
kind: 'promptRegionLayer';
objects: LayerObject[];
prompt: string;
positivePrompt: string;
negativePrompt: string;
color: RgbColor;
};
@ -73,7 +74,7 @@ const initialRegionalPromptsState: RegionalPromptsState = {
selectedLayer: null,
brushSize: 40,
layers: [],
promptLayerOpacity: 0.5,
promptLayerOpacity: 0.5, // This currently doesn't work
};
const isLine = (obj: LayerObject): obj is LineObject => obj.kind === 'line';
@ -89,7 +90,8 @@ export const regionalPromptsSlice = createSlice({
isVisible: true,
bbox: null,
kind: action.payload,
prompt: '',
positivePrompt: '',
negativePrompt: '',
objects: [],
color: action.meta.color,
x: 0,
@ -118,7 +120,6 @@ export const regionalPromptsSlice = createSlice({
layer.objects = [];
layer.bbox = null;
layer.isVisible = true;
layer.prompt = '';
},
layerDeleted: (state, action: PayloadAction<string>) => {
state.layers = state.layers.filter((l) => l.id !== action.payload);
@ -163,13 +164,21 @@ export const regionalPromptsSlice = createSlice({
state.layers = [];
state.selectedLayer = null;
},
promptChanged: (state, action: PayloadAction<{ layerId: string; prompt: string }>) => {
positivePromptChanged: (state, action: PayloadAction<{ layerId: string; prompt: string }>) => {
const { layerId, prompt } = action.payload;
const layer = state.layers.find((l) => l.id === layerId);
if (!layer) {
return;
}
layer.prompt = prompt;
layer.positivePrompt = prompt;
},
negativePromptChanged: (state, action: PayloadAction<{ layerId: string; prompt: string }>) => {
const { layerId, prompt } = action.payload;
const layer = state.layers.find((l) => l.id === layerId);
if (!layer) {
return;
}
layer.negativePrompt = prompt;
},
promptRegionLayerColorChanged: (state, action: PayloadAction<{ layerId: string; color: RgbColor }>) => {
const { layerId, color } = action.payload;
@ -254,7 +263,8 @@ export const {
layerReset,
layerDeleted,
layerIsVisibleToggled,
promptChanged,
positivePromptChanged,
negativePromptChanged,
lineAdded,
pointsAdded,
promptRegionLayerColorChanged,

View File

@ -48,7 +48,7 @@ export const getRegionalPromptLayerBlobs = async (
if (preview) {
const base64 = await blobToDataURL(blob);
openBase64ImageInTab([{ base64, caption: `${reduxLayer.id}: ${reduxLayer.prompt}` }]);
openBase64ImageInTab([{ base64, caption: `${reduxLayer.id}: ${reduxLayer.positivePrompt} / ${reduxLayer.negativePrompt}` }]);
}
layer.remove();
blobs[layer.id()] = blob;