feat: Make Refiner work with Canvas

This commit is contained in:
blessedcoolant 2023-08-13 03:53:40 +12:00
parent 500cd552bc
commit c33acf951e

View File

@ -2,10 +2,14 @@ import { RootState } from 'app/store/store';
import { MetadataAccumulatorInvocation } from 'services/api/types';
import { NonNullableGraph } from '../../types/types';
import {
IMAGE_TO_LATENTS,
CANVAS_OUTPUT,
LATENTS_TO_IMAGE,
MASK_BLUR,
METADATA_ACCUMULATOR,
SDXL_DENOISE_LATENTS,
SDXL_CANVAS_IMAGE_TO_IMAGE_GRAPH,
SDXL_CANVAS_INPAINT_GRAPH,
SDXL_CANVAS_OUTPAINT_GRAPH,
SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH,
SDXL_MODEL_LOADER,
SDXL_REFINER_DENOISE_LATENTS,
SDXL_REFINER_MODEL_LOADER,
@ -59,21 +63,6 @@ export const addSDXLRefinerToGraph = (
)
);
// connect the VAE back to the i2l, which we just removed in the filter
// but only if we are doing l2l
if (baseNodeId === SDXL_DENOISE_LATENTS) {
graph.edges.push({
source: {
node_id: SDXL_MODEL_LOADER,
field: 'vae',
},
destination: {
node_id: IMAGE_TO_LATENTS,
field: 'vae',
},
});
}
graph.nodes[SDXL_REFINER_MODEL_LOADER] = {
type: 'sdxl_refiner_model_loader',
id: SDXL_REFINER_MODEL_LOADER,
@ -112,16 +101,6 @@ export const addSDXLRefinerToGraph = (
field: 'unet',
},
},
{
source: {
node_id: SDXL_REFINER_MODEL_LOADER,
field: 'vae',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'vae',
},
},
{
source: {
node_id: SDXL_REFINER_MODEL_LOADER,
@ -171,8 +150,25 @@ export const addSDXLRefinerToGraph = (
node_id: SDXL_REFINER_DENOISE_LATENTS,
field: 'latents',
},
},
{
}
);
if (
graph.id === SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH ||
graph.id === SDXL_CANVAS_IMAGE_TO_IMAGE_GRAPH
) {
graph.edges.push({
source: {
node_id: SDXL_REFINER_DENOISE_LATENTS,
field: 'latents',
},
destination: {
node_id: CANVAS_OUTPUT,
field: 'latents',
},
});
} else {
graph.edges.push({
source: {
node_id: SDXL_REFINER_DENOISE_LATENTS,
field: 'latents',
@ -181,6 +177,22 @@ export const addSDXLRefinerToGraph = (
node_id: LATENTS_TO_IMAGE,
field: 'latents',
},
}
);
});
}
if (
graph.id === SDXL_CANVAS_INPAINT_GRAPH ||
graph.id === SDXL_CANVAS_OUTPAINT_GRAPH
) {
graph.edges.push({
source: {
node_id: MASK_BLUR,
field: 'image',
},
destination: {
node_id: SDXL_REFINER_DENOISE_LATENTS,
field: 'mask',
},
});
}
};