mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(ui): batch building after removing canvas files
This commit is contained in:
parent
8f5f9bd44e
commit
788bad61d0
@ -19,7 +19,7 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
|
||||
const model = state.canvasV2.params.model;
|
||||
const { prepend } = action.payload;
|
||||
|
||||
let graph;
|
||||
let g;
|
||||
|
||||
const manager = getNodeManager();
|
||||
assert(model, 'No model found in state');
|
||||
@ -29,14 +29,14 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
|
||||
manager.getImageSourceImage({ bbox: state.canvasV2.bbox, preview: true });
|
||||
|
||||
if (base === 'sdxl') {
|
||||
graph = await buildSDXLGraph(state, manager);
|
||||
g = await buildSDXLGraph(state, manager);
|
||||
} else if (base === 'sd-1' || base === 'sd-2') {
|
||||
graph = await buildSD1Graph(state, manager);
|
||||
g = await buildSD1Graph(state, manager);
|
||||
} else {
|
||||
assert(false, `No graph builders for base ${base}`);
|
||||
}
|
||||
|
||||
const batchConfig = prepareLinearUIBatch(state, graph, prepend);
|
||||
const batchConfig = prepareLinearUIBatch(state, g, prepend);
|
||||
|
||||
const req = dispatch(
|
||||
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
|
||||
|
@ -1,14 +1,13 @@
|
||||
import { NUMPY_RAND_MAX } from 'app/constants';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { generateSeeds } from 'common/util/generateSeeds';
|
||||
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import { range } from 'lodash-es';
|
||||
import type { components } from 'services/api/schema';
|
||||
import type { Batch, BatchConfig, NonNullableGraph } from 'services/api/types';
|
||||
import type { Batch, BatchConfig } from 'services/api/types';
|
||||
|
||||
import { getHasMetadata, removeMetadata } from './canvas/metadata';
|
||||
import { CANVAS_COHERENCE_NOISE, METADATA, NOISE, POSITIVE_CONDITIONING } from './constants';
|
||||
import { NOISE, POSITIVE_CONDITIONING } from './constants';
|
||||
|
||||
export const prepareLinearUIBatch = (state: RootState, graph: NonNullableGraph, prepend: boolean): BatchConfig => {
|
||||
export const prepareLinearUIBatch = (state: RootState, g: Graph, prepend: boolean): BatchConfig => {
|
||||
const { iterations, model, shouldRandomizeSeed, seed, shouldConcatPrompts } = state.canvasV2.params;
|
||||
const { prompts, seedBehaviour } = state.dynamicPrompts;
|
||||
|
||||
@ -23,7 +22,7 @@ export const prepareLinearUIBatch = (state: RootState, graph: NonNullableGraph,
|
||||
start: shouldRandomizeSeed ? undefined : seed,
|
||||
});
|
||||
|
||||
if (graph.nodes[NOISE]) {
|
||||
if (g.hasNode(NOISE)) {
|
||||
firstBatchDatumList.push({
|
||||
node_path: NOISE,
|
||||
field_name: 'seed',
|
||||
@ -32,22 +31,12 @@ export const prepareLinearUIBatch = (state: RootState, graph: NonNullableGraph,
|
||||
}
|
||||
|
||||
// add to metadata
|
||||
if (getHasMetadata(graph)) {
|
||||
removeMetadata(graph, 'seed');
|
||||
g.removeMetadata(['seed']);
|
||||
firstBatchDatumList.push({
|
||||
node_path: METADATA,
|
||||
node_path: g.getMetadataNode().id,
|
||||
field_name: 'seed',
|
||||
items: seeds,
|
||||
});
|
||||
}
|
||||
|
||||
if (graph.nodes[CANVAS_COHERENCE_NOISE]) {
|
||||
firstBatchDatumList.push({
|
||||
node_path: CANVAS_COHERENCE_NOISE,
|
||||
field_name: 'seed',
|
||||
items: seeds.map((seed) => (seed + 1) % NUMPY_RAND_MAX),
|
||||
});
|
||||
}
|
||||
} else {
|
||||
// seedBehaviour = SeedBehaviour.PerRun
|
||||
const seeds = generateSeeds({
|
||||
@ -55,7 +44,7 @@ export const prepareLinearUIBatch = (state: RootState, graph: NonNullableGraph,
|
||||
start: shouldRandomizeSeed ? undefined : seed,
|
||||
});
|
||||
|
||||
if (graph.nodes[NOISE]) {
|
||||
if (g.hasNode(NOISE)) {
|
||||
secondBatchDatumList.push({
|
||||
node_path: NOISE,
|
||||
field_name: 'seed',
|
||||
@ -64,29 +53,19 @@ export const prepareLinearUIBatch = (state: RootState, graph: NonNullableGraph,
|
||||
}
|
||||
|
||||
// add to metadata
|
||||
if (getHasMetadata(graph)) {
|
||||
removeMetadata(graph, 'seed');
|
||||
g.removeMetadata(['seed']);
|
||||
secondBatchDatumList.push({
|
||||
node_path: METADATA,
|
||||
node_path: g.getMetadataNode().id,
|
||||
field_name: 'seed',
|
||||
items: seeds,
|
||||
});
|
||||
}
|
||||
|
||||
if (graph.nodes[CANVAS_COHERENCE_NOISE]) {
|
||||
secondBatchDatumList.push({
|
||||
node_path: CANVAS_COHERENCE_NOISE,
|
||||
field_name: 'seed',
|
||||
items: seeds.map((seed) => (seed + 1) % NUMPY_RAND_MAX),
|
||||
});
|
||||
}
|
||||
data.push(secondBatchDatumList);
|
||||
}
|
||||
|
||||
const extendedPrompts = seedBehaviour === 'PER_PROMPT' ? range(iterations).flatMap(() => prompts) : prompts;
|
||||
|
||||
// zipped batch of prompts
|
||||
if (graph.nodes[POSITIVE_CONDITIONING]) {
|
||||
if (g.hasNode(POSITIVE_CONDITIONING)) {
|
||||
firstBatchDatumList.push({
|
||||
node_path: POSITIVE_CONDITIONING,
|
||||
field_name: 'prompt',
|
||||
@ -95,17 +74,15 @@ export const prepareLinearUIBatch = (state: RootState, graph: NonNullableGraph,
|
||||
}
|
||||
|
||||
// add to metadata
|
||||
if (getHasMetadata(graph)) {
|
||||
removeMetadata(graph, 'positive_prompt');
|
||||
g.removeMetadata(['positive_prompt']);
|
||||
firstBatchDatumList.push({
|
||||
node_path: METADATA,
|
||||
node_path: g.getMetadataNode().id,
|
||||
field_name: 'positive_prompt',
|
||||
items: extendedPrompts,
|
||||
});
|
||||
}
|
||||
|
||||
if (shouldConcatPrompts && model?.base === 'sdxl') {
|
||||
if (graph.nodes[POSITIVE_CONDITIONING]) {
|
||||
if (g.hasNode(POSITIVE_CONDITIONING)) {
|
||||
firstBatchDatumList.push({
|
||||
node_path: POSITIVE_CONDITIONING,
|
||||
field_name: 'style',
|
||||
@ -114,22 +91,20 @@ export const prepareLinearUIBatch = (state: RootState, graph: NonNullableGraph,
|
||||
}
|
||||
|
||||
// add to metadata
|
||||
if (getHasMetadata(graph)) {
|
||||
removeMetadata(graph, 'positive_style_prompt');
|
||||
g.removeMetadata(['positive_style_prompt']);
|
||||
firstBatchDatumList.push({
|
||||
node_path: METADATA,
|
||||
node_path: g.getMetadataNode().id,
|
||||
field_name: 'positive_style_prompt',
|
||||
items: extendedPrompts,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
data.push(firstBatchDatumList);
|
||||
|
||||
const enqueueBatchArg: BatchConfig = {
|
||||
prepend,
|
||||
batch: {
|
||||
graph,
|
||||
graph: g.getGraph(),
|
||||
runs: 1,
|
||||
data,
|
||||
},
|
||||
|
@ -1,6 +1,5 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import { getModelMetadataField } from 'features/nodes/util/graph/canvas/metadata';
|
||||
import {
|
||||
SDXL_REFINER_DENOISE_LATENTS,
|
||||
SDXL_REFINER_MODEL_LOADER,
|
||||
@ -8,7 +7,7 @@ import {
|
||||
SDXL_REFINER_POSITIVE_CONDITIONING,
|
||||
SDXL_REFINER_SEAMLESS,
|
||||
} from 'features/nodes/util/graph/constants';
|
||||
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import type { Invocation } from 'services/api/types';
|
||||
import { isRefinerMainModelModelConfig } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
@ -89,7 +88,7 @@ export const addSDXLRefiner = async (
|
||||
g.addEdge(refinerDenoise, 'latents', l2i, 'latents');
|
||||
|
||||
g.upsertMetadata({
|
||||
refiner_model: getModelMetadataField(modelConfig),
|
||||
refiner_model: Graph.getModelMetadataField(modelConfig),
|
||||
refiner_positive_aesthetic_score: refinerPositiveAestheticScore,
|
||||
refiner_negative_aesthetic_score: refinerNegativeAestheticScore,
|
||||
refiner_cfg_scale: refinerCFGScale,
|
||||
|
@ -26,7 +26,6 @@ import { addOutpaint } from 'features/nodes/util/graph/generation/addOutpaint';
|
||||
import { addSeamless } from 'features/nodes/util/graph/generation/addSeamless';
|
||||
import { addTextToImage } from 'features/nodes/util/graph/generation/addTextToImage';
|
||||
import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker';
|
||||
import type { GraphType } from 'features/nodes/util/graph/generation/Graph';
|
||||
import { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import { getBoardField, getSizes } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import type { Invocation } from 'services/api/types';
|
||||
@ -35,7 +34,7 @@ import { assert } from 'tsafe';
|
||||
|
||||
import { addRegions } from './addRegions';
|
||||
|
||||
export const buildSD1Graph = async (state: RootState, manager: KonvaNodeManager): Promise<GraphType> => {
|
||||
export const buildSD1Graph = async (state: RootState, manager: KonvaNodeManager): Promise<Graph> => {
|
||||
const generationMode = manager.getGenerationMode();
|
||||
|
||||
const { bbox, params } = state.canvasV2;
|
||||
@ -248,5 +247,5 @@ export const buildSD1Graph = async (state: RootState, manager: KonvaNodeManager)
|
||||
});
|
||||
|
||||
g.setMetadataReceivingNode(canvasOutput);
|
||||
return g.getGraph();
|
||||
return g;
|
||||
};
|
||||
|
@ -27,13 +27,13 @@ import { addTextToImage } from 'features/nodes/util/graph/generation/addTextToIm
|
||||
import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker';
|
||||
import { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import { getBoardField, getSDXLStylePrompts, getSizes } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import type { Invocation, NonNullableGraph } from 'services/api/types';
|
||||
import type { Invocation } from 'services/api/types';
|
||||
import { isNonRefinerMainModelConfig } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
import { addRegions } from './addRegions';
|
||||
|
||||
export const buildSDXLGraph = async (state: RootState, manager: KonvaNodeManager): Promise<NonNullableGraph> => {
|
||||
export const buildSDXLGraph = async (state: RootState, manager: KonvaNodeManager): Promise<Graph> => {
|
||||
const generationMode = manager.getGenerationMode();
|
||||
|
||||
const { bbox, params } = state.canvasV2;
|
||||
@ -246,5 +246,5 @@ export const buildSDXLGraph = async (state: RootState, manager: KonvaNodeManager
|
||||
});
|
||||
|
||||
g.setMetadataReceivingNode(canvasOutput);
|
||||
return g.getGraph();
|
||||
return g;
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user