feat: Make Refiner work with Canvas

This commit is contained in:
blessedcoolant 2023-08-13 03:53:40 +12:00
parent 500cd552bc
commit c33acf951e

View File

@ -2,10 +2,14 @@ import { RootState } from 'app/store/store';
import { MetadataAccumulatorInvocation } from 'services/api/types'; import { MetadataAccumulatorInvocation } from 'services/api/types';
import { NonNullableGraph } from '../../types/types'; import { NonNullableGraph } from '../../types/types';
import { import {
IMAGE_TO_LATENTS, CANVAS_OUTPUT,
LATENTS_TO_IMAGE, LATENTS_TO_IMAGE,
MASK_BLUR,
METADATA_ACCUMULATOR, METADATA_ACCUMULATOR,
SDXL_DENOISE_LATENTS, SDXL_CANVAS_IMAGE_TO_IMAGE_GRAPH,
SDXL_CANVAS_INPAINT_GRAPH,
SDXL_CANVAS_OUTPAINT_GRAPH,
SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH,
SDXL_MODEL_LOADER, SDXL_MODEL_LOADER,
SDXL_REFINER_DENOISE_LATENTS, SDXL_REFINER_DENOISE_LATENTS,
SDXL_REFINER_MODEL_LOADER, SDXL_REFINER_MODEL_LOADER,
@ -59,21 +63,6 @@ export const addSDXLRefinerToGraph = (
) )
); );
// connect the VAE back to the i2l, which we just removed in the filter
// but only if we are doing l2l
if (baseNodeId === SDXL_DENOISE_LATENTS) {
graph.edges.push({
source: {
node_id: SDXL_MODEL_LOADER,
field: 'vae',
},
destination: {
node_id: IMAGE_TO_LATENTS,
field: 'vae',
},
});
}
graph.nodes[SDXL_REFINER_MODEL_LOADER] = { graph.nodes[SDXL_REFINER_MODEL_LOADER] = {
type: 'sdxl_refiner_model_loader', type: 'sdxl_refiner_model_loader',
id: SDXL_REFINER_MODEL_LOADER, id: SDXL_REFINER_MODEL_LOADER,
@ -112,16 +101,6 @@ export const addSDXLRefinerToGraph = (
field: 'unet', field: 'unet',
}, },
}, },
{
source: {
node_id: SDXL_REFINER_MODEL_LOADER,
field: 'vae',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'vae',
},
},
{ {
source: { source: {
node_id: SDXL_REFINER_MODEL_LOADER, node_id: SDXL_REFINER_MODEL_LOADER,
@ -171,8 +150,25 @@ export const addSDXLRefinerToGraph = (
node_id: SDXL_REFINER_DENOISE_LATENTS, node_id: SDXL_REFINER_DENOISE_LATENTS,
field: 'latents', field: 'latents',
}, },
}
);
if (
graph.id === SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH ||
graph.id === SDXL_CANVAS_IMAGE_TO_IMAGE_GRAPH
) {
graph.edges.push({
source: {
node_id: SDXL_REFINER_DENOISE_LATENTS,
field: 'latents',
}, },
{ destination: {
node_id: CANVAS_OUTPUT,
field: 'latents',
},
});
} else {
graph.edges.push({
source: { source: {
node_id: SDXL_REFINER_DENOISE_LATENTS, node_id: SDXL_REFINER_DENOISE_LATENTS,
field: 'latents', field: 'latents',
@ -181,6 +177,22 @@ export const addSDXLRefinerToGraph = (
node_id: LATENTS_TO_IMAGE, node_id: LATENTS_TO_IMAGE,
field: 'latents', field: 'latents',
}, },
});
}
if (
graph.id === SDXL_CANVAS_INPAINT_GRAPH ||
graph.id === SDXL_CANVAS_OUTPAINT_GRAPH
) {
graph.edges.push({
source: {
node_id: MASK_BLUR,
field: 'image',
},
destination: {
node_id: SDXL_REFINER_DENOISE_LATENTS,
field: 'mask',
},
});
} }
);
}; };