fix(ui): batch building after removing canvas files

This commit is contained in:
psychedelicious 2024-06-28 18:28:27 +10:00
parent 0c26d28278
commit ec6361e5cb
5 changed files with 44 additions and 71 deletions

View File

@ -19,7 +19,7 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
const model = state.canvasV2.params.model; const model = state.canvasV2.params.model;
const { prepend } = action.payload; const { prepend } = action.payload;
let graph; let g;
const manager = getNodeManager(); const manager = getNodeManager();
assert(model, 'No model found in state'); assert(model, 'No model found in state');
@ -29,14 +29,14 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
manager.getImageSourceImage({ bbox: state.canvasV2.bbox, preview: true }); manager.getImageSourceImage({ bbox: state.canvasV2.bbox, preview: true });
if (base === 'sdxl') { if (base === 'sdxl') {
graph = await buildSDXLGraph(state, manager); g = await buildSDXLGraph(state, manager);
} else if (base === 'sd-1' || base === 'sd-2') { } else if (base === 'sd-1' || base === 'sd-2') {
graph = await buildSD1Graph(state, manager); g = await buildSD1Graph(state, manager);
} else { } else {
assert(false, `No graph builders for base ${base}`); assert(false, `No graph builders for base ${base}`);
} }
const batchConfig = prepareLinearUIBatch(state, graph, prepend); const batchConfig = prepareLinearUIBatch(state, g, prepend);
const req = dispatch( const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(batchConfig, { queueApi.endpoints.enqueueBatch.initiate(batchConfig, {

View File

@ -1,14 +1,13 @@
import { NUMPY_RAND_MAX } from 'app/constants';
import type { RootState } from 'app/store/store'; import type { RootState } from 'app/store/store';
import { generateSeeds } from 'common/util/generateSeeds'; import { generateSeeds } from 'common/util/generateSeeds';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import { range } from 'lodash-es'; import { range } from 'lodash-es';
import type { components } from 'services/api/schema'; import type { components } from 'services/api/schema';
import type { Batch, BatchConfig, NonNullableGraph } from 'services/api/types'; import type { Batch, BatchConfig } from 'services/api/types';
import { getHasMetadata, removeMetadata } from './canvas/metadata'; import { NOISE, POSITIVE_CONDITIONING } from './constants';
import { CANVAS_COHERENCE_NOISE, METADATA, NOISE, POSITIVE_CONDITIONING } from './constants';
export const prepareLinearUIBatch = (state: RootState, graph: NonNullableGraph, prepend: boolean): BatchConfig => { export const prepareLinearUIBatch = (state: RootState, g: Graph, prepend: boolean): BatchConfig => {
const { iterations, model, shouldRandomizeSeed, seed, shouldConcatPrompts } = state.canvasV2.params; const { iterations, model, shouldRandomizeSeed, seed, shouldConcatPrompts } = state.canvasV2.params;
const { prompts, seedBehaviour } = state.dynamicPrompts; const { prompts, seedBehaviour } = state.dynamicPrompts;
@ -23,7 +22,7 @@ export const prepareLinearUIBatch = (state: RootState, graph: NonNullableGraph,
start: shouldRandomizeSeed ? undefined : seed, start: shouldRandomizeSeed ? undefined : seed,
}); });
if (graph.nodes[NOISE]) { if (g.hasNode(NOISE)) {
firstBatchDatumList.push({ firstBatchDatumList.push({
node_path: NOISE, node_path: NOISE,
field_name: 'seed', field_name: 'seed',
@ -32,22 +31,12 @@ export const prepareLinearUIBatch = (state: RootState, graph: NonNullableGraph,
} }
// add to metadata // add to metadata
if (getHasMetadata(graph)) { g.removeMetadata(['seed']);
removeMetadata(graph, 'seed');
firstBatchDatumList.push({ firstBatchDatumList.push({
node_path: METADATA, node_path: g.getMetadataNode().id,
field_name: 'seed', field_name: 'seed',
items: seeds, items: seeds,
}); });
}
if (graph.nodes[CANVAS_COHERENCE_NOISE]) {
firstBatchDatumList.push({
node_path: CANVAS_COHERENCE_NOISE,
field_name: 'seed',
items: seeds.map((seed) => (seed + 1) % NUMPY_RAND_MAX),
});
}
} else { } else {
// seedBehaviour = SeedBehaviour.PerRun // seedBehaviour = SeedBehaviour.PerRun
const seeds = generateSeeds({ const seeds = generateSeeds({
@ -55,7 +44,7 @@ export const prepareLinearUIBatch = (state: RootState, graph: NonNullableGraph,
start: shouldRandomizeSeed ? undefined : seed, start: shouldRandomizeSeed ? undefined : seed,
}); });
if (graph.nodes[NOISE]) { if (g.hasNode(NOISE)) {
secondBatchDatumList.push({ secondBatchDatumList.push({
node_path: NOISE, node_path: NOISE,
field_name: 'seed', field_name: 'seed',
@ -64,29 +53,19 @@ export const prepareLinearUIBatch = (state: RootState, graph: NonNullableGraph,
} }
// add to metadata // add to metadata
if (getHasMetadata(graph)) { g.removeMetadata(['seed']);
removeMetadata(graph, 'seed');
secondBatchDatumList.push({ secondBatchDatumList.push({
node_path: METADATA, node_path: g.getMetadataNode().id,
field_name: 'seed', field_name: 'seed',
items: seeds, items: seeds,
}); });
}
if (graph.nodes[CANVAS_COHERENCE_NOISE]) {
secondBatchDatumList.push({
node_path: CANVAS_COHERENCE_NOISE,
field_name: 'seed',
items: seeds.map((seed) => (seed + 1) % NUMPY_RAND_MAX),
});
}
data.push(secondBatchDatumList); data.push(secondBatchDatumList);
} }
const extendedPrompts = seedBehaviour === 'PER_PROMPT' ? range(iterations).flatMap(() => prompts) : prompts; const extendedPrompts = seedBehaviour === 'PER_PROMPT' ? range(iterations).flatMap(() => prompts) : prompts;
// zipped batch of prompts // zipped batch of prompts
if (graph.nodes[POSITIVE_CONDITIONING]) { if (g.hasNode(POSITIVE_CONDITIONING)) {
firstBatchDatumList.push({ firstBatchDatumList.push({
node_path: POSITIVE_CONDITIONING, node_path: POSITIVE_CONDITIONING,
field_name: 'prompt', field_name: 'prompt',
@ -95,17 +74,15 @@ export const prepareLinearUIBatch = (state: RootState, graph: NonNullableGraph,
} }
// add to metadata // add to metadata
if (getHasMetadata(graph)) { g.removeMetadata(['positive_prompt']);
removeMetadata(graph, 'positive_prompt');
firstBatchDatumList.push({ firstBatchDatumList.push({
node_path: METADATA, node_path: g.getMetadataNode().id,
field_name: 'positive_prompt', field_name: 'positive_prompt',
items: extendedPrompts, items: extendedPrompts,
}); });
}
if (shouldConcatPrompts && model?.base === 'sdxl') { if (shouldConcatPrompts && model?.base === 'sdxl') {
if (graph.nodes[POSITIVE_CONDITIONING]) { if (g.hasNode(POSITIVE_CONDITIONING)) {
firstBatchDatumList.push({ firstBatchDatumList.push({
node_path: POSITIVE_CONDITIONING, node_path: POSITIVE_CONDITIONING,
field_name: 'style', field_name: 'style',
@ -114,22 +91,20 @@ export const prepareLinearUIBatch = (state: RootState, graph: NonNullableGraph,
} }
// add to metadata // add to metadata
if (getHasMetadata(graph)) { g.removeMetadata(['positive_style_prompt']);
removeMetadata(graph, 'positive_style_prompt');
firstBatchDatumList.push({ firstBatchDatumList.push({
node_path: METADATA, node_path: g.getMetadataNode().id,
field_name: 'positive_style_prompt', field_name: 'positive_style_prompt',
items: extendedPrompts, items: extendedPrompts,
}); });
} }
}
data.push(firstBatchDatumList); data.push(firstBatchDatumList);
const enqueueBatchArg: BatchConfig = { const enqueueBatchArg: BatchConfig = {
prepend, prepend,
batch: { batch: {
graph, graph: g.getGraph(),
runs: 1, runs: 1,
data, data,
}, },

View File

@ -1,6 +1,5 @@
import type { RootState } from 'app/store/store'; import type { RootState } from 'app/store/store';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { getModelMetadataField } from 'features/nodes/util/graph/canvas/metadata';
import { import {
SDXL_REFINER_DENOISE_LATENTS, SDXL_REFINER_DENOISE_LATENTS,
SDXL_REFINER_MODEL_LOADER, SDXL_REFINER_MODEL_LOADER,
@ -8,7 +7,7 @@ import {
SDXL_REFINER_POSITIVE_CONDITIONING, SDXL_REFINER_POSITIVE_CONDITIONING,
SDXL_REFINER_SEAMLESS, SDXL_REFINER_SEAMLESS,
} from 'features/nodes/util/graph/constants'; } from 'features/nodes/util/graph/constants';
import type { Graph } from 'features/nodes/util/graph/generation/Graph'; import { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { Invocation } from 'services/api/types'; import type { Invocation } from 'services/api/types';
import { isRefinerMainModelModelConfig } from 'services/api/types'; import { isRefinerMainModelModelConfig } from 'services/api/types';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
@ -89,7 +88,7 @@ export const addSDXLRefiner = async (
g.addEdge(refinerDenoise, 'latents', l2i, 'latents'); g.addEdge(refinerDenoise, 'latents', l2i, 'latents');
g.upsertMetadata({ g.upsertMetadata({
refiner_model: getModelMetadataField(modelConfig), refiner_model: Graph.getModelMetadataField(modelConfig),
refiner_positive_aesthetic_score: refinerPositiveAestheticScore, refiner_positive_aesthetic_score: refinerPositiveAestheticScore,
refiner_negative_aesthetic_score: refinerNegativeAestheticScore, refiner_negative_aesthetic_score: refinerNegativeAestheticScore,
refiner_cfg_scale: refinerCFGScale, refiner_cfg_scale: refinerCFGScale,

View File

@ -26,7 +26,6 @@ import { addOutpaint } from 'features/nodes/util/graph/generation/addOutpaint';
import { addSeamless } from 'features/nodes/util/graph/generation/addSeamless'; import { addSeamless } from 'features/nodes/util/graph/generation/addSeamless';
import { addTextToImage } from 'features/nodes/util/graph/generation/addTextToImage'; import { addTextToImage } from 'features/nodes/util/graph/generation/addTextToImage';
import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker'; import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker';
import type { GraphType } from 'features/nodes/util/graph/generation/Graph';
import { Graph } from 'features/nodes/util/graph/generation/Graph'; import { Graph } from 'features/nodes/util/graph/generation/Graph';
import { getBoardField, getSizes } from 'features/nodes/util/graph/graphBuilderUtils'; import { getBoardField, getSizes } from 'features/nodes/util/graph/graphBuilderUtils';
import type { Invocation } from 'services/api/types'; import type { Invocation } from 'services/api/types';
@ -35,7 +34,7 @@ import { assert } from 'tsafe';
import { addRegions } from './addRegions'; import { addRegions } from './addRegions';
export const buildSD1Graph = async (state: RootState, manager: KonvaNodeManager): Promise<GraphType> => { export const buildSD1Graph = async (state: RootState, manager: KonvaNodeManager): Promise<Graph> => {
const generationMode = manager.getGenerationMode(); const generationMode = manager.getGenerationMode();
const { bbox, params } = state.canvasV2; const { bbox, params } = state.canvasV2;
@ -248,5 +247,5 @@ export const buildSD1Graph = async (state: RootState, manager: KonvaNodeManager)
}); });
g.setMetadataReceivingNode(canvasOutput); g.setMetadataReceivingNode(canvasOutput);
return g.getGraph(); return g;
}; };

View File

@ -27,13 +27,13 @@ import { addTextToImage } from 'features/nodes/util/graph/generation/addTextToIm
import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker'; import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker';
import { Graph } from 'features/nodes/util/graph/generation/Graph'; import { Graph } from 'features/nodes/util/graph/generation/Graph';
import { getBoardField, getSDXLStylePrompts, getSizes } from 'features/nodes/util/graph/graphBuilderUtils'; import { getBoardField, getSDXLStylePrompts, getSizes } from 'features/nodes/util/graph/graphBuilderUtils';
import type { Invocation, NonNullableGraph } from 'services/api/types'; import type { Invocation } from 'services/api/types';
import { isNonRefinerMainModelConfig } from 'services/api/types'; import { isNonRefinerMainModelConfig } from 'services/api/types';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
import { addRegions } from './addRegions'; import { addRegions } from './addRegions';
export const buildSDXLGraph = async (state: RootState, manager: KonvaNodeManager): Promise<NonNullableGraph> => { export const buildSDXLGraph = async (state: RootState, manager: KonvaNodeManager): Promise<Graph> => {
const generationMode = manager.getGenerationMode(); const generationMode = manager.getGenerationMode();
const { bbox, params } = state.canvasV2; const { bbox, params } = state.canvasV2;
@ -246,5 +246,5 @@ export const buildSDXLGraph = async (state: RootState, manager: KonvaNodeManager
}); });
g.setMetadataReceivingNode(canvasOutput); g.setMetadataReceivingNode(canvasOutput);
return g.getGraph(); return g;
}; };