mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): port NSFW and watermark nodes to graph builder
This commit is contained in:
parent
04d12a1e98
commit
8d39520232
@ -65,7 +65,7 @@ function calculateHrfRes(
|
||||
* @param noise The noise node
|
||||
* @param l2i The l2i node
|
||||
* @param vaeSource The VAE source node (may be a model loader, VAE loader, or seamless node)
|
||||
* @returns
|
||||
* @returns The HRF image output node.
|
||||
*/
|
||||
export const addGenerationTabHRF = (
|
||||
state: RootState,
|
||||
@ -74,11 +74,7 @@ export const addGenerationTabHRF = (
|
||||
noise: Invocation<'noise'>,
|
||||
l2i: Invocation<'l2i'>,
|
||||
vaeSource: Invocation<'vae_loader'> | Invocation<'main_model_loader'> | Invocation<'seamless'>
|
||||
): void => {
|
||||
if (!state.hrf.hrfEnabled || state.config.disabledSDFeatures.includes('hrf')) {
|
||||
return;
|
||||
}
|
||||
|
||||
): Invocation<'l2i'> => {
|
||||
const { hrfStrength, hrfEnabled, hrfMethod } = state.hrf;
|
||||
const { width, height } = state.controlLayers.present.size;
|
||||
const optimalDimension = selectOptimalDimension(state);
|
||||
@ -167,4 +163,6 @@ export const addGenerationTabHRF = (
|
||||
hrf_method: hrfMethod,
|
||||
});
|
||||
MetadataUtil.setMetadataReceivingNode(g, l2iHrfHR);
|
||||
|
||||
return l2iHrfHR;
|
||||
};
|
||||
|
@ -0,0 +1,31 @@
|
||||
import type { Graph } from 'features/nodes/util/graph/Graph';
|
||||
import type { Invocation } from 'services/api/types';
|
||||
|
||||
import { NSFW_CHECKER } from './constants';
|
||||
|
||||
/**
|
||||
* Adds the NSFW checker to the output image
|
||||
* @param g The graph
|
||||
* @param imageOutput The current image output node
|
||||
* @returns The nsfw checker node
|
||||
*/
|
||||
export const addGenerationTabNSFWChecker = (
|
||||
g: Graph,
|
||||
imageOutput: Invocation<'l2i'> | Invocation<'img_nsfw'> | Invocation<'img_watermark'>
|
||||
): Invocation<'img_nsfw'> => {
|
||||
const nsfw = g.addNode({
|
||||
id: NSFW_CHECKER,
|
||||
type: 'img_nsfw',
|
||||
is_intermediate: imageOutput.is_intermediate,
|
||||
board: imageOutput.board,
|
||||
use_cache: false,
|
||||
});
|
||||
|
||||
imageOutput.is_intermediate = true;
|
||||
imageOutput.use_cache = true;
|
||||
imageOutput.board = undefined;
|
||||
|
||||
g.addEdge(imageOutput, 'image', nsfw, 'image');
|
||||
|
||||
return nsfw;
|
||||
};
|
@ -0,0 +1,31 @@
|
||||
import type { Graph } from 'features/nodes/util/graph/Graph';
|
||||
import type { Invocation } from 'services/api/types';
|
||||
|
||||
import { WATERMARKER } from './constants';
|
||||
|
||||
/**
|
||||
* Adds a watermark to the output image
|
||||
* @param g The graph
|
||||
* @param imageOutput The image output node
|
||||
* @returns The watermark node
|
||||
*/
|
||||
export const addGenerationTabWatermarker = (
|
||||
g: Graph,
|
||||
imageOutput: Invocation<'l2i'> | Invocation<'img_nsfw'> | Invocation<'img_watermark'>
|
||||
): Invocation<'img_watermark'> => {
|
||||
const watermark = g.addNode({
|
||||
id: WATERMARKER,
|
||||
type: 'img_watermark',
|
||||
is_intermediate: imageOutput.is_intermediate,
|
||||
board: imageOutput.board,
|
||||
use_cache: false,
|
||||
});
|
||||
|
||||
imageOutput.is_intermediate = true;
|
||||
imageOutput.use_cache = true;
|
||||
imageOutput.board = undefined;
|
||||
|
||||
g.addEdge(imageOutput, 'image', watermark, 'image');
|
||||
|
||||
return watermark;
|
||||
};
|
@ -5,15 +5,16 @@ import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetch
|
||||
import { addGenerationTabControlLayers } from 'features/nodes/util/graph/addGenerationTabControlLayers';
|
||||
import { addGenerationTabHRF } from 'features/nodes/util/graph/addGenerationTabHRF';
|
||||
import { addGenerationTabLoRAs } from 'features/nodes/util/graph/addGenerationTabLoRAs';
|
||||
import { addGenerationTabNSFWChecker } from 'features/nodes/util/graph/addGenerationTabNSFWChecker';
|
||||
import { addGenerationTabSeamless } from 'features/nodes/util/graph/addGenerationTabSeamless';
|
||||
import { addGenerationTabWatermarker } from 'features/nodes/util/graph/addGenerationTabWatermarker';
|
||||
import type { GraphType } from 'features/nodes/util/graph/Graph';
|
||||
import { Graph } from 'features/nodes/util/graph/Graph';
|
||||
import { getBoardField } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil';
|
||||
import type { Invocation } from 'services/api/types';
|
||||
import { isNonRefinerMainModelConfig } from 'services/api/types';
|
||||
|
||||
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
|
||||
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
|
||||
import {
|
||||
CLIP_SKIP,
|
||||
CONTROL_LAYERS_GRAPH,
|
||||
@ -116,6 +117,8 @@ export const buildGenerationTabGraph2 = async (state: RootState): Promise<GraphT
|
||||
})
|
||||
: null;
|
||||
|
||||
let imageOutput: Invocation<'l2i'> | Invocation<'img_nsfw'> | Invocation<'img_watermark'> = l2i;
|
||||
|
||||
g.addEdge(modelLoader, 'unet', denoise, 'unet');
|
||||
g.addEdge(modelLoader, 'clip', clipSkip, 'clip');
|
||||
g.addEdge(clipSkip, 'clip', posCond, 'clip');
|
||||
@ -145,7 +148,6 @@ export const buildGenerationTabGraph2 = async (state: RootState): Promise<GraphT
|
||||
clip_skip: skipped_layers,
|
||||
vae: vae ?? undefined,
|
||||
});
|
||||
MetadataUtil.setMetadataReceivingNode(g, l2i);
|
||||
g.validate();
|
||||
|
||||
const seamless = addGenerationTabSeamless(state, g, denoise, modelLoader, vaeLoader);
|
||||
@ -172,20 +174,18 @@ export const buildGenerationTabGraph2 = async (state: RootState): Promise<GraphT
|
||||
g.validate();
|
||||
|
||||
const isHRFAllowed = !addedLayers.some((l) => isInitialImageLayer(l) || isRegionalGuidanceLayer(l));
|
||||
if (isHRFAllowed) {
|
||||
addGenerationTabHRF(state, g, denoise, noise, l2i, vaeSource);
|
||||
if (isHRFAllowed && state.hrf.hrfEnabled) {
|
||||
imageOutput = addGenerationTabHRF(state, g, denoise, noise, l2i, vaeSource);
|
||||
}
|
||||
|
||||
// NSFW & watermark - must be last thing added to graph
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
// must add before watermarker!
|
||||
addNSFWCheckerToGraph(state, graph);
|
||||
imageOutput = addGenerationTabNSFWChecker(g, imageOutput);
|
||||
}
|
||||
|
||||
if (state.system.shouldUseWatermarker) {
|
||||
// must add after nsfw checker!
|
||||
addWatermarkerToGraph(state, graph);
|
||||
imageOutput = addGenerationTabWatermarker(g, imageOutput);
|
||||
}
|
||||
|
||||
MetadataUtil.setMetadataReceivingNode(g, imageOutput);
|
||||
return g.getGraph();
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user