mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): support negative regional prompt
This commit is contained in:
parent
aa6bfc8645
commit
a5bfe2dccb
@ -1,11 +1,13 @@
|
|||||||
import { getStore } from 'app/store/nanostores/store';
|
import { getStore } from 'app/store/nanostores/store';
|
||||||
import type { RootState } from 'app/store/store';
|
import type { RootState } from 'app/store/store';
|
||||||
import {
|
import {
|
||||||
|
NEGATIVE_CONDITIONING,
|
||||||
NEGATIVE_CONDITIONING_COLLECT,
|
NEGATIVE_CONDITIONING_COLLECT,
|
||||||
POSITIVE_CONDITIONING,
|
POSITIVE_CONDITIONING,
|
||||||
POSITIVE_CONDITIONING_COLLECT,
|
POSITIVE_CONDITIONING_COLLECT,
|
||||||
PROMPT_REGION_COND_PREFIX,
|
|
||||||
PROMPT_REGION_MASK_PREFIX,
|
PROMPT_REGION_MASK_PREFIX,
|
||||||
|
PROMPT_REGION_NEGATIVE_COND_PREFIX,
|
||||||
|
PROMPT_REGION_POSITIVE_COND_PREFIX,
|
||||||
} from 'features/nodes/util/graph/constants';
|
} from 'features/nodes/util/graph/constants';
|
||||||
import { getRegionalPromptLayerBlobs } from 'features/regionalPrompts/util/getLayerBlobs';
|
import { getRegionalPromptLayerBlobs } from 'features/regionalPrompts/util/getLayerBlobs';
|
||||||
import { size } from 'lodash-es';
|
import { size } from 'lodash-es';
|
||||||
@ -116,34 +118,55 @@ export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNull
|
|||||||
// Create the conditioning nodes for each region - different handling for SDXL
|
// Create the conditioning nodes for each region - different handling for SDXL
|
||||||
|
|
||||||
// TODO: negative prompt
|
// TODO: negative prompt
|
||||||
const regionalCondNodeId = `${PROMPT_REGION_COND_PREFIX}_${layerId}`;
|
const regionalPositiveCondNodeId = `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${layerId}`;
|
||||||
|
const regionalNegativeCondNodeId = `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${layerId}`;
|
||||||
|
|
||||||
if (isSDXL) {
|
if (isSDXL) {
|
||||||
graph.nodes[regionalCondNodeId] = {
|
graph.nodes[regionalPositiveCondNodeId] = {
|
||||||
type: 'sdxl_compel_prompt',
|
type: 'sdxl_compel_prompt',
|
||||||
id: regionalCondNodeId,
|
id: regionalPositiveCondNodeId,
|
||||||
prompt: layer.prompt,
|
prompt: layer.positivePrompt,
|
||||||
|
};
|
||||||
|
graph.nodes[regionalNegativeCondNodeId] = {
|
||||||
|
type: 'sdxl_compel_prompt',
|
||||||
|
id: regionalNegativeCondNodeId,
|
||||||
|
prompt: layer.negativePrompt,
|
||||||
};
|
};
|
||||||
} else {
|
} else {
|
||||||
graph.nodes[regionalCondNodeId] = {
|
// TODO: non sdxl
|
||||||
type: 'compel',
|
// graph.nodes[regionalCondNodeId] = {
|
||||||
id: regionalCondNodeId,
|
// type: 'compel',
|
||||||
prompt: layer.prompt,
|
// id: regionalCondNodeId,
|
||||||
};
|
// prompt: layer.prompt,
|
||||||
|
// };
|
||||||
}
|
}
|
||||||
graph.edges.push({
|
graph.edges.push({
|
||||||
source: { node_id: id, field: 'mask' },
|
source: { node_id: id, field: 'mask' },
|
||||||
destination: { node_id: regionalCondNodeId, field: 'mask' },
|
destination: { node_id: regionalPositiveCondNodeId, field: 'mask' },
|
||||||
});
|
});
|
||||||
graph.edges.push({
|
graph.edges.push({
|
||||||
source: { node_id: regionalCondNodeId, field: 'conditioning' },
|
source: { node_id: id, field: 'mask' },
|
||||||
|
destination: { node_id: regionalNegativeCondNodeId, field: 'mask' },
|
||||||
|
});
|
||||||
|
graph.edges.push({
|
||||||
|
source: { node_id: regionalPositiveCondNodeId, field: 'conditioning' },
|
||||||
destination: { node_id: posCondCollectNode.id, field: 'item' },
|
destination: { node_id: posCondCollectNode.id, field: 'item' },
|
||||||
});
|
});
|
||||||
|
graph.edges.push({
|
||||||
|
source: { node_id: regionalNegativeCondNodeId, field: 'conditioning' },
|
||||||
|
destination: { node_id: negCondCollectNode.id, field: 'item' },
|
||||||
|
});
|
||||||
for (const edge of graph.edges) {
|
for (const edge of graph.edges) {
|
||||||
if (edge.destination.node_id === POSITIVE_CONDITIONING && edge.destination.field !== 'prompt') {
|
if (edge.destination.node_id === POSITIVE_CONDITIONING && edge.destination.field !== 'prompt') {
|
||||||
graph.edges.push({
|
graph.edges.push({
|
||||||
source: edge.source,
|
source: edge.source,
|
||||||
destination: { node_id: regionalCondNodeId, field: edge.destination.field },
|
destination: { node_id: regionalPositiveCondNodeId, field: edge.destination.field },
|
||||||
|
});
|
||||||
|
}
|
||||||
|
if (edge.destination.node_id === NEGATIVE_CONDITIONING && edge.destination.field !== 'prompt') {
|
||||||
|
graph.edges.push({
|
||||||
|
source: edge.source,
|
||||||
|
destination: { node_id: regionalNegativeCondNodeId, field: edge.destination.field },
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -47,7 +47,8 @@ export const SDXL_REFINER_INPAINT_CREATE_MASK = 'refiner_inpaint_create_mask';
|
|||||||
export const SEAMLESS = 'seamless';
|
export const SEAMLESS = 'seamless';
|
||||||
export const SDXL_REFINER_SEAMLESS = 'refiner_seamless';
|
export const SDXL_REFINER_SEAMLESS = 'refiner_seamless';
|
||||||
export const PROMPT_REGION_MASK_PREFIX = 'prompt_region_mask';
|
export const PROMPT_REGION_MASK_PREFIX = 'prompt_region_mask';
|
||||||
export const PROMPT_REGION_COND_PREFIX = 'prompt_region_cond';
|
export const PROMPT_REGION_POSITIVE_COND_PREFIX = 'prompt_region_positive_cond';
|
||||||
|
export const PROMPT_REGION_NEGATIVE_COND_PREFIX = 'prompt_region_negative_cond';
|
||||||
export const POSITIVE_CONDITIONING_COLLECT = 'positive_conditioning_collect';
|
export const POSITIVE_CONDITIONING_COLLECT = 'positive_conditioning_collect';
|
||||||
export const NEGATIVE_CONDITIONING_COLLECT = 'negative_conditioning_collect';
|
export const NEGATIVE_CONDITIONING_COLLECT = 'negative_conditioning_collect';
|
||||||
|
|
||||||
|
@ -4,7 +4,8 @@ import { rgbColorToString } from 'features/canvas/util/colorToString';
|
|||||||
import { LayerColorPicker } from 'features/regionalPrompts/components/LayerColorPicker';
|
import { LayerColorPicker } from 'features/regionalPrompts/components/LayerColorPicker';
|
||||||
import { LayerMenu } from 'features/regionalPrompts/components/LayerMenu';
|
import { LayerMenu } from 'features/regionalPrompts/components/LayerMenu';
|
||||||
import { LayerVisibilityToggle } from 'features/regionalPrompts/components/LayerVisibilityToggle';
|
import { LayerVisibilityToggle } from 'features/regionalPrompts/components/LayerVisibilityToggle';
|
||||||
import { RegionalPromptsPrompt } from 'features/regionalPrompts/components/RegionalPromptsPrompt';
|
import { RegionalPromptsNegativePrompt } from 'features/regionalPrompts/components/RegionalPromptsNegativePrompt';
|
||||||
|
import { RegionalPromptsPositivePrompt } from 'features/regionalPrompts/components/RegionalPromptsPositivePrompt';
|
||||||
import { layerSelected } from 'features/regionalPrompts/store/regionalPromptsSlice';
|
import { layerSelected } from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
@ -42,7 +43,8 @@ export const LayerListItem = memo(({ id }: Props) => {
|
|||||||
)}
|
)}
|
||||||
<LayerMenu id={id} />
|
<LayerMenu id={id} />
|
||||||
</Flex>
|
</Flex>
|
||||||
<RegionalPromptsPrompt layerId={id} />
|
<RegionalPromptsPositivePrompt layerId={id} />
|
||||||
|
<RegionalPromptsNegativePrompt layerId={id} />
|
||||||
</Flex>
|
</Flex>
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
|
@ -0,0 +1,69 @@
|
|||||||
|
import { Box, Textarea } from '@invoke-ai/ui-library';
|
||||||
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
|
import { PromptOverlayButtonWrapper } from 'features/parameters/components/Prompts/PromptOverlayButtonWrapper';
|
||||||
|
import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton';
|
||||||
|
import { PromptPopover } from 'features/prompt/PromptPopover';
|
||||||
|
import { usePrompt } from 'features/prompt/usePrompt';
|
||||||
|
import { useLayerNegativePrompt } from 'features/regionalPrompts/hooks/layerStateHooks';
|
||||||
|
import { negativePromptChanged } from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||||
|
import { memo, useCallback, useRef } from 'react';
|
||||||
|
import type { HotkeyCallback } from 'react-hotkeys-hook';
|
||||||
|
import { useHotkeys } from 'react-hotkeys-hook';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
layerId: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const RegionalPromptsNegativePrompt = memo((props: Props) => {
|
||||||
|
const prompt = useLayerNegativePrompt(props.layerId);
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const textareaRef = useRef<HTMLTextAreaElement>(null);
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const _onChange = useCallback(
|
||||||
|
(v: string) => {
|
||||||
|
dispatch(negativePromptChanged({ layerId: props.layerId, prompt: v }));
|
||||||
|
},
|
||||||
|
[dispatch, props.layerId]
|
||||||
|
);
|
||||||
|
const { onChange, isOpen, onClose, onOpen, onSelect, onKeyDown, onFocus } = usePrompt({
|
||||||
|
prompt,
|
||||||
|
textareaRef,
|
||||||
|
onChange: _onChange,
|
||||||
|
});
|
||||||
|
const focus: HotkeyCallback = useCallback(
|
||||||
|
(e) => {
|
||||||
|
onFocus();
|
||||||
|
e.preventDefault();
|
||||||
|
},
|
||||||
|
[onFocus]
|
||||||
|
);
|
||||||
|
|
||||||
|
useHotkeys('alt+a', focus, []);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<PromptPopover isOpen={isOpen} onClose={onClose} onSelect={onSelect} width={textareaRef.current?.clientWidth}>
|
||||||
|
<Box pos="relative" w="full">
|
||||||
|
<Textarea
|
||||||
|
id="prompt"
|
||||||
|
name="prompt"
|
||||||
|
ref={textareaRef}
|
||||||
|
value={prompt}
|
||||||
|
placeholder={t('parameters.negativePromptPlaceholder')}
|
||||||
|
onChange={onChange}
|
||||||
|
minH={28}
|
||||||
|
minW={64}
|
||||||
|
onKeyDown={onKeyDown}
|
||||||
|
variant="darkFilled"
|
||||||
|
paddingRight={30}
|
||||||
|
fontSize="sm"
|
||||||
|
/>
|
||||||
|
<PromptOverlayButtonWrapper>
|
||||||
|
<AddPromptTriggerButton isOpen={isOpen} onOpen={onOpen} />
|
||||||
|
</PromptOverlayButtonWrapper>
|
||||||
|
</Box>
|
||||||
|
</PromptPopover>
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
RegionalPromptsNegativePrompt.displayName = 'RegionalPromptsPrompt';
|
@ -4,8 +4,8 @@ import { PromptOverlayButtonWrapper } from 'features/parameters/components/Promp
|
|||||||
import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton';
|
import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton';
|
||||||
import { PromptPopover } from 'features/prompt/PromptPopover';
|
import { PromptPopover } from 'features/prompt/PromptPopover';
|
||||||
import { usePrompt } from 'features/prompt/usePrompt';
|
import { usePrompt } from 'features/prompt/usePrompt';
|
||||||
import { useLayerPrompt } from 'features/regionalPrompts/hooks/layerStateHooks';
|
import { useLayerPositivePrompt } from 'features/regionalPrompts/hooks/layerStateHooks';
|
||||||
import { promptChanged } from 'features/regionalPrompts/store/regionalPromptsSlice';
|
import { positivePromptChanged } from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||||
import { memo, useCallback, useRef } from 'react';
|
import { memo, useCallback, useRef } from 'react';
|
||||||
import type { HotkeyCallback } from 'react-hotkeys-hook';
|
import type { HotkeyCallback } from 'react-hotkeys-hook';
|
||||||
import { useHotkeys } from 'react-hotkeys-hook';
|
import { useHotkeys } from 'react-hotkeys-hook';
|
||||||
@ -15,14 +15,14 @@ type Props = {
|
|||||||
layerId: string;
|
layerId: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
export const RegionalPromptsPrompt = memo((props: Props) => {
|
export const RegionalPromptsPositivePrompt = memo((props: Props) => {
|
||||||
const prompt = useLayerPrompt(props.layerId);
|
const prompt = useLayerPositivePrompt(props.layerId);
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const textareaRef = useRef<HTMLTextAreaElement>(null);
|
const textareaRef = useRef<HTMLTextAreaElement>(null);
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const _onChange = useCallback(
|
const _onChange = useCallback(
|
||||||
(v: string) => {
|
(v: string) => {
|
||||||
dispatch(promptChanged({ layerId: props.layerId, prompt: v }));
|
dispatch(positivePromptChanged({ layerId: props.layerId, prompt: v }));
|
||||||
},
|
},
|
||||||
[dispatch, props.layerId]
|
[dispatch, props.layerId]
|
||||||
);
|
);
|
||||||
@ -65,4 +65,4 @@ export const RegionalPromptsPrompt = memo((props: Props) => {
|
|||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
RegionalPromptsPrompt.displayName = 'RegionalPromptsPrompt';
|
RegionalPromptsPositivePrompt.displayName = 'RegionalPromptsPrompt';
|
@ -17,12 +17,26 @@ export const useLayer = (layerId: string) => {
|
|||||||
return layer;
|
return layer;
|
||||||
};
|
};
|
||||||
|
|
||||||
export const useLayerPrompt = (layerId: string) => {
|
export const useLayerPositivePrompt = (layerId: string) => {
|
||||||
const selectLayer = useMemo(
|
const selectLayer = useMemo(
|
||||||
() =>
|
() =>
|
||||||
createSelector(
|
createSelector(
|
||||||
selectRegionalPromptsSlice,
|
selectRegionalPromptsSlice,
|
||||||
(regionalPrompts) => regionalPrompts.layers.find((l) => l.id === layerId)?.prompt
|
(regionalPrompts) => regionalPrompts.layers.find((l) => l.id === layerId)?.positivePrompt
|
||||||
|
),
|
||||||
|
[layerId]
|
||||||
|
);
|
||||||
|
const prompt = useAppSelector(selectLayer);
|
||||||
|
assert(prompt !== undefined, `Layer ${layerId} doesn't exist!`);
|
||||||
|
return prompt;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const useLayerNegativePrompt = (layerId: string) => {
|
||||||
|
const selectLayer = useMemo(
|
||||||
|
() =>
|
||||||
|
createSelector(
|
||||||
|
selectRegionalPromptsSlice,
|
||||||
|
(regionalPrompts) => regionalPrompts.layers.find((l) => l.id === layerId)?.negativePrompt
|
||||||
),
|
),
|
||||||
[layerId]
|
[layerId]
|
||||||
);
|
);
|
||||||
|
@ -52,7 +52,8 @@ type LayerBase = {
|
|||||||
type PromptRegionLayer = LayerBase & {
|
type PromptRegionLayer = LayerBase & {
|
||||||
kind: 'promptRegionLayer';
|
kind: 'promptRegionLayer';
|
||||||
objects: LayerObject[];
|
objects: LayerObject[];
|
||||||
prompt: string;
|
positivePrompt: string;
|
||||||
|
negativePrompt: string;
|
||||||
color: RgbColor;
|
color: RgbColor;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -73,7 +74,7 @@ const initialRegionalPromptsState: RegionalPromptsState = {
|
|||||||
selectedLayer: null,
|
selectedLayer: null,
|
||||||
brushSize: 40,
|
brushSize: 40,
|
||||||
layers: [],
|
layers: [],
|
||||||
promptLayerOpacity: 0.5,
|
promptLayerOpacity: 0.5, // This currently doesn't work
|
||||||
};
|
};
|
||||||
|
|
||||||
const isLine = (obj: LayerObject): obj is LineObject => obj.kind === 'line';
|
const isLine = (obj: LayerObject): obj is LineObject => obj.kind === 'line';
|
||||||
@ -89,7 +90,8 @@ export const regionalPromptsSlice = createSlice({
|
|||||||
isVisible: true,
|
isVisible: true,
|
||||||
bbox: null,
|
bbox: null,
|
||||||
kind: action.payload,
|
kind: action.payload,
|
||||||
prompt: '',
|
positivePrompt: '',
|
||||||
|
negativePrompt: '',
|
||||||
objects: [],
|
objects: [],
|
||||||
color: action.meta.color,
|
color: action.meta.color,
|
||||||
x: 0,
|
x: 0,
|
||||||
@ -118,7 +120,6 @@ export const regionalPromptsSlice = createSlice({
|
|||||||
layer.objects = [];
|
layer.objects = [];
|
||||||
layer.bbox = null;
|
layer.bbox = null;
|
||||||
layer.isVisible = true;
|
layer.isVisible = true;
|
||||||
layer.prompt = '';
|
|
||||||
},
|
},
|
||||||
layerDeleted: (state, action: PayloadAction<string>) => {
|
layerDeleted: (state, action: PayloadAction<string>) => {
|
||||||
state.layers = state.layers.filter((l) => l.id !== action.payload);
|
state.layers = state.layers.filter((l) => l.id !== action.payload);
|
||||||
@ -163,13 +164,21 @@ export const regionalPromptsSlice = createSlice({
|
|||||||
state.layers = [];
|
state.layers = [];
|
||||||
state.selectedLayer = null;
|
state.selectedLayer = null;
|
||||||
},
|
},
|
||||||
promptChanged: (state, action: PayloadAction<{ layerId: string; prompt: string }>) => {
|
positivePromptChanged: (state, action: PayloadAction<{ layerId: string; prompt: string }>) => {
|
||||||
const { layerId, prompt } = action.payload;
|
const { layerId, prompt } = action.payload;
|
||||||
const layer = state.layers.find((l) => l.id === layerId);
|
const layer = state.layers.find((l) => l.id === layerId);
|
||||||
if (!layer) {
|
if (!layer) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
layer.prompt = prompt;
|
layer.positivePrompt = prompt;
|
||||||
|
},
|
||||||
|
negativePromptChanged: (state, action: PayloadAction<{ layerId: string; prompt: string }>) => {
|
||||||
|
const { layerId, prompt } = action.payload;
|
||||||
|
const layer = state.layers.find((l) => l.id === layerId);
|
||||||
|
if (!layer) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
layer.negativePrompt = prompt;
|
||||||
},
|
},
|
||||||
promptRegionLayerColorChanged: (state, action: PayloadAction<{ layerId: string; color: RgbColor }>) => {
|
promptRegionLayerColorChanged: (state, action: PayloadAction<{ layerId: string; color: RgbColor }>) => {
|
||||||
const { layerId, color } = action.payload;
|
const { layerId, color } = action.payload;
|
||||||
@ -254,7 +263,8 @@ export const {
|
|||||||
layerReset,
|
layerReset,
|
||||||
layerDeleted,
|
layerDeleted,
|
||||||
layerIsVisibleToggled,
|
layerIsVisibleToggled,
|
||||||
promptChanged,
|
positivePromptChanged,
|
||||||
|
negativePromptChanged,
|
||||||
lineAdded,
|
lineAdded,
|
||||||
pointsAdded,
|
pointsAdded,
|
||||||
promptRegionLayerColorChanged,
|
promptRegionLayerColorChanged,
|
||||||
|
@ -48,7 +48,7 @@ export const getRegionalPromptLayerBlobs = async (
|
|||||||
|
|
||||||
if (preview) {
|
if (preview) {
|
||||||
const base64 = await blobToDataURL(blob);
|
const base64 = await blobToDataURL(blob);
|
||||||
openBase64ImageInTab([{ base64, caption: `${reduxLayer.id}: ${reduxLayer.prompt}` }]);
|
openBase64ImageInTab([{ base64, caption: `${reduxLayer.id}: ${reduxLayer.positivePrompt} / ${reduxLayer.negativePrompt}` }]);
|
||||||
}
|
}
|
||||||
layer.remove();
|
layer.remove();
|
||||||
blobs[layer.id()] = blob;
|
blobs[layer.id()] = blob;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user