feat(ui): revise graphs to not use LinearUIOutputInvocation

See this comment for context: https://github.com/invoke-ai/InvokeAI/pull/5491#discussion_r1480760629

- Remove this now-unnecessary node from all graphs
- Update graphs' terminal image-outputting nodes' `is_intermediate` and `board` fields appropriately
- Add util function to prepare the `board` field, tidy the utils
- Update `socketInvocationComplete` listener to work correctly with this change

I've manually tested all graph permutations that were changed (I think this is all...) to ensure images go to the gallery as expected:
- ad-hoc upscaling
- t2i w/ sd1.5
- t2i w/ sd1.5 & hrf
- t2i w/ sdxl
- t2i w/ sdxl + refiner
- i2i w/ sd1.5
- i2i w/ sdxl
- i2i w/ sdxl + refiner
- canvas t2i w/ sd1.5
- canvas t2i w/ sdxl
- canvas t2i w/ sdxl + refiner
- canvas i2i w/ sd1.5
- canvas i2i w/ sdxl
- canvas i2i w/ sdxl + refiner
- canvas inpaint w/ sd1.5
- canvas inpaint w/ sdxl
- canvas inpaint w/ sdxl + refiner
- canvas outpaint w/ sd1.5
- canvas outpaint w/ sdxl
- canvas outpaint w/ sdxl + refiner
This commit is contained in:
psychedelicious 2024-02-07 16:41:24 +11:00 committed by Brandon Rising
parent c16f77bb23
commit bb8c71f706
24 changed files with 108 additions and 197 deletions

View File

@ -4,7 +4,7 @@ import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
import { boardIdSelected, galleryViewChanged, imageSelected } from 'features/gallery/store/gallerySlice';
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
import { isImageOutput } from 'features/nodes/types/common';
import { LINEAR_UI_OUTPUT, nodeIDDenyList } from 'features/nodes/util/graph/constants';
import { CANVAS_OUTPUT } from 'features/nodes/util/graph/constants';
import { boardsApi } from 'services/api/endpoints/boards';
import { imagesApi } from 'services/api/endpoints/images';
import { imagesAdapter } from 'services/api/util';
@ -24,10 +24,9 @@ export const addInvocationCompleteEventListener = () => {
const { data } = action.payload;
log.debug({ data: parseify(data) }, `Invocation complete (${action.payload.data.node.type})`);
const { result, node, queue_batch_id, source_node_id } = data;
const { result, node, queue_batch_id } = data;
// This complete event has an associated image output
if (isImageOutput(result) && !nodeTypeDenylist.includes(node.type) && !nodeIDDenyList.includes(source_node_id)) {
if (isImageOutput(result) && !nodeTypeDenylist.includes(node.type)) {
const { image_name } = result.image;
const { canvas, gallery } = getState();
@ -42,7 +41,7 @@ export const addInvocationCompleteEventListener = () => {
imageDTORequest.unsubscribe();
// Add canvas images to the staging area
if (canvas.batchIds.includes(queue_batch_id) && [LINEAR_UI_OUTPUT].includes(data.source_node_id)) {
if (canvas.batchIds.includes(queue_batch_id) && data.source_node_id === CANVAS_OUTPUT) {
dispatch(addImageToStagingArea(imageDTO));
}

View File

@ -39,16 +39,12 @@ export const addUpscaleRequestedListener = () => {
return;
}
const { esrganModelName } = state.postprocessing;
const { autoAddBoardId } = state.gallery;
const enqueueBatchArg: BatchConfig = {
prepend: true,
batch: {
graph: buildAdHocUpscaleGraph({
image_name,
esrganModelName,
autoAddBoardId,
state,
}),
runs: 1,
},

View File

@ -1,6 +1,7 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import { roundToMultiple } from 'common/util/roundDownToMultiple';
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
import type {
DenoiseLatentsInvocation,
@ -322,7 +323,8 @@ export const addHrfToGraph = (state: RootState, graph: NonNullableGraph): void =
type: 'l2i',
id: LATENTS_TO_IMAGE_HRF_HR,
fp32: originalLatentsToImageNode?.fp32,
is_intermediate: true,
is_intermediate: getIsIntermediate(state),
board: getBoardField(state),
};
graph.edges.push(
{

View File

@ -1,78 +0,0 @@
import type { RootState } from 'app/store/store';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import type { LinearUIOutputInvocation, NonNullableGraph } from 'services/api/types';
import {
CANVAS_OUTPUT,
LATENTS_TO_IMAGE,
LATENTS_TO_IMAGE_HRF_HR,
LINEAR_UI_OUTPUT,
NSFW_CHECKER,
WATERMARKER,
} from './constants';
/**
* Set the `use_cache` field on the linear/canvas graph's final image output node to False.
*/
export const addLinearUIOutputNode = (state: RootState, graph: NonNullableGraph): void => {
const activeTabName = activeTabNameSelector(state);
const is_intermediate = activeTabName === 'unifiedCanvas' ? !state.canvas.shouldAutoSave : false;
const { autoAddBoardId } = state.gallery;
const linearUIOutputNode: LinearUIOutputInvocation = {
id: LINEAR_UI_OUTPUT,
type: 'linear_ui_output',
is_intermediate,
use_cache: false,
board: autoAddBoardId === 'none' ? undefined : { board_id: autoAddBoardId },
};
graph.nodes[LINEAR_UI_OUTPUT] = linearUIOutputNode;
const destination = {
node_id: LINEAR_UI_OUTPUT,
field: 'image',
};
if (WATERMARKER in graph.nodes) {
graph.edges.push({
source: {
node_id: WATERMARKER,
field: 'image',
},
destination,
});
} else if (NSFW_CHECKER in graph.nodes) {
graph.edges.push({
source: {
node_id: NSFW_CHECKER,
field: 'image',
},
destination,
});
} else if (CANVAS_OUTPUT in graph.nodes) {
graph.edges.push({
source: {
node_id: CANVAS_OUTPUT,
field: 'image',
},
destination,
});
} else if (LATENTS_TO_IMAGE_HRF_HR in graph.nodes) {
graph.edges.push({
source: {
node_id: LATENTS_TO_IMAGE_HRF_HR,
field: 'image',
},
destination,
});
} else if (LATENTS_TO_IMAGE in graph.nodes) {
graph.edges.push({
source: {
node_id: LATENTS_TO_IMAGE,
field: 'image',
},
destination,
});
}
};

View File

@ -2,6 +2,7 @@ import type { RootState } from 'app/store/store';
import type { ImageNSFWBlurInvocation, LatentsToImageInvocation, NonNullableGraph } from 'services/api/types';
import { LATENTS_TO_IMAGE, NSFW_CHECKER } from './constants';
import { getBoardField, getIsIntermediate } from './graphBuilderUtils';
export const addNSFWCheckerToGraph = (
state: RootState,
@ -21,7 +22,8 @@ export const addNSFWCheckerToGraph = (
const nsfwCheckerNode: ImageNSFWBlurInvocation = {
id: NSFW_CHECKER,
type: 'img_nsfw',
is_intermediate: true,
is_intermediate: getIsIntermediate(state),
board: getBoardField(state),
};
graph.nodes[NSFW_CHECKER] = nsfwCheckerNode as ImageNSFWBlurInvocation;

View File

@ -24,7 +24,7 @@ import {
SDXL_REFINER_POSITIVE_CONDITIONING,
SDXL_REFINER_SEAMLESS,
} from './constants';
import { getSDXLStylePrompts } from './getSDXLStylePrompt';
import { getSDXLStylePrompts } from './graphBuilderUtils';
import { upsertMetadata } from './metadata';
export const addSDXLRefinerToGraph = (

View File

@ -1,5 +1,4 @@
import type { RootState } from 'app/store/store';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import type {
ImageNSFWBlurInvocation,
ImageWatermarkInvocation,
@ -8,16 +7,13 @@ import type {
} from 'services/api/types';
import { LATENTS_TO_IMAGE, NSFW_CHECKER, WATERMARKER } from './constants';
import { getBoardField, getIsIntermediate } from './graphBuilderUtils';
export const addWatermarkerToGraph = (
state: RootState,
graph: NonNullableGraph,
nodeIdToAddTo = LATENTS_TO_IMAGE
): void => {
const activeTabName = activeTabNameSelector(state);
const is_intermediate = activeTabName === 'unifiedCanvas' ? !state.canvas.shouldAutoSave : false;
const nodeToAddTo = graph.nodes[nodeIdToAddTo] as LatentsToImageInvocation | undefined;
const nsfwCheckerNode = graph.nodes[NSFW_CHECKER] as ImageNSFWBlurInvocation | undefined;
@ -30,7 +26,8 @@ export const addWatermarkerToGraph = (
const watermarkerNode: ImageWatermarkInvocation = {
id: WATERMARKER,
type: 'img_watermark',
is_intermediate,
is_intermediate: getIsIntermediate(state),
board: getBoardField(state),
};
graph.nodes[WATERMARKER] = watermarkerNode;

View File

@ -1,51 +1,33 @@
import type { BoardId } from 'features/gallery/store/types';
import type { ParamESRGANModelName } from 'features/parameters/store/postprocessingSlice';
import type { ESRGANInvocation, Graph, LinearUIOutputInvocation, NonNullableGraph } from 'services/api/types';
import type { RootState } from 'app/store/store';
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
import type { ESRGANInvocation, Graph, NonNullableGraph } from 'services/api/types';
import { ESRGAN, LINEAR_UI_OUTPUT } from './constants';
import { ESRGAN } from './constants';
import { addCoreMetadataNode, upsertMetadata } from './metadata';
type Arg = {
image_name: string;
esrganModelName: ParamESRGANModelName;
autoAddBoardId: BoardId;
state: RootState;
};
export const buildAdHocUpscaleGraph = ({ image_name, esrganModelName, autoAddBoardId }: Arg): Graph => {
export const buildAdHocUpscaleGraph = ({ image_name, state }: Arg): Graph => {
const { esrganModelName } = state.postprocessing;
const realesrganNode: ESRGANInvocation = {
id: ESRGAN,
type: 'esrgan',
image: { image_name },
model_name: esrganModelName,
is_intermediate: true,
};
const linearUIOutputNode: LinearUIOutputInvocation = {
id: LINEAR_UI_OUTPUT,
type: 'linear_ui_output',
use_cache: false,
is_intermediate: false,
board: autoAddBoardId === 'none' ? undefined : { board_id: autoAddBoardId },
is_intermediate: getIsIntermediate(state),
board: getBoardField(state),
};
const graph: NonNullableGraph = {
id: `adhoc-esrgan-graph`,
nodes: {
[ESRGAN]: realesrganNode,
[LINEAR_UI_OUTPUT]: linearUIOutputNode,
},
edges: [
{
source: {
node_id: ESRGAN,
field: 'image',
},
destination: {
node_id: LINEAR_UI_OUTPUT,
field: 'image',
},
},
],
edges: [],
};
addCoreMetadataNode(graph, {}, ESRGAN);

View File

@ -1,10 +1,10 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
import type { ImageDTO, ImageToLatentsInvocation, NonNullableGraph } from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
@ -132,7 +132,8 @@ export const buildCanvasImageToImageGraph = (state: RootState, initialImage: Ima
[CANVAS_OUTPUT]: {
type: 'l2i',
id: CANVAS_OUTPUT,
is_intermediate,
is_intermediate: getIsIntermediate(state),
board: getBoardField(state),
use_cache: false,
},
},
@ -242,7 +243,8 @@ export const buildCanvasImageToImageGraph = (state: RootState, initialImage: Ima
graph.nodes[CANVAS_OUTPUT] = {
id: CANVAS_OUTPUT,
type: 'img_resize',
is_intermediate,
is_intermediate: getIsIntermediate(state),
board: getBoardField(state),
width: width,
height: height,
use_cache: false,
@ -284,7 +286,8 @@ export const buildCanvasImageToImageGraph = (state: RootState, initialImage: Ima
graph.nodes[CANVAS_OUTPUT] = {
type: 'l2i',
id: CANVAS_OUTPUT,
is_intermediate,
is_intermediate: getIsIntermediate(state),
board: getBoardField(state),
fp32,
use_cache: false,
};
@ -355,7 +358,5 @@ export const buildCanvasImageToImageGraph = (state: RootState, initialImage: Ima
addWatermarkerToGraph(state, graph, CANVAS_OUTPUT);
}
addLinearUIOutputNode(state, graph);
return graph;
};

View File

@ -1,5 +1,6 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
import type {
CreateDenoiseMaskInvocation,
ImageBlurInvocation,
@ -12,7 +13,6 @@ import type {
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
@ -191,7 +191,8 @@ export const buildCanvasInpaintGraph = (
[CANVAS_OUTPUT]: {
type: 'color_correct',
id: CANVAS_OUTPUT,
is_intermediate,
is_intermediate: getIsIntermediate(state),
board: getBoardField(state),
reference: canvasInitImage,
use_cache: false,
},
@ -663,7 +664,5 @@ export const buildCanvasInpaintGraph = (
addWatermarkerToGraph(state, graph, CANVAS_OUTPUT);
}
addLinearUIOutputNode(state, graph);
return graph;
};

View File

@ -1,5 +1,6 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
import type {
ImageDTO,
ImageToLatentsInvocation,
@ -11,7 +12,6 @@ import type {
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
@ -200,7 +200,8 @@ export const buildCanvasOutpaintGraph = (
[CANVAS_OUTPUT]: {
type: 'color_correct',
id: CANVAS_OUTPUT,
is_intermediate,
is_intermediate: getIsIntermediate(state),
board: getBoardField(state),
use_cache: false,
},
},
@ -769,7 +770,5 @@ export const buildCanvasOutpaintGraph = (
addWatermarkerToGraph(state, graph, CANVAS_OUTPUT);
}
addLinearUIOutputNode(state, graph);
return graph;
};

View File

@ -4,7 +4,6 @@ import type { ImageDTO, ImageToLatentsInvocation, NonNullableGraph } from 'servi
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
@ -26,7 +25,7 @@ import {
SDXL_REFINER_SEAMLESS,
SEAMLESS,
} from './constants';
import { getSDXLStylePrompts } from './getSDXLStylePrompt';
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils';
import { addCoreMetadataNode } from './metadata';
/**
@ -246,7 +245,8 @@ export const buildCanvasSDXLImageToImageGraph = (state: RootState, initialImage:
graph.nodes[CANVAS_OUTPUT] = {
id: CANVAS_OUTPUT,
type: 'img_resize',
is_intermediate,
is_intermediate: getIsIntermediate(state),
board: getBoardField(state),
width: width,
height: height,
use_cache: false,
@ -368,7 +368,5 @@ export const buildCanvasSDXLImageToImageGraph = (state: RootState, initialImage:
addWatermarkerToGraph(state, graph, CANVAS_OUTPUT);
}
addLinearUIOutputNode(state, graph);
return graph;
};

View File

@ -12,7 +12,6 @@ import type {
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
@ -44,7 +43,7 @@ import {
SDXL_REFINER_SEAMLESS,
SEAMLESS,
} from './constants';
import { getSDXLStylePrompts } from './getSDXLStylePrompt';
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils';
/**
* Builds the Canvas tab's Inpaint graph.
@ -190,7 +189,8 @@ export const buildCanvasSDXLInpaintGraph = (
[CANVAS_OUTPUT]: {
type: 'color_correct',
id: CANVAS_OUTPUT,
is_intermediate,
is_intermediate: getIsIntermediate(state),
board: getBoardField(state),
reference: canvasInitImage,
use_cache: false,
},
@ -687,7 +687,5 @@ export const buildCanvasSDXLInpaintGraph = (
addWatermarkerToGraph(state, graph, CANVAS_OUTPUT);
}
addLinearUIOutputNode(state, graph);
return graph;
};

View File

@ -11,7 +11,6 @@ import type {
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
@ -46,7 +45,7 @@ import {
SDXL_REFINER_SEAMLESS,
SEAMLESS,
} from './constants';
import { getSDXLStylePrompts } from './getSDXLStylePrompt';
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils';
/**
* Builds the Canvas tab's Outpaint graph.
@ -199,7 +198,8 @@ export const buildCanvasSDXLOutpaintGraph = (
[CANVAS_OUTPUT]: {
type: 'color_correct',
id: CANVAS_OUTPUT,
is_intermediate,
is_intermediate: getIsIntermediate(state),
board: getBoardField(state),
use_cache: false,
},
},
@ -786,7 +786,5 @@ export const buildCanvasSDXLOutpaintGraph = (
addWatermarkerToGraph(state, graph, CANVAS_OUTPUT);
}
addLinearUIOutputNode(state, graph);
return graph;
};

View File

@ -4,7 +4,6 @@ import type { NonNullableGraph } from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
@ -24,7 +23,7 @@ import {
SDXL_REFINER_SEAMLESS,
SEAMLESS,
} from './constants';
import { getSDXLStylePrompts } from './getSDXLStylePrompt';
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils';
import { addCoreMetadataNode } from './metadata';
/**
@ -222,7 +221,8 @@ export const buildCanvasSDXLTextToImageGraph = (state: RootState): NonNullableGr
graph.nodes[CANVAS_OUTPUT] = {
id: CANVAS_OUTPUT,
type: 'img_resize',
is_intermediate,
is_intermediate: getIsIntermediate(state),
board: getBoardField(state),
width: width,
height: height,
use_cache: false,
@ -254,7 +254,8 @@ export const buildCanvasSDXLTextToImageGraph = (state: RootState): NonNullableGr
graph.nodes[CANVAS_OUTPUT] = {
type: 'l2i',
id: CANVAS_OUTPUT,
is_intermediate,
is_intermediate: getIsIntermediate(state),
board: getBoardField(state),
fp32,
use_cache: false,
};
@ -330,7 +331,5 @@ export const buildCanvasSDXLTextToImageGraph = (state: RootState): NonNullableGr
addWatermarkerToGraph(state, graph, CANVAS_OUTPUT);
}
addLinearUIOutputNode(state, graph);
return graph;
};

View File

@ -1,10 +1,10 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
import type { NonNullableGraph } from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
@ -211,7 +211,8 @@ export const buildCanvasTextToImageGraph = (state: RootState): NonNullableGraph
graph.nodes[CANVAS_OUTPUT] = {
id: CANVAS_OUTPUT,
type: 'img_resize',
is_intermediate,
is_intermediate: getIsIntermediate(state),
board: getBoardField(state),
width: width,
height: height,
use_cache: false,
@ -243,7 +244,8 @@ export const buildCanvasTextToImageGraph = (state: RootState): NonNullableGraph
graph.nodes[CANVAS_OUTPUT] = {
type: 'l2i',
id: CANVAS_OUTPUT,
is_intermediate,
is_intermediate: getIsIntermediate(state),
board: getBoardField(state),
fp32,
use_cache: false,
};
@ -310,7 +312,5 @@ export const buildCanvasTextToImageGraph = (state: RootState): NonNullableGraph
addWatermarkerToGraph(state, graph, CANVAS_OUTPUT);
}
addLinearUIOutputNode(state, graph);
return graph;
};

View File

@ -1,10 +1,10 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
import type { ImageResizeInvocation, ImageToLatentsInvocation, NonNullableGraph } from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
@ -117,7 +117,8 @@ export const buildLinearImageToImageGraph = (state: RootState): NonNullableGraph
type: 'l2i',
id: LATENTS_TO_IMAGE,
fp32,
is_intermediate,
is_intermediate: getIsIntermediate(state),
board: getBoardField(state),
},
[DENOISE_LATENTS]: {
type: 'denoise_latents',
@ -358,7 +359,5 @@ export const buildLinearImageToImageGraph = (state: RootState): NonNullableGraph
addWatermarkerToGraph(state, graph);
}
addLinearUIOutputNode(state, graph);
return graph;
};

View File

@ -4,7 +4,6 @@ import type { ImageResizeInvocation, ImageToLatentsInvocation, NonNullableGraph
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
@ -25,7 +24,7 @@ import {
SDXL_REFINER_SEAMLESS,
SEAMLESS,
} from './constants';
import { getSDXLStylePrompts } from './getSDXLStylePrompt';
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils';
import { addCoreMetadataNode } from './metadata';
/**
@ -120,7 +119,8 @@ export const buildLinearSDXLImageToImageGraph = (state: RootState): NonNullableG
type: 'l2i',
id: LATENTS_TO_IMAGE,
fp32,
is_intermediate,
is_intermediate: getIsIntermediate(state),
board: getBoardField(state),
},
[SDXL_DENOISE_LATENTS]: {
type: 'denoise_latents',
@ -380,7 +380,5 @@ export const buildLinearSDXLImageToImageGraph = (state: RootState): NonNullableG
addWatermarkerToGraph(state, graph);
}
addLinearUIOutputNode(state, graph);
return graph;
};

View File

@ -4,7 +4,6 @@ import type { NonNullableGraph } from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
@ -23,7 +22,7 @@ import {
SDXL_TEXT_TO_IMAGE_GRAPH,
SEAMLESS,
} from './constants';
import { getSDXLStylePrompts } from './getSDXLStylePrompt';
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils';
import { addCoreMetadataNode } from './metadata';
export const buildLinearSDXLTextToImageGraph = (state: RootState): NonNullableGraph => {
@ -120,7 +119,8 @@ export const buildLinearSDXLTextToImageGraph = (state: RootState): NonNullableGr
type: 'l2i',
id: LATENTS_TO_IMAGE,
fp32,
is_intermediate,
is_intermediate: getIsIntermediate(state),
board: getBoardField(state),
use_cache: false,
},
},
@ -281,7 +281,5 @@ export const buildLinearSDXLTextToImageGraph = (state: RootState): NonNullableGr
addWatermarkerToGraph(state, graph);
}
addLinearUIOutputNode(state, graph);
return graph;
};

View File

@ -1,11 +1,11 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
import type { NonNullableGraph } from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addHrfToGraph } from './addHrfToGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
@ -119,7 +119,8 @@ export const buildLinearTextToImageGraph = (state: RootState): NonNullableGraph
type: 'l2i',
id: LATENTS_TO_IMAGE,
fp32,
is_intermediate,
is_intermediate: getIsIntermediate(state),
board: getBoardField(state),
use_cache: false,
},
},
@ -267,7 +268,5 @@ export const buildLinearTextToImageGraph = (state: RootState): NonNullableGraph
addWatermarkerToGraph(state, graph);
}
addLinearUIOutputNode(state, graph);
return graph;
};

View File

@ -9,7 +9,6 @@ export const LATENTS_TO_IMAGE_HRF_LR = 'latents_to_image_hrf_lr';
export const IMAGE_TO_LATENTS_HRF = 'image_to_latents_hrf';
export const RESIZE_HRF = 'resize_hrf';
export const ESRGAN_HRF = 'esrgan_hrf';
export const LINEAR_UI_OUTPUT = 'linear_ui_output';
export const NSFW_CHECKER = 'nsfw_checker';
export const WATERMARKER = 'invisible_watermark';
export const NOISE = 'noise';

View File

@ -1,11 +0,0 @@
import type { RootState } from 'app/store/store';
export const getSDXLStylePrompts = (state: RootState): { positiveStylePrompt: string; negativeStylePrompt: string } => {
const { positivePrompt, negativePrompt } = state.generation;
const { positiveStylePrompt, negativeStylePrompt, shouldConcatSDXLStylePrompt } = state.sdxl;
return {
positiveStylePrompt: shouldConcatSDXLStylePrompt ? positivePrompt : positiveStylePrompt,
negativeStylePrompt: shouldConcatSDXLStylePrompt ? negativePrompt : negativeStylePrompt,
};
};

View File

@ -0,0 +1,38 @@
import type { RootState } from 'app/store/store';
import type { BoardField } from 'features/nodes/types/common';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
/**
* Gets the board field, based on the autoAddBoardId setting.
*/
export const getBoardField = (state: RootState): BoardField | undefined => {
const { autoAddBoardId } = state.gallery;
if (autoAddBoardId === 'none') {
return undefined;
}
return { board_id: autoAddBoardId };
};
/**
* Gets the SDXL style prompts, based on the concat setting.
*/
export const getSDXLStylePrompts = (state: RootState): { positiveStylePrompt: string; negativeStylePrompt: string } => {
const { positivePrompt, negativePrompt } = state.generation;
const { positiveStylePrompt, negativeStylePrompt, shouldConcatSDXLStylePrompt } = state.sdxl;
return {
positiveStylePrompt: shouldConcatSDXLStylePrompt ? positivePrompt : positiveStylePrompt,
negativeStylePrompt: shouldConcatSDXLStylePrompt ? negativePrompt : negativeStylePrompt,
};
};
/**
* Gets the is_intermediate field, based on the active tab and shouldAutoSave setting.
*/
export const getIsIntermediate = (state: RootState) => {
const activeTabName = activeTabNameSelector(state);
if (activeTabName === 'unifiedCanvas') {
return !state.canvas.shouldAutoSave;
}
return false;
};

View File

@ -132,7 +132,6 @@ export type DivideInvocation = s['DivideInvocation'];
export type ImageNSFWBlurInvocation = s['ImageNSFWBlurInvocation'];
export type ImageWatermarkInvocation = s['ImageWatermarkInvocation'];
export type SeamlessModeInvocation = s['SeamlessModeInvocation'];
export type LinearUIOutputInvocation = s['LinearUIOutputInvocation'];
export type MetadataInvocation = s['MetadataInvocation'];
export type CoreMetadataInvocation = s['CoreMetadataInvocation'];
export type MetadataItemInvocation = s['MetadataItemInvocation'];