mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(ui): upscale tab graph
This commit is contained in:
parent
fadd20fb8e
commit
71e742e238
@ -14,9 +14,9 @@ export const addEnqueueRequestedUpscale = (startAppListening: AppStartListening)
|
|||||||
const { shouldShowProgressInViewer } = state.ui;
|
const { shouldShowProgressInViewer } = state.ui;
|
||||||
const { prepend } = action.payload;
|
const { prepend } = action.payload;
|
||||||
|
|
||||||
const graph = await buildMultidiffusionUpscaleGraph(state);
|
const { g, noise, posCond } = await buildMultidiffusionUpscaleGraph(state);
|
||||||
|
|
||||||
const batchConfig = prepareLinearUIBatch(state, graph, prepend);
|
const batchConfig = prepareLinearUIBatch(state, g, prepend, noise, posCond);
|
||||||
|
|
||||||
const req = dispatch(
|
const req = dispatch(
|
||||||
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
|
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
|
||||||
|
@ -3,13 +3,16 @@ import { getPrefixedId } from 'features/controlLayers/konva/util';
|
|||||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||||
import { addSDXLLoRAs } from 'features/nodes/util/graph/generation/addSDXLLoRAs';
|
import { addSDXLLoRAs } from 'features/nodes/util/graph/generation/addSDXLLoRAs';
|
||||||
import { Graph } from 'features/nodes/util/graph/generation/Graph';
|
import { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||||
|
import type { Invocation } from 'services/api/types';
|
||||||
import { isNonRefinerMainModelConfig, isSpandrelImageToImageModelConfig } from 'services/api/types';
|
import { isNonRefinerMainModelConfig, isSpandrelImageToImageModelConfig } from 'services/api/types';
|
||||||
import { assert } from 'tsafe';
|
import { assert } from 'tsafe';
|
||||||
|
|
||||||
import { addLoRAs } from './generation/addLoRAs';
|
import { addLoRAs } from './generation/addLoRAs';
|
||||||
import { getBoardField, getPresetModifiedPrompts } from './graphBuilderUtils';
|
import { getBoardField, getPresetModifiedPrompts } from './graphBuilderUtils';
|
||||||
|
|
||||||
export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise<Graph> => {
|
export const buildMultidiffusionUpscaleGraph = async (
|
||||||
|
state: RootState
|
||||||
|
): Promise<{ g: Graph; noise: Invocation<'noise'>; posCond: Invocation<'compel' | 'sdxl_compel_prompt'> }> => {
|
||||||
const { model, cfgScale: cfg_scale, scheduler, steps, vaePrecision, seed, vae } = state.canvasV2.params;
|
const { model, cfgScale: cfg_scale, scheduler, steps, vaePrecision, seed, vae } = state.canvasV2.params;
|
||||||
const { upscaleModel, upscaleInitialImage, structure, creativity, tileControlnetModel, scale } = state.upscale;
|
const { upscaleModel, upscaleInitialImage, structure, creativity, tileControlnetModel, scale } = state.upscale;
|
||||||
|
|
||||||
@ -20,7 +23,7 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise
|
|||||||
|
|
||||||
const g = new Graph();
|
const g = new Graph();
|
||||||
|
|
||||||
const upscaleNode = g.addNode({
|
const spandrelAutoscale = g.addNode({
|
||||||
type: 'spandrel_image_to_image_autoscale',
|
type: 'spandrel_image_to_image_autoscale',
|
||||||
id: getPrefixedId('spandrel_autoscale'),
|
id: getPrefixedId('spandrel_autoscale'),
|
||||||
image: upscaleInitialImage,
|
image: upscaleInitialImage,
|
||||||
@ -29,34 +32,34 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise
|
|||||||
scale,
|
scale,
|
||||||
});
|
});
|
||||||
|
|
||||||
const unsharpMaskNode2 = g.addNode({
|
const unsharpMask = g.addNode({
|
||||||
type: 'unsharp_mask',
|
type: 'unsharp_mask',
|
||||||
id: getPrefixedId('unsharp_2'),
|
id: getPrefixedId('unsharp_2'),
|
||||||
radius: 2,
|
radius: 2,
|
||||||
strength: 60,
|
strength: 60,
|
||||||
});
|
});
|
||||||
|
|
||||||
g.addEdge(upscaleNode, 'image', unsharpMaskNode2, 'image');
|
g.addEdge(spandrelAutoscale, 'image', unsharpMask, 'image');
|
||||||
|
|
||||||
const noiseNode = g.addNode({
|
const noise = g.addNode({
|
||||||
type: 'noise',
|
type: 'noise',
|
||||||
id: getPrefixedId('noise'),
|
id: getPrefixedId('noise'),
|
||||||
seed,
|
seed,
|
||||||
});
|
});
|
||||||
|
|
||||||
g.addEdge(unsharpMaskNode2, 'width', noiseNode, 'width');
|
g.addEdge(unsharpMask, 'width', noise, 'width');
|
||||||
g.addEdge(unsharpMaskNode2, 'height', noiseNode, 'height');
|
g.addEdge(unsharpMask, 'height', noise, 'height');
|
||||||
|
|
||||||
const i2lNode = g.addNode({
|
const i2l = g.addNode({
|
||||||
type: 'i2l',
|
type: 'i2l',
|
||||||
id: getPrefixedId('i2l'),
|
id: getPrefixedId('i2l'),
|
||||||
fp32: vaePrecision === 'fp32',
|
fp32: vaePrecision === 'fp32',
|
||||||
tiled: true,
|
tiled: true,
|
||||||
});
|
});
|
||||||
|
|
||||||
g.addEdge(unsharpMaskNode2, 'image', i2lNode, 'image');
|
g.addEdge(unsharpMask, 'image', i2l, 'image');
|
||||||
|
|
||||||
const l2iNode = g.addNode({
|
const l2i = g.addNode({
|
||||||
type: 'l2i',
|
type: 'l2i',
|
||||||
id: getPrefixedId('l2i'),
|
id: getPrefixedId('l2i'),
|
||||||
fp32: vaePrecision === 'fp32',
|
fp32: vaePrecision === 'fp32',
|
||||||
@ -65,7 +68,7 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise
|
|||||||
is_intermediate: false,
|
is_intermediate: false,
|
||||||
});
|
});
|
||||||
|
|
||||||
const tiledMultidiffusionNode = g.addNode({
|
const tiledMultidiffusion = g.addNode({
|
||||||
type: 'tiled_multi_diffusion_denoise_latents',
|
type: 'tiled_multi_diffusion_denoise_latents',
|
||||||
id: getPrefixedId('tiled_multidiffusion_denoise_latents'),
|
id: getPrefixedId('tiled_multidiffusion_denoise_latents'),
|
||||||
tile_height: 1024, // is this dependent on base model
|
tile_height: 1024, // is this dependent on base model
|
||||||
@ -78,37 +81,37 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise
|
|||||||
denoising_end: 1,
|
denoising_end: 1,
|
||||||
});
|
});
|
||||||
|
|
||||||
let posCondNode;
|
let posCond;
|
||||||
let negCondNode;
|
let negCond;
|
||||||
let modelNode;
|
let modelLoader;
|
||||||
|
|
||||||
if (model.base === 'sdxl') {
|
if (model.base === 'sdxl') {
|
||||||
const { positivePrompt, negativePrompt, positiveStylePrompt, negativeStylePrompt } =
|
const { positivePrompt, negativePrompt, positiveStylePrompt, negativeStylePrompt } =
|
||||||
getPresetModifiedPrompts(state);
|
getPresetModifiedPrompts(state);
|
||||||
|
|
||||||
posCondNode = g.addNode({
|
posCond = g.addNode({
|
||||||
type: 'sdxl_compel_prompt',
|
type: 'sdxl_compel_prompt',
|
||||||
id: getPrefixedId('pos_cond'),
|
id: getPrefixedId('pos_cond'),
|
||||||
prompt: positivePrompt,
|
prompt: positivePrompt,
|
||||||
style: positiveStylePrompt,
|
style: positiveStylePrompt,
|
||||||
});
|
});
|
||||||
negCondNode = g.addNode({
|
negCond = g.addNode({
|
||||||
type: 'sdxl_compel_prompt',
|
type: 'sdxl_compel_prompt',
|
||||||
id: getPrefixedId('neg_cond'),
|
id: getPrefixedId('neg_cond'),
|
||||||
prompt: negativePrompt,
|
prompt: negativePrompt,
|
||||||
style: negativeStylePrompt,
|
style: negativeStylePrompt,
|
||||||
});
|
});
|
||||||
modelNode = g.addNode({
|
modelLoader = g.addNode({
|
||||||
type: 'sdxl_model_loader',
|
type: 'sdxl_model_loader',
|
||||||
id: getPrefixedId('sdxl_model_loader'),
|
id: getPrefixedId('sdxl_model_loader'),
|
||||||
model,
|
model,
|
||||||
});
|
});
|
||||||
g.addEdge(modelNode, 'clip', posCondNode, 'clip');
|
g.addEdge(modelLoader, 'clip', posCond, 'clip');
|
||||||
g.addEdge(modelNode, 'clip', negCondNode, 'clip');
|
g.addEdge(modelLoader, 'clip', negCond, 'clip');
|
||||||
g.addEdge(modelNode, 'clip2', posCondNode, 'clip2');
|
g.addEdge(modelLoader, 'clip2', posCond, 'clip2');
|
||||||
g.addEdge(modelNode, 'clip2', negCondNode, 'clip2');
|
g.addEdge(modelLoader, 'clip2', negCond, 'clip2');
|
||||||
g.addEdge(modelNode, 'unet', tiledMultidiffusionNode, 'unet');
|
g.addEdge(modelLoader, 'unet', tiledMultidiffusion, 'unet');
|
||||||
addSDXLLoRAs(state, g, tiledMultidiffusionNode, modelNode, null, posCondNode, negCondNode);
|
addSDXLLoRAs(state, g, tiledMultidiffusion, modelLoader, null, posCond, negCond);
|
||||||
|
|
||||||
g.upsertMetadata({
|
g.upsertMetadata({
|
||||||
positive_prompt: positivePrompt,
|
positive_prompt: positivePrompt,
|
||||||
@ -119,17 +122,17 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise
|
|||||||
} else {
|
} else {
|
||||||
const { positivePrompt, negativePrompt } = getPresetModifiedPrompts(state);
|
const { positivePrompt, negativePrompt } = getPresetModifiedPrompts(state);
|
||||||
|
|
||||||
posCondNode = g.addNode({
|
posCond = g.addNode({
|
||||||
type: 'compel',
|
type: 'compel',
|
||||||
id: getPrefixedId('pos_cond'),
|
id: getPrefixedId('pos_cond'),
|
||||||
prompt: positivePrompt,
|
prompt: positivePrompt,
|
||||||
});
|
});
|
||||||
negCondNode = g.addNode({
|
negCond = g.addNode({
|
||||||
type: 'compel',
|
type: 'compel',
|
||||||
id: getPrefixedId('neg_cond'),
|
id: getPrefixedId('neg_cond'),
|
||||||
prompt: negativePrompt,
|
prompt: negativePrompt,
|
||||||
});
|
});
|
||||||
modelNode = g.addNode({
|
modelLoader = g.addNode({
|
||||||
type: 'main_model_loader',
|
type: 'main_model_loader',
|
||||||
id: getPrefixedId('sd1_model_loader'),
|
id: getPrefixedId('sd1_model_loader'),
|
||||||
model,
|
model,
|
||||||
@ -139,11 +142,11 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise
|
|||||||
id: getPrefixedId('clip_skip'),
|
id: getPrefixedId('clip_skip'),
|
||||||
});
|
});
|
||||||
|
|
||||||
g.addEdge(modelNode, 'clip', clipSkipNode, 'clip');
|
g.addEdge(modelLoader, 'clip', clipSkipNode, 'clip');
|
||||||
g.addEdge(clipSkipNode, 'clip', posCondNode, 'clip');
|
g.addEdge(clipSkipNode, 'clip', posCond, 'clip');
|
||||||
g.addEdge(clipSkipNode, 'clip', negCondNode, 'clip');
|
g.addEdge(clipSkipNode, 'clip', negCond, 'clip');
|
||||||
g.addEdge(modelNode, 'unet', tiledMultidiffusionNode, 'unet');
|
g.addEdge(modelLoader, 'unet', tiledMultidiffusion, 'unet');
|
||||||
addLoRAs(state, g, tiledMultidiffusionNode, modelNode, null, clipSkipNode, posCondNode, negCondNode);
|
addLoRAs(state, g, tiledMultidiffusion, modelLoader, null, clipSkipNode, posCond, negCond);
|
||||||
|
|
||||||
g.upsertMetadata({
|
g.upsertMetadata({
|
||||||
positive_prompt: positivePrompt,
|
positive_prompt: positivePrompt,
|
||||||
@ -172,30 +175,30 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise
|
|||||||
upscale_scale: scale,
|
upscale_scale: scale,
|
||||||
});
|
});
|
||||||
|
|
||||||
g.setMetadataReceivingNode(l2iNode);
|
g.setMetadataReceivingNode(l2i);
|
||||||
g.addEdgeToMetadata(upscaleNode, 'width', 'width');
|
g.addEdgeToMetadata(spandrelAutoscale, 'width', 'width');
|
||||||
g.addEdgeToMetadata(upscaleNode, 'height', 'height');
|
g.addEdgeToMetadata(spandrelAutoscale, 'height', 'height');
|
||||||
|
|
||||||
let vaeNode;
|
let vaeLoader;
|
||||||
if (vae) {
|
if (vae) {
|
||||||
vaeNode = g.addNode({
|
vaeLoader = g.addNode({
|
||||||
type: 'vae_loader',
|
type: 'vae_loader',
|
||||||
id: getPrefixedId('vae'),
|
id: getPrefixedId('vae'),
|
||||||
vae_model: vae,
|
vae_model: vae,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
g.addEdge(vaeNode || modelNode, 'vae', i2lNode, 'vae');
|
g.addEdge(vaeLoader || modelLoader, 'vae', i2l, 'vae');
|
||||||
g.addEdge(vaeNode || modelNode, 'vae', l2iNode, 'vae');
|
g.addEdge(vaeLoader || modelLoader, 'vae', l2i, 'vae');
|
||||||
|
|
||||||
g.addEdge(noiseNode, 'noise', tiledMultidiffusionNode, 'noise');
|
g.addEdge(noise, 'noise', tiledMultidiffusion, 'noise');
|
||||||
g.addEdge(i2lNode, 'latents', tiledMultidiffusionNode, 'latents');
|
g.addEdge(i2l, 'latents', tiledMultidiffusion, 'latents');
|
||||||
g.addEdge(posCondNode, 'conditioning', tiledMultidiffusionNode, 'positive_conditioning');
|
g.addEdge(posCond, 'conditioning', tiledMultidiffusion, 'positive_conditioning');
|
||||||
g.addEdge(negCondNode, 'conditioning', tiledMultidiffusionNode, 'negative_conditioning');
|
g.addEdge(negCond, 'conditioning', tiledMultidiffusion, 'negative_conditioning');
|
||||||
|
|
||||||
g.addEdge(tiledMultidiffusionNode, 'latents', l2iNode, 'latents');
|
g.addEdge(tiledMultidiffusion, 'latents', l2i, 'latents');
|
||||||
|
|
||||||
const controlnetNode1 = g.addNode({
|
const controlNet1 = g.addNode({
|
||||||
id: 'controlnet_1',
|
id: 'controlnet_1',
|
||||||
type: 'controlnet',
|
type: 'controlnet',
|
||||||
control_model: tileControlnetModel,
|
control_model: tileControlnetModel,
|
||||||
@ -206,9 +209,9 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise
|
|||||||
end_step_percent: (structure + 10) * 0.025 + 0.3,
|
end_step_percent: (structure + 10) * 0.025 + 0.3,
|
||||||
});
|
});
|
||||||
|
|
||||||
g.addEdge(unsharpMaskNode2, 'image', controlnetNode1, 'image');
|
g.addEdge(unsharpMask, 'image', controlNet1, 'image');
|
||||||
|
|
||||||
const controlnetNode2 = g.addNode({
|
const controlNet2 = g.addNode({
|
||||||
id: 'controlnet_2',
|
id: 'controlnet_2',
|
||||||
type: 'controlnet',
|
type: 'controlnet',
|
||||||
control_model: tileControlnetModel,
|
control_model: tileControlnetModel,
|
||||||
@ -219,16 +222,16 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise
|
|||||||
end_step_percent: 0.85,
|
end_step_percent: 0.85,
|
||||||
});
|
});
|
||||||
|
|
||||||
g.addEdge(unsharpMaskNode2, 'image', controlnetNode2, 'image');
|
g.addEdge(unsharpMask, 'image', controlNet2, 'image');
|
||||||
|
|
||||||
const collectNode = g.addNode({
|
const controlNetCollector = g.addNode({
|
||||||
type: 'collect',
|
type: 'collect',
|
||||||
id: getPrefixedId('controlnet_collector'),
|
id: getPrefixedId('controlnet_collector'),
|
||||||
});
|
});
|
||||||
g.addEdge(controlnetNode1, 'control', collectNode, 'item');
|
g.addEdge(controlNet1, 'control', controlNetCollector, 'item');
|
||||||
g.addEdge(controlnetNode2, 'control', collectNode, 'item');
|
g.addEdge(controlNet2, 'control', controlNetCollector, 'item');
|
||||||
|
|
||||||
g.addEdge(collectNode, 'collection', tiledMultidiffusionNode, 'control');
|
g.addEdge(controlNetCollector, 'collection', tiledMultidiffusion, 'control');
|
||||||
|
|
||||||
return g;
|
return { g, noise, posCond };
|
||||||
};
|
};
|
||||||
|
Loading…
Reference in New Issue
Block a user