diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addSDXLRefinerToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addSDXLRefinerToGraph.ts index ca0d574d61..65664c9f2d 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addSDXLRefinerToGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addSDXLRefinerToGraph.ts @@ -2,10 +2,14 @@ import { RootState } from 'app/store/store'; import { MetadataAccumulatorInvocation } from 'services/api/types'; import { NonNullableGraph } from '../../types/types'; import { - IMAGE_TO_LATENTS, + CANVAS_OUTPUT, LATENTS_TO_IMAGE, + MASK_BLUR, METADATA_ACCUMULATOR, - SDXL_DENOISE_LATENTS, + SDXL_CANVAS_IMAGE_TO_IMAGE_GRAPH, + SDXL_CANVAS_INPAINT_GRAPH, + SDXL_CANVAS_OUTPAINT_GRAPH, + SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH, SDXL_MODEL_LOADER, SDXL_REFINER_DENOISE_LATENTS, SDXL_REFINER_MODEL_LOADER, @@ -59,21 +63,6 @@ export const addSDXLRefinerToGraph = ( ) ); - // connect the VAE back to the i2l, which we just removed in the filter - // but only if we are doing l2l - if (baseNodeId === SDXL_DENOISE_LATENTS) { - graph.edges.push({ - source: { - node_id: SDXL_MODEL_LOADER, - field: 'vae', - }, - destination: { - node_id: IMAGE_TO_LATENTS, - field: 'vae', - }, - }); - } - graph.nodes[SDXL_REFINER_MODEL_LOADER] = { type: 'sdxl_refiner_model_loader', id: SDXL_REFINER_MODEL_LOADER, @@ -112,16 +101,6 @@ export const addSDXLRefinerToGraph = ( field: 'unet', }, }, - { - source: { - node_id: SDXL_REFINER_MODEL_LOADER, - field: 'vae', - }, - destination: { - node_id: LATENTS_TO_IMAGE, - field: 'vae', - }, - }, { source: { node_id: SDXL_REFINER_MODEL_LOADER, @@ -171,8 +150,25 @@ export const addSDXLRefinerToGraph = ( node_id: SDXL_REFINER_DENOISE_LATENTS, field: 'latents', }, - }, - { + } + ); + + if ( + graph.id === SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH || + graph.id === SDXL_CANVAS_IMAGE_TO_IMAGE_GRAPH + ) { + graph.edges.push({ + source: { + node_id: SDXL_REFINER_DENOISE_LATENTS, + field: 'latents', + }, + destination: { + node_id: CANVAS_OUTPUT, + field: 'latents', + }, + }); + } else { + graph.edges.push({ source: { node_id: SDXL_REFINER_DENOISE_LATENTS, field: 'latents', @@ -181,6 +177,22 @@ export const addSDXLRefinerToGraph = ( node_id: LATENTS_TO_IMAGE, field: 'latents', }, - } - ); + }); + } + + if ( + graph.id === SDXL_CANVAS_INPAINT_GRAPH || + graph.id === SDXL_CANVAS_OUTPAINT_GRAPH + ) { + graph.edges.push({ + source: { + node_id: MASK_BLUR, + field: 'image', + }, + destination: { + node_id: SDXL_REFINER_DENOISE_LATENTS, + field: 'mask', + }, + }); + } };