feat(ui): port NSFW and watermark nodes to graph builder

This commit is contained in:
psychedelicious 2024-05-13 18:21:55 +10:00
parent 04d12a1e98
commit 8d39520232
4 changed files with 76 additions and 16 deletions

View File

@ -65,7 +65,7 @@ function calculateHrfRes(
* @param noise The noise node
* @param l2i The l2i node
* @param vaeSource The VAE source node (may be a model loader, VAE loader, or seamless node)
* @returns
* @returns The HRF image output node.
*/
export const addGenerationTabHRF = (
state: RootState,
@ -74,11 +74,7 @@ export const addGenerationTabHRF = (
noise: Invocation<'noise'>,
l2i: Invocation<'l2i'>,
vaeSource: Invocation<'vae_loader'> | Invocation<'main_model_loader'> | Invocation<'seamless'>
): void => {
if (!state.hrf.hrfEnabled || state.config.disabledSDFeatures.includes('hrf')) {
return;
}
): Invocation<'l2i'> => {
const { hrfStrength, hrfEnabled, hrfMethod } = state.hrf;
const { width, height } = state.controlLayers.present.size;
const optimalDimension = selectOptimalDimension(state);
@ -167,4 +163,6 @@ export const addGenerationTabHRF = (
hrf_method: hrfMethod,
});
MetadataUtil.setMetadataReceivingNode(g, l2iHrfHR);
return l2iHrfHR;
};

View File

@ -0,0 +1,31 @@
import type { Graph } from 'features/nodes/util/graph/Graph';
import type { Invocation } from 'services/api/types';
import { NSFW_CHECKER } from './constants';
/**
* Adds the NSFW checker to the output image
* @param g The graph
* @param imageOutput The current image output node
* @returns The nsfw checker node
*/
export const addGenerationTabNSFWChecker = (
g: Graph,
imageOutput: Invocation<'l2i'> | Invocation<'img_nsfw'> | Invocation<'img_watermark'>
): Invocation<'img_nsfw'> => {
const nsfw = g.addNode({
id: NSFW_CHECKER,
type: 'img_nsfw',
is_intermediate: imageOutput.is_intermediate,
board: imageOutput.board,
use_cache: false,
});
imageOutput.is_intermediate = true;
imageOutput.use_cache = true;
imageOutput.board = undefined;
g.addEdge(imageOutput, 'image', nsfw, 'image');
return nsfw;
};

View File

@ -0,0 +1,31 @@
import type { Graph } from 'features/nodes/util/graph/Graph';
import type { Invocation } from 'services/api/types';
import { WATERMARKER } from './constants';
/**
* Adds a watermark to the output image
* @param g The graph
* @param imageOutput The image output node
* @returns The watermark node
*/
export const addGenerationTabWatermarker = (
g: Graph,
imageOutput: Invocation<'l2i'> | Invocation<'img_nsfw'> | Invocation<'img_watermark'>
): Invocation<'img_watermark'> => {
const watermark = g.addNode({
id: WATERMARKER,
type: 'img_watermark',
is_intermediate: imageOutput.is_intermediate,
board: imageOutput.board,
use_cache: false,
});
imageOutput.is_intermediate = true;
imageOutput.use_cache = true;
imageOutput.board = undefined;
g.addEdge(imageOutput, 'image', watermark, 'image');
return watermark;
};

View File

@ -5,15 +5,16 @@ import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetch
import { addGenerationTabControlLayers } from 'features/nodes/util/graph/addGenerationTabControlLayers';
import { addGenerationTabHRF } from 'features/nodes/util/graph/addGenerationTabHRF';
import { addGenerationTabLoRAs } from 'features/nodes/util/graph/addGenerationTabLoRAs';
import { addGenerationTabNSFWChecker } from 'features/nodes/util/graph/addGenerationTabNSFWChecker';
import { addGenerationTabSeamless } from 'features/nodes/util/graph/addGenerationTabSeamless';
import { addGenerationTabWatermarker } from 'features/nodes/util/graph/addGenerationTabWatermarker';
import type { GraphType } from 'features/nodes/util/graph/Graph';
import { Graph } from 'features/nodes/util/graph/Graph';
import { getBoardField } from 'features/nodes/util/graph/graphBuilderUtils';
import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil';
import type { Invocation } from 'services/api/types';
import { isNonRefinerMainModelConfig } from 'services/api/types';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import {
CLIP_SKIP,
CONTROL_LAYERS_GRAPH,
@ -116,6 +117,8 @@ export const buildGenerationTabGraph2 = async (state: RootState): Promise<GraphT
})
: null;
let imageOutput: Invocation<'l2i'> | Invocation<'img_nsfw'> | Invocation<'img_watermark'> = l2i;
g.addEdge(modelLoader, 'unet', denoise, 'unet');
g.addEdge(modelLoader, 'clip', clipSkip, 'clip');
g.addEdge(clipSkip, 'clip', posCond, 'clip');
@ -145,7 +148,6 @@ export const buildGenerationTabGraph2 = async (state: RootState): Promise<GraphT
clip_skip: skipped_layers,
vae: vae ?? undefined,
});
MetadataUtil.setMetadataReceivingNode(g, l2i);
g.validate();
const seamless = addGenerationTabSeamless(state, g, denoise, modelLoader, vaeLoader);
@ -172,20 +174,18 @@ export const buildGenerationTabGraph2 = async (state: RootState): Promise<GraphT
g.validate();
const isHRFAllowed = !addedLayers.some((l) => isInitialImageLayer(l) || isRegionalGuidanceLayer(l));
if (isHRFAllowed) {
addGenerationTabHRF(state, g, denoise, noise, l2i, vaeSource);
if (isHRFAllowed && state.hrf.hrfEnabled) {
imageOutput = addGenerationTabHRF(state, g, denoise, noise, l2i, vaeSource);
}
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!
addNSFWCheckerToGraph(state, graph);
imageOutput = addGenerationTabNSFWChecker(g, imageOutput);
}
if (state.system.shouldUseWatermarker) {
// must add after nsfw checker!
addWatermarkerToGraph(state, graph);
imageOutput = addGenerationTabWatermarker(g, imageOutput);
}
MetadataUtil.setMetadataReceivingNode(g, imageOutput);
return g.getGraph();
};