From 7d8ece45bb40c646537132c2e2a8dbbf2d4417f1 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 28 Jun 2024 18:28:27 +1000 Subject: [PATCH] fix(ui): batch building after removing canvas files --- .../listeners/enqueueRequestedLinear.ts | 8 +- .../util/graph/buildLinearBatchConfig.ts | 91 +++++++------------ .../util/graph/generation/addSDXLRefiner.ts | 5 +- .../util/graph/generation/buildSD1Graph.ts | 5 +- .../util/graph/generation/buildSDXLGraph.ts | 6 +- 5 files changed, 44 insertions(+), 71 deletions(-) diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts index 8ded74a06d..1a476f889f 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts @@ -19,7 +19,7 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening) const model = state.canvasV2.params.model; const { prepend } = action.payload; - let graph; + let g; const manager = getNodeManager(); assert(model, 'No model found in state'); @@ -29,14 +29,14 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening) manager.getImageSourceImage({ bbox: state.canvasV2.bbox, preview: true }); if (base === 'sdxl') { - graph = await buildSDXLGraph(state, manager); + g = await buildSDXLGraph(state, manager); } else if (base === 'sd-1' || base === 'sd-2') { - graph = await buildSD1Graph(state, manager); + g = await buildSD1Graph(state, manager); } else { assert(false, `No graph builders for base ${base}`); } - const batchConfig = prepareLinearUIBatch(state, graph, prepend); + const batchConfig = prepareLinearUIBatch(state, g, prepend); const req = dispatch( queueApi.endpoints.enqueueBatch.initiate(batchConfig, { diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearBatchConfig.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearBatchConfig.ts index 7fd8ab1065..bb282863b9 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearBatchConfig.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearBatchConfig.ts @@ -1,14 +1,13 @@ -import { NUMPY_RAND_MAX } from 'app/constants'; import type { RootState } from 'app/store/store'; import { generateSeeds } from 'common/util/generateSeeds'; +import type { Graph } from 'features/nodes/util/graph/generation/Graph'; import { range } from 'lodash-es'; import type { components } from 'services/api/schema'; -import type { Batch, BatchConfig, NonNullableGraph } from 'services/api/types'; +import type { Batch, BatchConfig } from 'services/api/types'; -import { getHasMetadata, removeMetadata } from './canvas/metadata'; -import { CANVAS_COHERENCE_NOISE, METADATA, NOISE, POSITIVE_CONDITIONING } from './constants'; +import { NOISE, POSITIVE_CONDITIONING } from './constants'; -export const prepareLinearUIBatch = (state: RootState, graph: NonNullableGraph, prepend: boolean): BatchConfig => { +export const prepareLinearUIBatch = (state: RootState, g: Graph, prepend: boolean): BatchConfig => { const { iterations, model, shouldRandomizeSeed, seed, shouldConcatPrompts } = state.canvasV2.params; const { prompts, seedBehaviour } = state.dynamicPrompts; @@ -23,7 +22,7 @@ export const prepareLinearUIBatch = (state: RootState, graph: NonNullableGraph, start: shouldRandomizeSeed ? undefined : seed, }); - if (graph.nodes[NOISE]) { + if (g.hasNode(NOISE)) { firstBatchDatumList.push({ node_path: NOISE, field_name: 'seed', @@ -32,22 +31,12 @@ export const prepareLinearUIBatch = (state: RootState, graph: NonNullableGraph, } // add to metadata - if (getHasMetadata(graph)) { - removeMetadata(graph, 'seed'); - firstBatchDatumList.push({ - node_path: METADATA, - field_name: 'seed', - items: seeds, - }); - } - - if (graph.nodes[CANVAS_COHERENCE_NOISE]) { - firstBatchDatumList.push({ - node_path: CANVAS_COHERENCE_NOISE, - field_name: 'seed', - items: seeds.map((seed) => (seed + 1) % NUMPY_RAND_MAX), - }); - } + g.removeMetadata(['seed']); + firstBatchDatumList.push({ + node_path: g.getMetadataNode().id, + field_name: 'seed', + items: seeds, + }); } else { // seedBehaviour = SeedBehaviour.PerRun const seeds = generateSeeds({ @@ -55,7 +44,7 @@ export const prepareLinearUIBatch = (state: RootState, graph: NonNullableGraph, start: shouldRandomizeSeed ? undefined : seed, }); - if (graph.nodes[NOISE]) { + if (g.hasNode(NOISE)) { secondBatchDatumList.push({ node_path: NOISE, field_name: 'seed', @@ -64,29 +53,19 @@ export const prepareLinearUIBatch = (state: RootState, graph: NonNullableGraph, } // add to metadata - if (getHasMetadata(graph)) { - removeMetadata(graph, 'seed'); - secondBatchDatumList.push({ - node_path: METADATA, - field_name: 'seed', - items: seeds, - }); - } - - if (graph.nodes[CANVAS_COHERENCE_NOISE]) { - secondBatchDatumList.push({ - node_path: CANVAS_COHERENCE_NOISE, - field_name: 'seed', - items: seeds.map((seed) => (seed + 1) % NUMPY_RAND_MAX), - }); - } + g.removeMetadata(['seed']); + secondBatchDatumList.push({ + node_path: g.getMetadataNode().id, + field_name: 'seed', + items: seeds, + }); data.push(secondBatchDatumList); } const extendedPrompts = seedBehaviour === 'PER_PROMPT' ? range(iterations).flatMap(() => prompts) : prompts; // zipped batch of prompts - if (graph.nodes[POSITIVE_CONDITIONING]) { + if (g.hasNode(POSITIVE_CONDITIONING)) { firstBatchDatumList.push({ node_path: POSITIVE_CONDITIONING, field_name: 'prompt', @@ -95,17 +74,15 @@ export const prepareLinearUIBatch = (state: RootState, graph: NonNullableGraph, } // add to metadata - if (getHasMetadata(graph)) { - removeMetadata(graph, 'positive_prompt'); - firstBatchDatumList.push({ - node_path: METADATA, - field_name: 'positive_prompt', - items: extendedPrompts, - }); - } + g.removeMetadata(['positive_prompt']); + firstBatchDatumList.push({ + node_path: g.getMetadataNode().id, + field_name: 'positive_prompt', + items: extendedPrompts, + }); if (shouldConcatPrompts && model?.base === 'sdxl') { - if (graph.nodes[POSITIVE_CONDITIONING]) { + if (g.hasNode(POSITIVE_CONDITIONING)) { firstBatchDatumList.push({ node_path: POSITIVE_CONDITIONING, field_name: 'style', @@ -114,14 +91,12 @@ export const prepareLinearUIBatch = (state: RootState, graph: NonNullableGraph, } // add to metadata - if (getHasMetadata(graph)) { - removeMetadata(graph, 'positive_style_prompt'); - firstBatchDatumList.push({ - node_path: METADATA, - field_name: 'positive_style_prompt', - items: extendedPrompts, - }); - } + g.removeMetadata(['positive_style_prompt']); + firstBatchDatumList.push({ + node_path: g.getMetadataNode().id, + field_name: 'positive_style_prompt', + items: extendedPrompts, + }); } data.push(firstBatchDatumList); @@ -129,7 +104,7 @@ export const prepareLinearUIBatch = (state: RootState, graph: NonNullableGraph, const enqueueBatchArg: BatchConfig = { prepend, batch: { - graph, + graph: g.getGraph(), runs: 1, data, }, diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addSDXLRefiner.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addSDXLRefiner.ts index f92e3cf7f8..7e79ffe4ff 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addSDXLRefiner.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addSDXLRefiner.ts @@ -1,6 +1,5 @@ import type { RootState } from 'app/store/store'; import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; -import { getModelMetadataField } from 'features/nodes/util/graph/canvas/metadata'; import { SDXL_REFINER_DENOISE_LATENTS, SDXL_REFINER_MODEL_LOADER, @@ -8,7 +7,7 @@ import { SDXL_REFINER_POSITIVE_CONDITIONING, SDXL_REFINER_SEAMLESS, } from 'features/nodes/util/graph/constants'; -import type { Graph } from 'features/nodes/util/graph/generation/Graph'; +import { Graph } from 'features/nodes/util/graph/generation/Graph'; import type { Invocation } from 'services/api/types'; import { isRefinerMainModelModelConfig } from 'services/api/types'; import { assert } from 'tsafe'; @@ -89,7 +88,7 @@ export const addSDXLRefiner = async ( g.addEdge(refinerDenoise, 'latents', l2i, 'latents'); g.upsertMetadata({ - refiner_model: getModelMetadataField(modelConfig), + refiner_model: Graph.getModelMetadataField(modelConfig), refiner_positive_aesthetic_score: refinerPositiveAestheticScore, refiner_negative_aesthetic_score: refinerNegativeAestheticScore, refiner_cfg_scale: refinerCFGScale, diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts index f55c8ad6b4..9b929cc9ce 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts @@ -26,7 +26,6 @@ import { addOutpaint } from 'features/nodes/util/graph/generation/addOutpaint'; import { addSeamless } from 'features/nodes/util/graph/generation/addSeamless'; import { addTextToImage } from 'features/nodes/util/graph/generation/addTextToImage'; import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker'; -import type { GraphType } from 'features/nodes/util/graph/generation/Graph'; import { Graph } from 'features/nodes/util/graph/generation/Graph'; import { getBoardField, getSizes } from 'features/nodes/util/graph/graphBuilderUtils'; import type { Invocation } from 'services/api/types'; @@ -35,7 +34,7 @@ import { assert } from 'tsafe'; import { addRegions } from './addRegions'; -export const buildSD1Graph = async (state: RootState, manager: KonvaNodeManager): Promise => { +export const buildSD1Graph = async (state: RootState, manager: KonvaNodeManager): Promise => { const generationMode = manager.getGenerationMode(); const { bbox, params } = state.canvasV2; @@ -248,5 +247,5 @@ export const buildSD1Graph = async (state: RootState, manager: KonvaNodeManager) }); g.setMetadataReceivingNode(canvasOutput); - return g.getGraph(); + return g; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts index d75044c736..04c0f0cf2f 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts @@ -27,13 +27,13 @@ import { addTextToImage } from 'features/nodes/util/graph/generation/addTextToIm import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker'; import { Graph } from 'features/nodes/util/graph/generation/Graph'; import { getBoardField, getSDXLStylePrompts, getSizes } from 'features/nodes/util/graph/graphBuilderUtils'; -import type { Invocation, NonNullableGraph } from 'services/api/types'; +import type { Invocation } from 'services/api/types'; import { isNonRefinerMainModelConfig } from 'services/api/types'; import { assert } from 'tsafe'; import { addRegions } from './addRegions'; -export const buildSDXLGraph = async (state: RootState, manager: KonvaNodeManager): Promise => { +export const buildSDXLGraph = async (state: RootState, manager: KonvaNodeManager): Promise => { const generationMode = manager.getGenerationMode(); const { bbox, params } = state.canvasV2; @@ -246,5 +246,5 @@ export const buildSDXLGraph = async (state: RootState, manager: KonvaNodeManager }); g.setMetadataReceivingNode(canvasOutput); - return g.getGraph(); + return g; };