mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(ui): batch building after removing canvas files
This commit is contained in:
parent
0c26d28278
commit
ec6361e5cb
@ -19,7 +19,7 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
|
|||||||
const model = state.canvasV2.params.model;
|
const model = state.canvasV2.params.model;
|
||||||
const { prepend } = action.payload;
|
const { prepend } = action.payload;
|
||||||
|
|
||||||
let graph;
|
let g;
|
||||||
|
|
||||||
const manager = getNodeManager();
|
const manager = getNodeManager();
|
||||||
assert(model, 'No model found in state');
|
assert(model, 'No model found in state');
|
||||||
@ -29,14 +29,14 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
|
|||||||
manager.getImageSourceImage({ bbox: state.canvasV2.bbox, preview: true });
|
manager.getImageSourceImage({ bbox: state.canvasV2.bbox, preview: true });
|
||||||
|
|
||||||
if (base === 'sdxl') {
|
if (base === 'sdxl') {
|
||||||
graph = await buildSDXLGraph(state, manager);
|
g = await buildSDXLGraph(state, manager);
|
||||||
} else if (base === 'sd-1' || base === 'sd-2') {
|
} else if (base === 'sd-1' || base === 'sd-2') {
|
||||||
graph = await buildSD1Graph(state, manager);
|
g = await buildSD1Graph(state, manager);
|
||||||
} else {
|
} else {
|
||||||
assert(false, `No graph builders for base ${base}`);
|
assert(false, `No graph builders for base ${base}`);
|
||||||
}
|
}
|
||||||
|
|
||||||
const batchConfig = prepareLinearUIBatch(state, graph, prepend);
|
const batchConfig = prepareLinearUIBatch(state, g, prepend);
|
||||||
|
|
||||||
const req = dispatch(
|
const req = dispatch(
|
||||||
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
|
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
|
||||||
|
@ -1,14 +1,13 @@
|
|||||||
import { NUMPY_RAND_MAX } from 'app/constants';
|
|
||||||
import type { RootState } from 'app/store/store';
|
import type { RootState } from 'app/store/store';
|
||||||
import { generateSeeds } from 'common/util/generateSeeds';
|
import { generateSeeds } from 'common/util/generateSeeds';
|
||||||
|
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||||
import { range } from 'lodash-es';
|
import { range } from 'lodash-es';
|
||||||
import type { components } from 'services/api/schema';
|
import type { components } from 'services/api/schema';
|
||||||
import type { Batch, BatchConfig, NonNullableGraph } from 'services/api/types';
|
import type { Batch, BatchConfig } from 'services/api/types';
|
||||||
|
|
||||||
import { getHasMetadata, removeMetadata } from './canvas/metadata';
|
import { NOISE, POSITIVE_CONDITIONING } from './constants';
|
||||||
import { CANVAS_COHERENCE_NOISE, METADATA, NOISE, POSITIVE_CONDITIONING } from './constants';
|
|
||||||
|
|
||||||
export const prepareLinearUIBatch = (state: RootState, graph: NonNullableGraph, prepend: boolean): BatchConfig => {
|
export const prepareLinearUIBatch = (state: RootState, g: Graph, prepend: boolean): BatchConfig => {
|
||||||
const { iterations, model, shouldRandomizeSeed, seed, shouldConcatPrompts } = state.canvasV2.params;
|
const { iterations, model, shouldRandomizeSeed, seed, shouldConcatPrompts } = state.canvasV2.params;
|
||||||
const { prompts, seedBehaviour } = state.dynamicPrompts;
|
const { prompts, seedBehaviour } = state.dynamicPrompts;
|
||||||
|
|
||||||
@ -23,7 +22,7 @@ export const prepareLinearUIBatch = (state: RootState, graph: NonNullableGraph,
|
|||||||
start: shouldRandomizeSeed ? undefined : seed,
|
start: shouldRandomizeSeed ? undefined : seed,
|
||||||
});
|
});
|
||||||
|
|
||||||
if (graph.nodes[NOISE]) {
|
if (g.hasNode(NOISE)) {
|
||||||
firstBatchDatumList.push({
|
firstBatchDatumList.push({
|
||||||
node_path: NOISE,
|
node_path: NOISE,
|
||||||
field_name: 'seed',
|
field_name: 'seed',
|
||||||
@ -32,22 +31,12 @@ export const prepareLinearUIBatch = (state: RootState, graph: NonNullableGraph,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// add to metadata
|
// add to metadata
|
||||||
if (getHasMetadata(graph)) {
|
g.removeMetadata(['seed']);
|
||||||
removeMetadata(graph, 'seed');
|
firstBatchDatumList.push({
|
||||||
firstBatchDatumList.push({
|
node_path: g.getMetadataNode().id,
|
||||||
node_path: METADATA,
|
field_name: 'seed',
|
||||||
field_name: 'seed',
|
items: seeds,
|
||||||
items: seeds,
|
});
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
if (graph.nodes[CANVAS_COHERENCE_NOISE]) {
|
|
||||||
firstBatchDatumList.push({
|
|
||||||
node_path: CANVAS_COHERENCE_NOISE,
|
|
||||||
field_name: 'seed',
|
|
||||||
items: seeds.map((seed) => (seed + 1) % NUMPY_RAND_MAX),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
// seedBehaviour = SeedBehaviour.PerRun
|
// seedBehaviour = SeedBehaviour.PerRun
|
||||||
const seeds = generateSeeds({
|
const seeds = generateSeeds({
|
||||||
@ -55,7 +44,7 @@ export const prepareLinearUIBatch = (state: RootState, graph: NonNullableGraph,
|
|||||||
start: shouldRandomizeSeed ? undefined : seed,
|
start: shouldRandomizeSeed ? undefined : seed,
|
||||||
});
|
});
|
||||||
|
|
||||||
if (graph.nodes[NOISE]) {
|
if (g.hasNode(NOISE)) {
|
||||||
secondBatchDatumList.push({
|
secondBatchDatumList.push({
|
||||||
node_path: NOISE,
|
node_path: NOISE,
|
||||||
field_name: 'seed',
|
field_name: 'seed',
|
||||||
@ -64,29 +53,19 @@ export const prepareLinearUIBatch = (state: RootState, graph: NonNullableGraph,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// add to metadata
|
// add to metadata
|
||||||
if (getHasMetadata(graph)) {
|
g.removeMetadata(['seed']);
|
||||||
removeMetadata(graph, 'seed');
|
secondBatchDatumList.push({
|
||||||
secondBatchDatumList.push({
|
node_path: g.getMetadataNode().id,
|
||||||
node_path: METADATA,
|
field_name: 'seed',
|
||||||
field_name: 'seed',
|
items: seeds,
|
||||||
items: seeds,
|
});
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
if (graph.nodes[CANVAS_COHERENCE_NOISE]) {
|
|
||||||
secondBatchDatumList.push({
|
|
||||||
node_path: CANVAS_COHERENCE_NOISE,
|
|
||||||
field_name: 'seed',
|
|
||||||
items: seeds.map((seed) => (seed + 1) % NUMPY_RAND_MAX),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
data.push(secondBatchDatumList);
|
data.push(secondBatchDatumList);
|
||||||
}
|
}
|
||||||
|
|
||||||
const extendedPrompts = seedBehaviour === 'PER_PROMPT' ? range(iterations).flatMap(() => prompts) : prompts;
|
const extendedPrompts = seedBehaviour === 'PER_PROMPT' ? range(iterations).flatMap(() => prompts) : prompts;
|
||||||
|
|
||||||
// zipped batch of prompts
|
// zipped batch of prompts
|
||||||
if (graph.nodes[POSITIVE_CONDITIONING]) {
|
if (g.hasNode(POSITIVE_CONDITIONING)) {
|
||||||
firstBatchDatumList.push({
|
firstBatchDatumList.push({
|
||||||
node_path: POSITIVE_CONDITIONING,
|
node_path: POSITIVE_CONDITIONING,
|
||||||
field_name: 'prompt',
|
field_name: 'prompt',
|
||||||
@ -95,17 +74,15 @@ export const prepareLinearUIBatch = (state: RootState, graph: NonNullableGraph,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// add to metadata
|
// add to metadata
|
||||||
if (getHasMetadata(graph)) {
|
g.removeMetadata(['positive_prompt']);
|
||||||
removeMetadata(graph, 'positive_prompt');
|
firstBatchDatumList.push({
|
||||||
firstBatchDatumList.push({
|
node_path: g.getMetadataNode().id,
|
||||||
node_path: METADATA,
|
field_name: 'positive_prompt',
|
||||||
field_name: 'positive_prompt',
|
items: extendedPrompts,
|
||||||
items: extendedPrompts,
|
});
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
if (shouldConcatPrompts && model?.base === 'sdxl') {
|
if (shouldConcatPrompts && model?.base === 'sdxl') {
|
||||||
if (graph.nodes[POSITIVE_CONDITIONING]) {
|
if (g.hasNode(POSITIVE_CONDITIONING)) {
|
||||||
firstBatchDatumList.push({
|
firstBatchDatumList.push({
|
||||||
node_path: POSITIVE_CONDITIONING,
|
node_path: POSITIVE_CONDITIONING,
|
||||||
field_name: 'style',
|
field_name: 'style',
|
||||||
@ -114,14 +91,12 @@ export const prepareLinearUIBatch = (state: RootState, graph: NonNullableGraph,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// add to metadata
|
// add to metadata
|
||||||
if (getHasMetadata(graph)) {
|
g.removeMetadata(['positive_style_prompt']);
|
||||||
removeMetadata(graph, 'positive_style_prompt');
|
firstBatchDatumList.push({
|
||||||
firstBatchDatumList.push({
|
node_path: g.getMetadataNode().id,
|
||||||
node_path: METADATA,
|
field_name: 'positive_style_prompt',
|
||||||
field_name: 'positive_style_prompt',
|
items: extendedPrompts,
|
||||||
items: extendedPrompts,
|
});
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
data.push(firstBatchDatumList);
|
data.push(firstBatchDatumList);
|
||||||
@ -129,7 +104,7 @@ export const prepareLinearUIBatch = (state: RootState, graph: NonNullableGraph,
|
|||||||
const enqueueBatchArg: BatchConfig = {
|
const enqueueBatchArg: BatchConfig = {
|
||||||
prepend,
|
prepend,
|
||||||
batch: {
|
batch: {
|
||||||
graph,
|
graph: g.getGraph(),
|
||||||
runs: 1,
|
runs: 1,
|
||||||
data,
|
data,
|
||||||
},
|
},
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import type { RootState } from 'app/store/store';
|
import type { RootState } from 'app/store/store';
|
||||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||||
import { getModelMetadataField } from 'features/nodes/util/graph/canvas/metadata';
|
|
||||||
import {
|
import {
|
||||||
SDXL_REFINER_DENOISE_LATENTS,
|
SDXL_REFINER_DENOISE_LATENTS,
|
||||||
SDXL_REFINER_MODEL_LOADER,
|
SDXL_REFINER_MODEL_LOADER,
|
||||||
@ -8,7 +7,7 @@ import {
|
|||||||
SDXL_REFINER_POSITIVE_CONDITIONING,
|
SDXL_REFINER_POSITIVE_CONDITIONING,
|
||||||
SDXL_REFINER_SEAMLESS,
|
SDXL_REFINER_SEAMLESS,
|
||||||
} from 'features/nodes/util/graph/constants';
|
} from 'features/nodes/util/graph/constants';
|
||||||
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
|
import { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||||
import type { Invocation } from 'services/api/types';
|
import type { Invocation } from 'services/api/types';
|
||||||
import { isRefinerMainModelModelConfig } from 'services/api/types';
|
import { isRefinerMainModelModelConfig } from 'services/api/types';
|
||||||
import { assert } from 'tsafe';
|
import { assert } from 'tsafe';
|
||||||
@ -89,7 +88,7 @@ export const addSDXLRefiner = async (
|
|||||||
g.addEdge(refinerDenoise, 'latents', l2i, 'latents');
|
g.addEdge(refinerDenoise, 'latents', l2i, 'latents');
|
||||||
|
|
||||||
g.upsertMetadata({
|
g.upsertMetadata({
|
||||||
refiner_model: getModelMetadataField(modelConfig),
|
refiner_model: Graph.getModelMetadataField(modelConfig),
|
||||||
refiner_positive_aesthetic_score: refinerPositiveAestheticScore,
|
refiner_positive_aesthetic_score: refinerPositiveAestheticScore,
|
||||||
refiner_negative_aesthetic_score: refinerNegativeAestheticScore,
|
refiner_negative_aesthetic_score: refinerNegativeAestheticScore,
|
||||||
refiner_cfg_scale: refinerCFGScale,
|
refiner_cfg_scale: refinerCFGScale,
|
||||||
|
@ -26,7 +26,6 @@ import { addOutpaint } from 'features/nodes/util/graph/generation/addOutpaint';
|
|||||||
import { addSeamless } from 'features/nodes/util/graph/generation/addSeamless';
|
import { addSeamless } from 'features/nodes/util/graph/generation/addSeamless';
|
||||||
import { addTextToImage } from 'features/nodes/util/graph/generation/addTextToImage';
|
import { addTextToImage } from 'features/nodes/util/graph/generation/addTextToImage';
|
||||||
import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker';
|
import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker';
|
||||||
import type { GraphType } from 'features/nodes/util/graph/generation/Graph';
|
|
||||||
import { Graph } from 'features/nodes/util/graph/generation/Graph';
|
import { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||||
import { getBoardField, getSizes } from 'features/nodes/util/graph/graphBuilderUtils';
|
import { getBoardField, getSizes } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||||
import type { Invocation } from 'services/api/types';
|
import type { Invocation } from 'services/api/types';
|
||||||
@ -35,7 +34,7 @@ import { assert } from 'tsafe';
|
|||||||
|
|
||||||
import { addRegions } from './addRegions';
|
import { addRegions } from './addRegions';
|
||||||
|
|
||||||
export const buildSD1Graph = async (state: RootState, manager: KonvaNodeManager): Promise<GraphType> => {
|
export const buildSD1Graph = async (state: RootState, manager: KonvaNodeManager): Promise<Graph> => {
|
||||||
const generationMode = manager.getGenerationMode();
|
const generationMode = manager.getGenerationMode();
|
||||||
|
|
||||||
const { bbox, params } = state.canvasV2;
|
const { bbox, params } = state.canvasV2;
|
||||||
@ -248,5 +247,5 @@ export const buildSD1Graph = async (state: RootState, manager: KonvaNodeManager)
|
|||||||
});
|
});
|
||||||
|
|
||||||
g.setMetadataReceivingNode(canvasOutput);
|
g.setMetadataReceivingNode(canvasOutput);
|
||||||
return g.getGraph();
|
return g;
|
||||||
};
|
};
|
||||||
|
@ -27,13 +27,13 @@ import { addTextToImage } from 'features/nodes/util/graph/generation/addTextToIm
|
|||||||
import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker';
|
import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker';
|
||||||
import { Graph } from 'features/nodes/util/graph/generation/Graph';
|
import { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||||
import { getBoardField, getSDXLStylePrompts, getSizes } from 'features/nodes/util/graph/graphBuilderUtils';
|
import { getBoardField, getSDXLStylePrompts, getSizes } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||||
import type { Invocation, NonNullableGraph } from 'services/api/types';
|
import type { Invocation } from 'services/api/types';
|
||||||
import { isNonRefinerMainModelConfig } from 'services/api/types';
|
import { isNonRefinerMainModelConfig } from 'services/api/types';
|
||||||
import { assert } from 'tsafe';
|
import { assert } from 'tsafe';
|
||||||
|
|
||||||
import { addRegions } from './addRegions';
|
import { addRegions } from './addRegions';
|
||||||
|
|
||||||
export const buildSDXLGraph = async (state: RootState, manager: KonvaNodeManager): Promise<NonNullableGraph> => {
|
export const buildSDXLGraph = async (state: RootState, manager: KonvaNodeManager): Promise<Graph> => {
|
||||||
const generationMode = manager.getGenerationMode();
|
const generationMode = manager.getGenerationMode();
|
||||||
|
|
||||||
const { bbox, params } = state.canvasV2;
|
const { bbox, params } = state.canvasV2;
|
||||||
@ -246,5 +246,5 @@ export const buildSDXLGraph = async (state: RootState, manager: KonvaNodeManager
|
|||||||
});
|
});
|
||||||
|
|
||||||
g.setMetadataReceivingNode(canvasOutput);
|
g.setMetadataReceivingNode(canvasOutput);
|
||||||
return g.getGraph();
|
return g;
|
||||||
};
|
};
|
||||||
|
Loading…
Reference in New Issue
Block a user