fix: SDXL Refiner not working with Canvas Inpaint & Outpaint

This commit is contained in:
blessedcoolant 2023-08-31 06:26:02 +12:00
parent 754666ed09
commit 97763f778a
10 changed files with 99 additions and 34 deletions

View File

@ -11,10 +11,10 @@ import {
METADATA_ACCUMULATOR, METADATA_ACCUMULATOR,
NEGATIVE_CONDITIONING, NEGATIVE_CONDITIONING,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
REFINER_SEAMLESS,
SDXL_CANVAS_INPAINT_GRAPH, SDXL_CANVAS_INPAINT_GRAPH,
SDXL_CANVAS_OUTPAINT_GRAPH, SDXL_CANVAS_OUTPAINT_GRAPH,
SDXL_MODEL_LOADER, SDXL_MODEL_LOADER,
SDXL_REFINER_INPAINT_CREATE_MASK,
SEAMLESS, SEAMLESS,
} from './constants'; } from './constants';
@ -41,7 +41,9 @@ export const addSDXLLoRAsToGraph = (
// Handle Seamless Plugs // Handle Seamless Plugs
const unetLoaderId = modelLoaderNodeId; const unetLoaderId = modelLoaderNodeId;
let clipLoaderId = modelLoaderNodeId; let clipLoaderId = modelLoaderNodeId;
if ([SEAMLESS, REFINER_SEAMLESS].includes(modelLoaderNodeId)) { if (
[SEAMLESS, SDXL_REFINER_INPAINT_CREATE_MASK].includes(modelLoaderNodeId)
) {
clipLoaderId = SDXL_MODEL_LOADER; clipLoaderId = SDXL_MODEL_LOADER;
} }

View File

@ -1,24 +1,28 @@
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { import {
CreateDenoiseMaskInvocation,
ImageDTO,
MetadataAccumulatorInvocation, MetadataAccumulatorInvocation,
SeamlessModeInvocation, SeamlessModeInvocation,
} from 'services/api/types'; } from 'services/api/types';
import { NonNullableGraph } from '../../types/types'; import { NonNullableGraph } from '../../types/types';
import { import {
CANVAS_OUTPUT, CANVAS_OUTPUT,
INPAINT_IMAGE_RESIZE_UP,
LATENTS_TO_IMAGE, LATENTS_TO_IMAGE,
MASK_BLUR, MASK_BLUR,
METADATA_ACCUMULATOR, METADATA_ACCUMULATOR,
REFINER_SEAMLESS,
SDXL_CANVAS_IMAGE_TO_IMAGE_GRAPH, SDXL_CANVAS_IMAGE_TO_IMAGE_GRAPH,
SDXL_CANVAS_INPAINT_GRAPH, SDXL_CANVAS_INPAINT_GRAPH,
SDXL_CANVAS_OUTPAINT_GRAPH, SDXL_CANVAS_OUTPAINT_GRAPH,
SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH, SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH,
SDXL_MODEL_LOADER, SDXL_MODEL_LOADER,
SDXL_REFINER_DENOISE_LATENTS, SDXL_REFINER_DENOISE_LATENTS,
SDXL_REFINER_INPAINT_CREATE_MASK,
SDXL_REFINER_MODEL_LOADER, SDXL_REFINER_MODEL_LOADER,
SDXL_REFINER_NEGATIVE_CONDITIONING, SDXL_REFINER_NEGATIVE_CONDITIONING,
SDXL_REFINER_POSITIVE_CONDITIONING, SDXL_REFINER_POSITIVE_CONDITIONING,
SDXL_REFINER_SEAMLESS,
} from './constants'; } from './constants';
import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt'; import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt';
@ -26,7 +30,8 @@ export const addSDXLRefinerToGraph = (
state: RootState, state: RootState,
graph: NonNullableGraph, graph: NonNullableGraph,
baseNodeId: string, baseNodeId: string,
modelLoaderNodeId?: string modelLoaderNodeId?: string,
canvasInitImage?: ImageDTO
): void => { ): void => {
const { const {
refinerModel, refinerModel,
@ -38,7 +43,8 @@ export const addSDXLRefinerToGraph = (
refinerStart, refinerStart,
} = state.sdxl; } = state.sdxl;
const { seamlessXAxis, seamlessYAxis } = state.generation; const { seamlessXAxis, seamlessYAxis, vaePrecision } = state.generation;
const { boundingBoxScaleMethod } = state.canvas;
if (!refinerModel) { if (!refinerModel) {
return; return;
@ -108,8 +114,8 @@ export const addSDXLRefinerToGraph = (
// Add Seamless To Refiner // Add Seamless To Refiner
if (seamlessXAxis || seamlessYAxis) { if (seamlessXAxis || seamlessYAxis) {
graph.nodes[REFINER_SEAMLESS] = { graph.nodes[SDXL_REFINER_SEAMLESS] = {
id: REFINER_SEAMLESS, id: SDXL_REFINER_SEAMLESS,
type: 'seamless', type: 'seamless',
seamless_x: seamlessXAxis, seamless_x: seamlessXAxis,
seamless_y: seamlessYAxis, seamless_y: seamlessYAxis,
@ -122,7 +128,7 @@ export const addSDXLRefinerToGraph = (
field: 'unet', field: 'unet',
}, },
destination: { destination: {
node_id: REFINER_SEAMLESS, node_id: SDXL_REFINER_SEAMLESS,
field: 'unet', field: 'unet',
}, },
}, },
@ -132,13 +138,13 @@ export const addSDXLRefinerToGraph = (
field: 'vae', field: 'vae',
}, },
destination: { destination: {
node_id: REFINER_SEAMLESS, node_id: SDXL_REFINER_SEAMLESS,
field: 'vae', field: 'vae',
}, },
}, },
{ {
source: { source: {
node_id: REFINER_SEAMLESS, node_id: SDXL_REFINER_SEAMLESS,
field: 'unet', field: 'unet',
}, },
destination: { destination: {
@ -244,15 +250,54 @@ export const addSDXLRefinerToGraph = (
graph.id === SDXL_CANVAS_INPAINT_GRAPH || graph.id === SDXL_CANVAS_INPAINT_GRAPH ||
graph.id === SDXL_CANVAS_OUTPAINT_GRAPH graph.id === SDXL_CANVAS_OUTPAINT_GRAPH
) { ) {
graph.edges.push({ graph.nodes[SDXL_REFINER_INPAINT_CREATE_MASK] = {
source: { type: 'create_denoise_mask',
node_id: MASK_BLUR, id: SDXL_REFINER_INPAINT_CREATE_MASK,
field: 'image', is_intermediate: true,
fp32: vaePrecision === 'fp32' ? true : false,
};
if (['auto', 'manual'].includes(boundingBoxScaleMethod)) {
graph.edges.push({
source: {
node_id: INPAINT_IMAGE_RESIZE_UP,
field: 'image',
},
destination: {
node_id: SDXL_REFINER_INPAINT_CREATE_MASK,
field: 'image',
},
});
} else {
graph.nodes[SDXL_REFINER_INPAINT_CREATE_MASK] = {
...(graph.nodes[
SDXL_REFINER_INPAINT_CREATE_MASK
] as CreateDenoiseMaskInvocation),
image: canvasInitImage,
};
}
graph.edges.push(
{
source: {
node_id: MASK_BLUR,
field: 'image',
},
destination: {
node_id: SDXL_REFINER_INPAINT_CREATE_MASK,
field: 'mask',
},
}, },
destination: { {
node_id: SDXL_REFINER_DENOISE_LATENTS, source: {
field: 'mask', node_id: SDXL_REFINER_INPAINT_CREATE_MASK,
}, field: 'denoise_mask',
}); },
destination: {
node_id: SDXL_REFINER_DENOISE_LATENTS,
field: 'denoise_mask',
},
}
);
} }
}; };

View File

@ -20,6 +20,7 @@ import {
SDXL_CANVAS_OUTPAINT_GRAPH, SDXL_CANVAS_OUTPAINT_GRAPH,
SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH, SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH,
SDXL_IMAGE_TO_IMAGE_GRAPH, SDXL_IMAGE_TO_IMAGE_GRAPH,
SDXL_REFINER_INPAINT_CREATE_MASK,
SDXL_TEXT_TO_IMAGE_GRAPH, SDXL_TEXT_TO_IMAGE_GRAPH,
TEXT_TO_IMAGE_GRAPH, TEXT_TO_IMAGE_GRAPH,
VAE_LOADER, VAE_LOADER,
@ -32,6 +33,7 @@ export const addVAEToGraph = (
): void => { ): void => {
const { vae } = state.generation; const { vae } = state.generation;
const { boundingBoxScaleMethod } = state.canvas; const { boundingBoxScaleMethod } = state.canvas;
const { shouldUseSDXLRefiner } = state.sdxl;
const isUsingScaledDimensions = ['auto', 'manual'].includes( const isUsingScaledDimensions = ['auto', 'manual'].includes(
boundingBoxScaleMethod boundingBoxScaleMethod
@ -146,6 +148,19 @@ export const addVAEToGraph = (
); );
} }
if (shouldUseSDXLRefiner) {
graph.edges.push({
source: {
node_id: isAutoVae ? modelLoaderNodeId : VAE_LOADER,
field: isAutoVae && isOnnxModel ? 'vae_decoder' : 'vae',
},
destination: {
node_id: SDXL_REFINER_INPAINT_CREATE_MASK,
field: 'vae',
},
});
}
if (vae && metadataAccumulator) { if (vae && metadataAccumulator) {
metadataAccumulator.vae = vae; metadataAccumulator.vae = vae;
} }

View File

@ -20,10 +20,10 @@ import {
NEGATIVE_CONDITIONING, NEGATIVE_CONDITIONING,
NOISE, NOISE,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
REFINER_SEAMLESS,
SDXL_CANVAS_IMAGE_TO_IMAGE_GRAPH, SDXL_CANVAS_IMAGE_TO_IMAGE_GRAPH,
SDXL_DENOISE_LATENTS, SDXL_DENOISE_LATENTS,
SDXL_MODEL_LOADER, SDXL_MODEL_LOADER,
SDXL_REFINER_SEAMLESS,
SEAMLESS, SEAMLESS,
} from './constants'; } from './constants';
import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt'; import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt';
@ -369,7 +369,7 @@ export const buildCanvasSDXLImageToImageGraph = (
if (shouldUseSDXLRefiner) { if (shouldUseSDXLRefiner) {
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS); addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS);
if (seamlessXAxis || seamlessYAxis) { if (seamlessXAxis || seamlessYAxis) {
modelLoaderNodeId = REFINER_SEAMLESS; modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
} }
} }

View File

@ -36,10 +36,10 @@ import {
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
RANDOM_INT, RANDOM_INT,
RANGE_OF_SIZE, RANGE_OF_SIZE,
REFINER_SEAMLESS,
SDXL_CANVAS_INPAINT_GRAPH, SDXL_CANVAS_INPAINT_GRAPH,
SDXL_DENOISE_LATENTS, SDXL_DENOISE_LATENTS,
SDXL_MODEL_LOADER, SDXL_MODEL_LOADER,
SDXL_REFINER_SEAMLESS,
SEAMLESS, SEAMLESS,
} from './constants'; } from './constants';
import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt'; import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt';
@ -628,10 +628,11 @@ export const buildCanvasSDXLInpaintGraph = (
state, state,
graph, graph,
CANVAS_COHERENCE_DENOISE_LATENTS, CANVAS_COHERENCE_DENOISE_LATENTS,
modelLoaderNodeId modelLoaderNodeId,
canvasInitImage
); );
if (seamlessXAxis || seamlessYAxis) { if (seamlessXAxis || seamlessYAxis) {
modelLoaderNodeId = REFINER_SEAMLESS; modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
} }
} }

View File

@ -41,10 +41,10 @@ import {
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
RANDOM_INT, RANDOM_INT,
RANGE_OF_SIZE, RANGE_OF_SIZE,
REFINER_SEAMLESS,
SDXL_CANVAS_OUTPAINT_GRAPH, SDXL_CANVAS_OUTPAINT_GRAPH,
SDXL_DENOISE_LATENTS, SDXL_DENOISE_LATENTS,
SDXL_MODEL_LOADER, SDXL_MODEL_LOADER,
SDXL_REFINER_SEAMLESS,
SEAMLESS, SEAMLESS,
} from './constants'; } from './constants';
import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt'; import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt';
@ -766,10 +766,11 @@ export const buildCanvasSDXLOutpaintGraph = (
state, state,
graph, graph,
CANVAS_COHERENCE_DENOISE_LATENTS, CANVAS_COHERENCE_DENOISE_LATENTS,
modelLoaderNodeId modelLoaderNodeId,
canvasInitImage
); );
if (seamlessXAxis || seamlessYAxis) { if (seamlessXAxis || seamlessYAxis) {
modelLoaderNodeId = REFINER_SEAMLESS; modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
} }
} }

View File

@ -22,10 +22,10 @@ import {
NOISE, NOISE,
ONNX_MODEL_LOADER, ONNX_MODEL_LOADER,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
REFINER_SEAMLESS,
SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH, SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH,
SDXL_DENOISE_LATENTS, SDXL_DENOISE_LATENTS,
SDXL_MODEL_LOADER, SDXL_MODEL_LOADER,
SDXL_REFINER_SEAMLESS,
SEAMLESS, SEAMLESS,
} from './constants'; } from './constants';
import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt'; import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt';
@ -349,7 +349,7 @@ export const buildCanvasSDXLTextToImageGraph = (
if (shouldUseSDXLRefiner) { if (shouldUseSDXLRefiner) {
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS); addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS);
if (seamlessXAxis || seamlessYAxis) { if (seamlessXAxis || seamlessYAxis) {
modelLoaderNodeId = REFINER_SEAMLESS; modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
} }
} }

View File

@ -21,11 +21,11 @@ import {
NEGATIVE_CONDITIONING, NEGATIVE_CONDITIONING,
NOISE, NOISE,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
REFINER_SEAMLESS,
RESIZE, RESIZE,
SDXL_DENOISE_LATENTS, SDXL_DENOISE_LATENTS,
SDXL_IMAGE_TO_IMAGE_GRAPH, SDXL_IMAGE_TO_IMAGE_GRAPH,
SDXL_MODEL_LOADER, SDXL_MODEL_LOADER,
SDXL_REFINER_SEAMLESS,
SEAMLESS, SEAMLESS,
} from './constants'; } from './constants';
import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt'; import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt';
@ -369,7 +369,7 @@ export const buildLinearSDXLImageToImageGraph = (
if (shouldUseSDXLRefiner) { if (shouldUseSDXLRefiner) {
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS); addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS);
if (seamlessXAxis || seamlessYAxis) { if (seamlessXAxis || seamlessYAxis) {
modelLoaderNodeId = REFINER_SEAMLESS; modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
} }
} }

View File

@ -16,9 +16,9 @@ import {
NEGATIVE_CONDITIONING, NEGATIVE_CONDITIONING,
NOISE, NOISE,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
REFINER_SEAMLESS,
SDXL_DENOISE_LATENTS, SDXL_DENOISE_LATENTS,
SDXL_MODEL_LOADER, SDXL_MODEL_LOADER,
SDXL_REFINER_SEAMLESS,
SDXL_TEXT_TO_IMAGE_GRAPH, SDXL_TEXT_TO_IMAGE_GRAPH,
SEAMLESS, SEAMLESS,
} from './constants'; } from './constants';
@ -262,7 +262,7 @@ export const buildLinearSDXLTextToImageGraph = (
if (shouldUseSDXLRefiner) { if (shouldUseSDXLRefiner) {
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS); addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS);
if (seamlessXAxis || seamlessYAxis) { if (seamlessXAxis || seamlessYAxis) {
modelLoaderNodeId = REFINER_SEAMLESS; modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
} }
} }

View File

@ -56,8 +56,9 @@ export const SDXL_REFINER_POSITIVE_CONDITIONING =
export const SDXL_REFINER_NEGATIVE_CONDITIONING = export const SDXL_REFINER_NEGATIVE_CONDITIONING =
'sdxl_refiner_negative_conditioning'; 'sdxl_refiner_negative_conditioning';
export const SDXL_REFINER_DENOISE_LATENTS = 'sdxl_refiner_denoise_latents'; export const SDXL_REFINER_DENOISE_LATENTS = 'sdxl_refiner_denoise_latents';
export const SDXL_REFINER_INPAINT_CREATE_MASK = 'refiner_inpaint_create_mask';
export const SEAMLESS = 'seamless'; export const SEAMLESS = 'seamless';
export const REFINER_SEAMLESS = 'refiner_seamless'; export const SDXL_REFINER_SEAMLESS = 'refiner_seamless';
// friendly graph ids // friendly graph ids
export const TEXT_TO_IMAGE_GRAPH = 'text_to_image_graph'; export const TEXT_TO_IMAGE_GRAPH = 'text_to_image_graph';