mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): support negative regional prompt
This commit is contained in:
parent
aa6bfc8645
commit
a5bfe2dccb
@ -1,11 +1,13 @@
|
||||
import { getStore } from 'app/store/nanostores/store';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import {
|
||||
NEGATIVE_CONDITIONING,
|
||||
NEGATIVE_CONDITIONING_COLLECT,
|
||||
POSITIVE_CONDITIONING,
|
||||
POSITIVE_CONDITIONING_COLLECT,
|
||||
PROMPT_REGION_COND_PREFIX,
|
||||
PROMPT_REGION_MASK_PREFIX,
|
||||
PROMPT_REGION_NEGATIVE_COND_PREFIX,
|
||||
PROMPT_REGION_POSITIVE_COND_PREFIX,
|
||||
} from 'features/nodes/util/graph/constants';
|
||||
import { getRegionalPromptLayerBlobs } from 'features/regionalPrompts/util/getLayerBlobs';
|
||||
import { size } from 'lodash-es';
|
||||
@ -116,34 +118,55 @@ export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNull
|
||||
// Create the conditioning nodes for each region - different handling for SDXL
|
||||
|
||||
// TODO: negative prompt
|
||||
const regionalCondNodeId = `${PROMPT_REGION_COND_PREFIX}_${layerId}`;
|
||||
const regionalPositiveCondNodeId = `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${layerId}`;
|
||||
const regionalNegativeCondNodeId = `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${layerId}`;
|
||||
|
||||
if (isSDXL) {
|
||||
graph.nodes[regionalCondNodeId] = {
|
||||
graph.nodes[regionalPositiveCondNodeId] = {
|
||||
type: 'sdxl_compel_prompt',
|
||||
id: regionalCondNodeId,
|
||||
prompt: layer.prompt,
|
||||
id: regionalPositiveCondNodeId,
|
||||
prompt: layer.positivePrompt,
|
||||
};
|
||||
graph.nodes[regionalNegativeCondNodeId] = {
|
||||
type: 'sdxl_compel_prompt',
|
||||
id: regionalNegativeCondNodeId,
|
||||
prompt: layer.negativePrompt,
|
||||
};
|
||||
} else {
|
||||
graph.nodes[regionalCondNodeId] = {
|
||||
type: 'compel',
|
||||
id: regionalCondNodeId,
|
||||
prompt: layer.prompt,
|
||||
};
|
||||
// TODO: non sdxl
|
||||
// graph.nodes[regionalCondNodeId] = {
|
||||
// type: 'compel',
|
||||
// id: regionalCondNodeId,
|
||||
// prompt: layer.prompt,
|
||||
// };
|
||||
}
|
||||
graph.edges.push({
|
||||
source: { node_id: id, field: 'mask' },
|
||||
destination: { node_id: regionalCondNodeId, field: 'mask' },
|
||||
destination: { node_id: regionalPositiveCondNodeId, field: 'mask' },
|
||||
});
|
||||
graph.edges.push({
|
||||
source: { node_id: regionalCondNodeId, field: 'conditioning' },
|
||||
source: { node_id: id, field: 'mask' },
|
||||
destination: { node_id: regionalNegativeCondNodeId, field: 'mask' },
|
||||
});
|
||||
graph.edges.push({
|
||||
source: { node_id: regionalPositiveCondNodeId, field: 'conditioning' },
|
||||
destination: { node_id: posCondCollectNode.id, field: 'item' },
|
||||
});
|
||||
graph.edges.push({
|
||||
source: { node_id: regionalNegativeCondNodeId, field: 'conditioning' },
|
||||
destination: { node_id: negCondCollectNode.id, field: 'item' },
|
||||
});
|
||||
for (const edge of graph.edges) {
|
||||
if (edge.destination.node_id === POSITIVE_CONDITIONING && edge.destination.field !== 'prompt') {
|
||||
graph.edges.push({
|
||||
source: edge.source,
|
||||
destination: { node_id: regionalCondNodeId, field: edge.destination.field },
|
||||
destination: { node_id: regionalPositiveCondNodeId, field: edge.destination.field },
|
||||
});
|
||||
}
|
||||
if (edge.destination.node_id === NEGATIVE_CONDITIONING && edge.destination.field !== 'prompt') {
|
||||
graph.edges.push({
|
||||
source: edge.source,
|
||||
destination: { node_id: regionalNegativeCondNodeId, field: edge.destination.field },
|
||||
});
|
||||
}
|
||||
}
|
||||
|
@ -47,7 +47,8 @@ export const SDXL_REFINER_INPAINT_CREATE_MASK = 'refiner_inpaint_create_mask';
|
||||
export const SEAMLESS = 'seamless';
|
||||
export const SDXL_REFINER_SEAMLESS = 'refiner_seamless';
|
||||
export const PROMPT_REGION_MASK_PREFIX = 'prompt_region_mask';
|
||||
export const PROMPT_REGION_COND_PREFIX = 'prompt_region_cond';
|
||||
export const PROMPT_REGION_POSITIVE_COND_PREFIX = 'prompt_region_positive_cond';
|
||||
export const PROMPT_REGION_NEGATIVE_COND_PREFIX = 'prompt_region_negative_cond';
|
||||
export const POSITIVE_CONDITIONING_COLLECT = 'positive_conditioning_collect';
|
||||
export const NEGATIVE_CONDITIONING_COLLECT = 'negative_conditioning_collect';
|
||||
|
||||
|
@ -4,7 +4,8 @@ import { rgbColorToString } from 'features/canvas/util/colorToString';
|
||||
import { LayerColorPicker } from 'features/regionalPrompts/components/LayerColorPicker';
|
||||
import { LayerMenu } from 'features/regionalPrompts/components/LayerMenu';
|
||||
import { LayerVisibilityToggle } from 'features/regionalPrompts/components/LayerVisibilityToggle';
|
||||
import { RegionalPromptsPrompt } from 'features/regionalPrompts/components/RegionalPromptsPrompt';
|
||||
import { RegionalPromptsNegativePrompt } from 'features/regionalPrompts/components/RegionalPromptsNegativePrompt';
|
||||
import { RegionalPromptsPositivePrompt } from 'features/regionalPrompts/components/RegionalPromptsPositivePrompt';
|
||||
import { layerSelected } from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@ -42,7 +43,8 @@ export const LayerListItem = memo(({ id }: Props) => {
|
||||
)}
|
||||
<LayerMenu id={id} />
|
||||
</Flex>
|
||||
<RegionalPromptsPrompt layerId={id} />
|
||||
<RegionalPromptsPositivePrompt layerId={id} />
|
||||
<RegionalPromptsNegativePrompt layerId={id} />
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
|
@ -0,0 +1,69 @@
|
||||
import { Box, Textarea } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { PromptOverlayButtonWrapper } from 'features/parameters/components/Prompts/PromptOverlayButtonWrapper';
|
||||
import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton';
|
||||
import { PromptPopover } from 'features/prompt/PromptPopover';
|
||||
import { usePrompt } from 'features/prompt/usePrompt';
|
||||
import { useLayerNegativePrompt } from 'features/regionalPrompts/hooks/layerStateHooks';
|
||||
import { negativePromptChanged } from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||
import { memo, useCallback, useRef } from 'react';
|
||||
import type { HotkeyCallback } from 'react-hotkeys-hook';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
type Props = {
|
||||
layerId: string;
|
||||
};
|
||||
|
||||
export const RegionalPromptsNegativePrompt = memo((props: Props) => {
|
||||
const prompt = useLayerNegativePrompt(props.layerId);
|
||||
const dispatch = useAppDispatch();
|
||||
const textareaRef = useRef<HTMLTextAreaElement>(null);
|
||||
const { t } = useTranslation();
|
||||
const _onChange = useCallback(
|
||||
(v: string) => {
|
||||
dispatch(negativePromptChanged({ layerId: props.layerId, prompt: v }));
|
||||
},
|
||||
[dispatch, props.layerId]
|
||||
);
|
||||
const { onChange, isOpen, onClose, onOpen, onSelect, onKeyDown, onFocus } = usePrompt({
|
||||
prompt,
|
||||
textareaRef,
|
||||
onChange: _onChange,
|
||||
});
|
||||
const focus: HotkeyCallback = useCallback(
|
||||
(e) => {
|
||||
onFocus();
|
||||
e.preventDefault();
|
||||
},
|
||||
[onFocus]
|
||||
);
|
||||
|
||||
useHotkeys('alt+a', focus, []);
|
||||
|
||||
return (
|
||||
<PromptPopover isOpen={isOpen} onClose={onClose} onSelect={onSelect} width={textareaRef.current?.clientWidth}>
|
||||
<Box pos="relative" w="full">
|
||||
<Textarea
|
||||
id="prompt"
|
||||
name="prompt"
|
||||
ref={textareaRef}
|
||||
value={prompt}
|
||||
placeholder={t('parameters.negativePromptPlaceholder')}
|
||||
onChange={onChange}
|
||||
minH={28}
|
||||
minW={64}
|
||||
onKeyDown={onKeyDown}
|
||||
variant="darkFilled"
|
||||
paddingRight={30}
|
||||
fontSize="sm"
|
||||
/>
|
||||
<PromptOverlayButtonWrapper>
|
||||
<AddPromptTriggerButton isOpen={isOpen} onOpen={onOpen} />
|
||||
</PromptOverlayButtonWrapper>
|
||||
</Box>
|
||||
</PromptPopover>
|
||||
);
|
||||
});
|
||||
|
||||
RegionalPromptsNegativePrompt.displayName = 'RegionalPromptsPrompt';
|
@ -4,8 +4,8 @@ import { PromptOverlayButtonWrapper } from 'features/parameters/components/Promp
|
||||
import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton';
|
||||
import { PromptPopover } from 'features/prompt/PromptPopover';
|
||||
import { usePrompt } from 'features/prompt/usePrompt';
|
||||
import { useLayerPrompt } from 'features/regionalPrompts/hooks/layerStateHooks';
|
||||
import { promptChanged } from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||
import { useLayerPositivePrompt } from 'features/regionalPrompts/hooks/layerStateHooks';
|
||||
import { positivePromptChanged } from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||
import { memo, useCallback, useRef } from 'react';
|
||||
import type { HotkeyCallback } from 'react-hotkeys-hook';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
@ -15,14 +15,14 @@ type Props = {
|
||||
layerId: string;
|
||||
};
|
||||
|
||||
export const RegionalPromptsPrompt = memo((props: Props) => {
|
||||
const prompt = useLayerPrompt(props.layerId);
|
||||
export const RegionalPromptsPositivePrompt = memo((props: Props) => {
|
||||
const prompt = useLayerPositivePrompt(props.layerId);
|
||||
const dispatch = useAppDispatch();
|
||||
const textareaRef = useRef<HTMLTextAreaElement>(null);
|
||||
const { t } = useTranslation();
|
||||
const _onChange = useCallback(
|
||||
(v: string) => {
|
||||
dispatch(promptChanged({ layerId: props.layerId, prompt: v }));
|
||||
dispatch(positivePromptChanged({ layerId: props.layerId, prompt: v }));
|
||||
},
|
||||
[dispatch, props.layerId]
|
||||
);
|
||||
@ -65,4 +65,4 @@ export const RegionalPromptsPrompt = memo((props: Props) => {
|
||||
);
|
||||
});
|
||||
|
||||
RegionalPromptsPrompt.displayName = 'RegionalPromptsPrompt';
|
||||
RegionalPromptsPositivePrompt.displayName = 'RegionalPromptsPrompt';
|
@ -17,12 +17,26 @@ export const useLayer = (layerId: string) => {
|
||||
return layer;
|
||||
};
|
||||
|
||||
export const useLayerPrompt = (layerId: string) => {
|
||||
export const useLayerPositivePrompt = (layerId: string) => {
|
||||
const selectLayer = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
selectRegionalPromptsSlice,
|
||||
(regionalPrompts) => regionalPrompts.layers.find((l) => l.id === layerId)?.prompt
|
||||
(regionalPrompts) => regionalPrompts.layers.find((l) => l.id === layerId)?.positivePrompt
|
||||
),
|
||||
[layerId]
|
||||
);
|
||||
const prompt = useAppSelector(selectLayer);
|
||||
assert(prompt !== undefined, `Layer ${layerId} doesn't exist!`);
|
||||
return prompt;
|
||||
};
|
||||
|
||||
export const useLayerNegativePrompt = (layerId: string) => {
|
||||
const selectLayer = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
selectRegionalPromptsSlice,
|
||||
(regionalPrompts) => regionalPrompts.layers.find((l) => l.id === layerId)?.negativePrompt
|
||||
),
|
||||
[layerId]
|
||||
);
|
||||
|
@ -52,7 +52,8 @@ type LayerBase = {
|
||||
type PromptRegionLayer = LayerBase & {
|
||||
kind: 'promptRegionLayer';
|
||||
objects: LayerObject[];
|
||||
prompt: string;
|
||||
positivePrompt: string;
|
||||
negativePrompt: string;
|
||||
color: RgbColor;
|
||||
};
|
||||
|
||||
@ -73,7 +74,7 @@ const initialRegionalPromptsState: RegionalPromptsState = {
|
||||
selectedLayer: null,
|
||||
brushSize: 40,
|
||||
layers: [],
|
||||
promptLayerOpacity: 0.5,
|
||||
promptLayerOpacity: 0.5, // This currently doesn't work
|
||||
};
|
||||
|
||||
const isLine = (obj: LayerObject): obj is LineObject => obj.kind === 'line';
|
||||
@ -89,7 +90,8 @@ export const regionalPromptsSlice = createSlice({
|
||||
isVisible: true,
|
||||
bbox: null,
|
||||
kind: action.payload,
|
||||
prompt: '',
|
||||
positivePrompt: '',
|
||||
negativePrompt: '',
|
||||
objects: [],
|
||||
color: action.meta.color,
|
||||
x: 0,
|
||||
@ -118,7 +120,6 @@ export const regionalPromptsSlice = createSlice({
|
||||
layer.objects = [];
|
||||
layer.bbox = null;
|
||||
layer.isVisible = true;
|
||||
layer.prompt = '';
|
||||
},
|
||||
layerDeleted: (state, action: PayloadAction<string>) => {
|
||||
state.layers = state.layers.filter((l) => l.id !== action.payload);
|
||||
@ -163,13 +164,21 @@ export const regionalPromptsSlice = createSlice({
|
||||
state.layers = [];
|
||||
state.selectedLayer = null;
|
||||
},
|
||||
promptChanged: (state, action: PayloadAction<{ layerId: string; prompt: string }>) => {
|
||||
positivePromptChanged: (state, action: PayloadAction<{ layerId: string; prompt: string }>) => {
|
||||
const { layerId, prompt } = action.payload;
|
||||
const layer = state.layers.find((l) => l.id === layerId);
|
||||
if (!layer) {
|
||||
return;
|
||||
}
|
||||
layer.prompt = prompt;
|
||||
layer.positivePrompt = prompt;
|
||||
},
|
||||
negativePromptChanged: (state, action: PayloadAction<{ layerId: string; prompt: string }>) => {
|
||||
const { layerId, prompt } = action.payload;
|
||||
const layer = state.layers.find((l) => l.id === layerId);
|
||||
if (!layer) {
|
||||
return;
|
||||
}
|
||||
layer.negativePrompt = prompt;
|
||||
},
|
||||
promptRegionLayerColorChanged: (state, action: PayloadAction<{ layerId: string; color: RgbColor }>) => {
|
||||
const { layerId, color } = action.payload;
|
||||
@ -254,7 +263,8 @@ export const {
|
||||
layerReset,
|
||||
layerDeleted,
|
||||
layerIsVisibleToggled,
|
||||
promptChanged,
|
||||
positivePromptChanged,
|
||||
negativePromptChanged,
|
||||
lineAdded,
|
||||
pointsAdded,
|
||||
promptRegionLayerColorChanged,
|
||||
|
@ -48,7 +48,7 @@ export const getRegionalPromptLayerBlobs = async (
|
||||
|
||||
if (preview) {
|
||||
const base64 = await blobToDataURL(blob);
|
||||
openBase64ImageInTab([{ base64, caption: `${reduxLayer.id}: ${reduxLayer.prompt}` }]);
|
||||
openBase64ImageInTab([{ base64, caption: `${reduxLayer.id}: ${reduxLayer.positivePrompt} / ${reduxLayer.negativePrompt}` }]);
|
||||
}
|
||||
layer.remove();
|
||||
blobs[layer.id()] = blob;
|
||||
|
Loading…
Reference in New Issue
Block a user