fix: Update SDXL Refiner graphs to use Gradient Mask

This commit is contained in:
blessedcoolant 2024-04-05 22:49:55 +05:30 committed by psychedelicious
parent b58494c420
commit 381b41a56e
2 changed files with 14 additions and 53 deletions

View File

@ -1,16 +1,15 @@
import type { RootState } from 'app/store/store';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import {
type CreateDenoiseMaskInvocation,
type ImageDTO,
isRefinerMainModelModelConfig,
type NonNullableGraph,
type SeamlessModeInvocation,
import type {
CreateGradientMaskInvocation,
ImageDTO,
NonNullableGraph,
SeamlessModeInvocation,
} from 'services/api/types';
import { isRefinerMainModelModelConfig } from 'services/api/types';
import {
CANVAS_OUTPUT,
INPAINT_IMAGE_RESIZE_UP,
LATENTS_TO_IMAGE,
MASK_COMBINE,
MASK_RESIZE_UP,
@ -47,15 +46,15 @@ export const addSDXLRefinerToGraph = async (
refinerStart,
} = state.sdxl;
const { canvasCoherenceEdgeSize, canvasCoherenceMinDenoise, canvasCoherenceMode } = state.generation;
if (!refinerModel) {
return;
}
const { seamlessXAxis, seamlessYAxis, vaePrecision } = state.generation;
const { seamlessXAxis, seamlessYAxis } = state.generation;
const { boundingBoxScaleMethod } = state.canvas;
const fp32 = vaePrecision === 'fp32';
const isUsingScaledDimensions = ['auto', 'manual'].includes(boundingBoxScaleMethod);
const modelConfig = await fetchModelConfigWithTypeGuard(refinerModel.key, isRefinerMainModelModelConfig);
@ -215,30 +214,14 @@ export const addSDXLRefinerToGraph = async (
if (graph.id === SDXL_CANVAS_INPAINT_GRAPH || graph.id === SDXL_CANVAS_OUTPAINT_GRAPH) {
graph.nodes[SDXL_REFINER_INPAINT_CREATE_MASK] = {
type: 'create_denoise_mask',
type: 'create_gradient_mask',
id: SDXL_REFINER_INPAINT_CREATE_MASK,
is_intermediate: true,
fp32,
edge_radius: canvasCoherenceEdgeSize,
coherence_mode: canvasCoherenceMode,
minimum_denoise: canvasCoherenceMinDenoise,
};
if (isUsingScaledDimensions) {
graph.edges.push({
source: {
node_id: INPAINT_IMAGE_RESIZE_UP,
field: 'image',
},
destination: {
node_id: SDXL_REFINER_INPAINT_CREATE_MASK,
field: 'image',
},
});
} else {
graph.nodes[SDXL_REFINER_INPAINT_CREATE_MASK] = {
...(graph.nodes[SDXL_REFINER_INPAINT_CREATE_MASK] as CreateDenoiseMaskInvocation),
image: canvasInitImage,
};
}
if (graph.id === SDXL_CANVAS_INPAINT_GRAPH) {
if (isUsingScaledDimensions) {
graph.edges.push({
@ -253,7 +236,7 @@ export const addSDXLRefinerToGraph = async (
});
} else {
graph.nodes[SDXL_REFINER_INPAINT_CREATE_MASK] = {
...(graph.nodes[SDXL_REFINER_INPAINT_CREATE_MASK] as CreateDenoiseMaskInvocation),
...(graph.nodes[SDXL_REFINER_INPAINT_CREATE_MASK] as CreateGradientMaskInvocation),
mask: canvasMaskImage,
};
}

View File

@ -17,7 +17,6 @@ import {
SDXL_CANVAS_OUTPAINT_GRAPH,
SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH,
SDXL_IMAGE_TO_IMAGE_GRAPH,
SDXL_REFINER_INPAINT_CREATE_MASK,
SDXL_REFINER_SEAMLESS,
SDXL_TEXT_TO_IMAGE_GRAPH,
SEAMLESS,
@ -166,27 +165,6 @@ export const addVAEToGraph = async (
);
}
if (refinerModel) {
if (graph.id === SDXL_CANVAS_INPAINT_GRAPH || graph.id === SDXL_CANVAS_OUTPAINT_GRAPH) {
graph.edges.push({
source: {
node_id: isSeamlessEnabled
? isUsingRefiner
? SDXL_REFINER_SEAMLESS
: SEAMLESS
: isAutoVae
? modelLoaderNodeId
: VAE_LOADER,
field: 'vae',
},
destination: {
node_id: SDXL_REFINER_INPAINT_CREATE_MASK,
field: 'vae',
},
});
}
}
if (vae) {
upsertMetadata(graph, { vae });
}