fix(ui): upscale tab graph

This commit is contained in:
psychedelicious 2024-08-23 18:23:45 +10:00
parent fadd20fb8e
commit 71e742e238
2 changed files with 57 additions and 54 deletions

View File

@ -14,9 +14,9 @@ export const addEnqueueRequestedUpscale = (startAppListening: AppStartListening)
const { shouldShowProgressInViewer } = state.ui; const { shouldShowProgressInViewer } = state.ui;
const { prepend } = action.payload; const { prepend } = action.payload;
const graph = await buildMultidiffusionUpscaleGraph(state); const { g, noise, posCond } = await buildMultidiffusionUpscaleGraph(state);
const batchConfig = prepareLinearUIBatch(state, graph, prepend); const batchConfig = prepareLinearUIBatch(state, g, prepend, noise, posCond);
const req = dispatch( const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(batchConfig, { queueApi.endpoints.enqueueBatch.initiate(batchConfig, {

View File

@ -3,13 +3,16 @@ import { getPrefixedId } from 'features/controlLayers/konva/util';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { addSDXLLoRAs } from 'features/nodes/util/graph/generation/addSDXLLoRAs'; import { addSDXLLoRAs } from 'features/nodes/util/graph/generation/addSDXLLoRAs';
import { Graph } from 'features/nodes/util/graph/generation/Graph'; import { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { Invocation } from 'services/api/types';
import { isNonRefinerMainModelConfig, isSpandrelImageToImageModelConfig } from 'services/api/types'; import { isNonRefinerMainModelConfig, isSpandrelImageToImageModelConfig } from 'services/api/types';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
import { addLoRAs } from './generation/addLoRAs'; import { addLoRAs } from './generation/addLoRAs';
import { getBoardField, getPresetModifiedPrompts } from './graphBuilderUtils'; import { getBoardField, getPresetModifiedPrompts } from './graphBuilderUtils';
export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise<Graph> => { export const buildMultidiffusionUpscaleGraph = async (
state: RootState
): Promise<{ g: Graph; noise: Invocation<'noise'>; posCond: Invocation<'compel' | 'sdxl_compel_prompt'> }> => {
const { model, cfgScale: cfg_scale, scheduler, steps, vaePrecision, seed, vae } = state.canvasV2.params; const { model, cfgScale: cfg_scale, scheduler, steps, vaePrecision, seed, vae } = state.canvasV2.params;
const { upscaleModel, upscaleInitialImage, structure, creativity, tileControlnetModel, scale } = state.upscale; const { upscaleModel, upscaleInitialImage, structure, creativity, tileControlnetModel, scale } = state.upscale;
@ -20,7 +23,7 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise
const g = new Graph(); const g = new Graph();
const upscaleNode = g.addNode({ const spandrelAutoscale = g.addNode({
type: 'spandrel_image_to_image_autoscale', type: 'spandrel_image_to_image_autoscale',
id: getPrefixedId('spandrel_autoscale'), id: getPrefixedId('spandrel_autoscale'),
image: upscaleInitialImage, image: upscaleInitialImage,
@ -29,34 +32,34 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise
scale, scale,
}); });
const unsharpMaskNode2 = g.addNode({ const unsharpMask = g.addNode({
type: 'unsharp_mask', type: 'unsharp_mask',
id: getPrefixedId('unsharp_2'), id: getPrefixedId('unsharp_2'),
radius: 2, radius: 2,
strength: 60, strength: 60,
}); });
g.addEdge(upscaleNode, 'image', unsharpMaskNode2, 'image'); g.addEdge(spandrelAutoscale, 'image', unsharpMask, 'image');
const noiseNode = g.addNode({ const noise = g.addNode({
type: 'noise', type: 'noise',
id: getPrefixedId('noise'), id: getPrefixedId('noise'),
seed, seed,
}); });
g.addEdge(unsharpMaskNode2, 'width', noiseNode, 'width'); g.addEdge(unsharpMask, 'width', noise, 'width');
g.addEdge(unsharpMaskNode2, 'height', noiseNode, 'height'); g.addEdge(unsharpMask, 'height', noise, 'height');
const i2lNode = g.addNode({ const i2l = g.addNode({
type: 'i2l', type: 'i2l',
id: getPrefixedId('i2l'), id: getPrefixedId('i2l'),
fp32: vaePrecision === 'fp32', fp32: vaePrecision === 'fp32',
tiled: true, tiled: true,
}); });
g.addEdge(unsharpMaskNode2, 'image', i2lNode, 'image'); g.addEdge(unsharpMask, 'image', i2l, 'image');
const l2iNode = g.addNode({ const l2i = g.addNode({
type: 'l2i', type: 'l2i',
id: getPrefixedId('l2i'), id: getPrefixedId('l2i'),
fp32: vaePrecision === 'fp32', fp32: vaePrecision === 'fp32',
@ -65,7 +68,7 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise
is_intermediate: false, is_intermediate: false,
}); });
const tiledMultidiffusionNode = g.addNode({ const tiledMultidiffusion = g.addNode({
type: 'tiled_multi_diffusion_denoise_latents', type: 'tiled_multi_diffusion_denoise_latents',
id: getPrefixedId('tiled_multidiffusion_denoise_latents'), id: getPrefixedId('tiled_multidiffusion_denoise_latents'),
tile_height: 1024, // is this dependent on base model tile_height: 1024, // is this dependent on base model
@ -78,37 +81,37 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise
denoising_end: 1, denoising_end: 1,
}); });
let posCondNode; let posCond;
let negCondNode; let negCond;
let modelNode; let modelLoader;
if (model.base === 'sdxl') { if (model.base === 'sdxl') {
const { positivePrompt, negativePrompt, positiveStylePrompt, negativeStylePrompt } = const { positivePrompt, negativePrompt, positiveStylePrompt, negativeStylePrompt } =
getPresetModifiedPrompts(state); getPresetModifiedPrompts(state);
posCondNode = g.addNode({ posCond = g.addNode({
type: 'sdxl_compel_prompt', type: 'sdxl_compel_prompt',
id: getPrefixedId('pos_cond'), id: getPrefixedId('pos_cond'),
prompt: positivePrompt, prompt: positivePrompt,
style: positiveStylePrompt, style: positiveStylePrompt,
}); });
negCondNode = g.addNode({ negCond = g.addNode({
type: 'sdxl_compel_prompt', type: 'sdxl_compel_prompt',
id: getPrefixedId('neg_cond'), id: getPrefixedId('neg_cond'),
prompt: negativePrompt, prompt: negativePrompt,
style: negativeStylePrompt, style: negativeStylePrompt,
}); });
modelNode = g.addNode({ modelLoader = g.addNode({
type: 'sdxl_model_loader', type: 'sdxl_model_loader',
id: getPrefixedId('sdxl_model_loader'), id: getPrefixedId('sdxl_model_loader'),
model, model,
}); });
g.addEdge(modelNode, 'clip', posCondNode, 'clip'); g.addEdge(modelLoader, 'clip', posCond, 'clip');
g.addEdge(modelNode, 'clip', negCondNode, 'clip'); g.addEdge(modelLoader, 'clip', negCond, 'clip');
g.addEdge(modelNode, 'clip2', posCondNode, 'clip2'); g.addEdge(modelLoader, 'clip2', posCond, 'clip2');
g.addEdge(modelNode, 'clip2', negCondNode, 'clip2'); g.addEdge(modelLoader, 'clip2', negCond, 'clip2');
g.addEdge(modelNode, 'unet', tiledMultidiffusionNode, 'unet'); g.addEdge(modelLoader, 'unet', tiledMultidiffusion, 'unet');
addSDXLLoRAs(state, g, tiledMultidiffusionNode, modelNode, null, posCondNode, negCondNode); addSDXLLoRAs(state, g, tiledMultidiffusion, modelLoader, null, posCond, negCond);
g.upsertMetadata({ g.upsertMetadata({
positive_prompt: positivePrompt, positive_prompt: positivePrompt,
@ -119,17 +122,17 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise
} else { } else {
const { positivePrompt, negativePrompt } = getPresetModifiedPrompts(state); const { positivePrompt, negativePrompt } = getPresetModifiedPrompts(state);
posCondNode = g.addNode({ posCond = g.addNode({
type: 'compel', type: 'compel',
id: getPrefixedId('pos_cond'), id: getPrefixedId('pos_cond'),
prompt: positivePrompt, prompt: positivePrompt,
}); });
negCondNode = g.addNode({ negCond = g.addNode({
type: 'compel', type: 'compel',
id: getPrefixedId('neg_cond'), id: getPrefixedId('neg_cond'),
prompt: negativePrompt, prompt: negativePrompt,
}); });
modelNode = g.addNode({ modelLoader = g.addNode({
type: 'main_model_loader', type: 'main_model_loader',
id: getPrefixedId('sd1_model_loader'), id: getPrefixedId('sd1_model_loader'),
model, model,
@ -139,11 +142,11 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise
id: getPrefixedId('clip_skip'), id: getPrefixedId('clip_skip'),
}); });
g.addEdge(modelNode, 'clip', clipSkipNode, 'clip'); g.addEdge(modelLoader, 'clip', clipSkipNode, 'clip');
g.addEdge(clipSkipNode, 'clip', posCondNode, 'clip'); g.addEdge(clipSkipNode, 'clip', posCond, 'clip');
g.addEdge(clipSkipNode, 'clip', negCondNode, 'clip'); g.addEdge(clipSkipNode, 'clip', negCond, 'clip');
g.addEdge(modelNode, 'unet', tiledMultidiffusionNode, 'unet'); g.addEdge(modelLoader, 'unet', tiledMultidiffusion, 'unet');
addLoRAs(state, g, tiledMultidiffusionNode, modelNode, null, clipSkipNode, posCondNode, negCondNode); addLoRAs(state, g, tiledMultidiffusion, modelLoader, null, clipSkipNode, posCond, negCond);
g.upsertMetadata({ g.upsertMetadata({
positive_prompt: positivePrompt, positive_prompt: positivePrompt,
@ -172,30 +175,30 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise
upscale_scale: scale, upscale_scale: scale,
}); });
g.setMetadataReceivingNode(l2iNode); g.setMetadataReceivingNode(l2i);
g.addEdgeToMetadata(upscaleNode, 'width', 'width'); g.addEdgeToMetadata(spandrelAutoscale, 'width', 'width');
g.addEdgeToMetadata(upscaleNode, 'height', 'height'); g.addEdgeToMetadata(spandrelAutoscale, 'height', 'height');
let vaeNode; let vaeLoader;
if (vae) { if (vae) {
vaeNode = g.addNode({ vaeLoader = g.addNode({
type: 'vae_loader', type: 'vae_loader',
id: getPrefixedId('vae'), id: getPrefixedId('vae'),
vae_model: vae, vae_model: vae,
}); });
} }
g.addEdge(vaeNode || modelNode, 'vae', i2lNode, 'vae'); g.addEdge(vaeLoader || modelLoader, 'vae', i2l, 'vae');
g.addEdge(vaeNode || modelNode, 'vae', l2iNode, 'vae'); g.addEdge(vaeLoader || modelLoader, 'vae', l2i, 'vae');
g.addEdge(noiseNode, 'noise', tiledMultidiffusionNode, 'noise'); g.addEdge(noise, 'noise', tiledMultidiffusion, 'noise');
g.addEdge(i2lNode, 'latents', tiledMultidiffusionNode, 'latents'); g.addEdge(i2l, 'latents', tiledMultidiffusion, 'latents');
g.addEdge(posCondNode, 'conditioning', tiledMultidiffusionNode, 'positive_conditioning'); g.addEdge(posCond, 'conditioning', tiledMultidiffusion, 'positive_conditioning');
g.addEdge(negCondNode, 'conditioning', tiledMultidiffusionNode, 'negative_conditioning'); g.addEdge(negCond, 'conditioning', tiledMultidiffusion, 'negative_conditioning');
g.addEdge(tiledMultidiffusionNode, 'latents', l2iNode, 'latents'); g.addEdge(tiledMultidiffusion, 'latents', l2i, 'latents');
const controlnetNode1 = g.addNode({ const controlNet1 = g.addNode({
id: 'controlnet_1', id: 'controlnet_1',
type: 'controlnet', type: 'controlnet',
control_model: tileControlnetModel, control_model: tileControlnetModel,
@ -206,9 +209,9 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise
end_step_percent: (structure + 10) * 0.025 + 0.3, end_step_percent: (structure + 10) * 0.025 + 0.3,
}); });
g.addEdge(unsharpMaskNode2, 'image', controlnetNode1, 'image'); g.addEdge(unsharpMask, 'image', controlNet1, 'image');
const controlnetNode2 = g.addNode({ const controlNet2 = g.addNode({
id: 'controlnet_2', id: 'controlnet_2',
type: 'controlnet', type: 'controlnet',
control_model: tileControlnetModel, control_model: tileControlnetModel,
@ -219,16 +222,16 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise
end_step_percent: 0.85, end_step_percent: 0.85,
}); });
g.addEdge(unsharpMaskNode2, 'image', controlnetNode2, 'image'); g.addEdge(unsharpMask, 'image', controlNet2, 'image');
const collectNode = g.addNode({ const controlNetCollector = g.addNode({
type: 'collect', type: 'collect',
id: getPrefixedId('controlnet_collector'), id: getPrefixedId('controlnet_collector'),
}); });
g.addEdge(controlnetNode1, 'control', collectNode, 'item'); g.addEdge(controlNet1, 'control', controlNetCollector, 'item');
g.addEdge(controlnetNode2, 'control', collectNode, 'item'); g.addEdge(controlNet2, 'control', controlNetCollector, 'item');
g.addEdge(collectNode, 'collection', tiledMultidiffusionNode, 'control'); g.addEdge(controlNetCollector, 'collection', tiledMultidiffusion, 'control');
return g; return { g, noise, posCond };
}; };