fix: Update SDXL Refiner graphs to use Gradient Mask

This commit is contained in:
blessedcoolant 2024-04-05 22:49:55 +05:30 committed by psychedelicious
parent b58494c420
commit 381b41a56e
2 changed files with 14 additions and 53 deletions

View File

@ -1,16 +1,15 @@
import type { RootState } from 'app/store/store'; import type { RootState } from 'app/store/store';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { import type {
type CreateDenoiseMaskInvocation, CreateGradientMaskInvocation,
type ImageDTO, ImageDTO,
isRefinerMainModelModelConfig, NonNullableGraph,
type NonNullableGraph, SeamlessModeInvocation,
type SeamlessModeInvocation,
} from 'services/api/types'; } from 'services/api/types';
import { isRefinerMainModelModelConfig } from 'services/api/types';
import { import {
CANVAS_OUTPUT, CANVAS_OUTPUT,
INPAINT_IMAGE_RESIZE_UP,
LATENTS_TO_IMAGE, LATENTS_TO_IMAGE,
MASK_COMBINE, MASK_COMBINE,
MASK_RESIZE_UP, MASK_RESIZE_UP,
@ -47,15 +46,15 @@ export const addSDXLRefinerToGraph = async (
refinerStart, refinerStart,
} = state.sdxl; } = state.sdxl;
const { canvasCoherenceEdgeSize, canvasCoherenceMinDenoise, canvasCoherenceMode } = state.generation;
if (!refinerModel) { if (!refinerModel) {
return; return;
} }
const { seamlessXAxis, seamlessYAxis, vaePrecision } = state.generation; const { seamlessXAxis, seamlessYAxis } = state.generation;
const { boundingBoxScaleMethod } = state.canvas; const { boundingBoxScaleMethod } = state.canvas;
const fp32 = vaePrecision === 'fp32';
const isUsingScaledDimensions = ['auto', 'manual'].includes(boundingBoxScaleMethod); const isUsingScaledDimensions = ['auto', 'manual'].includes(boundingBoxScaleMethod);
const modelConfig = await fetchModelConfigWithTypeGuard(refinerModel.key, isRefinerMainModelModelConfig); const modelConfig = await fetchModelConfigWithTypeGuard(refinerModel.key, isRefinerMainModelModelConfig);
@ -215,30 +214,14 @@ export const addSDXLRefinerToGraph = async (
if (graph.id === SDXL_CANVAS_INPAINT_GRAPH || graph.id === SDXL_CANVAS_OUTPAINT_GRAPH) { if (graph.id === SDXL_CANVAS_INPAINT_GRAPH || graph.id === SDXL_CANVAS_OUTPAINT_GRAPH) {
graph.nodes[SDXL_REFINER_INPAINT_CREATE_MASK] = { graph.nodes[SDXL_REFINER_INPAINT_CREATE_MASK] = {
type: 'create_denoise_mask', type: 'create_gradient_mask',
id: SDXL_REFINER_INPAINT_CREATE_MASK, id: SDXL_REFINER_INPAINT_CREATE_MASK,
is_intermediate: true, is_intermediate: true,
fp32, edge_radius: canvasCoherenceEdgeSize,
coherence_mode: canvasCoherenceMode,
minimum_denoise: canvasCoherenceMinDenoise,
}; };
if (isUsingScaledDimensions) {
graph.edges.push({
source: {
node_id: INPAINT_IMAGE_RESIZE_UP,
field: 'image',
},
destination: {
node_id: SDXL_REFINER_INPAINT_CREATE_MASK,
field: 'image',
},
});
} else {
graph.nodes[SDXL_REFINER_INPAINT_CREATE_MASK] = {
...(graph.nodes[SDXL_REFINER_INPAINT_CREATE_MASK] as CreateDenoiseMaskInvocation),
image: canvasInitImage,
};
}
if (graph.id === SDXL_CANVAS_INPAINT_GRAPH) { if (graph.id === SDXL_CANVAS_INPAINT_GRAPH) {
if (isUsingScaledDimensions) { if (isUsingScaledDimensions) {
graph.edges.push({ graph.edges.push({
@ -253,7 +236,7 @@ export const addSDXLRefinerToGraph = async (
}); });
} else { } else {
graph.nodes[SDXL_REFINER_INPAINT_CREATE_MASK] = { graph.nodes[SDXL_REFINER_INPAINT_CREATE_MASK] = {
...(graph.nodes[SDXL_REFINER_INPAINT_CREATE_MASK] as CreateDenoiseMaskInvocation), ...(graph.nodes[SDXL_REFINER_INPAINT_CREATE_MASK] as CreateGradientMaskInvocation),
mask: canvasMaskImage, mask: canvasMaskImage,
}; };
} }

View File

@ -17,7 +17,6 @@ import {
SDXL_CANVAS_OUTPAINT_GRAPH, SDXL_CANVAS_OUTPAINT_GRAPH,
SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH, SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH,
SDXL_IMAGE_TO_IMAGE_GRAPH, SDXL_IMAGE_TO_IMAGE_GRAPH,
SDXL_REFINER_INPAINT_CREATE_MASK,
SDXL_REFINER_SEAMLESS, SDXL_REFINER_SEAMLESS,
SDXL_TEXT_TO_IMAGE_GRAPH, SDXL_TEXT_TO_IMAGE_GRAPH,
SEAMLESS, SEAMLESS,
@ -166,27 +165,6 @@ export const addVAEToGraph = async (
); );
} }
if (refinerModel) {
if (graph.id === SDXL_CANVAS_INPAINT_GRAPH || graph.id === SDXL_CANVAS_OUTPAINT_GRAPH) {
graph.edges.push({
source: {
node_id: isSeamlessEnabled
? isUsingRefiner
? SDXL_REFINER_SEAMLESS
: SEAMLESS
: isAutoVae
? modelLoaderNodeId
: VAE_LOADER,
field: 'vae',
},
destination: {
node_id: SDXL_REFINER_INPAINT_CREATE_MASK,
field: 'vae',
},
});
}
}
if (vae) { if (vae) {
upsertMetadata(graph, { vae }); upsertMetadata(graph, { vae });
} }