fix: SDXL Refiner not working with Canvas Inpaint & Outpaint

This commit is contained in:
blessedcoolant 2023-08-31 06:26:02 +12:00
parent 754666ed09
commit 97763f778a
10 changed files with 99 additions and 34 deletions

View File

@ -11,10 +11,10 @@ import {
METADATA_ACCUMULATOR,
NEGATIVE_CONDITIONING,
POSITIVE_CONDITIONING,
REFINER_SEAMLESS,
SDXL_CANVAS_INPAINT_GRAPH,
SDXL_CANVAS_OUTPAINT_GRAPH,
SDXL_MODEL_LOADER,
SDXL_REFINER_INPAINT_CREATE_MASK,
SEAMLESS,
} from './constants';
@ -41,7 +41,9 @@ export const addSDXLLoRAsToGraph = (
// Handle Seamless Plugs
const unetLoaderId = modelLoaderNodeId;
let clipLoaderId = modelLoaderNodeId;
if ([SEAMLESS, REFINER_SEAMLESS].includes(modelLoaderNodeId)) {
if (
[SEAMLESS, SDXL_REFINER_INPAINT_CREATE_MASK].includes(modelLoaderNodeId)
) {
clipLoaderId = SDXL_MODEL_LOADER;
}

View File

@ -1,24 +1,28 @@
import { RootState } from 'app/store/store';
import {
CreateDenoiseMaskInvocation,
ImageDTO,
MetadataAccumulatorInvocation,
SeamlessModeInvocation,
} from 'services/api/types';
import { NonNullableGraph } from '../../types/types';
import {
CANVAS_OUTPUT,
INPAINT_IMAGE_RESIZE_UP,
LATENTS_TO_IMAGE,
MASK_BLUR,
METADATA_ACCUMULATOR,
REFINER_SEAMLESS,
SDXL_CANVAS_IMAGE_TO_IMAGE_GRAPH,
SDXL_CANVAS_INPAINT_GRAPH,
SDXL_CANVAS_OUTPAINT_GRAPH,
SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH,
SDXL_MODEL_LOADER,
SDXL_REFINER_DENOISE_LATENTS,
SDXL_REFINER_INPAINT_CREATE_MASK,
SDXL_REFINER_MODEL_LOADER,
SDXL_REFINER_NEGATIVE_CONDITIONING,
SDXL_REFINER_POSITIVE_CONDITIONING,
SDXL_REFINER_SEAMLESS,
} from './constants';
import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt';
@ -26,7 +30,8 @@ export const addSDXLRefinerToGraph = (
state: RootState,
graph: NonNullableGraph,
baseNodeId: string,
modelLoaderNodeId?: string
modelLoaderNodeId?: string,
canvasInitImage?: ImageDTO
): void => {
const {
refinerModel,
@ -38,7 +43,8 @@ export const addSDXLRefinerToGraph = (
refinerStart,
} = state.sdxl;
const { seamlessXAxis, seamlessYAxis } = state.generation;
const { seamlessXAxis, seamlessYAxis, vaePrecision } = state.generation;
const { boundingBoxScaleMethod } = state.canvas;
if (!refinerModel) {
return;
@ -108,8 +114,8 @@ export const addSDXLRefinerToGraph = (
// Add Seamless To Refiner
if (seamlessXAxis || seamlessYAxis) {
graph.nodes[REFINER_SEAMLESS] = {
id: REFINER_SEAMLESS,
graph.nodes[SDXL_REFINER_SEAMLESS] = {
id: SDXL_REFINER_SEAMLESS,
type: 'seamless',
seamless_x: seamlessXAxis,
seamless_y: seamlessYAxis,
@ -122,7 +128,7 @@ export const addSDXLRefinerToGraph = (
field: 'unet',
},
destination: {
node_id: REFINER_SEAMLESS,
node_id: SDXL_REFINER_SEAMLESS,
field: 'unet',
},
},
@ -132,13 +138,13 @@ export const addSDXLRefinerToGraph = (
field: 'vae',
},
destination: {
node_id: REFINER_SEAMLESS,
node_id: SDXL_REFINER_SEAMLESS,
field: 'vae',
},
},
{
source: {
node_id: REFINER_SEAMLESS,
node_id: SDXL_REFINER_SEAMLESS,
field: 'unet',
},
destination: {
@ -244,15 +250,54 @@ export const addSDXLRefinerToGraph = (
graph.id === SDXL_CANVAS_INPAINT_GRAPH ||
graph.id === SDXL_CANVAS_OUTPAINT_GRAPH
) {
graph.edges.push({
source: {
node_id: MASK_BLUR,
field: 'image',
graph.nodes[SDXL_REFINER_INPAINT_CREATE_MASK] = {
type: 'create_denoise_mask',
id: SDXL_REFINER_INPAINT_CREATE_MASK,
is_intermediate: true,
fp32: vaePrecision === 'fp32' ? true : false,
};
if (['auto', 'manual'].includes(boundingBoxScaleMethod)) {
graph.edges.push({
source: {
node_id: INPAINT_IMAGE_RESIZE_UP,
field: 'image',
},
destination: {
node_id: SDXL_REFINER_INPAINT_CREATE_MASK,
field: 'image',
},
});
} else {
graph.nodes[SDXL_REFINER_INPAINT_CREATE_MASK] = {
...(graph.nodes[
SDXL_REFINER_INPAINT_CREATE_MASK
] as CreateDenoiseMaskInvocation),
image: canvasInitImage,
};
}
graph.edges.push(
{
source: {
node_id: MASK_BLUR,
field: 'image',
},
destination: {
node_id: SDXL_REFINER_INPAINT_CREATE_MASK,
field: 'mask',
},
},
destination: {
node_id: SDXL_REFINER_DENOISE_LATENTS,
field: 'mask',
},
});
{
source: {
node_id: SDXL_REFINER_INPAINT_CREATE_MASK,
field: 'denoise_mask',
},
destination: {
node_id: SDXL_REFINER_DENOISE_LATENTS,
field: 'denoise_mask',
},
}
);
}
};

View File

@ -20,6 +20,7 @@ import {
SDXL_CANVAS_OUTPAINT_GRAPH,
SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH,
SDXL_IMAGE_TO_IMAGE_GRAPH,
SDXL_REFINER_INPAINT_CREATE_MASK,
SDXL_TEXT_TO_IMAGE_GRAPH,
TEXT_TO_IMAGE_GRAPH,
VAE_LOADER,
@ -32,6 +33,7 @@ export const addVAEToGraph = (
): void => {
const { vae } = state.generation;
const { boundingBoxScaleMethod } = state.canvas;
const { shouldUseSDXLRefiner } = state.sdxl;
const isUsingScaledDimensions = ['auto', 'manual'].includes(
boundingBoxScaleMethod
@ -146,6 +148,19 @@ export const addVAEToGraph = (
);
}
if (shouldUseSDXLRefiner) {
graph.edges.push({
source: {
node_id: isAutoVae ? modelLoaderNodeId : VAE_LOADER,
field: isAutoVae && isOnnxModel ? 'vae_decoder' : 'vae',
},
destination: {
node_id: SDXL_REFINER_INPAINT_CREATE_MASK,
field: 'vae',
},
});
}
if (vae && metadataAccumulator) {
metadataAccumulator.vae = vae;
}

View File

@ -20,10 +20,10 @@ import {
NEGATIVE_CONDITIONING,
NOISE,
POSITIVE_CONDITIONING,
REFINER_SEAMLESS,
SDXL_CANVAS_IMAGE_TO_IMAGE_GRAPH,
SDXL_DENOISE_LATENTS,
SDXL_MODEL_LOADER,
SDXL_REFINER_SEAMLESS,
SEAMLESS,
} from './constants';
import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt';
@ -369,7 +369,7 @@ export const buildCanvasSDXLImageToImageGraph = (
if (shouldUseSDXLRefiner) {
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS);
if (seamlessXAxis || seamlessYAxis) {
modelLoaderNodeId = REFINER_SEAMLESS;
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
}
}

View File

@ -36,10 +36,10 @@ import {
POSITIVE_CONDITIONING,
RANDOM_INT,
RANGE_OF_SIZE,
REFINER_SEAMLESS,
SDXL_CANVAS_INPAINT_GRAPH,
SDXL_DENOISE_LATENTS,
SDXL_MODEL_LOADER,
SDXL_REFINER_SEAMLESS,
SEAMLESS,
} from './constants';
import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt';
@ -628,10 +628,11 @@ export const buildCanvasSDXLInpaintGraph = (
state,
graph,
CANVAS_COHERENCE_DENOISE_LATENTS,
modelLoaderNodeId
modelLoaderNodeId,
canvasInitImage
);
if (seamlessXAxis || seamlessYAxis) {
modelLoaderNodeId = REFINER_SEAMLESS;
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
}
}

View File

@ -41,10 +41,10 @@ import {
POSITIVE_CONDITIONING,
RANDOM_INT,
RANGE_OF_SIZE,
REFINER_SEAMLESS,
SDXL_CANVAS_OUTPAINT_GRAPH,
SDXL_DENOISE_LATENTS,
SDXL_MODEL_LOADER,
SDXL_REFINER_SEAMLESS,
SEAMLESS,
} from './constants';
import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt';
@ -766,10 +766,11 @@ export const buildCanvasSDXLOutpaintGraph = (
state,
graph,
CANVAS_COHERENCE_DENOISE_LATENTS,
modelLoaderNodeId
modelLoaderNodeId,
canvasInitImage
);
if (seamlessXAxis || seamlessYAxis) {
modelLoaderNodeId = REFINER_SEAMLESS;
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
}
}

View File

@ -22,10 +22,10 @@ import {
NOISE,
ONNX_MODEL_LOADER,
POSITIVE_CONDITIONING,
REFINER_SEAMLESS,
SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH,
SDXL_DENOISE_LATENTS,
SDXL_MODEL_LOADER,
SDXL_REFINER_SEAMLESS,
SEAMLESS,
} from './constants';
import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt';
@ -349,7 +349,7 @@ export const buildCanvasSDXLTextToImageGraph = (
if (shouldUseSDXLRefiner) {
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS);
if (seamlessXAxis || seamlessYAxis) {
modelLoaderNodeId = REFINER_SEAMLESS;
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
}
}

View File

@ -21,11 +21,11 @@ import {
NEGATIVE_CONDITIONING,
NOISE,
POSITIVE_CONDITIONING,
REFINER_SEAMLESS,
RESIZE,
SDXL_DENOISE_LATENTS,
SDXL_IMAGE_TO_IMAGE_GRAPH,
SDXL_MODEL_LOADER,
SDXL_REFINER_SEAMLESS,
SEAMLESS,
} from './constants';
import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt';
@ -369,7 +369,7 @@ export const buildLinearSDXLImageToImageGraph = (
if (shouldUseSDXLRefiner) {
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS);
if (seamlessXAxis || seamlessYAxis) {
modelLoaderNodeId = REFINER_SEAMLESS;
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
}
}

View File

@ -16,9 +16,9 @@ import {
NEGATIVE_CONDITIONING,
NOISE,
POSITIVE_CONDITIONING,
REFINER_SEAMLESS,
SDXL_DENOISE_LATENTS,
SDXL_MODEL_LOADER,
SDXL_REFINER_SEAMLESS,
SDXL_TEXT_TO_IMAGE_GRAPH,
SEAMLESS,
} from './constants';
@ -262,7 +262,7 @@ export const buildLinearSDXLTextToImageGraph = (
if (shouldUseSDXLRefiner) {
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS);
if (seamlessXAxis || seamlessYAxis) {
modelLoaderNodeId = REFINER_SEAMLESS;
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
}
}

View File

@ -56,8 +56,9 @@ export const SDXL_REFINER_POSITIVE_CONDITIONING =
export const SDXL_REFINER_NEGATIVE_CONDITIONING =
'sdxl_refiner_negative_conditioning';
export const SDXL_REFINER_DENOISE_LATENTS = 'sdxl_refiner_denoise_latents';
export const SDXL_REFINER_INPAINT_CREATE_MASK = 'refiner_inpaint_create_mask';
export const SEAMLESS = 'seamless';
export const REFINER_SEAMLESS = 'refiner_seamless';
export const SDXL_REFINER_SEAMLESS = 'refiner_seamless';
// friendly graph ids
export const TEXT_TO_IMAGE_GRAPH = 'text_to_image_graph';