fix(ui): batch building after removing canvas files

This commit is contained in:
psychedelicious 2024-06-28 18:28:27 +10:00
parent 8f5f9bd44e
commit 788bad61d0
5 changed files with 44 additions and 71 deletions

View File

@ -19,7 +19,7 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
const model = state.canvasV2.params.model;
const { prepend } = action.payload;
let graph;
let g;
const manager = getNodeManager();
assert(model, 'No model found in state');
@ -29,14 +29,14 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
manager.getImageSourceImage({ bbox: state.canvasV2.bbox, preview: true });
if (base === 'sdxl') {
graph = await buildSDXLGraph(state, manager);
g = await buildSDXLGraph(state, manager);
} else if (base === 'sd-1' || base === 'sd-2') {
graph = await buildSD1Graph(state, manager);
g = await buildSD1Graph(state, manager);
} else {
assert(false, `No graph builders for base ${base}`);
}
const batchConfig = prepareLinearUIBatch(state, graph, prepend);
const batchConfig = prepareLinearUIBatch(state, g, prepend);
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {

View File

@ -1,14 +1,13 @@
import { NUMPY_RAND_MAX } from 'app/constants';
import type { RootState } from 'app/store/store';
import { generateSeeds } from 'common/util/generateSeeds';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import { range } from 'lodash-es';
import type { components } from 'services/api/schema';
import type { Batch, BatchConfig, NonNullableGraph } from 'services/api/types';
import type { Batch, BatchConfig } from 'services/api/types';
import { getHasMetadata, removeMetadata } from './canvas/metadata';
import { CANVAS_COHERENCE_NOISE, METADATA, NOISE, POSITIVE_CONDITIONING } from './constants';
import { NOISE, POSITIVE_CONDITIONING } from './constants';
export const prepareLinearUIBatch = (state: RootState, graph: NonNullableGraph, prepend: boolean): BatchConfig => {
export const prepareLinearUIBatch = (state: RootState, g: Graph, prepend: boolean): BatchConfig => {
const { iterations, model, shouldRandomizeSeed, seed, shouldConcatPrompts } = state.canvasV2.params;
const { prompts, seedBehaviour } = state.dynamicPrompts;
@ -23,7 +22,7 @@ export const prepareLinearUIBatch = (state: RootState, graph: NonNullableGraph,
start: shouldRandomizeSeed ? undefined : seed,
});
if (graph.nodes[NOISE]) {
if (g.hasNode(NOISE)) {
firstBatchDatumList.push({
node_path: NOISE,
field_name: 'seed',
@ -32,22 +31,12 @@ export const prepareLinearUIBatch = (state: RootState, graph: NonNullableGraph,
}
// add to metadata
if (getHasMetadata(graph)) {
removeMetadata(graph, 'seed');
g.removeMetadata(['seed']);
firstBatchDatumList.push({
node_path: METADATA,
node_path: g.getMetadataNode().id,
field_name: 'seed',
items: seeds,
});
}
if (graph.nodes[CANVAS_COHERENCE_NOISE]) {
firstBatchDatumList.push({
node_path: CANVAS_COHERENCE_NOISE,
field_name: 'seed',
items: seeds.map((seed) => (seed + 1) % NUMPY_RAND_MAX),
});
}
} else {
// seedBehaviour = SeedBehaviour.PerRun
const seeds = generateSeeds({
@ -55,7 +44,7 @@ export const prepareLinearUIBatch = (state: RootState, graph: NonNullableGraph,
start: shouldRandomizeSeed ? undefined : seed,
});
if (graph.nodes[NOISE]) {
if (g.hasNode(NOISE)) {
secondBatchDatumList.push({
node_path: NOISE,
field_name: 'seed',
@ -64,29 +53,19 @@ export const prepareLinearUIBatch = (state: RootState, graph: NonNullableGraph,
}
// add to metadata
if (getHasMetadata(graph)) {
removeMetadata(graph, 'seed');
g.removeMetadata(['seed']);
secondBatchDatumList.push({
node_path: METADATA,
node_path: g.getMetadataNode().id,
field_name: 'seed',
items: seeds,
});
}
if (graph.nodes[CANVAS_COHERENCE_NOISE]) {
secondBatchDatumList.push({
node_path: CANVAS_COHERENCE_NOISE,
field_name: 'seed',
items: seeds.map((seed) => (seed + 1) % NUMPY_RAND_MAX),
});
}
data.push(secondBatchDatumList);
}
const extendedPrompts = seedBehaviour === 'PER_PROMPT' ? range(iterations).flatMap(() => prompts) : prompts;
// zipped batch of prompts
if (graph.nodes[POSITIVE_CONDITIONING]) {
if (g.hasNode(POSITIVE_CONDITIONING)) {
firstBatchDatumList.push({
node_path: POSITIVE_CONDITIONING,
field_name: 'prompt',
@ -95,17 +74,15 @@ export const prepareLinearUIBatch = (state: RootState, graph: NonNullableGraph,
}
// add to metadata
if (getHasMetadata(graph)) {
removeMetadata(graph, 'positive_prompt');
g.removeMetadata(['positive_prompt']);
firstBatchDatumList.push({
node_path: METADATA,
node_path: g.getMetadataNode().id,
field_name: 'positive_prompt',
items: extendedPrompts,
});
}
if (shouldConcatPrompts && model?.base === 'sdxl') {
if (graph.nodes[POSITIVE_CONDITIONING]) {
if (g.hasNode(POSITIVE_CONDITIONING)) {
firstBatchDatumList.push({
node_path: POSITIVE_CONDITIONING,
field_name: 'style',
@ -114,22 +91,20 @@ export const prepareLinearUIBatch = (state: RootState, graph: NonNullableGraph,
}
// add to metadata
if (getHasMetadata(graph)) {
removeMetadata(graph, 'positive_style_prompt');
g.removeMetadata(['positive_style_prompt']);
firstBatchDatumList.push({
node_path: METADATA,
node_path: g.getMetadataNode().id,
field_name: 'positive_style_prompt',
items: extendedPrompts,
});
}
}
data.push(firstBatchDatumList);
const enqueueBatchArg: BatchConfig = {
prepend,
batch: {
graph,
graph: g.getGraph(),
runs: 1,
data,
},

View File

@ -1,6 +1,5 @@
import type { RootState } from 'app/store/store';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { getModelMetadataField } from 'features/nodes/util/graph/canvas/metadata';
import {
SDXL_REFINER_DENOISE_LATENTS,
SDXL_REFINER_MODEL_LOADER,
@ -8,7 +7,7 @@ import {
SDXL_REFINER_POSITIVE_CONDITIONING,
SDXL_REFINER_SEAMLESS,
} from 'features/nodes/util/graph/constants';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { Invocation } from 'services/api/types';
import { isRefinerMainModelModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
@ -89,7 +88,7 @@ export const addSDXLRefiner = async (
g.addEdge(refinerDenoise, 'latents', l2i, 'latents');
g.upsertMetadata({
refiner_model: getModelMetadataField(modelConfig),
refiner_model: Graph.getModelMetadataField(modelConfig),
refiner_positive_aesthetic_score: refinerPositiveAestheticScore,
refiner_negative_aesthetic_score: refinerNegativeAestheticScore,
refiner_cfg_scale: refinerCFGScale,

View File

@ -26,7 +26,6 @@ import { addOutpaint } from 'features/nodes/util/graph/generation/addOutpaint';
import { addSeamless } from 'features/nodes/util/graph/generation/addSeamless';
import { addTextToImage } from 'features/nodes/util/graph/generation/addTextToImage';
import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker';
import type { GraphType } from 'features/nodes/util/graph/generation/Graph';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
import { getBoardField, getSizes } from 'features/nodes/util/graph/graphBuilderUtils';
import type { Invocation } from 'services/api/types';
@ -35,7 +34,7 @@ import { assert } from 'tsafe';
import { addRegions } from './addRegions';
export const buildSD1Graph = async (state: RootState, manager: KonvaNodeManager): Promise<GraphType> => {
export const buildSD1Graph = async (state: RootState, manager: KonvaNodeManager): Promise<Graph> => {
const generationMode = manager.getGenerationMode();
const { bbox, params } = state.canvasV2;
@ -248,5 +247,5 @@ export const buildSD1Graph = async (state: RootState, manager: KonvaNodeManager)
});
g.setMetadataReceivingNode(canvasOutput);
return g.getGraph();
return g;
};

View File

@ -27,13 +27,13 @@ import { addTextToImage } from 'features/nodes/util/graph/generation/addTextToIm
import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
import { getBoardField, getSDXLStylePrompts, getSizes } from 'features/nodes/util/graph/graphBuilderUtils';
import type { Invocation, NonNullableGraph } from 'services/api/types';
import type { Invocation } from 'services/api/types';
import { isNonRefinerMainModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
import { addRegions } from './addRegions';
export const buildSDXLGraph = async (state: RootState, manager: KonvaNodeManager): Promise<NonNullableGraph> => {
export const buildSDXLGraph = async (state: RootState, manager: KonvaNodeManager): Promise<Graph> => {
const generationMode = manager.getGenerationMode();
const { bbox, params } = state.canvasV2;
@ -246,5 +246,5 @@ export const buildSDXLGraph = async (state: RootState, manager: KonvaNodeManager
});
g.setMetadataReceivingNode(canvasOutput);
return g.getGraph();
return g;
};