mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): add invert negative mode
Adds an additional negative conditioning using the inverted mask of the positive conditioning and the positive prompt. May be useful for mutually exclusive regions.
This commit is contained in:
parent
e4fcb6627a
commit
085f7bdbee
@ -5,8 +5,11 @@ import {
|
||||
NEGATIVE_CONDITIONING_COLLECT,
|
||||
POSITIVE_CONDITIONING,
|
||||
POSITIVE_CONDITIONING_COLLECT,
|
||||
PROMPT_REGION_MASK_PREFIX,
|
||||
PROMPT_REGION_MASK_IMAGE_PRIMITIVE_PREFIX,
|
||||
PROMPT_REGION_MASK_TO_TENSOR_INVERTED_PREFIX,
|
||||
PROMPT_REGION_MASK_TO_TENSOR_PREFIX,
|
||||
PROMPT_REGION_NEGATIVE_COND_PREFIX,
|
||||
PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX,
|
||||
PROMPT_REGION_POSITIVE_COND_PREFIX,
|
||||
} from 'features/nodes/util/graph/constants';
|
||||
import { getRegionalPromptLayerBlobs } from 'features/regionalPrompts/util/getLayerBlobs';
|
||||
@ -17,7 +20,9 @@ import { assert } from 'tsafe';
|
||||
|
||||
export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNullableGraph, denoiseNodeId: string) => {
|
||||
const { dispatch } = getStore();
|
||||
const isSDXL = state.generation.model?.base === 'sdxl';
|
||||
// TODO: Handle non-SDXL
|
||||
// const isSDXL = state.generation.model?.base === 'sdxl';
|
||||
const { autoNegative } = state.regionalPrompts;
|
||||
const layers = state.regionalPrompts.layers
|
||||
.filter((l) => l.kind === 'promptRegionLayer') // We only want the prompt region layers
|
||||
.filter((l) => l.isVisible); // Only visible layers are rendered on the canvas
|
||||
@ -89,6 +94,7 @@ export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNull
|
||||
});
|
||||
|
||||
// Remove the global prompt
|
||||
// TODO: Append regional prompts to CLIP2's prompt?
|
||||
(graph.nodes[POSITIVE_CONDITIONING] as S['SDXLCompelPromptInvocation'] | S['CompelInvocation']).prompt = '';
|
||||
|
||||
// Upload the blobs to the backend, add each to graph
|
||||
@ -96,8 +102,7 @@ export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNull
|
||||
const layer = layers.find((l) => l.id === layerId);
|
||||
assert(layer, `Layer ${layerId} not found`);
|
||||
|
||||
const id = `${PROMPT_REGION_MASK_PREFIX}_${layerId}`;
|
||||
const file = new File([blob], `${id}.png`, { type: 'image/png' });
|
||||
const file = new File([blob], `${layerId}_mask.png`, { type: 'image/png' });
|
||||
const req = dispatch(
|
||||
imagesApi.endpoints.uploadImage.initiate({ file, image_category: 'mask', is_intermediate: true })
|
||||
);
|
||||
@ -106,69 +111,121 @@ export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNull
|
||||
// TODO: this will raise an error
|
||||
const { image_name } = await req.unwrap();
|
||||
|
||||
const alphaMaskToTensorNode: S['AlphaMaskToTensorInvocation'] = {
|
||||
id,
|
||||
type: 'alpha_mask_to_tensor',
|
||||
const maskImageNode: S['ImageInvocation'] = {
|
||||
id: `${PROMPT_REGION_MASK_IMAGE_PRIMITIVE_PREFIX}_${layerId}`,
|
||||
type: 'image',
|
||||
image: {
|
||||
image_name,
|
||||
},
|
||||
};
|
||||
graph.nodes[id] = alphaMaskToTensorNode;
|
||||
graph.nodes[maskImageNode.id] = maskImageNode;
|
||||
|
||||
const maskToTensorNode: S['AlphaMaskToTensorInvocation'] = {
|
||||
id: `${PROMPT_REGION_MASK_TO_TENSOR_PREFIX}_${layerId}`,
|
||||
type: 'alpha_mask_to_tensor',
|
||||
};
|
||||
graph.nodes[maskToTensorNode.id] = maskToTensorNode;
|
||||
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: maskImageNode.id,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: maskToTensorNode.id,
|
||||
field: 'image',
|
||||
},
|
||||
});
|
||||
|
||||
// Create the conditioning nodes for each region - different handling for SDXL
|
||||
|
||||
// TODO: negative prompt
|
||||
const regionalPositiveCondNodeId = `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${layerId}`;
|
||||
const regionalNegativeCondNodeId = `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${layerId}`;
|
||||
|
||||
if (isSDXL) {
|
||||
graph.nodes[regionalPositiveCondNodeId] = {
|
||||
type: 'sdxl_compel_prompt',
|
||||
id: regionalPositiveCondNodeId,
|
||||
prompt: layer.positivePrompt,
|
||||
};
|
||||
graph.nodes[regionalNegativeCondNodeId] = {
|
||||
type: 'sdxl_compel_prompt',
|
||||
id: regionalNegativeCondNodeId,
|
||||
prompt: layer.negativePrompt,
|
||||
};
|
||||
} else {
|
||||
// TODO: non sdxl
|
||||
// graph.nodes[regionalCondNodeId] = {
|
||||
// type: 'compel',
|
||||
// id: regionalCondNodeId,
|
||||
// prompt: layer.prompt,
|
||||
// };
|
||||
}
|
||||
const regionalPositiveCondNode: S['SDXLCompelPromptInvocation'] = {
|
||||
type: 'sdxl_compel_prompt',
|
||||
id: `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${layerId}`,
|
||||
prompt: layer.positivePrompt,
|
||||
style: layer.positivePrompt,
|
||||
};
|
||||
const regionalNegativeCondNode: S['SDXLCompelPromptInvocation'] = {
|
||||
type: 'sdxl_compel_prompt',
|
||||
id: `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${layerId}`,
|
||||
prompt: layer.negativePrompt,
|
||||
style: layer.negativePrompt,
|
||||
};
|
||||
graph.nodes[regionalPositiveCondNode.id] = regionalPositiveCondNode;
|
||||
graph.nodes[regionalNegativeCondNode.id] = regionalNegativeCondNode;
|
||||
graph.edges.push({
|
||||
source: { node_id: id, field: 'mask' },
|
||||
destination: { node_id: regionalPositiveCondNodeId, field: 'mask' },
|
||||
source: { node_id: maskToTensorNode.id, field: 'mask' },
|
||||
destination: { node_id: regionalPositiveCondNode.id, field: 'mask' },
|
||||
});
|
||||
graph.edges.push({
|
||||
source: { node_id: id, field: 'mask' },
|
||||
destination: { node_id: regionalNegativeCondNodeId, field: 'mask' },
|
||||
source: { node_id: maskToTensorNode.id, field: 'mask' },
|
||||
destination: { node_id: regionalNegativeCondNode.id, field: 'mask' },
|
||||
});
|
||||
graph.edges.push({
|
||||
source: { node_id: regionalPositiveCondNodeId, field: 'conditioning' },
|
||||
source: { node_id: regionalPositiveCondNode.id, field: 'conditioning' },
|
||||
destination: { node_id: posCondCollectNode.id, field: 'item' },
|
||||
});
|
||||
graph.edges.push({
|
||||
source: { node_id: regionalNegativeCondNodeId, field: 'conditioning' },
|
||||
source: { node_id: regionalNegativeCondNode.id, field: 'conditioning' },
|
||||
destination: { node_id: negCondCollectNode.id, field: 'item' },
|
||||
});
|
||||
for (const edge of graph.edges) {
|
||||
if (edge.destination.node_id === POSITIVE_CONDITIONING && edge.destination.field !== 'prompt') {
|
||||
graph.edges.push({
|
||||
source: edge.source,
|
||||
destination: { node_id: regionalPositiveCondNodeId, field: edge.destination.field },
|
||||
destination: { node_id: regionalPositiveCondNode.id, field: edge.destination.field },
|
||||
});
|
||||
}
|
||||
if (edge.destination.node_id === NEGATIVE_CONDITIONING && edge.destination.field !== 'prompt') {
|
||||
graph.edges.push({
|
||||
source: edge.source,
|
||||
destination: { node_id: regionalNegativeCondNodeId, field: edge.destination.field },
|
||||
destination: { node_id: regionalNegativeCondNode.id, field: edge.destination.field },
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (autoNegative === 'invert') {
|
||||
// Add an additional negative conditioning node with the positive prompt & inverted region mask
|
||||
const invertedMaskToTensorNode: S['AlphaMaskToTensorInvocation'] = {
|
||||
id: `${PROMPT_REGION_MASK_TO_TENSOR_INVERTED_PREFIX}_${layerId}`,
|
||||
type: 'alpha_mask_to_tensor',
|
||||
invert: true,
|
||||
};
|
||||
graph.nodes[invertedMaskToTensorNode.id] = invertedMaskToTensorNode;
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: maskImageNode.id,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: invertedMaskToTensorNode.id,
|
||||
field: 'image',
|
||||
},
|
||||
});
|
||||
|
||||
const regionalPositiveCondInvertedNode: S['SDXLCompelPromptInvocation'] = {
|
||||
type: 'sdxl_compel_prompt',
|
||||
id: `${PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX}_${layerId}`,
|
||||
prompt: layer.positivePrompt,
|
||||
style: layer.positivePrompt,
|
||||
};
|
||||
graph.nodes[regionalPositiveCondInvertedNode.id] = regionalPositiveCondInvertedNode;
|
||||
graph.edges.push({
|
||||
source: { node_id: invertedMaskToTensorNode.id, field: 'mask' },
|
||||
destination: { node_id: regionalPositiveCondInvertedNode.id, field: 'mask' },
|
||||
});
|
||||
graph.edges.push({
|
||||
source: { node_id: regionalPositiveCondInvertedNode.id, field: 'conditioning' },
|
||||
destination: { node_id: negCondCollectNode.id, field: 'item' },
|
||||
});
|
||||
for (const edge of graph.edges) {
|
||||
if (edge.destination.node_id === POSITIVE_CONDITIONING && edge.destination.field !== 'prompt') {
|
||||
graph.edges.push({
|
||||
source: edge.source,
|
||||
destination: { node_id: regionalPositiveCondInvertedNode.id, field: edge.destination.field },
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
@ -5,7 +5,13 @@ import { range, some } from 'lodash-es';
|
||||
import type { components } from 'services/api/schema';
|
||||
import type { Batch, BatchConfig, NonNullableGraph } from 'services/api/types';
|
||||
|
||||
import { CANVAS_COHERENCE_NOISE, METADATA, NOISE, POSITIVE_CONDITIONING, PROMPT_REGION_MASK_PREFIX } from './constants';
|
||||
import {
|
||||
CANVAS_COHERENCE_NOISE,
|
||||
METADATA,
|
||||
NOISE,
|
||||
POSITIVE_CONDITIONING,
|
||||
PROMPT_REGION_MASK_TO_TENSOR_PREFIX,
|
||||
} from './constants';
|
||||
import { getHasMetadata, removeMetadata } from './metadata';
|
||||
|
||||
export const prepareLinearUIBatch = (state: RootState, graph: NonNullableGraph, prepend: boolean): BatchConfig => {
|
||||
@ -86,7 +92,7 @@ export const prepareLinearUIBatch = (state: RootState, graph: NonNullableGraph,
|
||||
|
||||
const extendedPrompts = seedBehaviour === 'PER_PROMPT' ? range(iterations).flatMap(() => prompts) : prompts;
|
||||
|
||||
const hasRegionalPrompts = some(graph.nodes, (n) => n.id.startsWith(PROMPT_REGION_MASK_PREFIX));
|
||||
const hasRegionalPrompts = some(graph.nodes, (n) => n.id.startsWith(PROMPT_REGION_MASK_TO_TENSOR_PREFIX));
|
||||
|
||||
if (!hasRegionalPrompts) {
|
||||
// zipped batch of prompts
|
||||
|
@ -46,9 +46,12 @@ export const SDXL_REFINER_DENOISE_LATENTS = 'sdxl_refiner_denoise_latents';
|
||||
export const SDXL_REFINER_INPAINT_CREATE_MASK = 'refiner_inpaint_create_mask';
|
||||
export const SEAMLESS = 'seamless';
|
||||
export const SDXL_REFINER_SEAMLESS = 'refiner_seamless';
|
||||
export const PROMPT_REGION_MASK_PREFIX = 'prompt_region_mask';
|
||||
export const PROMPT_REGION_MASK_IMAGE_PRIMITIVE_PREFIX = 'prompt_region_mask_image_primitive';
|
||||
export const PROMPT_REGION_MASK_TO_TENSOR_PREFIX = 'prompt_region_mask_to_tensor';
|
||||
export const PROMPT_REGION_MASK_TO_TENSOR_INVERTED_PREFIX = 'prompt_region_mask_to_tensor_inverted';
|
||||
export const PROMPT_REGION_POSITIVE_COND_PREFIX = 'prompt_region_positive_cond';
|
||||
export const PROMPT_REGION_NEGATIVE_COND_PREFIX = 'prompt_region_negative_cond';
|
||||
export const PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX = 'prompt_region_positive_cond_inverted';
|
||||
export const POSITIVE_CONDITIONING_COLLECT = 'positive_conditioning_collect';
|
||||
export const NEGATIVE_CONDITIONING_COLLECT = 'negative_conditioning_collect';
|
||||
|
||||
|
@ -196,3 +196,10 @@ const zLoRAWeight = z.number();
|
||||
type ParameterLoRAWeight = z.infer<typeof zLoRAWeight>;
|
||||
export const isParameterLoRAWeight = (val: unknown): val is ParameterLoRAWeight => zLoRAWeight.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region Regional Prompts AutoNegative
|
||||
const zAutoNegative = z.enum(['off', 'invert']);
|
||||
export type ParameterAutoNegative = z.infer<typeof zAutoNegative>;
|
||||
export const isParameterAutoNegative = (val: unknown): val is ParameterAutoNegative =>
|
||||
zAutoNegative.safeParse(val).success;
|
||||
// #endregion
|
||||
|
@ -0,0 +1,39 @@
|
||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { isParameterAutoNegative } from 'features/parameters/types/parameterSchemas';
|
||||
import { autoNegativeChanged } from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const options: ComboboxOption[] = [
|
||||
{ label: 'Off', value: 'off' },
|
||||
{ label: 'Invert', value: 'invert' },
|
||||
];
|
||||
|
||||
const AutoNegativeCombobox = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const autoNegative = useAppSelector((s) => s.regionalPrompts.autoNegative);
|
||||
|
||||
const onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
if (!isParameterAutoNegative(v?.value)) {
|
||||
return;
|
||||
}
|
||||
dispatch(autoNegativeChanged(v.value));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const value = useMemo(() => options.find((o) => o.value === autoNegative), [autoNegative]);
|
||||
|
||||
return (
|
||||
<FormControl>
|
||||
<FormLabel>Negative Mode</FormLabel>
|
||||
<Combobox value={value} options={options} onChange={onChange} isSearchable={false} />
|
||||
</FormControl>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(AutoNegativeCombobox);
|
@ -6,6 +6,7 @@ import { AddLayerButton } from 'features/regionalPrompts/components/AddLayerButt
|
||||
import { BrushSize } from 'features/regionalPrompts/components/BrushSize';
|
||||
import { DeleteAllLayersButton } from 'features/regionalPrompts/components/DeleteAllLayersButton';
|
||||
import { LayerListItem } from 'features/regionalPrompts/components/LayerListItem';
|
||||
import AutoNegativeCombobox from 'features/regionalPrompts/components/NegativeModeCombobox';
|
||||
import { PromptLayerOpacity } from 'features/regionalPrompts/components/PromptLayerOpacity';
|
||||
import { StageComponent } from 'features/regionalPrompts/components/StageComponent';
|
||||
import { ToolChooser } from 'features/regionalPrompts/components/ToolChooser';
|
||||
@ -36,6 +37,7 @@ export const RegionalPromptsEditor = memo(() => {
|
||||
</Flex>
|
||||
<BrushSize />
|
||||
<PromptLayerOpacity />
|
||||
<AutoNegativeCombobox />
|
||||
{layerIdsReversed.map((id) => (
|
||||
<LayerListItem key={id} id={id} />
|
||||
))}
|
||||
|
@ -2,6 +2,7 @@ import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import { moveBackward, moveForward, moveToBack, moveToFront } from 'common/util/arrayUtils';
|
||||
import type { ParameterAutoNegative } from 'features/parameters/types/parameterSchemas';
|
||||
import type { IRect, Vector2d } from 'konva/lib/types';
|
||||
import { atom } from 'nanostores';
|
||||
import type { RgbColor } from 'react-colorful';
|
||||
@ -65,6 +66,7 @@ type RegionalPromptsState = {
|
||||
layers: PromptRegionLayer[];
|
||||
brushSize: number;
|
||||
promptLayerOpacity: number;
|
||||
autoNegative: ParameterAutoNegative;
|
||||
};
|
||||
|
||||
const initialRegionalPromptsState: RegionalPromptsState = {
|
||||
@ -74,6 +76,7 @@ const initialRegionalPromptsState: RegionalPromptsState = {
|
||||
brushSize: 40,
|
||||
layers: [],
|
||||
promptLayerOpacity: 0.5, // This currently doesn't work
|
||||
autoNegative: 'off',
|
||||
};
|
||||
|
||||
const isLine = (obj: LayerObject): obj is LineObject => obj.kind === 'line';
|
||||
@ -229,6 +232,9 @@ export const regionalPromptsSlice = createSlice({
|
||||
promptLayerOpacityChanged: (state, action: PayloadAction<number>) => {
|
||||
state.promptLayerOpacity = action.payload;
|
||||
},
|
||||
autoNegativeChanged: (state, action: PayloadAction<ParameterAutoNegative>) => {
|
||||
state.autoNegative = action.payload;
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
@ -277,6 +283,7 @@ export const {
|
||||
layerBboxChanged,
|
||||
promptLayerOpacityChanged,
|
||||
allLayersDeleted,
|
||||
autoNegativeChanged,
|
||||
} = regionalPromptsSlice.actions;
|
||||
|
||||
export const selectRegionalPromptsSlice = (state: RootState) => state.regionalPrompts;
|
||||
|
Loading…
Reference in New Issue
Block a user