feat(ui): add invert negative mode

Adds an additional negative conditioning using the inverted mask of the positive conditioning and the positive prompt. May be useful for mutually exclusive regions.
This commit is contained in:
psychedelicious 2024-04-18 13:58:00 +10:00 committed by Kent Keirsey
parent e4fcb6627a
commit 085f7bdbee
7 changed files with 163 additions and 42 deletions

View File

@ -5,8 +5,11 @@ import {
NEGATIVE_CONDITIONING_COLLECT,
POSITIVE_CONDITIONING,
POSITIVE_CONDITIONING_COLLECT,
PROMPT_REGION_MASK_PREFIX,
PROMPT_REGION_MASK_IMAGE_PRIMITIVE_PREFIX,
PROMPT_REGION_MASK_TO_TENSOR_INVERTED_PREFIX,
PROMPT_REGION_MASK_TO_TENSOR_PREFIX,
PROMPT_REGION_NEGATIVE_COND_PREFIX,
PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX,
PROMPT_REGION_POSITIVE_COND_PREFIX,
} from 'features/nodes/util/graph/constants';
import { getRegionalPromptLayerBlobs } from 'features/regionalPrompts/util/getLayerBlobs';
@ -17,7 +20,9 @@ import { assert } from 'tsafe';
export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNullableGraph, denoiseNodeId: string) => {
const { dispatch } = getStore();
const isSDXL = state.generation.model?.base === 'sdxl';
// TODO: Handle non-SDXL
// const isSDXL = state.generation.model?.base === 'sdxl';
const { autoNegative } = state.regionalPrompts;
const layers = state.regionalPrompts.layers
.filter((l) => l.kind === 'promptRegionLayer') // We only want the prompt region layers
.filter((l) => l.isVisible); // Only visible layers are rendered on the canvas
@ -89,6 +94,7 @@ export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNull
});
// Remove the global prompt
// TODO: Append regional prompts to CLIP2's prompt?
(graph.nodes[POSITIVE_CONDITIONING] as S['SDXLCompelPromptInvocation'] | S['CompelInvocation']).prompt = '';
// Upload the blobs to the backend, add each to graph
@ -96,8 +102,7 @@ export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNull
const layer = layers.find((l) => l.id === layerId);
assert(layer, `Layer ${layerId} not found`);
const id = `${PROMPT_REGION_MASK_PREFIX}_${layerId}`;
const file = new File([blob], `${id}.png`, { type: 'image/png' });
const file = new File([blob], `${layerId}_mask.png`, { type: 'image/png' });
const req = dispatch(
imagesApi.endpoints.uploadImage.initiate({ file, image_category: 'mask', is_intermediate: true })
);
@ -106,69 +111,121 @@ export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNull
// TODO: this will raise an error
const { image_name } = await req.unwrap();
const alphaMaskToTensorNode: S['AlphaMaskToTensorInvocation'] = {
id,
type: 'alpha_mask_to_tensor',
const maskImageNode: S['ImageInvocation'] = {
id: `${PROMPT_REGION_MASK_IMAGE_PRIMITIVE_PREFIX}_${layerId}`,
type: 'image',
image: {
image_name,
},
};
graph.nodes[id] = alphaMaskToTensorNode;
graph.nodes[maskImageNode.id] = maskImageNode;
const maskToTensorNode: S['AlphaMaskToTensorInvocation'] = {
id: `${PROMPT_REGION_MASK_TO_TENSOR_PREFIX}_${layerId}`,
type: 'alpha_mask_to_tensor',
};
graph.nodes[maskToTensorNode.id] = maskToTensorNode;
graph.edges.push({
source: {
node_id: maskImageNode.id,
field: 'image',
},
destination: {
node_id: maskToTensorNode.id,
field: 'image',
},
});
// Create the conditioning nodes for each region - different handling for SDXL
// TODO: negative prompt
const regionalPositiveCondNodeId = `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${layerId}`;
const regionalNegativeCondNodeId = `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${layerId}`;
if (isSDXL) {
graph.nodes[regionalPositiveCondNodeId] = {
type: 'sdxl_compel_prompt',
id: regionalPositiveCondNodeId,
prompt: layer.positivePrompt,
};
graph.nodes[regionalNegativeCondNodeId] = {
type: 'sdxl_compel_prompt',
id: regionalNegativeCondNodeId,
prompt: layer.negativePrompt,
};
} else {
// TODO: non sdxl
// graph.nodes[regionalCondNodeId] = {
// type: 'compel',
// id: regionalCondNodeId,
// prompt: layer.prompt,
// };
}
const regionalPositiveCondNode: S['SDXLCompelPromptInvocation'] = {
type: 'sdxl_compel_prompt',
id: `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${layerId}`,
prompt: layer.positivePrompt,
style: layer.positivePrompt,
};
const regionalNegativeCondNode: S['SDXLCompelPromptInvocation'] = {
type: 'sdxl_compel_prompt',
id: `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${layerId}`,
prompt: layer.negativePrompt,
style: layer.negativePrompt,
};
graph.nodes[regionalPositiveCondNode.id] = regionalPositiveCondNode;
graph.nodes[regionalNegativeCondNode.id] = regionalNegativeCondNode;
graph.edges.push({
source: { node_id: id, field: 'mask' },
destination: { node_id: regionalPositiveCondNodeId, field: 'mask' },
source: { node_id: maskToTensorNode.id, field: 'mask' },
destination: { node_id: regionalPositiveCondNode.id, field: 'mask' },
});
graph.edges.push({
source: { node_id: id, field: 'mask' },
destination: { node_id: regionalNegativeCondNodeId, field: 'mask' },
source: { node_id: maskToTensorNode.id, field: 'mask' },
destination: { node_id: regionalNegativeCondNode.id, field: 'mask' },
});
graph.edges.push({
source: { node_id: regionalPositiveCondNodeId, field: 'conditioning' },
source: { node_id: regionalPositiveCondNode.id, field: 'conditioning' },
destination: { node_id: posCondCollectNode.id, field: 'item' },
});
graph.edges.push({
source: { node_id: regionalNegativeCondNodeId, field: 'conditioning' },
source: { node_id: regionalNegativeCondNode.id, field: 'conditioning' },
destination: { node_id: negCondCollectNode.id, field: 'item' },
});
for (const edge of graph.edges) {
if (edge.destination.node_id === POSITIVE_CONDITIONING && edge.destination.field !== 'prompt') {
graph.edges.push({
source: edge.source,
destination: { node_id: regionalPositiveCondNodeId, field: edge.destination.field },
destination: { node_id: regionalPositiveCondNode.id, field: edge.destination.field },
});
}
if (edge.destination.node_id === NEGATIVE_CONDITIONING && edge.destination.field !== 'prompt') {
graph.edges.push({
source: edge.source,
destination: { node_id: regionalNegativeCondNodeId, field: edge.destination.field },
destination: { node_id: regionalNegativeCondNode.id, field: edge.destination.field },
});
}
}
if (autoNegative === 'invert') {
// Add an additional negative conditioning node with the positive prompt & inverted region mask
const invertedMaskToTensorNode: S['AlphaMaskToTensorInvocation'] = {
id: `${PROMPT_REGION_MASK_TO_TENSOR_INVERTED_PREFIX}_${layerId}`,
type: 'alpha_mask_to_tensor',
invert: true,
};
graph.nodes[invertedMaskToTensorNode.id] = invertedMaskToTensorNode;
graph.edges.push({
source: {
node_id: maskImageNode.id,
field: 'image',
},
destination: {
node_id: invertedMaskToTensorNode.id,
field: 'image',
},
});
const regionalPositiveCondInvertedNode: S['SDXLCompelPromptInvocation'] = {
type: 'sdxl_compel_prompt',
id: `${PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX}_${layerId}`,
prompt: layer.positivePrompt,
style: layer.positivePrompt,
};
graph.nodes[regionalPositiveCondInvertedNode.id] = regionalPositiveCondInvertedNode;
graph.edges.push({
source: { node_id: invertedMaskToTensorNode.id, field: 'mask' },
destination: { node_id: regionalPositiveCondInvertedNode.id, field: 'mask' },
});
graph.edges.push({
source: { node_id: regionalPositiveCondInvertedNode.id, field: 'conditioning' },
destination: { node_id: negCondCollectNode.id, field: 'item' },
});
for (const edge of graph.edges) {
if (edge.destination.node_id === POSITIVE_CONDITIONING && edge.destination.field !== 'prompt') {
graph.edges.push({
source: edge.source,
destination: { node_id: regionalPositiveCondInvertedNode.id, field: edge.destination.field },
});
}
}
}
}
};

View File

@ -5,7 +5,13 @@ import { range, some } from 'lodash-es';
import type { components } from 'services/api/schema';
import type { Batch, BatchConfig, NonNullableGraph } from 'services/api/types';
import { CANVAS_COHERENCE_NOISE, METADATA, NOISE, POSITIVE_CONDITIONING, PROMPT_REGION_MASK_PREFIX } from './constants';
import {
CANVAS_COHERENCE_NOISE,
METADATA,
NOISE,
POSITIVE_CONDITIONING,
PROMPT_REGION_MASK_TO_TENSOR_PREFIX,
} from './constants';
import { getHasMetadata, removeMetadata } from './metadata';
export const prepareLinearUIBatch = (state: RootState, graph: NonNullableGraph, prepend: boolean): BatchConfig => {
@ -86,7 +92,7 @@ export const prepareLinearUIBatch = (state: RootState, graph: NonNullableGraph,
const extendedPrompts = seedBehaviour === 'PER_PROMPT' ? range(iterations).flatMap(() => prompts) : prompts;
const hasRegionalPrompts = some(graph.nodes, (n) => n.id.startsWith(PROMPT_REGION_MASK_PREFIX));
const hasRegionalPrompts = some(graph.nodes, (n) => n.id.startsWith(PROMPT_REGION_MASK_TO_TENSOR_PREFIX));
if (!hasRegionalPrompts) {
// zipped batch of prompts

View File

@ -46,9 +46,12 @@ export const SDXL_REFINER_DENOISE_LATENTS = 'sdxl_refiner_denoise_latents';
export const SDXL_REFINER_INPAINT_CREATE_MASK = 'refiner_inpaint_create_mask';
export const SEAMLESS = 'seamless';
export const SDXL_REFINER_SEAMLESS = 'refiner_seamless';
export const PROMPT_REGION_MASK_PREFIX = 'prompt_region_mask';
export const PROMPT_REGION_MASK_IMAGE_PRIMITIVE_PREFIX = 'prompt_region_mask_image_primitive';
export const PROMPT_REGION_MASK_TO_TENSOR_PREFIX = 'prompt_region_mask_to_tensor';
export const PROMPT_REGION_MASK_TO_TENSOR_INVERTED_PREFIX = 'prompt_region_mask_to_tensor_inverted';
export const PROMPT_REGION_POSITIVE_COND_PREFIX = 'prompt_region_positive_cond';
export const PROMPT_REGION_NEGATIVE_COND_PREFIX = 'prompt_region_negative_cond';
export const PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX = 'prompt_region_positive_cond_inverted';
export const POSITIVE_CONDITIONING_COLLECT = 'positive_conditioning_collect';
export const NEGATIVE_CONDITIONING_COLLECT = 'negative_conditioning_collect';

View File

@ -196,3 +196,10 @@ const zLoRAWeight = z.number();
type ParameterLoRAWeight = z.infer<typeof zLoRAWeight>;
export const isParameterLoRAWeight = (val: unknown): val is ParameterLoRAWeight => zLoRAWeight.safeParse(val).success;
// #endregion
// #region Regional Prompts AutoNegative
const zAutoNegative = z.enum(['off', 'invert']);
export type ParameterAutoNegative = z.infer<typeof zAutoNegative>;
export const isParameterAutoNegative = (val: unknown): val is ParameterAutoNegative =>
zAutoNegative.safeParse(val).success;
// #endregion

View File

@ -0,0 +1,39 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { isParameterAutoNegative } from 'features/parameters/types/parameterSchemas';
import { autoNegativeChanged } from 'features/regionalPrompts/store/regionalPromptsSlice';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
const options: ComboboxOption[] = [
{ label: 'Off', value: 'off' },
{ label: 'Invert', value: 'invert' },
];
const AutoNegativeCombobox = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const autoNegative = useAppSelector((s) => s.regionalPrompts.autoNegative);
const onChange = useCallback<ComboboxOnChange>(
(v) => {
if (!isParameterAutoNegative(v?.value)) {
return;
}
dispatch(autoNegativeChanged(v.value));
},
[dispatch]
);
const value = useMemo(() => options.find((o) => o.value === autoNegative), [autoNegative]);
return (
<FormControl>
<FormLabel>Negative Mode</FormLabel>
<Combobox value={value} options={options} onChange={onChange} isSearchable={false} />
</FormControl>
);
};
export default memo(AutoNegativeCombobox);

View File

@ -6,6 +6,7 @@ import { AddLayerButton } from 'features/regionalPrompts/components/AddLayerButt
import { BrushSize } from 'features/regionalPrompts/components/BrushSize';
import { DeleteAllLayersButton } from 'features/regionalPrompts/components/DeleteAllLayersButton';
import { LayerListItem } from 'features/regionalPrompts/components/LayerListItem';
import AutoNegativeCombobox from 'features/regionalPrompts/components/NegativeModeCombobox';
import { PromptLayerOpacity } from 'features/regionalPrompts/components/PromptLayerOpacity';
import { StageComponent } from 'features/regionalPrompts/components/StageComponent';
import { ToolChooser } from 'features/regionalPrompts/components/ToolChooser';
@ -36,6 +37,7 @@ export const RegionalPromptsEditor = memo(() => {
</Flex>
<BrushSize />
<PromptLayerOpacity />
<AutoNegativeCombobox />
{layerIdsReversed.map((id) => (
<LayerListItem key={id} id={id} />
))}

View File

@ -2,6 +2,7 @@ import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import { moveBackward, moveForward, moveToBack, moveToFront } from 'common/util/arrayUtils';
import type { ParameterAutoNegative } from 'features/parameters/types/parameterSchemas';
import type { IRect, Vector2d } from 'konva/lib/types';
import { atom } from 'nanostores';
import type { RgbColor } from 'react-colorful';
@ -65,6 +66,7 @@ type RegionalPromptsState = {
layers: PromptRegionLayer[];
brushSize: number;
promptLayerOpacity: number;
autoNegative: ParameterAutoNegative;
};
const initialRegionalPromptsState: RegionalPromptsState = {
@ -74,6 +76,7 @@ const initialRegionalPromptsState: RegionalPromptsState = {
brushSize: 40,
layers: [],
promptLayerOpacity: 0.5, // This currently doesn't work
autoNegative: 'off',
};
const isLine = (obj: LayerObject): obj is LineObject => obj.kind === 'line';
@ -229,6 +232,9 @@ export const regionalPromptsSlice = createSlice({
promptLayerOpacityChanged: (state, action: PayloadAction<number>) => {
state.promptLayerOpacity = action.payload;
},
autoNegativeChanged: (state, action: PayloadAction<ParameterAutoNegative>) => {
state.autoNegative = action.payload;
},
},
});
@ -277,6 +283,7 @@ export const {
layerBboxChanged,
promptLayerOpacityChanged,
allLayersDeleted,
autoNegativeChanged,
} = regionalPromptsSlice.actions;
export const selectRegionalPromptsSlice = (state: RootState) => state.regionalPrompts;