feat(ui): draft of graph helper for regional prompts

This commit is contained in:
psychedelicious
2024-04-15 19:27:01 +10:00
committed by Kent Keirsey
parent 602a59066e
commit 05deeb68fa
9 changed files with 254 additions and 35 deletions

View File

@ -0,0 +1,151 @@
import { getStore } from 'app/store/nanostores/store';
import type { RootState } from 'app/store/store';
import {
NEGATIVE_CONDITIONING_COLLECT,
POSITIVE_CONDITIONING,
POSITIVE_CONDITIONING_COLLECT,
PROMPT_REGION_COND_PREFIX,
PROMPT_REGION_MASK_PREFIX,
} from 'features/nodes/util/graph/constants';
import { getRegionalPromptLayerBlobs } from 'features/regionalPrompts/util/getLayerBlobs';
import { size } from 'lodash-es';
import { imagesApi } from 'services/api/endpoints/images';
import type { CollectInvocation, Edge, NonNullableGraph, S } from 'services/api/types';
import { assert } from 'tsafe';
export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNullableGraph, denoiseNodeId: string) => {
const { dispatch } = getStore();
const isSDXL = state.generation.model?.base === 'sdxl';
const layers = state.regionalPrompts.layers
.filter((l) => l.kind === 'promptRegionLayer') // We only want the prompt region layers
.filter((l) => l.isVisible); // Only visible layers are rendered on the canvas
const layerIds = layers.map((l) => l.id); // We only need the IDs
const blobs = await getRegionalPromptLayerBlobs(layerIds);
console.log('blobs', blobs, 'layerIds', layerIds);
assert(size(blobs) === size(layerIds), 'Mismatch between layer IDs and blobs');
// Set up the conditioning collectors
const posCondCollectNode: CollectInvocation = {
id: POSITIVE_CONDITIONING_COLLECT,
type: 'collect',
};
const negCondCollectNode: CollectInvocation = {
id: NEGATIVE_CONDITIONING_COLLECT,
type: 'collect',
};
graph.nodes[POSITIVE_CONDITIONING_COLLECT] = posCondCollectNode;
graph.nodes[NEGATIVE_CONDITIONING_COLLECT] = negCondCollectNode;
// Re-route the denoise node's OG conditioning inputs to the collect nodes
const newEdges: Edge[] = [];
for (const edge of graph.edges) {
if (edge.destination.node_id === denoiseNodeId && edge.destination.field === 'positive_conditioning') {
newEdges.push({
source: edge.source,
destination: {
node_id: POSITIVE_CONDITIONING_COLLECT,
field: 'item',
},
});
} else if (edge.destination.node_id === denoiseNodeId && edge.destination.field === 'negative_conditioning') {
newEdges.push({
source: edge.source,
destination: {
node_id: NEGATIVE_CONDITIONING_COLLECT,
field: 'item',
},
});
} else {
newEdges.push(edge);
}
}
graph.edges = newEdges;
// Connect collectors to the denoise nodes - must happen _after_ rerouting else you get cycles
graph.edges.push({
source: {
node_id: POSITIVE_CONDITIONING_COLLECT,
field: 'collection',
},
destination: {
node_id: denoiseNodeId,
field: 'positive_conditioning',
},
});
graph.edges.push({
source: {
node_id: NEGATIVE_CONDITIONING_COLLECT,
field: 'collection',
},
destination: {
node_id: denoiseNodeId,
field: 'negative_conditioning',
},
});
// Remove the global prompt
(graph.nodes[POSITIVE_CONDITIONING] as S['SDXLCompelPromptInvocation'] | S['CompelInvocation']).prompt = '';
// Upload the blobs to the backend, add each to graph
for (const [layerId, blob] of Object.entries(blobs)) {
const layer = layers.find((l) => l.id === layerId);
assert(layer, `Layer ${layerId} not found`);
const id = `${PROMPT_REGION_MASK_PREFIX}_${layerId}`;
const file = new File([blob], `${id}.png`, { type: 'image/png' });
const req = dispatch(
imagesApi.endpoints.uploadImage.initiate({ file, image_category: 'mask', is_intermediate: true })
);
req.reset();
// TODO: this will raise an error
const { image_name } = await req.unwrap();
const alphaMaskToTensorNode: S['AlphaMaskToTensorInvocation'] = {
id,
type: 'alpha_mask_to_tensor',
image: {
image_name,
},
};
graph.nodes[id] = alphaMaskToTensorNode;
// Create the conditioning nodes for each region - different handling for SDXL
// TODO: negative prompt
const regionalCondNodeId = `${PROMPT_REGION_COND_PREFIX}_${layerId}`;
if (isSDXL) {
graph.nodes[regionalCondNodeId] = {
type: 'sdxl_compel_prompt',
id: regionalCondNodeId,
prompt: layer.prompt,
};
} else {
graph.nodes[regionalCondNodeId] = {
type: 'compel',
id: regionalCondNodeId,
prompt: layer.prompt,
};
}
graph.edges.push({
source: { node_id: id, field: 'mask' },
destination: { node_id: regionalCondNodeId, field: 'mask' },
});
graph.edges.push({
source: { node_id: regionalCondNodeId, field: 'conditioning' },
destination: { node_id: posCondCollectNode.id, field: 'item' },
});
for (const edge of graph.edges) {
if (edge.destination.node_id === POSITIVE_CONDITIONING && edge.destination.field !== 'prompt') {
graph.edges.push({
source: edge.source,
destination: { node_id: regionalCondNodeId, field: edge.destination.field },
});
}
}
}
};

View File

@ -1,11 +1,11 @@
import { NUMPY_RAND_MAX } from 'app/constants';
import type { RootState } from 'app/store/store';
import { generateSeeds } from 'common/util/generateSeeds';
import { range } from 'lodash-es';
import { range, some } from 'lodash-es';
import type { components } from 'services/api/schema';
import type { Batch, BatchConfig, NonNullableGraph } from 'services/api/types';
import { CANVAS_COHERENCE_NOISE, METADATA, NOISE, POSITIVE_CONDITIONING } from './constants';
import { CANVAS_COHERENCE_NOISE, METADATA, NOISE, POSITIVE_CONDITIONING, PROMPT_REGION_MASK_PREFIX } from './constants';
import { getHasMetadata, removeMetadata } from './metadata';
export const prepareLinearUIBatch = (state: RootState, graph: NonNullableGraph, prepend: boolean): BatchConfig => {
@ -86,23 +86,27 @@ export const prepareLinearUIBatch = (state: RootState, graph: NonNullableGraph,
const extendedPrompts = seedBehaviour === 'PER_PROMPT' ? range(iterations).flatMap(() => prompts) : prompts;
// zipped batch of prompts
if (graph.nodes[POSITIVE_CONDITIONING]) {
firstBatchDatumList.push({
node_path: POSITIVE_CONDITIONING,
field_name: 'prompt',
items: extendedPrompts,
});
}
const hasRegionalPrompts = some(graph.nodes, (n) => n.id.startsWith(PROMPT_REGION_MASK_PREFIX));
// add to metadata
if (getHasMetadata(graph)) {
removeMetadata(graph, 'positive_prompt');
firstBatchDatumList.push({
node_path: METADATA,
field_name: 'positive_prompt',
items: extendedPrompts,
});
if (!hasRegionalPrompts) {
// zipped batch of prompts
if (graph.nodes[POSITIVE_CONDITIONING]) {
firstBatchDatumList.push({
node_path: POSITIVE_CONDITIONING,
field_name: 'prompt',
items: extendedPrompts,
});
}
// add to metadata
if (getHasMetadata(graph)) {
removeMetadata(graph, 'positive_prompt');
firstBatchDatumList.push({
node_path: METADATA,
field_name: 'positive_prompt',
items: extendedPrompts,
});
}
}
if (shouldConcatSDXLStylePrompt && model?.base === 'sdxl') {

View File

@ -1,6 +1,7 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { addRegionalPromptsToGraph } from 'features/nodes/util/graph/addRegionalPromptsToGraph';
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
@ -273,6 +274,8 @@ export const buildLinearSDXLTextToImageGraph = async (state: RootState): Promise
await addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
await addRegionalPromptsToGraph(state, graph, SDXL_DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!

View File

@ -46,6 +46,10 @@ export const SDXL_REFINER_DENOISE_LATENTS = 'sdxl_refiner_denoise_latents';
export const SDXL_REFINER_INPAINT_CREATE_MASK = 'refiner_inpaint_create_mask';
export const SEAMLESS = 'seamless';
export const SDXL_REFINER_SEAMLESS = 'refiner_seamless';
export const PROMPT_REGION_MASK_PREFIX = 'prompt_region_mask';
export const PROMPT_REGION_COND_PREFIX = 'prompt_region_cond';
export const POSITIVE_CONDITIONING_COLLECT = 'positive_conditioning_collect';
export const NEGATIVE_CONDITIONING_COLLECT = 'negative_conditioning_collect';
// friendly graph ids
export const TEXT_TO_IMAGE_GRAPH = 'text_to_image_graph';