feat(ui): draft of graph helper for regional prompts

This commit is contained in:
psychedelicious 2024-04-15 19:27:01 +10:00 committed by Kent Keirsey
parent 602a59066e
commit 05deeb68fa
9 changed files with 254 additions and 35 deletions

View File

@ -23,6 +23,7 @@ export type NodesState = {
nodeOpacity: number;
shouldSnapToGrid: boolean;
shouldColorEdges: boolean;
shouldShowEdgeLabels: boolean;
selectedNodes: string[];
selectedEdges: string[];
nodeExecutionStates: Record<string, NodeExecutionState>;

View File

@ -0,0 +1,151 @@
import { getStore } from 'app/store/nanostores/store';
import type { RootState } from 'app/store/store';
import {
NEGATIVE_CONDITIONING_COLLECT,
POSITIVE_CONDITIONING,
POSITIVE_CONDITIONING_COLLECT,
PROMPT_REGION_COND_PREFIX,
PROMPT_REGION_MASK_PREFIX,
} from 'features/nodes/util/graph/constants';
import { getRegionalPromptLayerBlobs } from 'features/regionalPrompts/util/getLayerBlobs';
import { size } from 'lodash-es';
import { imagesApi } from 'services/api/endpoints/images';
import type { CollectInvocation, Edge, NonNullableGraph, S } from 'services/api/types';
import { assert } from 'tsafe';
export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNullableGraph, denoiseNodeId: string) => {
const { dispatch } = getStore();
const isSDXL = state.generation.model?.base === 'sdxl';
const layers = state.regionalPrompts.layers
.filter((l) => l.kind === 'promptRegionLayer') // We only want the prompt region layers
.filter((l) => l.isVisible); // Only visible layers are rendered on the canvas
const layerIds = layers.map((l) => l.id); // We only need the IDs
const blobs = await getRegionalPromptLayerBlobs(layerIds);
console.log('blobs', blobs, 'layerIds', layerIds);
assert(size(blobs) === size(layerIds), 'Mismatch between layer IDs and blobs');
// Set up the conditioning collectors
const posCondCollectNode: CollectInvocation = {
id: POSITIVE_CONDITIONING_COLLECT,
type: 'collect',
};
const negCondCollectNode: CollectInvocation = {
id: NEGATIVE_CONDITIONING_COLLECT,
type: 'collect',
};
graph.nodes[POSITIVE_CONDITIONING_COLLECT] = posCondCollectNode;
graph.nodes[NEGATIVE_CONDITIONING_COLLECT] = negCondCollectNode;
// Re-route the denoise node's OG conditioning inputs to the collect nodes
const newEdges: Edge[] = [];
for (const edge of graph.edges) {
if (edge.destination.node_id === denoiseNodeId && edge.destination.field === 'positive_conditioning') {
newEdges.push({
source: edge.source,
destination: {
node_id: POSITIVE_CONDITIONING_COLLECT,
field: 'item',
},
});
} else if (edge.destination.node_id === denoiseNodeId && edge.destination.field === 'negative_conditioning') {
newEdges.push({
source: edge.source,
destination: {
node_id: NEGATIVE_CONDITIONING_COLLECT,
field: 'item',
},
});
} else {
newEdges.push(edge);
}
}
graph.edges = newEdges;
// Connect collectors to the denoise nodes - must happen _after_ rerouting else you get cycles
graph.edges.push({
source: {
node_id: POSITIVE_CONDITIONING_COLLECT,
field: 'collection',
},
destination: {
node_id: denoiseNodeId,
field: 'positive_conditioning',
},
});
graph.edges.push({
source: {
node_id: NEGATIVE_CONDITIONING_COLLECT,
field: 'collection',
},
destination: {
node_id: denoiseNodeId,
field: 'negative_conditioning',
},
});
// Remove the global prompt
(graph.nodes[POSITIVE_CONDITIONING] as S['SDXLCompelPromptInvocation'] | S['CompelInvocation']).prompt = '';
// Upload the blobs to the backend, add each to graph
for (const [layerId, blob] of Object.entries(blobs)) {
const layer = layers.find((l) => l.id === layerId);
assert(layer, `Layer ${layerId} not found`);
const id = `${PROMPT_REGION_MASK_PREFIX}_${layerId}`;
const file = new File([blob], `${id}.png`, { type: 'image/png' });
const req = dispatch(
imagesApi.endpoints.uploadImage.initiate({ file, image_category: 'mask', is_intermediate: true })
);
req.reset();
// TODO: this will raise an error
const { image_name } = await req.unwrap();
const alphaMaskToTensorNode: S['AlphaMaskToTensorInvocation'] = {
id,
type: 'alpha_mask_to_tensor',
image: {
image_name,
},
};
graph.nodes[id] = alphaMaskToTensorNode;
// Create the conditioning nodes for each region - different handling for SDXL
// TODO: negative prompt
const regionalCondNodeId = `${PROMPT_REGION_COND_PREFIX}_${layerId}`;
if (isSDXL) {
graph.nodes[regionalCondNodeId] = {
type: 'sdxl_compel_prompt',
id: regionalCondNodeId,
prompt: layer.prompt,
};
} else {
graph.nodes[regionalCondNodeId] = {
type: 'compel',
id: regionalCondNodeId,
prompt: layer.prompt,
};
}
graph.edges.push({
source: { node_id: id, field: 'mask' },
destination: { node_id: regionalCondNodeId, field: 'mask' },
});
graph.edges.push({
source: { node_id: regionalCondNodeId, field: 'conditioning' },
destination: { node_id: posCondCollectNode.id, field: 'item' },
});
for (const edge of graph.edges) {
if (edge.destination.node_id === POSITIVE_CONDITIONING && edge.destination.field !== 'prompt') {
graph.edges.push({
source: edge.source,
destination: { node_id: regionalCondNodeId, field: edge.destination.field },
});
}
}
}
};

View File

@ -1,11 +1,11 @@
import { NUMPY_RAND_MAX } from 'app/constants';
import type { RootState } from 'app/store/store';
import { generateSeeds } from 'common/util/generateSeeds';
import { range } from 'lodash-es';
import { range, some } from 'lodash-es';
import type { components } from 'services/api/schema';
import type { Batch, BatchConfig, NonNullableGraph } from 'services/api/types';
import { CANVAS_COHERENCE_NOISE, METADATA, NOISE, POSITIVE_CONDITIONING } from './constants';
import { CANVAS_COHERENCE_NOISE, METADATA, NOISE, POSITIVE_CONDITIONING, PROMPT_REGION_MASK_PREFIX } from './constants';
import { getHasMetadata, removeMetadata } from './metadata';
export const prepareLinearUIBatch = (state: RootState, graph: NonNullableGraph, prepend: boolean): BatchConfig => {
@ -86,23 +86,27 @@ export const prepareLinearUIBatch = (state: RootState, graph: NonNullableGraph,
const extendedPrompts = seedBehaviour === 'PER_PROMPT' ? range(iterations).flatMap(() => prompts) : prompts;
// zipped batch of prompts
if (graph.nodes[POSITIVE_CONDITIONING]) {
firstBatchDatumList.push({
node_path: POSITIVE_CONDITIONING,
field_name: 'prompt',
items: extendedPrompts,
});
}
const hasRegionalPrompts = some(graph.nodes, (n) => n.id.startsWith(PROMPT_REGION_MASK_PREFIX));
// add to metadata
if (getHasMetadata(graph)) {
removeMetadata(graph, 'positive_prompt');
firstBatchDatumList.push({
node_path: METADATA,
field_name: 'positive_prompt',
items: extendedPrompts,
});
if (!hasRegionalPrompts) {
// zipped batch of prompts
if (graph.nodes[POSITIVE_CONDITIONING]) {
firstBatchDatumList.push({
node_path: POSITIVE_CONDITIONING,
field_name: 'prompt',
items: extendedPrompts,
});
}
// add to metadata
if (getHasMetadata(graph)) {
removeMetadata(graph, 'positive_prompt');
firstBatchDatumList.push({
node_path: METADATA,
field_name: 'positive_prompt',
items: extendedPrompts,
});
}
}
if (shouldConcatSDXLStylePrompt && model?.base === 'sdxl') {

View File

@ -1,6 +1,7 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { addRegionalPromptsToGraph } from 'features/nodes/util/graph/addRegionalPromptsToGraph';
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
@ -273,6 +274,8 @@ export const buildLinearSDXLTextToImageGraph = async (state: RootState): Promise
await addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
await addRegionalPromptsToGraph(state, graph, SDXL_DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!

View File

@ -46,6 +46,10 @@ export const SDXL_REFINER_DENOISE_LATENTS = 'sdxl_refiner_denoise_latents';
export const SDXL_REFINER_INPAINT_CREATE_MASK = 'refiner_inpaint_create_mask';
export const SEAMLESS = 'seamless';
export const SDXL_REFINER_SEAMLESS = 'refiner_seamless';
export const PROMPT_REGION_MASK_PREFIX = 'prompt_region_mask';
export const PROMPT_REGION_COND_PREFIX = 'prompt_region_cond';
export const POSITIVE_CONDITIONING_COLLECT = 'positive_conditioning_collect';
export const NEGATIVE_CONDITIONING_COLLECT = 'negative_conditioning_collect';
// friendly graph ids
export const TEXT_TO_IMAGE_GRAPH = 'text_to_image_graph';

View File

@ -18,13 +18,13 @@ const selectLayerIdsReversed = createMemoizedSelector(selectRegionalPromptsSlice
);
const debugBlobs = () => {
getRegionalPromptLayerBlobs(true);
getRegionalPromptLayerBlobs(undefined, true);
};
export const RegionalPromptsEditor = memo(() => {
const layerIdsReversed = useAppSelector(selectLayerIdsReversed);
return (
<Flex gap={4}>
<Flex gap={4} w="full" h="full">
<Flex flexDir="column" gap={4} flexShrink={0}>
<Button onClick={debugBlobs}>DEBUG LAYERS</Button>
<AddLayerButton />

View File

@ -1,4 +1,3 @@
import { getStore } from 'app/store/nanostores/store';
import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
import { blobToDataURL } from 'features/canvas/util/blobToDataURL';
import { selectPromptLayerObjectGroup } from 'features/regionalPrompts/components/LayerComponent';
@ -7,16 +6,24 @@ import Konva from 'konva';
import { assert } from 'tsafe';
/**
* Get the blobs of all regional prompt layers.
* Get the blobs of all regional prompt layers. Only visible layers are returned.
* @param layerIds The IDs of the layers to get blobs for. If not provided, all regional prompt layers are used.
* @param preview Whether to open a new tab displaying each layer.
* @returns A map of layer IDs to blobs.
*/
export const getRegionalPromptLayerBlobs = async (preview: boolean = false): Promise<Record<string, Blob>> => {
const state = getStore().getState();
export const getRegionalPromptLayerBlobs = async (
layerIds?: string[],
preview: boolean = false
): Promise<Record<string, Blob>> => {
const stage = getStage();
// This automatically omits layers that are not rendered. Rendering is controlled by the layer's `isVisible` flag in redux.
const regionalPromptLayers = stage.getLayers().filter((l) => l.name() === REGIONAL_PROMPT_LAYER_NAME);
const regionalPromptLayers = stage.getLayers().filter((l) => {
console.log(l.name(), l.id())
const isRegionalPromptLayer = l.name() === REGIONAL_PROMPT_LAYER_NAME;
const isRequestedLayerId = layerIds ? layerIds.includes(l.id()) : true;
return isRegionalPromptLayer && isRequestedLayerId;
});
// We need to reconstruct each layer to only output the desired data. This logic mirrors the logic in
// `getKonvaLayerBbox()` in `invokeai/frontend/web/src/features/regionalPrompts/util/bbox.ts`
@ -48,14 +55,13 @@ export const getRegionalPromptLayerBlobs = async (preview: boolean = false): Pro
},
});
});
blobs[layer.id()] = blob;
if (preview) {
const base64 = await blobToDataURL(blob);
const prompt = state.regionalPrompts.layers.find((l) => l.id === layer.id())?.prompt;
openBase64ImageInTab([{ base64, caption: prompt ?? '' }]);
openBase64ImageInTab([{ base64, caption: layer.id() }]);
}
layerClone.destroy();
blobs[layer.id()] = blob;
}
return blobs;

View File

@ -1,13 +1,26 @@
import { Box, Flex } from '@invoke-ai/ui-library';
import { Box, Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/ui-library';
import CurrentImageDisplay from 'features/gallery/components/CurrentImage/CurrentImageDisplay';
import { RegionalPromptsEditor } from 'features/regionalPrompts/components/RegionalPromptsEditor';
import { memo } from 'react';
const TextToImageTab = () => {
return (
<Box layerStyle="first" position="relative" w="full" h="full" p={2} borderRadius="base">
<Flex w="full" h="full">
<CurrentImageDisplay />
</Flex>
<Box position="relative" w="full" h="full" p={2} borderRadius="base">
<Tabs variant="line" isLazy={true} display="flex" flexDir="column" w="full" h="full">
<TabList>
<Tab>Viewer</Tab>
<Tab>Regional Prompts</Tab>
</TabList>
<TabPanels w="full" h="full">
<TabPanel>
<CurrentImageDisplay />
</TabPanel>
<TabPanel>
<RegionalPromptsEditor />
</TabPanel>
</TabPanels>
</Tabs>
</Box>
);
};

File diff suppressed because one or more lines are too long