fix(ui): missing vae precision in graph builders

This commit is contained in:
psychedelicious 2024-08-24 18:35:30 +10:00
parent 6af5a22d3a
commit a9032a34f2
5 changed files with 32 additions and 26 deletions

View File

@ -14,7 +14,8 @@ export const addImageToImage = async (
originalSize: Dimensions, originalSize: Dimensions,
scaledSize: Dimensions, scaledSize: Dimensions,
bbox: CanvasV2State['bbox'], bbox: CanvasV2State['bbox'],
denoising_start: number denoising_start: number,
fp32: boolean
): Promise<Invocation<'img_resize' | 'l2i'>> => { ): Promise<Invocation<'img_resize' | 'l2i'>> => {
denoise.denoising_start = denoising_start; denoise.denoising_start = denoising_start;
@ -28,7 +29,7 @@ export const addImageToImage = async (
image: { image_name }, image: { image_name },
...scaledSize, ...scaledSize,
}); });
const i2l = g.addNode({ id: 'i2l', type: 'i2l' }); const i2l = g.addNode({ id: 'i2l', type: 'i2l', fp32 });
const resizeImageToOriginalSize = g.addNode({ const resizeImageToOriginalSize = g.addNode({
type: 'img_resize', type: 'img_resize',
id: getPrefixedId('initial_image_resize_out'), id: getPrefixedId('initial_image_resize_out'),
@ -43,8 +44,8 @@ export const addImageToImage = async (
// This is the new output node // This is the new output node
return resizeImageToOriginalSize; return resizeImageToOriginalSize;
} else { } else {
// No need to resize, just denoise // No need to resize, just decode
const i2l = g.addNode({ id: 'i2l', type: 'i2l', image: { image_name } }); const i2l = g.addNode({ id: 'i2l', type: 'i2l', image: { image_name }, fp32 });
g.addEdge(vaeSource, 'vae', i2l, 'vae'); g.addEdge(vaeSource, 'vae', i2l, 'vae');
g.addEdge(i2l, 'latents', denoise, 'latents'); g.addEdge(i2l, 'latents', denoise, 'latents');
return l2i; return l2i;

View File

@ -3,7 +3,6 @@ import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { getPrefixedId } from 'features/controlLayers/konva/util'; import { getPrefixedId } from 'features/controlLayers/konva/util';
import type { CanvasV2State, Dimensions } from 'features/controlLayers/store/types'; import type { CanvasV2State, Dimensions } from 'features/controlLayers/store/types';
import type { Graph } from 'features/nodes/util/graph/generation/Graph'; import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { ParameterPrecision } from 'features/parameters/types/parameterSchemas';
import { isEqual } from 'lodash-es'; import { isEqual } from 'lodash-es';
import type { Invocation } from 'services/api/types'; import type { Invocation } from 'services/api/types';
@ -20,7 +19,7 @@ export const addInpaint = async (
bbox: CanvasV2State['bbox'], bbox: CanvasV2State['bbox'],
compositing: CanvasV2State['compositing'], compositing: CanvasV2State['compositing'],
denoising_start: number, denoising_start: number,
vaePrecision: ParameterPrecision fp32: boolean
): Promise<Invocation<'canvas_v2_mask_and_crop'>> => { ): Promise<Invocation<'canvas_v2_mask_and_crop'>> => {
denoise.denoising_start = denoising_start; denoise.denoising_start = denoising_start;
@ -30,7 +29,7 @@ export const addInpaint = async (
if (!isEqual(scaledSize, originalSize)) { if (!isEqual(scaledSize, originalSize)) {
// Scale before processing requires some resizing // Scale before processing requires some resizing
const i2l = g.addNode({ id: getPrefixedId('i2l'), type: 'i2l' }); const i2l = g.addNode({ id: getPrefixedId('i2l'), type: 'i2l', fp32 });
const resizeImageToScaledSize = g.addNode({ const resizeImageToScaledSize = g.addNode({
type: 'img_resize', type: 'img_resize',
id: getPrefixedId('resize_image_to_scaled_size'), id: getPrefixedId('resize_image_to_scaled_size'),
@ -64,7 +63,7 @@ export const addInpaint = async (
coherence_mode: compositing.canvasCoherenceMode, coherence_mode: compositing.canvasCoherenceMode,
minimum_denoise: compositing.canvasCoherenceMinDenoise, minimum_denoise: compositing.canvasCoherenceMinDenoise,
edge_radius: compositing.canvasCoherenceEdgeSize, edge_radius: compositing.canvasCoherenceEdgeSize,
fp32: vaePrecision === 'fp32', fp32,
}); });
const canvasPasteBack = g.addNode({ const canvasPasteBack = g.addNode({
id: getPrefixedId('canvas_v2_mask_and_crop'), id: getPrefixedId('canvas_v2_mask_and_crop'),
@ -100,7 +99,12 @@ export const addInpaint = async (
return canvasPasteBack; return canvasPasteBack;
} else { } else {
// No scale before processing, much simpler // No scale before processing, much simpler
const i2l = g.addNode({ id: getPrefixedId('i2l'), type: 'i2l', image: { image_name: initialImage.image_name } }); const i2l = g.addNode({
id: getPrefixedId('i2l'),
type: 'i2l',
image: { image_name: initialImage.image_name },
fp32,
});
const alphaToMask = g.addNode({ const alphaToMask = g.addNode({
id: getPrefixedId('alpha_to_mask'), id: getPrefixedId('alpha_to_mask'),
type: 'tomask', type: 'tomask',
@ -113,7 +117,7 @@ export const addInpaint = async (
coherence_mode: compositing.canvasCoherenceMode, coherence_mode: compositing.canvasCoherenceMode,
minimum_denoise: compositing.canvasCoherenceMinDenoise, minimum_denoise: compositing.canvasCoherenceMinDenoise,
edge_radius: compositing.canvasCoherenceEdgeSize, edge_radius: compositing.canvasCoherenceEdgeSize,
fp32: vaePrecision === 'fp32', fp32,
image: { image_name: initialImage.image_name }, image: { image_name: initialImage.image_name },
}); });
const canvasPasteBack = g.addNode({ const canvasPasteBack = g.addNode({

View File

@ -4,7 +4,6 @@ import { getPrefixedId } from 'features/controlLayers/konva/util';
import type { CanvasV2State, Dimensions } from 'features/controlLayers/store/types'; import type { CanvasV2State, Dimensions } from 'features/controlLayers/store/types';
import type { Graph } from 'features/nodes/util/graph/generation/Graph'; import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import { getInfill } from 'features/nodes/util/graph/graphBuilderUtils'; import { getInfill } from 'features/nodes/util/graph/graphBuilderUtils';
import type { ParameterPrecision } from 'features/parameters/types/parameterSchemas';
import { isEqual } from 'lodash-es'; import { isEqual } from 'lodash-es';
import type { Invocation } from 'services/api/types'; import type { Invocation } from 'services/api/types';
@ -21,7 +20,7 @@ export const addOutpaint = async (
bbox: CanvasV2State['bbox'], bbox: CanvasV2State['bbox'],
compositing: CanvasV2State['compositing'], compositing: CanvasV2State['compositing'],
denoising_start: number, denoising_start: number,
vaePrecision: ParameterPrecision fp32: boolean
): Promise<Invocation<'canvas_v2_mask_and_crop'>> => { ): Promise<Invocation<'canvas_v2_mask_and_crop'>> => {
denoise.denoising_start = denoising_start; denoise.denoising_start = denoising_start;
@ -76,7 +75,7 @@ export const addOutpaint = async (
coherence_mode: compositing.canvasCoherenceMode, coherence_mode: compositing.canvasCoherenceMode,
minimum_denoise: compositing.canvasCoherenceMinDenoise, minimum_denoise: compositing.canvasCoherenceMinDenoise,
edge_radius: compositing.canvasCoherenceEdgeSize, edge_radius: compositing.canvasCoherenceEdgeSize,
fp32: vaePrecision === 'fp32', fp32,
}); });
g.addEdge(infill, 'image', createGradientMask, 'image'); g.addEdge(infill, 'image', createGradientMask, 'image');
g.addEdge(resizeInputMaskToScaledSize, 'image', createGradientMask, 'mask'); g.addEdge(resizeInputMaskToScaledSize, 'image', createGradientMask, 'mask');
@ -85,7 +84,7 @@ export const addOutpaint = async (
g.addEdge(createGradientMask, 'denoise_mask', denoise, 'denoise_mask'); g.addEdge(createGradientMask, 'denoise_mask', denoise, 'denoise_mask');
// Decode infilled image and connect to denoise // Decode infilled image and connect to denoise
const i2l = g.addNode({ id: getPrefixedId('i2l'), type: 'i2l' }); const i2l = g.addNode({ id: getPrefixedId('i2l'), type: 'i2l', fp32 });
g.addEdge(infill, 'image', i2l, 'image'); g.addEdge(infill, 'image', i2l, 'image');
g.addEdge(vaeSource, 'vae', i2l, 'vae'); g.addEdge(vaeSource, 'vae', i2l, 'vae');
g.addEdge(i2l, 'latents', denoise, 'latents'); g.addEdge(i2l, 'latents', denoise, 'latents');
@ -125,7 +124,7 @@ export const addOutpaint = async (
} else { } else {
infill.image = { image_name: initialImage.image_name }; infill.image = { image_name: initialImage.image_name };
// No scale before processing, much simpler // No scale before processing, much simpler
const i2l = g.addNode({ id: getPrefixedId('i2l'), type: 'i2l' }); const i2l = g.addNode({ id: getPrefixedId('i2l'), type: 'i2l', fp32 });
const maskAlphaToMask = g.addNode({ const maskAlphaToMask = g.addNode({
id: getPrefixedId('mask_alpha_to_mask'), id: getPrefixedId('mask_alpha_to_mask'),
type: 'tomask', type: 'tomask',
@ -147,7 +146,7 @@ export const addOutpaint = async (
coherence_mode: compositing.canvasCoherenceMode, coherence_mode: compositing.canvasCoherenceMode,
minimum_denoise: compositing.canvasCoherenceMinDenoise, minimum_denoise: compositing.canvasCoherenceMinDenoise,
edge_radius: compositing.canvasCoherenceEdgeSize, edge_radius: compositing.canvasCoherenceEdgeSize,
fp32: vaePrecision === 'fp32', fp32,
image: { image_name: initialImage.image_name }, image: { image_name: initialImage.image_name },
}); });
const canvasPasteBack = g.addNode({ const canvasPasteBack = g.addNode({

View File

@ -48,8 +48,8 @@ export const buildSD1Graph = async (
assert(model, 'No model found in state'); assert(model, 'No model found in state');
const fp32 = vaePrecision === 'fp32';
const { positivePrompt, negativePrompt } = getPresetModifiedPrompts(state); const { positivePrompt, negativePrompt } = getPresetModifiedPrompts(state);
const { originalSize, scaledSize } = getSizes(bbox); const { originalSize, scaledSize } = getSizes(bbox);
const g = new Graph(getPrefixedId('sd1_graph')); const g = new Graph(getPrefixedId('sd1_graph'));
@ -102,7 +102,7 @@ export const buildSD1Graph = async (
const l2i = g.addNode({ const l2i = g.addNode({
type: 'l2i', type: 'l2i',
id: getPrefixedId('l2i'), id: getPrefixedId('l2i'),
fp32: vaePrecision === 'fp32', fp32,
}); });
const vaeLoader = const vaeLoader =
vae?.base === model.base vae?.base === model.base
@ -168,7 +168,8 @@ export const buildSD1Graph = async (
originalSize, originalSize,
scaledSize, scaledSize,
bbox, bbox,
1 - params.img2imgStrength 1 - params.img2imgStrength,
vaePrecision === 'fp32'
); );
} else if (generationMode === 'inpaint') { } else if (generationMode === 'inpaint') {
const { compositing } = state.canvasV2; const { compositing } = state.canvasV2;
@ -185,7 +186,7 @@ export const buildSD1Graph = async (
bbox, bbox,
compositing, compositing,
1 - params.img2imgStrength, 1 - params.img2imgStrength,
vaePrecision vaePrecision === 'fp32'
); );
} else if (generationMode === 'outpaint') { } else if (generationMode === 'outpaint') {
const { compositing } = state.canvasV2; const { compositing } = state.canvasV2;
@ -202,7 +203,7 @@ export const buildSD1Graph = async (
bbox, bbox,
compositing, compositing,
1 - params.img2imgStrength, 1 - params.img2imgStrength,
vaePrecision fp32
); );
} }

View File

@ -49,8 +49,8 @@ export const buildSDXLGraph = async (
assert(model, 'No model found in state'); assert(model, 'No model found in state');
const fp32 = vaePrecision === 'fp32';
const { originalSize, scaledSize } = getSizes(bbox); const { originalSize, scaledSize } = getSizes(bbox);
const { positivePrompt, negativePrompt, positiveStylePrompt, negativeStylePrompt } = getPresetModifiedPrompts(state); const { positivePrompt, negativePrompt, positiveStylePrompt, negativeStylePrompt } = getPresetModifiedPrompts(state);
const g = new Graph(getPrefixedId('sdxl_graph')); const g = new Graph(getPrefixedId('sdxl_graph'));
@ -100,7 +100,7 @@ export const buildSDXLGraph = async (
const l2i = g.addNode({ const l2i = g.addNode({
type: 'l2i', type: 'l2i',
id: getPrefixedId('l2i'), id: getPrefixedId('l2i'),
fp32: vaePrecision === 'fp32', fp32,
}); });
const vaeLoader = const vaeLoader =
vae?.base === model.base vae?.base === model.base
@ -171,7 +171,8 @@ export const buildSDXLGraph = async (
originalSize, originalSize,
scaledSize, scaledSize,
bbox, bbox,
refinerModel ? Math.min(refinerStart, 1 - params.img2imgStrength) : 1 - params.img2imgStrength refinerModel ? Math.min(refinerStart, 1 - params.img2imgStrength) : 1 - params.img2imgStrength,
fp32
); );
} else if (generationMode === 'inpaint') { } else if (generationMode === 'inpaint') {
const { compositing } = state.canvasV2; const { compositing } = state.canvasV2;
@ -188,7 +189,7 @@ export const buildSDXLGraph = async (
bbox, bbox,
compositing, compositing,
refinerModel ? Math.min(refinerStart, 1 - params.img2imgStrength) : 1 - params.img2imgStrength, refinerModel ? Math.min(refinerStart, 1 - params.img2imgStrength) : 1 - params.img2imgStrength,
vaePrecision fp32
); );
} else if (generationMode === 'outpaint') { } else if (generationMode === 'outpaint') {
const { compositing } = state.canvasV2; const { compositing } = state.canvasV2;
@ -205,7 +206,7 @@ export const buildSDXLGraph = async (
bbox, bbox,
compositing, compositing,
refinerModel ? Math.min(refinerStart, 1 - params.img2imgStrength) : 1 - params.img2imgStrength, refinerModel ? Math.min(refinerStart, 1 - params.img2imgStrength) : 1 - params.img2imgStrength,
vaePrecision fp32
); );
} }