fix(ui): upscale tab graph

This commit is contained in:
psychedelicious 2024-08-23 18:23:45 +10:00
parent fadd20fb8e
commit 71e742e238
2 changed files with 57 additions and 54 deletions

View File

@ -14,9 +14,9 @@ export const addEnqueueRequestedUpscale = (startAppListening: AppStartListening)
const { shouldShowProgressInViewer } = state.ui;
const { prepend } = action.payload;
const graph = await buildMultidiffusionUpscaleGraph(state);
const { g, noise, posCond } = await buildMultidiffusionUpscaleGraph(state);
const batchConfig = prepareLinearUIBatch(state, graph, prepend);
const batchConfig = prepareLinearUIBatch(state, g, prepend, noise, posCond);
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {

View File

@ -3,13 +3,16 @@ import { getPrefixedId } from 'features/controlLayers/konva/util';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { addSDXLLoRAs } from 'features/nodes/util/graph/generation/addSDXLLoRAs';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { Invocation } from 'services/api/types';
import { isNonRefinerMainModelConfig, isSpandrelImageToImageModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
import { addLoRAs } from './generation/addLoRAs';
import { getBoardField, getPresetModifiedPrompts } from './graphBuilderUtils';
export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise<Graph> => {
export const buildMultidiffusionUpscaleGraph = async (
state: RootState
): Promise<{ g: Graph; noise: Invocation<'noise'>; posCond: Invocation<'compel' | 'sdxl_compel_prompt'> }> => {
const { model, cfgScale: cfg_scale, scheduler, steps, vaePrecision, seed, vae } = state.canvasV2.params;
const { upscaleModel, upscaleInitialImage, structure, creativity, tileControlnetModel, scale } = state.upscale;
@ -20,7 +23,7 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise
const g = new Graph();
const upscaleNode = g.addNode({
const spandrelAutoscale = g.addNode({
type: 'spandrel_image_to_image_autoscale',
id: getPrefixedId('spandrel_autoscale'),
image: upscaleInitialImage,
@ -29,34 +32,34 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise
scale,
});
const unsharpMaskNode2 = g.addNode({
const unsharpMask = g.addNode({
type: 'unsharp_mask',
id: getPrefixedId('unsharp_2'),
radius: 2,
strength: 60,
});
g.addEdge(upscaleNode, 'image', unsharpMaskNode2, 'image');
g.addEdge(spandrelAutoscale, 'image', unsharpMask, 'image');
const noiseNode = g.addNode({
const noise = g.addNode({
type: 'noise',
id: getPrefixedId('noise'),
seed,
});
g.addEdge(unsharpMaskNode2, 'width', noiseNode, 'width');
g.addEdge(unsharpMaskNode2, 'height', noiseNode, 'height');
g.addEdge(unsharpMask, 'width', noise, 'width');
g.addEdge(unsharpMask, 'height', noise, 'height');
const i2lNode = g.addNode({
const i2l = g.addNode({
type: 'i2l',
id: getPrefixedId('i2l'),
fp32: vaePrecision === 'fp32',
tiled: true,
});
g.addEdge(unsharpMaskNode2, 'image', i2lNode, 'image');
g.addEdge(unsharpMask, 'image', i2l, 'image');
const l2iNode = g.addNode({
const l2i = g.addNode({
type: 'l2i',
id: getPrefixedId('l2i'),
fp32: vaePrecision === 'fp32',
@ -65,7 +68,7 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise
is_intermediate: false,
});
const tiledMultidiffusionNode = g.addNode({
const tiledMultidiffusion = g.addNode({
type: 'tiled_multi_diffusion_denoise_latents',
id: getPrefixedId('tiled_multidiffusion_denoise_latents'),
tile_height: 1024, // is this dependent on base model
@ -78,37 +81,37 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise
denoising_end: 1,
});
let posCondNode;
let negCondNode;
let modelNode;
let posCond;
let negCond;
let modelLoader;
if (model.base === 'sdxl') {
const { positivePrompt, negativePrompt, positiveStylePrompt, negativeStylePrompt } =
getPresetModifiedPrompts(state);
posCondNode = g.addNode({
posCond = g.addNode({
type: 'sdxl_compel_prompt',
id: getPrefixedId('pos_cond'),
prompt: positivePrompt,
style: positiveStylePrompt,
});
negCondNode = g.addNode({
negCond = g.addNode({
type: 'sdxl_compel_prompt',
id: getPrefixedId('neg_cond'),
prompt: negativePrompt,
style: negativeStylePrompt,
});
modelNode = g.addNode({
modelLoader = g.addNode({
type: 'sdxl_model_loader',
id: getPrefixedId('sdxl_model_loader'),
model,
});
g.addEdge(modelNode, 'clip', posCondNode, 'clip');
g.addEdge(modelNode, 'clip', negCondNode, 'clip');
g.addEdge(modelNode, 'clip2', posCondNode, 'clip2');
g.addEdge(modelNode, 'clip2', negCondNode, 'clip2');
g.addEdge(modelNode, 'unet', tiledMultidiffusionNode, 'unet');
addSDXLLoRAs(state, g, tiledMultidiffusionNode, modelNode, null, posCondNode, negCondNode);
g.addEdge(modelLoader, 'clip', posCond, 'clip');
g.addEdge(modelLoader, 'clip', negCond, 'clip');
g.addEdge(modelLoader, 'clip2', posCond, 'clip2');
g.addEdge(modelLoader, 'clip2', negCond, 'clip2');
g.addEdge(modelLoader, 'unet', tiledMultidiffusion, 'unet');
addSDXLLoRAs(state, g, tiledMultidiffusion, modelLoader, null, posCond, negCond);
g.upsertMetadata({
positive_prompt: positivePrompt,
@ -119,17 +122,17 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise
} else {
const { positivePrompt, negativePrompt } = getPresetModifiedPrompts(state);
posCondNode = g.addNode({
posCond = g.addNode({
type: 'compel',
id: getPrefixedId('pos_cond'),
prompt: positivePrompt,
});
negCondNode = g.addNode({
negCond = g.addNode({
type: 'compel',
id: getPrefixedId('neg_cond'),
prompt: negativePrompt,
});
modelNode = g.addNode({
modelLoader = g.addNode({
type: 'main_model_loader',
id: getPrefixedId('sd1_model_loader'),
model,
@ -139,11 +142,11 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise
id: getPrefixedId('clip_skip'),
});
g.addEdge(modelNode, 'clip', clipSkipNode, 'clip');
g.addEdge(clipSkipNode, 'clip', posCondNode, 'clip');
g.addEdge(clipSkipNode, 'clip', negCondNode, 'clip');
g.addEdge(modelNode, 'unet', tiledMultidiffusionNode, 'unet');
addLoRAs(state, g, tiledMultidiffusionNode, modelNode, null, clipSkipNode, posCondNode, negCondNode);
g.addEdge(modelLoader, 'clip', clipSkipNode, 'clip');
g.addEdge(clipSkipNode, 'clip', posCond, 'clip');
g.addEdge(clipSkipNode, 'clip', negCond, 'clip');
g.addEdge(modelLoader, 'unet', tiledMultidiffusion, 'unet');
addLoRAs(state, g, tiledMultidiffusion, modelLoader, null, clipSkipNode, posCond, negCond);
g.upsertMetadata({
positive_prompt: positivePrompt,
@ -172,30 +175,30 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise
upscale_scale: scale,
});
g.setMetadataReceivingNode(l2iNode);
g.addEdgeToMetadata(upscaleNode, 'width', 'width');
g.addEdgeToMetadata(upscaleNode, 'height', 'height');
g.setMetadataReceivingNode(l2i);
g.addEdgeToMetadata(spandrelAutoscale, 'width', 'width');
g.addEdgeToMetadata(spandrelAutoscale, 'height', 'height');
let vaeNode;
let vaeLoader;
if (vae) {
vaeNode = g.addNode({
vaeLoader = g.addNode({
type: 'vae_loader',
id: getPrefixedId('vae'),
vae_model: vae,
});
}
g.addEdge(vaeNode || modelNode, 'vae', i2lNode, 'vae');
g.addEdge(vaeNode || modelNode, 'vae', l2iNode, 'vae');
g.addEdge(vaeLoader || modelLoader, 'vae', i2l, 'vae');
g.addEdge(vaeLoader || modelLoader, 'vae', l2i, 'vae');
g.addEdge(noiseNode, 'noise', tiledMultidiffusionNode, 'noise');
g.addEdge(i2lNode, 'latents', tiledMultidiffusionNode, 'latents');
g.addEdge(posCondNode, 'conditioning', tiledMultidiffusionNode, 'positive_conditioning');
g.addEdge(negCondNode, 'conditioning', tiledMultidiffusionNode, 'negative_conditioning');
g.addEdge(noise, 'noise', tiledMultidiffusion, 'noise');
g.addEdge(i2l, 'latents', tiledMultidiffusion, 'latents');
g.addEdge(posCond, 'conditioning', tiledMultidiffusion, 'positive_conditioning');
g.addEdge(negCond, 'conditioning', tiledMultidiffusion, 'negative_conditioning');
g.addEdge(tiledMultidiffusionNode, 'latents', l2iNode, 'latents');
g.addEdge(tiledMultidiffusion, 'latents', l2i, 'latents');
const controlnetNode1 = g.addNode({
const controlNet1 = g.addNode({
id: 'controlnet_1',
type: 'controlnet',
control_model: tileControlnetModel,
@ -206,9 +209,9 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise
end_step_percent: (structure + 10) * 0.025 + 0.3,
});
g.addEdge(unsharpMaskNode2, 'image', controlnetNode1, 'image');
g.addEdge(unsharpMask, 'image', controlNet1, 'image');
const controlnetNode2 = g.addNode({
const controlNet2 = g.addNode({
id: 'controlnet_2',
type: 'controlnet',
control_model: tileControlnetModel,
@ -219,16 +222,16 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise
end_step_percent: 0.85,
});
g.addEdge(unsharpMaskNode2, 'image', controlnetNode2, 'image');
g.addEdge(unsharpMask, 'image', controlNet2, 'image');
const collectNode = g.addNode({
const controlNetCollector = g.addNode({
type: 'collect',
id: getPrefixedId('controlnet_collector'),
});
g.addEdge(controlnetNode1, 'control', collectNode, 'item');
g.addEdge(controlnetNode2, 'control', collectNode, 'item');
g.addEdge(controlNet1, 'control', controlNetCollector, 'item');
g.addEdge(controlNet2, 'control', controlNetCollector, 'item');
g.addEdge(collectNode, 'collection', tiledMultidiffusionNode, 'control');
g.addEdge(controlNetCollector, 'collection', tiledMultidiffusion, 'control');
return g;
return { g, noise, posCond };
};