mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): use graph utils in builders (wip)
This commit is contained in:
parent
8f6078d007
commit
dbe22be598
@ -0,0 +1,79 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { isInitialImageLayer } from 'features/controlLayers/store/controlLayersSlice';
|
||||
import type { ImageField } from 'features/nodes/types/common';
|
||||
import type { Graph } from 'features/nodes/util/graph/Graph';
|
||||
import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil';
|
||||
import type { Invocation } from 'services/api/types';
|
||||
|
||||
import { IMAGE_TO_LATENTS, RESIZE } from './constants';
|
||||
|
||||
/**
|
||||
* Adds the initial image to the graph and connects it to the denoise and noise nodes.
|
||||
* @param state The current Redux state
|
||||
* @param g The graph to add the initial image to
|
||||
* @param denoise The denoise node in the graph
|
||||
* @param noise The noise node in the graph
|
||||
* @returns Whether the initial image was added to the graph
|
||||
*/
|
||||
export const addInitialImageToGenerationTabGraph = (
|
||||
state: RootState,
|
||||
g: Graph,
|
||||
denoise: Invocation<'denoise_latents'>,
|
||||
noise: Invocation<'noise'>
|
||||
): boolean => {
|
||||
// Remove Existing UNet Connections
|
||||
const { img2imgStrength, vaePrecision, model } = state.generation;
|
||||
const { refinerModel, refinerStart } = state.sdxl;
|
||||
const { width, height } = state.controlLayers.present.size;
|
||||
const initialImageLayer = state.controlLayers.present.layers.find(isInitialImageLayer);
|
||||
const initialImage = initialImageLayer?.isEnabled ? initialImageLayer?.image : null;
|
||||
|
||||
if (!initialImage) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const isSDXL = model?.base === 'sdxl';
|
||||
const useRefinerStartEnd = isSDXL && Boolean(refinerModel);
|
||||
const image: ImageField = {
|
||||
image_name: initialImage.imageName,
|
||||
};
|
||||
|
||||
denoise.denoising_start = useRefinerStartEnd ? Math.min(refinerStart, 1 - img2imgStrength) : 1 - img2imgStrength;
|
||||
denoise.denoising_end = useRefinerStartEnd ? refinerStart : 1;
|
||||
|
||||
const i2l = g.addNode({
|
||||
type: 'i2l',
|
||||
id: IMAGE_TO_LATENTS,
|
||||
fp32: vaePrecision === 'fp32',
|
||||
});
|
||||
g.addEdge(i2l, 'latents', denoise, 'latents');
|
||||
|
||||
if (initialImage.width !== width || initialImage.height !== height) {
|
||||
// The init image needs to be resized to the specified width and height before being passed to `IMAGE_TO_LATENTS`
|
||||
const resize = g.addNode({
|
||||
id: RESIZE,
|
||||
type: 'img_resize',
|
||||
image,
|
||||
width,
|
||||
height,
|
||||
});
|
||||
// The `RESIZE` node then passes its image, to `IMAGE_TO_LATENTS`
|
||||
g.addEdge(resize, 'image', i2l, 'image');
|
||||
// The `RESIZE` node also passes its width and height to `NOISE`
|
||||
g.addEdge(resize, 'width', noise, 'width');
|
||||
g.addEdge(resize, 'height', noise, 'height');
|
||||
} else {
|
||||
// We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly
|
||||
i2l.image = image;
|
||||
g.addEdge(i2l, 'width', noise, 'width');
|
||||
g.addEdge(i2l, 'height', noise, 'height');
|
||||
}
|
||||
|
||||
MetadataUtil.add(g, {
|
||||
generation_mode: isSDXL ? 'sdxl_img2img' : 'img2img',
|
||||
strength: img2imgStrength,
|
||||
init_image: initialImage.imageName,
|
||||
});
|
||||
|
||||
return true;
|
||||
};
|
@ -0,0 +1,70 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import type { Graph } from 'features/nodes/util/graph/Graph';
|
||||
import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil';
|
||||
import type { Invocation } from 'services/api/types';
|
||||
|
||||
import { SEAMLESS, VAE_LOADER } from './constants';
|
||||
|
||||
/**
|
||||
* Adds the seamless node to the graph and connects it to the model loader and denoise node.
|
||||
* Because the seamless node may insert a VAE loader node between the model loader and itself,
|
||||
* this function returns the terminal model loader node in the graph.
|
||||
* @param state The current Redux state
|
||||
* @param g The graph to add the seamless node to
|
||||
* @param denoise The denoise node in the graph
|
||||
* @param modelLoader The model loader node in the graph
|
||||
* @returns The terminal model loader node in the graph
|
||||
*/
|
||||
export const addSeamlessToGenerationTabGraph = (
|
||||
state: RootState,
|
||||
g: Graph,
|
||||
denoise: Invocation<'denoise_latents'>,
|
||||
modelLoader: Invocation<'main_model_loader'> | Invocation<'sdxl_model_loader'>
|
||||
): Invocation<'main_model_loader'> | Invocation<'sdxl_model_loader'> | Invocation<'seamless'> => {
|
||||
const { seamlessXAxis, seamlessYAxis, vae } = state.generation;
|
||||
|
||||
if (!seamlessXAxis && !seamlessYAxis) {
|
||||
return modelLoader;
|
||||
}
|
||||
|
||||
const seamless = g.addNode({
|
||||
id: SEAMLESS,
|
||||
type: 'seamless',
|
||||
seamless_x: seamlessXAxis,
|
||||
seamless_y: seamlessYAxis,
|
||||
});
|
||||
|
||||
const vaeLoader = vae
|
||||
? g.addNode({
|
||||
type: 'vae_loader',
|
||||
id: VAE_LOADER,
|
||||
vae_model: vae,
|
||||
})
|
||||
: null;
|
||||
|
||||
let terminalModelLoader: Invocation<'main_model_loader'> | Invocation<'sdxl_model_loader'> | Invocation<'seamless'> =
|
||||
modelLoader;
|
||||
|
||||
if (seamlessXAxis) {
|
||||
MetadataUtil.add(g, {
|
||||
seamless_x: seamlessXAxis,
|
||||
});
|
||||
terminalModelLoader = seamless;
|
||||
}
|
||||
if (seamlessYAxis) {
|
||||
MetadataUtil.add(g, {
|
||||
seamless_y: seamlessYAxis,
|
||||
});
|
||||
terminalModelLoader = seamless;
|
||||
}
|
||||
|
||||
// Seamless slots into the graph between the model loader and the denoise node
|
||||
g.deleteEdgesFrom(modelLoader, 'unet');
|
||||
g.deleteEdgesFrom(modelLoader, 'clip');
|
||||
|
||||
g.addEdge(modelLoader, 'unet', seamless, 'unet');
|
||||
g.addEdge(vaeLoader ?? modelLoader, 'vae', seamless, 'unet');
|
||||
g.addEdge(seamless, 'unet', denoise, 'unet');
|
||||
|
||||
return terminalModelLoader;
|
||||
};
|
@ -0,0 +1,167 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import { addControlLayersToGraph } from 'features/nodes/util/graph/addControlLayersToGraph';
|
||||
import { addInitialImageToGenerationTabGraph } from 'features/nodes/util/graph/addInitialImageToGenerationTabGraph';
|
||||
import { addSeamlessToGenerationTabGraph } from 'features/nodes/util/graph/addSeamlessToGenerationTabGraph';
|
||||
import { Graph } from 'features/nodes/util/graph/Graph';
|
||||
import { getBoardField } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil';
|
||||
import { isNonRefinerMainModelConfig } from 'services/api/types';
|
||||
|
||||
import { addHrfToGraph } from './addHrfToGraph';
|
||||
import { addLoRAsToGraph } from './addLoRAsToGraph';
|
||||
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
|
||||
import { addVAEToGraph } from './addVAEToGraph';
|
||||
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
|
||||
import {
|
||||
CLIP_SKIP,
|
||||
CONTROL_LAYERS_GRAPH,
|
||||
DENOISE_LATENTS,
|
||||
LATENTS_TO_IMAGE,
|
||||
MAIN_MODEL_LOADER,
|
||||
NEGATIVE_CONDITIONING,
|
||||
NEGATIVE_CONDITIONING_COLLECT,
|
||||
NOISE,
|
||||
POSITIVE_CONDITIONING,
|
||||
POSITIVE_CONDITIONING_COLLECT,
|
||||
} from './constants';
|
||||
import { getModelMetadataField } from './metadata';
|
||||
|
||||
const log = logger('nodes');
|
||||
export const buildGenerationTabGraph = async (state: RootState): Promise<Graph> => {
|
||||
const {
|
||||
model,
|
||||
cfgScale: cfg_scale,
|
||||
cfgRescaleMultiplier: cfg_rescale_multiplier,
|
||||
scheduler,
|
||||
steps,
|
||||
clipSkip: skipped_layers,
|
||||
shouldUseCpuNoise,
|
||||
vaePrecision,
|
||||
seamlessXAxis,
|
||||
seamlessYAxis,
|
||||
seed,
|
||||
} = state.generation;
|
||||
const { positivePrompt, negativePrompt } = state.controlLayers.present;
|
||||
const { width, height } = state.controlLayers.present.size;
|
||||
|
||||
if (!model) {
|
||||
log.error('No model found in state');
|
||||
throw new Error('No model found in state');
|
||||
}
|
||||
|
||||
const g = new Graph(CONTROL_LAYERS_GRAPH);
|
||||
const modelLoader = g.addNode({
|
||||
type: 'main_model_loader',
|
||||
id: MAIN_MODEL_LOADER,
|
||||
model,
|
||||
});
|
||||
const clipSkip = g.addNode({
|
||||
type: 'clip_skip',
|
||||
id: CLIP_SKIP,
|
||||
skipped_layers,
|
||||
});
|
||||
const posCond = g.addNode({
|
||||
type: 'compel',
|
||||
id: POSITIVE_CONDITIONING,
|
||||
prompt: positivePrompt,
|
||||
});
|
||||
const posCondCollect = g.addNode({
|
||||
type: 'collect',
|
||||
id: POSITIVE_CONDITIONING_COLLECT,
|
||||
});
|
||||
const negCond = g.addNode({
|
||||
type: 'compel',
|
||||
id: NEGATIVE_CONDITIONING,
|
||||
prompt: negativePrompt,
|
||||
});
|
||||
const negCondCollect = g.addNode({
|
||||
type: 'collect',
|
||||
id: NEGATIVE_CONDITIONING_COLLECT,
|
||||
});
|
||||
const noise = g.addNode({
|
||||
type: 'noise',
|
||||
id: NOISE,
|
||||
seed,
|
||||
width,
|
||||
height,
|
||||
use_cpu: shouldUseCpuNoise,
|
||||
});
|
||||
const denoise = g.addNode({
|
||||
type: 'denoise_latents',
|
||||
id: DENOISE_LATENTS,
|
||||
cfg_scale,
|
||||
cfg_rescale_multiplier,
|
||||
scheduler,
|
||||
steps,
|
||||
denoising_start: 0,
|
||||
denoising_end: 1,
|
||||
});
|
||||
const l2i = g.addNode({
|
||||
type: 'l2i',
|
||||
id: LATENTS_TO_IMAGE,
|
||||
fp32: vaePrecision === 'fp32',
|
||||
board: getBoardField(state),
|
||||
// This is the terminal node and must always save to gallery.
|
||||
is_intermediate: false,
|
||||
use_cache: false,
|
||||
});
|
||||
|
||||
g.addEdge(modelLoader, 'unet', denoise, 'unet');
|
||||
g.addEdge(modelLoader, 'clip', clipSkip, 'clip');
|
||||
g.addEdge(clipSkip, 'clip', posCond, 'clip');
|
||||
g.addEdge(clipSkip, 'clip', negCond, 'clip');
|
||||
g.addEdge(posCond, 'conditioning', posCondCollect, 'item');
|
||||
g.addEdge(negCond, 'conditioning', negCondCollect, 'item');
|
||||
g.addEdge(noise, 'noise', denoise, 'noise');
|
||||
g.addEdge(denoise, 'latents', l2i, 'latents');
|
||||
|
||||
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
|
||||
|
||||
MetadataUtil.add(g, {
|
||||
generation_mode: 'txt2img',
|
||||
cfg_scale,
|
||||
cfg_rescale_multiplier,
|
||||
height,
|
||||
width,
|
||||
positive_prompt: positivePrompt,
|
||||
negative_prompt: negativePrompt,
|
||||
model: getModelMetadataField(modelConfig),
|
||||
seed,
|
||||
steps,
|
||||
rand_device: shouldUseCpuNoise ? 'cpu' : 'cuda',
|
||||
scheduler,
|
||||
clip_skip: skipped_layers,
|
||||
});
|
||||
MetadataUtil.setMetadataReceivingNode(g, l2i);
|
||||
|
||||
const didAddInitialImage = addInitialImageToGenerationTabGraph(state, g, denoise, noise);
|
||||
const terminalModelLoader = addSeamlessToGenerationTabGraph(state, g, denoise, modelLoader);
|
||||
|
||||
// optionally add custom VAE
|
||||
await addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
|
||||
// add LoRA support
|
||||
await addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
|
||||
|
||||
await addControlLayersToGraph(state, graph, DENOISE_LATENTS);
|
||||
|
||||
// High resolution fix.
|
||||
if (state.hrf.hrfEnabled && !didAddInitialImage) {
|
||||
addHrfToGraph(state, graph);
|
||||
}
|
||||
|
||||
// NSFW & watermark - must be last thing added to graph
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
// must add before watermarker!
|
||||
addNSFWCheckerToGraph(state, graph);
|
||||
}
|
||||
|
||||
if (state.system.shouldUseWatermarker) {
|
||||
// must add after nsfw checker!
|
||||
addWatermarkerToGraph(state, graph);
|
||||
}
|
||||
|
||||
return graph;
|
||||
};
|
Loading…
Reference in New Issue
Block a user