feat(ui): use graph utils in builders (wip)

This commit is contained in:
psychedelicious 2024-05-05 16:26:30 +10:00
parent 8f6078d007
commit dbe22be598
3 changed files with 316 additions and 0 deletions

View File

@ -0,0 +1,79 @@
import type { RootState } from 'app/store/store';
import { isInitialImageLayer } from 'features/controlLayers/store/controlLayersSlice';
import type { ImageField } from 'features/nodes/types/common';
import type { Graph } from 'features/nodes/util/graph/Graph';
import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil';
import type { Invocation } from 'services/api/types';
import { IMAGE_TO_LATENTS, RESIZE } from './constants';
/**
* Adds the initial image to the graph and connects it to the denoise and noise nodes.
* @param state The current Redux state
* @param g The graph to add the initial image to
* @param denoise The denoise node in the graph
* @param noise The noise node in the graph
* @returns Whether the initial image was added to the graph
*/
export const addInitialImageToGenerationTabGraph = (
state: RootState,
g: Graph,
denoise: Invocation<'denoise_latents'>,
noise: Invocation<'noise'>
): boolean => {
// Remove Existing UNet Connections
const { img2imgStrength, vaePrecision, model } = state.generation;
const { refinerModel, refinerStart } = state.sdxl;
const { width, height } = state.controlLayers.present.size;
const initialImageLayer = state.controlLayers.present.layers.find(isInitialImageLayer);
const initialImage = initialImageLayer?.isEnabled ? initialImageLayer?.image : null;
if (!initialImage) {
return false;
}
const isSDXL = model?.base === 'sdxl';
const useRefinerStartEnd = isSDXL && Boolean(refinerModel);
const image: ImageField = {
image_name: initialImage.imageName,
};
denoise.denoising_start = useRefinerStartEnd ? Math.min(refinerStart, 1 - img2imgStrength) : 1 - img2imgStrength;
denoise.denoising_end = useRefinerStartEnd ? refinerStart : 1;
const i2l = g.addNode({
type: 'i2l',
id: IMAGE_TO_LATENTS,
fp32: vaePrecision === 'fp32',
});
g.addEdge(i2l, 'latents', denoise, 'latents');
if (initialImage.width !== width || initialImage.height !== height) {
// The init image needs to be resized to the specified width and height before being passed to `IMAGE_TO_LATENTS`
const resize = g.addNode({
id: RESIZE,
type: 'img_resize',
image,
width,
height,
});
// The `RESIZE` node then passes its image, to `IMAGE_TO_LATENTS`
g.addEdge(resize, 'image', i2l, 'image');
// The `RESIZE` node also passes its width and height to `NOISE`
g.addEdge(resize, 'width', noise, 'width');
g.addEdge(resize, 'height', noise, 'height');
} else {
// We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly
i2l.image = image;
g.addEdge(i2l, 'width', noise, 'width');
g.addEdge(i2l, 'height', noise, 'height');
}
MetadataUtil.add(g, {
generation_mode: isSDXL ? 'sdxl_img2img' : 'img2img',
strength: img2imgStrength,
init_image: initialImage.imageName,
});
return true;
};

View File

@ -0,0 +1,70 @@
import type { RootState } from 'app/store/store';
import type { Graph } from 'features/nodes/util/graph/Graph';
import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil';
import type { Invocation } from 'services/api/types';
import { SEAMLESS, VAE_LOADER } from './constants';
/**
* Adds the seamless node to the graph and connects it to the model loader and denoise node.
* Because the seamless node may insert a VAE loader node between the model loader and itself,
* this function returns the terminal model loader node in the graph.
* @param state The current Redux state
* @param g The graph to add the seamless node to
* @param denoise The denoise node in the graph
* @param modelLoader The model loader node in the graph
* @returns The terminal model loader node in the graph
*/
export const addSeamlessToGenerationTabGraph = (
state: RootState,
g: Graph,
denoise: Invocation<'denoise_latents'>,
modelLoader: Invocation<'main_model_loader'> | Invocation<'sdxl_model_loader'>
): Invocation<'main_model_loader'> | Invocation<'sdxl_model_loader'> | Invocation<'seamless'> => {
const { seamlessXAxis, seamlessYAxis, vae } = state.generation;
if (!seamlessXAxis && !seamlessYAxis) {
return modelLoader;
}
const seamless = g.addNode({
id: SEAMLESS,
type: 'seamless',
seamless_x: seamlessXAxis,
seamless_y: seamlessYAxis,
});
const vaeLoader = vae
? g.addNode({
type: 'vae_loader',
id: VAE_LOADER,
vae_model: vae,
})
: null;
let terminalModelLoader: Invocation<'main_model_loader'> | Invocation<'sdxl_model_loader'> | Invocation<'seamless'> =
modelLoader;
if (seamlessXAxis) {
MetadataUtil.add(g, {
seamless_x: seamlessXAxis,
});
terminalModelLoader = seamless;
}
if (seamlessYAxis) {
MetadataUtil.add(g, {
seamless_y: seamlessYAxis,
});
terminalModelLoader = seamless;
}
// Seamless slots into the graph between the model loader and the denoise node
g.deleteEdgesFrom(modelLoader, 'unet');
g.deleteEdgesFrom(modelLoader, 'clip');
g.addEdge(modelLoader, 'unet', seamless, 'unet');
g.addEdge(vaeLoader ?? modelLoader, 'vae', seamless, 'unet');
g.addEdge(seamless, 'unet', denoise, 'unet');
return terminalModelLoader;
};

View File

@ -0,0 +1,167 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { addControlLayersToGraph } from 'features/nodes/util/graph/addControlLayersToGraph';
import { addInitialImageToGenerationTabGraph } from 'features/nodes/util/graph/addInitialImageToGenerationTabGraph';
import { addSeamlessToGenerationTabGraph } from 'features/nodes/util/graph/addSeamlessToGenerationTabGraph';
import { Graph } from 'features/nodes/util/graph/Graph';
import { getBoardField } from 'features/nodes/util/graph/graphBuilderUtils';
import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil';
import { isNonRefinerMainModelConfig } from 'services/api/types';
import { addHrfToGraph } from './addHrfToGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import {
CLIP_SKIP,
CONTROL_LAYERS_GRAPH,
DENOISE_LATENTS,
LATENTS_TO_IMAGE,
MAIN_MODEL_LOADER,
NEGATIVE_CONDITIONING,
NEGATIVE_CONDITIONING_COLLECT,
NOISE,
POSITIVE_CONDITIONING,
POSITIVE_CONDITIONING_COLLECT,
} from './constants';
import { getModelMetadataField } from './metadata';
const log = logger('nodes');
export const buildGenerationTabGraph = async (state: RootState): Promise<Graph> => {
const {
model,
cfgScale: cfg_scale,
cfgRescaleMultiplier: cfg_rescale_multiplier,
scheduler,
steps,
clipSkip: skipped_layers,
shouldUseCpuNoise,
vaePrecision,
seamlessXAxis,
seamlessYAxis,
seed,
} = state.generation;
const { positivePrompt, negativePrompt } = state.controlLayers.present;
const { width, height } = state.controlLayers.present.size;
if (!model) {
log.error('No model found in state');
throw new Error('No model found in state');
}
const g = new Graph(CONTROL_LAYERS_GRAPH);
const modelLoader = g.addNode({
type: 'main_model_loader',
id: MAIN_MODEL_LOADER,
model,
});
const clipSkip = g.addNode({
type: 'clip_skip',
id: CLIP_SKIP,
skipped_layers,
});
const posCond = g.addNode({
type: 'compel',
id: POSITIVE_CONDITIONING,
prompt: positivePrompt,
});
const posCondCollect = g.addNode({
type: 'collect',
id: POSITIVE_CONDITIONING_COLLECT,
});
const negCond = g.addNode({
type: 'compel',
id: NEGATIVE_CONDITIONING,
prompt: negativePrompt,
});
const negCondCollect = g.addNode({
type: 'collect',
id: NEGATIVE_CONDITIONING_COLLECT,
});
const noise = g.addNode({
type: 'noise',
id: NOISE,
seed,
width,
height,
use_cpu: shouldUseCpuNoise,
});
const denoise = g.addNode({
type: 'denoise_latents',
id: DENOISE_LATENTS,
cfg_scale,
cfg_rescale_multiplier,
scheduler,
steps,
denoising_start: 0,
denoising_end: 1,
});
const l2i = g.addNode({
type: 'l2i',
id: LATENTS_TO_IMAGE,
fp32: vaePrecision === 'fp32',
board: getBoardField(state),
// This is the terminal node and must always save to gallery.
is_intermediate: false,
use_cache: false,
});
g.addEdge(modelLoader, 'unet', denoise, 'unet');
g.addEdge(modelLoader, 'clip', clipSkip, 'clip');
g.addEdge(clipSkip, 'clip', posCond, 'clip');
g.addEdge(clipSkip, 'clip', negCond, 'clip');
g.addEdge(posCond, 'conditioning', posCondCollect, 'item');
g.addEdge(negCond, 'conditioning', negCondCollect, 'item');
g.addEdge(noise, 'noise', denoise, 'noise');
g.addEdge(denoise, 'latents', l2i, 'latents');
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
MetadataUtil.add(g, {
generation_mode: 'txt2img',
cfg_scale,
cfg_rescale_multiplier,
height,
width,
positive_prompt: positivePrompt,
negative_prompt: negativePrompt,
model: getModelMetadataField(modelConfig),
seed,
steps,
rand_device: shouldUseCpuNoise ? 'cpu' : 'cuda',
scheduler,
clip_skip: skipped_layers,
});
MetadataUtil.setMetadataReceivingNode(g, l2i);
const didAddInitialImage = addInitialImageToGenerationTabGraph(state, g, denoise, noise);
const terminalModelLoader = addSeamlessToGenerationTabGraph(state, g, denoise, modelLoader);
// optionally add custom VAE
await addVAEToGraph(state, graph, modelLoaderNodeId);
// add LoRA support
await addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
await addControlLayersToGraph(state, graph, DENOISE_LATENTS);
// High resolution fix.
if (state.hrf.hrfEnabled && !didAddInitialImage) {
addHrfToGraph(state, graph);
}
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!
addNSFWCheckerToGraph(state, graph);
}
if (state.system.shouldUseWatermarker) {
// must add after nsfw checker!
addWatermarkerToGraph(state, graph);
}
return graph;
};