mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): draft of graph helper for regional prompts
This commit is contained in:
parent
602a59066e
commit
05deeb68fa
@ -23,6 +23,7 @@ export type NodesState = {
|
||||
nodeOpacity: number;
|
||||
shouldSnapToGrid: boolean;
|
||||
shouldColorEdges: boolean;
|
||||
shouldShowEdgeLabels: boolean;
|
||||
selectedNodes: string[];
|
||||
selectedEdges: string[];
|
||||
nodeExecutionStates: Record<string, NodeExecutionState>;
|
||||
|
@ -0,0 +1,151 @@
|
||||
import { getStore } from 'app/store/nanostores/store';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import {
|
||||
NEGATIVE_CONDITIONING_COLLECT,
|
||||
POSITIVE_CONDITIONING,
|
||||
POSITIVE_CONDITIONING_COLLECT,
|
||||
PROMPT_REGION_COND_PREFIX,
|
||||
PROMPT_REGION_MASK_PREFIX,
|
||||
} from 'features/nodes/util/graph/constants';
|
||||
import { getRegionalPromptLayerBlobs } from 'features/regionalPrompts/util/getLayerBlobs';
|
||||
import { size } from 'lodash-es';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import type { CollectInvocation, Edge, NonNullableGraph, S } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNullableGraph, denoiseNodeId: string) => {
|
||||
const { dispatch } = getStore();
|
||||
const isSDXL = state.generation.model?.base === 'sdxl';
|
||||
const layers = state.regionalPrompts.layers
|
||||
.filter((l) => l.kind === 'promptRegionLayer') // We only want the prompt region layers
|
||||
.filter((l) => l.isVisible); // Only visible layers are rendered on the canvas
|
||||
|
||||
const layerIds = layers.map((l) => l.id); // We only need the IDs
|
||||
|
||||
const blobs = await getRegionalPromptLayerBlobs(layerIds);
|
||||
|
||||
console.log('blobs', blobs, 'layerIds', layerIds);
|
||||
assert(size(blobs) === size(layerIds), 'Mismatch between layer IDs and blobs');
|
||||
|
||||
// Set up the conditioning collectors
|
||||
const posCondCollectNode: CollectInvocation = {
|
||||
id: POSITIVE_CONDITIONING_COLLECT,
|
||||
type: 'collect',
|
||||
};
|
||||
const negCondCollectNode: CollectInvocation = {
|
||||
id: NEGATIVE_CONDITIONING_COLLECT,
|
||||
type: 'collect',
|
||||
};
|
||||
graph.nodes[POSITIVE_CONDITIONING_COLLECT] = posCondCollectNode;
|
||||
graph.nodes[NEGATIVE_CONDITIONING_COLLECT] = negCondCollectNode;
|
||||
|
||||
// Re-route the denoise node's OG conditioning inputs to the collect nodes
|
||||
const newEdges: Edge[] = [];
|
||||
for (const edge of graph.edges) {
|
||||
if (edge.destination.node_id === denoiseNodeId && edge.destination.field === 'positive_conditioning') {
|
||||
newEdges.push({
|
||||
source: edge.source,
|
||||
destination: {
|
||||
node_id: POSITIVE_CONDITIONING_COLLECT,
|
||||
field: 'item',
|
||||
},
|
||||
});
|
||||
} else if (edge.destination.node_id === denoiseNodeId && edge.destination.field === 'negative_conditioning') {
|
||||
newEdges.push({
|
||||
source: edge.source,
|
||||
destination: {
|
||||
node_id: NEGATIVE_CONDITIONING_COLLECT,
|
||||
field: 'item',
|
||||
},
|
||||
});
|
||||
} else {
|
||||
newEdges.push(edge);
|
||||
}
|
||||
}
|
||||
graph.edges = newEdges;
|
||||
|
||||
// Connect collectors to the denoise nodes - must happen _after_ rerouting else you get cycles
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: POSITIVE_CONDITIONING_COLLECT,
|
||||
field: 'collection',
|
||||
},
|
||||
destination: {
|
||||
node_id: denoiseNodeId,
|
||||
field: 'positive_conditioning',
|
||||
},
|
||||
});
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: NEGATIVE_CONDITIONING_COLLECT,
|
||||
field: 'collection',
|
||||
},
|
||||
destination: {
|
||||
node_id: denoiseNodeId,
|
||||
field: 'negative_conditioning',
|
||||
},
|
||||
});
|
||||
|
||||
// Remove the global prompt
|
||||
(graph.nodes[POSITIVE_CONDITIONING] as S['SDXLCompelPromptInvocation'] | S['CompelInvocation']).prompt = '';
|
||||
|
||||
// Upload the blobs to the backend, add each to graph
|
||||
for (const [layerId, blob] of Object.entries(blobs)) {
|
||||
const layer = layers.find((l) => l.id === layerId);
|
||||
assert(layer, `Layer ${layerId} not found`);
|
||||
|
||||
const id = `${PROMPT_REGION_MASK_PREFIX}_${layerId}`;
|
||||
const file = new File([blob], `${id}.png`, { type: 'image/png' });
|
||||
const req = dispatch(
|
||||
imagesApi.endpoints.uploadImage.initiate({ file, image_category: 'mask', is_intermediate: true })
|
||||
);
|
||||
req.reset();
|
||||
|
||||
// TODO: this will raise an error
|
||||
const { image_name } = await req.unwrap();
|
||||
|
||||
const alphaMaskToTensorNode: S['AlphaMaskToTensorInvocation'] = {
|
||||
id,
|
||||
type: 'alpha_mask_to_tensor',
|
||||
image: {
|
||||
image_name,
|
||||
},
|
||||
};
|
||||
graph.nodes[id] = alphaMaskToTensorNode;
|
||||
|
||||
// Create the conditioning nodes for each region - different handling for SDXL
|
||||
|
||||
// TODO: negative prompt
|
||||
const regionalCondNodeId = `${PROMPT_REGION_COND_PREFIX}_${layerId}`;
|
||||
|
||||
if (isSDXL) {
|
||||
graph.nodes[regionalCondNodeId] = {
|
||||
type: 'sdxl_compel_prompt',
|
||||
id: regionalCondNodeId,
|
||||
prompt: layer.prompt,
|
||||
};
|
||||
} else {
|
||||
graph.nodes[regionalCondNodeId] = {
|
||||
type: 'compel',
|
||||
id: regionalCondNodeId,
|
||||
prompt: layer.prompt,
|
||||
};
|
||||
}
|
||||
graph.edges.push({
|
||||
source: { node_id: id, field: 'mask' },
|
||||
destination: { node_id: regionalCondNodeId, field: 'mask' },
|
||||
});
|
||||
graph.edges.push({
|
||||
source: { node_id: regionalCondNodeId, field: 'conditioning' },
|
||||
destination: { node_id: posCondCollectNode.id, field: 'item' },
|
||||
});
|
||||
for (const edge of graph.edges) {
|
||||
if (edge.destination.node_id === POSITIVE_CONDITIONING && edge.destination.field !== 'prompt') {
|
||||
graph.edges.push({
|
||||
source: edge.source,
|
||||
destination: { node_id: regionalCondNodeId, field: edge.destination.field },
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
@ -1,11 +1,11 @@
|
||||
import { NUMPY_RAND_MAX } from 'app/constants';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { generateSeeds } from 'common/util/generateSeeds';
|
||||
import { range } from 'lodash-es';
|
||||
import { range, some } from 'lodash-es';
|
||||
import type { components } from 'services/api/schema';
|
||||
import type { Batch, BatchConfig, NonNullableGraph } from 'services/api/types';
|
||||
|
||||
import { CANVAS_COHERENCE_NOISE, METADATA, NOISE, POSITIVE_CONDITIONING } from './constants';
|
||||
import { CANVAS_COHERENCE_NOISE, METADATA, NOISE, POSITIVE_CONDITIONING, PROMPT_REGION_MASK_PREFIX } from './constants';
|
||||
import { getHasMetadata, removeMetadata } from './metadata';
|
||||
|
||||
export const prepareLinearUIBatch = (state: RootState, graph: NonNullableGraph, prepend: boolean): BatchConfig => {
|
||||
@ -86,23 +86,27 @@ export const prepareLinearUIBatch = (state: RootState, graph: NonNullableGraph,
|
||||
|
||||
const extendedPrompts = seedBehaviour === 'PER_PROMPT' ? range(iterations).flatMap(() => prompts) : prompts;
|
||||
|
||||
// zipped batch of prompts
|
||||
if (graph.nodes[POSITIVE_CONDITIONING]) {
|
||||
firstBatchDatumList.push({
|
||||
node_path: POSITIVE_CONDITIONING,
|
||||
field_name: 'prompt',
|
||||
items: extendedPrompts,
|
||||
});
|
||||
}
|
||||
const hasRegionalPrompts = some(graph.nodes, (n) => n.id.startsWith(PROMPT_REGION_MASK_PREFIX));
|
||||
|
||||
// add to metadata
|
||||
if (getHasMetadata(graph)) {
|
||||
removeMetadata(graph, 'positive_prompt');
|
||||
firstBatchDatumList.push({
|
||||
node_path: METADATA,
|
||||
field_name: 'positive_prompt',
|
||||
items: extendedPrompts,
|
||||
});
|
||||
if (!hasRegionalPrompts) {
|
||||
// zipped batch of prompts
|
||||
if (graph.nodes[POSITIVE_CONDITIONING]) {
|
||||
firstBatchDatumList.push({
|
||||
node_path: POSITIVE_CONDITIONING,
|
||||
field_name: 'prompt',
|
||||
items: extendedPrompts,
|
||||
});
|
||||
}
|
||||
|
||||
// add to metadata
|
||||
if (getHasMetadata(graph)) {
|
||||
removeMetadata(graph, 'positive_prompt');
|
||||
firstBatchDatumList.push({
|
||||
node_path: METADATA,
|
||||
field_name: 'positive_prompt',
|
||||
items: extendedPrompts,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (shouldConcatSDXLStylePrompt && model?.base === 'sdxl') {
|
||||
|
@ -1,6 +1,7 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import { addRegionalPromptsToGraph } from 'features/nodes/util/graph/addRegionalPromptsToGraph';
|
||||
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
|
||||
|
||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
@ -273,6 +274,8 @@ export const buildLinearSDXLTextToImageGraph = async (state: RootState): Promise
|
||||
|
||||
await addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
await addRegionalPromptsToGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
// NSFW & watermark - must be last thing added to graph
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
// must add before watermarker!
|
||||
|
@ -46,6 +46,10 @@ export const SDXL_REFINER_DENOISE_LATENTS = 'sdxl_refiner_denoise_latents';
|
||||
export const SDXL_REFINER_INPAINT_CREATE_MASK = 'refiner_inpaint_create_mask';
|
||||
export const SEAMLESS = 'seamless';
|
||||
export const SDXL_REFINER_SEAMLESS = 'refiner_seamless';
|
||||
export const PROMPT_REGION_MASK_PREFIX = 'prompt_region_mask';
|
||||
export const PROMPT_REGION_COND_PREFIX = 'prompt_region_cond';
|
||||
export const POSITIVE_CONDITIONING_COLLECT = 'positive_conditioning_collect';
|
||||
export const NEGATIVE_CONDITIONING_COLLECT = 'negative_conditioning_collect';
|
||||
|
||||
// friendly graph ids
|
||||
export const TEXT_TO_IMAGE_GRAPH = 'text_to_image_graph';
|
||||
|
@ -18,13 +18,13 @@ const selectLayerIdsReversed = createMemoizedSelector(selectRegionalPromptsSlice
|
||||
);
|
||||
|
||||
const debugBlobs = () => {
|
||||
getRegionalPromptLayerBlobs(true);
|
||||
getRegionalPromptLayerBlobs(undefined, true);
|
||||
};
|
||||
|
||||
export const RegionalPromptsEditor = memo(() => {
|
||||
const layerIdsReversed = useAppSelector(selectLayerIdsReversed);
|
||||
return (
|
||||
<Flex gap={4}>
|
||||
<Flex gap={4} w="full" h="full">
|
||||
<Flex flexDir="column" gap={4} flexShrink={0}>
|
||||
<Button onClick={debugBlobs}>DEBUG LAYERS</Button>
|
||||
<AddLayerButton />
|
||||
|
@ -1,4 +1,3 @@
|
||||
import { getStore } from 'app/store/nanostores/store';
|
||||
import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
|
||||
import { blobToDataURL } from 'features/canvas/util/blobToDataURL';
|
||||
import { selectPromptLayerObjectGroup } from 'features/regionalPrompts/components/LayerComponent';
|
||||
@ -7,16 +6,24 @@ import Konva from 'konva';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
/**
|
||||
* Get the blobs of all regional prompt layers.
|
||||
* Get the blobs of all regional prompt layers. Only visible layers are returned.
|
||||
* @param layerIds The IDs of the layers to get blobs for. If not provided, all regional prompt layers are used.
|
||||
* @param preview Whether to open a new tab displaying each layer.
|
||||
* @returns A map of layer IDs to blobs.
|
||||
*/
|
||||
export const getRegionalPromptLayerBlobs = async (preview: boolean = false): Promise<Record<string, Blob>> => {
|
||||
const state = getStore().getState();
|
||||
export const getRegionalPromptLayerBlobs = async (
|
||||
layerIds?: string[],
|
||||
preview: boolean = false
|
||||
): Promise<Record<string, Blob>> => {
|
||||
const stage = getStage();
|
||||
|
||||
// This automatically omits layers that are not rendered. Rendering is controlled by the layer's `isVisible` flag in redux.
|
||||
const regionalPromptLayers = stage.getLayers().filter((l) => l.name() === REGIONAL_PROMPT_LAYER_NAME);
|
||||
const regionalPromptLayers = stage.getLayers().filter((l) => {
|
||||
console.log(l.name(), l.id())
|
||||
const isRegionalPromptLayer = l.name() === REGIONAL_PROMPT_LAYER_NAME;
|
||||
const isRequestedLayerId = layerIds ? layerIds.includes(l.id()) : true;
|
||||
return isRegionalPromptLayer && isRequestedLayerId;
|
||||
});
|
||||
|
||||
// We need to reconstruct each layer to only output the desired data. This logic mirrors the logic in
|
||||
// `getKonvaLayerBbox()` in `invokeai/frontend/web/src/features/regionalPrompts/util/bbox.ts`
|
||||
@ -48,14 +55,13 @@ export const getRegionalPromptLayerBlobs = async (preview: boolean = false): Pro
|
||||
},
|
||||
});
|
||||
});
|
||||
blobs[layer.id()] = blob;
|
||||
|
||||
if (preview) {
|
||||
const base64 = await blobToDataURL(blob);
|
||||
const prompt = state.regionalPrompts.layers.find((l) => l.id === layer.id())?.prompt;
|
||||
openBase64ImageInTab([{ base64, caption: prompt ?? '' }]);
|
||||
openBase64ImageInTab([{ base64, caption: layer.id() }]);
|
||||
}
|
||||
layerClone.destroy();
|
||||
blobs[layer.id()] = blob;
|
||||
}
|
||||
|
||||
return blobs;
|
||||
|
@ -1,13 +1,26 @@
|
||||
import { Box, Flex } from '@invoke-ai/ui-library';
|
||||
import { Box, Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/ui-library';
|
||||
import CurrentImageDisplay from 'features/gallery/components/CurrentImage/CurrentImageDisplay';
|
||||
import { RegionalPromptsEditor } from 'features/regionalPrompts/components/RegionalPromptsEditor';
|
||||
import { memo } from 'react';
|
||||
|
||||
const TextToImageTab = () => {
|
||||
return (
|
||||
<Box layerStyle="first" position="relative" w="full" h="full" p={2} borderRadius="base">
|
||||
<Flex w="full" h="full">
|
||||
<CurrentImageDisplay />
|
||||
</Flex>
|
||||
<Box position="relative" w="full" h="full" p={2} borderRadius="base">
|
||||
<Tabs variant="line" isLazy={true} display="flex" flexDir="column" w="full" h="full">
|
||||
<TabList>
|
||||
<Tab>Viewer</Tab>
|
||||
<Tab>Regional Prompts</Tab>
|
||||
</TabList>
|
||||
|
||||
<TabPanels w="full" h="full">
|
||||
<TabPanel>
|
||||
<CurrentImageDisplay />
|
||||
</TabPanel>
|
||||
<TabPanel>
|
||||
<RegionalPromptsEditor />
|
||||
</TabPanel>
|
||||
</TabPanels>
|
||||
</Tabs>
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user