From dc81357152764f7078dfdcf305621455a5b5c34d Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 2 May 2024 22:29:39 +1000 Subject: [PATCH] feat(ui): add img2img via control layers to graph builders --- .../graph/addInitialImageToLinearGraph.ts | 117 ++++++++++++++++++ .../util/graph/addSeamlessToLinearGraph.ts | 2 + .../nodes/util/graph/addVAEToGraph.ts | 7 +- .../graph/buildLinearSDXLTextToImageGraph.ts | 7 +- .../util/graph/buildLinearTextToImageGraph.ts | 7 +- .../features/nodes/util/graph/constants.ts | 2 + 6 files changed, 136 insertions(+), 6 deletions(-) create mode 100644 invokeai/frontend/web/src/features/nodes/util/graph/addInitialImageToLinearGraph.ts diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addInitialImageToLinearGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addInitialImageToLinearGraph.ts new file mode 100644 index 0000000000..4334a7cd31 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addInitialImageToLinearGraph.ts @@ -0,0 +1,117 @@ +import type { RootState } from 'app/store/store'; +import { isInitialImageLayer } from 'features/controlLayers/store/controlLayersSlice'; +import type { ImageResizeInvocation, ImageToLatentsInvocation, NonNullableGraph } from 'services/api/types'; +import { assert } from 'tsafe'; + +import { IMAGE_TO_LATENTS, NOISE, RESIZE } from './constants'; + +export const addInitialImageToLinearGraph = ( + state: RootState, + graph: NonNullableGraph, + denoiseNodeId: string +): void => { + // Remove Existing UNet Connections + const { img2imgStrength, vaePrecision, model } = state.generation; + const { refinerModel, refinerStart } = state.sdxl; + const { width, height } = state.controlLayers.present.size; + const initialImageLayer = state.controlLayers.present.layers.find(isInitialImageLayer); + const initialImage = initialImageLayer?.isEnabled ? initialImageLayer?.image : null; + + if (!initialImage) { + return; + } + + const useRefinerStartEnd = model?.base === 'sdxl' && Boolean(refinerModel); + + const denoiseNode = graph.nodes[denoiseNodeId]; + assert(denoiseNode?.type === 'denoise_latents', `Missing denoise node or incorrect type: ${denoiseNode?.type}`); + + denoiseNode.denoising_start = useRefinerStartEnd ? Math.min(refinerStart, 1 - img2imgStrength) : 1 - img2imgStrength; + denoiseNode.denoising_end = useRefinerStartEnd ? refinerStart : 1; + + // We conditionally hook the image in depending on if a resize is needed + const i2lNode: ImageToLatentsInvocation = { + type: 'i2l', + id: IMAGE_TO_LATENTS, + is_intermediate: true, + use_cache: true, + fp32: vaePrecision === 'fp32', + }; + + graph.nodes[i2lNode.id] = i2lNode; + graph.edges.push({ + source: { + node_id: IMAGE_TO_LATENTS, + field: 'latents', + }, + destination: { + node_id: denoiseNode.id, + field: 'latents', + }, + }); + + if (initialImage.width !== width || initialImage.height !== height) { + // The init image needs to be resized to the specified width and height before being passed to `IMAGE_TO_LATENTS` + + // Create a resize node, explicitly setting its image + const resizeNode: ImageResizeInvocation = { + id: RESIZE, + type: 'img_resize', + image: { + image_name: initialImage.imageName, + }, + is_intermediate: true, + width, + height, + }; + + graph.nodes[RESIZE] = resizeNode; + + // The `RESIZE` node then passes its image to `IMAGE_TO_LATENTS` + graph.edges.push({ + source: { node_id: RESIZE, field: 'image' }, + destination: { + node_id: IMAGE_TO_LATENTS, + field: 'image', + }, + }); + + // The `RESIZE` node also passes its width and height to `NOISE` + graph.edges.push({ + source: { node_id: RESIZE, field: 'width' }, + destination: { + node_id: NOISE, + field: 'width', + }, + }); + + graph.edges.push({ + source: { node_id: RESIZE, field: 'height' }, + destination: { + node_id: NOISE, + field: 'height', + }, + }); + } else { + // We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly + i2lNode.image = { + image_name: initialImage.imageName, + }; + + // Pass the image's dimensions to the `NOISE` node + graph.edges.push({ + source: { node_id: IMAGE_TO_LATENTS, field: 'width' }, + destination: { + node_id: NOISE, + field: 'width', + }, + }); + graph.edges.push({ + source: { node_id: IMAGE_TO_LATENTS, field: 'height' }, + destination: { + node_id: NOISE, + field: 'height', + }, + }); + } +}; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addSeamlessToLinearGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addSeamlessToLinearGraph.ts index d986130d64..24e8be6546 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addSeamlessToLinearGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addSeamlessToLinearGraph.ts @@ -7,6 +7,7 @@ import { SDXL_CANVAS_INPAINT_GRAPH, SDXL_CANVAS_OUTPAINT_GRAPH, SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH, + SDXL_CONTROL_LAYERS_GRAPH, SDXL_DENOISE_LATENTS, SDXL_IMAGE_TO_IMAGE_GRAPH, SDXL_TEXT_TO_IMAGE_GRAPH, @@ -54,6 +55,7 @@ export const addSeamlessToLinearGraph = ( let denoisingNodeId = DENOISE_LATENTS; if ( + graph.id === SDXL_CONTROL_LAYERS_GRAPH || graph.id === SDXL_TEXT_TO_IMAGE_GRAPH || graph.id === SDXL_IMAGE_TO_IMAGE_GRAPH || graph.id === SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH || diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addVAEToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addVAEToGraph.ts index 347027c539..ed705a08e6 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addVAEToGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addVAEToGraph.ts @@ -7,6 +7,7 @@ import { CANVAS_OUTPAINT_GRAPH, CANVAS_OUTPUT, CANVAS_TEXT_TO_IMAGE_GRAPH, + CONTROL_LAYERS_GRAPH, IMAGE_TO_IMAGE_GRAPH, IMAGE_TO_LATENTS, INPAINT_CREATE_MASK, @@ -17,11 +18,11 @@ import { SDXL_CANVAS_INPAINT_GRAPH, SDXL_CANVAS_OUTPAINT_GRAPH, SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH, + SDXL_CONTROL_LAYERS_GRAPH, SDXL_IMAGE_TO_IMAGE_GRAPH, SDXL_REFINER_SEAMLESS, SDXL_TEXT_TO_IMAGE_GRAPH, SEAMLESS, - TEXT_TO_IMAGE_GRAPH, VAE_LOADER, } from './constants'; import { upsertMetadata } from './metadata'; @@ -52,7 +53,8 @@ export const addVAEToGraph = async ( } if ( - graph.id === TEXT_TO_IMAGE_GRAPH || + graph.id === CONTROL_LAYERS_GRAPH || + graph.id === SDXL_CONTROL_LAYERS_GRAPH || graph.id === IMAGE_TO_IMAGE_GRAPH || graph.id === SDXL_TEXT_TO_IMAGE_GRAPH || graph.id === SDXL_IMAGE_TO_IMAGE_GRAPH @@ -100,6 +102,7 @@ export const addVAEToGraph = async ( } if ( + graph.id === SDXL_CONTROL_LAYERS_GRAPH || graph.id === IMAGE_TO_IMAGE_GRAPH || graph.id === SDXL_IMAGE_TO_IMAGE_GRAPH || graph.id === CANVAS_IMAGE_TO_IMAGE_GRAPH || diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearSDXLTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearSDXLTextToImageGraph.ts index 9134ef9de7..d3ee9e4a51 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearSDXLTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearSDXLTextToImageGraph.ts @@ -2,6 +2,7 @@ import { logger } from 'app/logging/logger'; import type { RootState } from 'app/store/store'; import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; import { addControlLayersToGraph } from 'features/nodes/util/graph/addControlLayersToGraph'; +import { addInitialImageToLinearGraph } from 'features/nodes/util/graph/addInitialImageToLinearGraph'; import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; @@ -15,10 +16,10 @@ import { NEGATIVE_CONDITIONING, NOISE, POSITIVE_CONDITIONING, + SDXL_CONTROL_LAYERS_GRAPH, SDXL_DENOISE_LATENTS, SDXL_MODEL_LOADER, SDXL_REFINER_SEAMLESS, - SDXL_TEXT_TO_IMAGE_GRAPH, SEAMLESS, } from './constants'; import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils'; @@ -70,7 +71,7 @@ export const buildLinearSDXLTextToImageGraph = async (state: RootState): Promise // copy-pasted graph from node editor, filled in with state values & friendly node ids const graph: NonNullableGraph = { - id: SDXL_TEXT_TO_IMAGE_GRAPH, + id: SDXL_CONTROL_LAYERS_GRAPH, nodes: { [modelLoaderNodeId]: { type: 'sdxl_model_loader', @@ -241,6 +242,8 @@ export const buildLinearSDXLTextToImageGraph = async (state: RootState): Promise LATENTS_TO_IMAGE ); + addInitialImageToLinearGraph(state, graph, SDXL_DENOISE_LATENTS); + // Add Seamless To Graph if (seamlessXAxis || seamlessYAxis) { addSeamlessToLinearGraph(state, graph, modelLoaderNodeId); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearTextToImageGraph.ts index 340a24bca4..c2ad5384e0 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearTextToImageGraph.ts @@ -2,6 +2,7 @@ import { logger } from 'app/logging/logger'; import type { RootState } from 'app/store/store'; import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; import { addControlLayersToGraph } from 'features/nodes/util/graph/addControlLayersToGraph'; +import { addInitialImageToLinearGraph } from 'features/nodes/util/graph/addInitialImageToLinearGraph'; import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils'; import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types'; @@ -13,6 +14,7 @@ import { addVAEToGraph } from './addVAEToGraph'; import { addWatermarkerToGraph } from './addWatermarkerToGraph'; import { CLIP_SKIP, + CONTROL_LAYERS_GRAPH, DENOISE_LATENTS, LATENTS_TO_IMAGE, MAIN_MODEL_LOADER, @@ -20,7 +22,6 @@ import { NOISE, POSITIVE_CONDITIONING, SEAMLESS, - TEXT_TO_IMAGE_GRAPH, } from './constants'; import { addCoreMetadataNode, getModelMetadataField } from './metadata'; @@ -66,7 +67,7 @@ export const buildLinearTextToImageGraph = async (state: RootState): Promise