From a953dc1dbd3a1cbf9edfa21314691e6cb32c4d45 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 24 Jun 2024 15:19:51 +1000 Subject: [PATCH] feat(ui): txt2img & img2img graphs --- .../listeners/enqueueRequestedLinear.ts | 17 +- .../util/graph/generation/addNSFWChecker.ts | 2 +- .../util/graph/generation/addSDXLLoRAs.ts | 2 +- .../util/graph/generation/addWatermarker.ts | 2 +- ...Graph.ts => buildImageToImageSDXLGraph.ts} | 41 +-- .../util/graph/generation/buildSD1Graph.ts | 246 ++++++++++++++++++ .../util/graph/generation/buildSDXLGraph.ts | 246 ++++++++++++++++++ ...raph.ts => buildTextToImageSD1SD2Graph.ts} | 36 ++- .../nodes/util/graph/graphBuilderUtils.ts | 10 +- 9 files changed, 562 insertions(+), 40 deletions(-) rename invokeai/frontend/web/src/features/nodes/util/graph/generation/{buildGenerationTabSDXLGraph.ts => buildImageToImageSDXLGraph.ts} (85%) create mode 100644 invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts create mode 100644 invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts rename invokeai/frontend/web/src/features/nodes/util/graph/generation/{buildGenerationTabGraph.ts => buildTextToImageSD1SD2Graph.ts} (85%) diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts index 7748e35998..0e1544b17b 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts @@ -3,9 +3,10 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware' import { getNodeManager } from 'features/controlLayers/konva/nodeManager'; import { isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice'; import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig'; -import { buildGenerationTabGraph } from 'features/nodes/util/graph/generation/buildGenerationTabGraph'; -import { buildGenerationTabSDXLGraph } from 'features/nodes/util/graph/generation/buildGenerationTabSDXLGraph'; +import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph'; +import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph'; import { queueApi } from 'services/api/endpoints/queue'; +import { assert } from 'tsafe'; export const addEnqueueRequestedLinear = (startAppListening: AppStartListening) => { startAppListening({ @@ -20,13 +21,15 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening) let graph; const manager = getNodeManager(); + assert(model, 'No model found in state'); + const base = model.base; - console.log('generation mode', manager.util.getGenerationMode()); - - if (model?.base === 'sdxl') { - graph = await buildGenerationTabSDXLGraph(state, manager); + if (base === 'sdxl') { + graph = await buildSDXLGraph(state, manager); + } else if (base === 'sd-1' || base === 'sd-2') { + graph = await buildSD1Graph(state, manager); } else { - graph = await buildGenerationTabGraph(state, manager); + assert(false, `No graph builders for base ${base}`); } const batchConfig = prepareLinearUIBatch(state, graph, prepend); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addNSFWChecker.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addNSFWChecker.ts index 7850413195..939aa6894c 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addNSFWChecker.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addNSFWChecker.ts @@ -10,7 +10,7 @@ import type { Invocation } from 'services/api/types'; */ export const addNSFWChecker = ( g: Graph, - imageOutput: Invocation<'l2i'> | Invocation<'img_nsfw'> | Invocation<'img_watermark'> + imageOutput: Invocation<'l2i'> | Invocation<'img_nsfw'> | Invocation<'img_watermark'> | Invocation<'img_resize'> ): Invocation<'img_nsfw'> => { const nsfw = g.addNode({ id: NSFW_CHECKER, diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addSDXLLoRAs.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addSDXLLoRAs.ts index d7377da4b0..f274ec9a09 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addSDXLLoRAs.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addSDXLLoRAs.ts @@ -4,7 +4,7 @@ import { LORA_LOADER } from 'features/nodes/util/graph/constants'; import type { Graph } from 'features/nodes/util/graph/generation/Graph'; import type { Invocation, S } from 'services/api/types'; -export const addSDXLLoRas = ( +export const addSDXLLoRAs = ( state: RootState, g: Graph, denoise: Invocation<'denoise_latents'> | Invocation<'tiled_multi_diffusion_denoise_latents'>, diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addWatermarker.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addWatermarker.ts index 2a7af866f8..9111a77630 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addWatermarker.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addWatermarker.ts @@ -10,7 +10,7 @@ import type { Invocation } from 'services/api/types'; */ export const addWatermarker = ( g: Graph, - imageOutput: Invocation<'l2i'> | Invocation<'img_nsfw'> | Invocation<'img_watermark'> + imageOutput: Invocation<'l2i'> | Invocation<'img_nsfw'> | Invocation<'img_watermark'> | Invocation<'img_resize'> ): Invocation<'img_watermark'> => { const watermark = g.addNode({ id: WATERMARKER, diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildGenerationTabSDXLGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildImageToImageSDXLGraph.ts similarity index 85% rename from invokeai/frontend/web/src/features/nodes/util/graph/generation/buildGenerationTabSDXLGraph.ts rename to invokeai/frontend/web/src/features/nodes/util/graph/generation/buildImageToImageSDXLGraph.ts index 9dc82bb237..4dd0e1e056 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildGenerationTabSDXLGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildImageToImageSDXLGraph.ts @@ -16,22 +16,23 @@ import { import { addControlAdapters } from 'features/nodes/util/graph/generation/addControlAdapters'; import { addIPAdapters } from 'features/nodes/util/graph/generation/addIPAdapters'; import { addNSFWChecker } from 'features/nodes/util/graph/generation/addNSFWChecker'; -import { addSDXLLoRas } from 'features/nodes/util/graph/generation/addSDXLLoRAs'; +import { addSDXLLoRAs } from 'features/nodes/util/graph/generation/addSDXLLoRAs'; import { addSDXLRefiner } from 'features/nodes/util/graph/generation/addSDXLRefiner'; import { addSeamless } from 'features/nodes/util/graph/generation/addSeamless'; import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker'; import { Graph } from 'features/nodes/util/graph/generation/Graph'; -import { getBoardField, getPresetModifiedPrompts } from 'features/nodes/util/graph/graphBuilderUtils'; +import { getBoardField, getPresetModifiedPrompts , getSizes } from 'features/nodes/util/graph/graphBuilderUtils'; import type { Invocation, NonNullableGraph } from 'services/api/types'; import { isNonRefinerMainModelConfig } from 'services/api/types'; import { assert } from 'tsafe'; import { addRegions } from './addRegions'; -export const buildGenerationTabSDXLGraph = async ( +export const buildImageToImageSDXLGraph = async ( state: RootState, manager: KonvaNodeManager ): Promise => { + const { bbox, params } = state.canvasV2; const { model, cfgScale: cfg_scale, @@ -42,17 +43,17 @@ export const buildGenerationTabSDXLGraph = async ( shouldUseCpuNoise, vaePrecision, vae, - positivePrompt, - negativePrompt, refinerModel, refinerStart, img2imgStrength, - } = state.canvasV2.params; - const { width, height } = state.canvasV2.bbox; + } = params; assert(model, 'No model found in state'); const { positivePrompt, negativePrompt, positiveStylePrompt, negativeStylePrompt } = getPresetModifiedPrompts(state); + const { originalSize, scaledSize } = getSizes(bbox); + + const g = new Graph(SDXL_CONTROL_LAYERS_GRAPH); const modelLoader = g.addNode({ @@ -80,8 +81,14 @@ export const buildGenerationTabSDXLGraph = async ( type: 'collect', id: NEGATIVE_CONDITIONING_COLLECT, }); - const noise = g.addNode({ type: 'noise', id: NOISE, seed, width, height, use_cpu: shouldUseCpuNoise }); - const i2l = g.addNode({ type: 'i2l', id: 'i2l' }); + const noise = g.addNode({ + type: 'noise', + id: NOISE, + seed, + width: scaledSize.width, + height: scaledSize.height, + use_cpu: shouldUseCpuNoise, + }); const denoise = g.addNode({ type: 'denoise_latents', id: SDXL_DENOISE_LATENTS, @@ -110,7 +117,8 @@ export const buildGenerationTabSDXLGraph = async ( }) : null; - let imageOutput: Invocation<'l2i'> | Invocation<'img_nsfw'> | Invocation<'img_watermark'> = l2i; + let imageOutput: Invocation<'l2i'> | Invocation<'img_nsfw'> | Invocation<'img_watermark'> | Invocation<'img_resize'> = + l2i; g.addEdge(modelLoader, 'unet', denoise, 'unet'); g.addEdge(modelLoader, 'clip', posCond, 'clip'); @@ -122,7 +130,6 @@ export const buildGenerationTabSDXLGraph = async ( g.addEdge(posCondCollect, 'collection', denoise, 'positive_conditioning'); g.addEdge(negCondCollect, 'collection', denoise, 'negative_conditioning'); g.addEdge(noise, 'noise', denoise, 'noise'); - g.addEdge(i2l, 'latents', denoise, 'latents'); g.addEdge(denoise, 'latents', l2i, 'latents'); const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig); @@ -132,8 +139,8 @@ export const buildGenerationTabSDXLGraph = async ( generation_mode: 'sdxl_txt2img', cfg_scale, cfg_rescale_multiplier, - height, - width, + width: scaledSize.width, + height: scaledSize.height, positive_prompt: positivePrompt, negative_prompt: negativePrompt, model: Graph.getModelMetadataField(modelConfig), @@ -148,18 +155,19 @@ export const buildGenerationTabSDXLGraph = async ( const seamless = addSeamless(state, g, denoise, modelLoader, vaeLoader); - addSDXLLoRas(state, g, denoise, modelLoader, seamless, posCond, negCond); + addSDXLLoRAs(state, g, denoise, modelLoader, seamless, posCond, negCond); // We might get the VAE from the main model, custom VAE, or seamless node. const vaeSource = seamless ?? vaeLoader ?? modelLoader; g.addEdge(vaeSource, 'vae', l2i, 'vae'); - g.addEdge(vaeSource, 'vae', i2l, 'vae'); // Add Refiner if enabled if (refinerModel) { await addSDXLRefiner(state, g, denoise, seamless, posCond, negCond, l2i); } + + const _addedCAs = addControlAdapters(state.canvasV2.controlAdapters.entities, g, denoise, modelConfig.base); const _addedIPAs = addIPAdapters(state.canvasV2.ipAdapters.entities, g, denoise, modelConfig.base); const _addedRegions = await addRegions( @@ -175,9 +183,6 @@ export const buildGenerationTabSDXLGraph = async ( posCondCollect, negCondCollect ); - const { image_name } = await manager.util.getImageSourceImage({ bbox: state.canvasV2.bbox, preview: true }); - await manager.util.getInpaintMaskImage({ bbox: state.canvasV2.bbox, preview: true }); - i2l.image = { image_name }; if (state.system.shouldUseNSFWChecker) { imageOutput = addNSFWChecker(g, imageOutput); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts new file mode 100644 index 0000000000..bf219cd160 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts @@ -0,0 +1,246 @@ +import type { RootState } from 'app/store/store'; +import type { KonvaNodeManager } from 'features/controlLayers/konva/nodeManager'; +import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; +import { + CLIP_SKIP, + CONTROL_LAYERS_GRAPH, + DENOISE_LATENTS, + LATENTS_TO_IMAGE, + MAIN_MODEL_LOADER, + NEGATIVE_CONDITIONING, + NEGATIVE_CONDITIONING_COLLECT, + NOISE, + POSITIVE_CONDITIONING, + POSITIVE_CONDITIONING_COLLECT, + VAE_LOADER, +} from 'features/nodes/util/graph/constants'; +import { addControlAdapters } from 'features/nodes/util/graph/generation/addControlAdapters'; +// import { addHRF } from 'features/nodes/util/graph/generation/addHRF'; +import { addIPAdapters } from 'features/nodes/util/graph/generation/addIPAdapters'; +import { addLoRAs } from 'features/nodes/util/graph/generation/addLoRAs'; +import { addNSFWChecker } from 'features/nodes/util/graph/generation/addNSFWChecker'; +import { addSeamless } from 'features/nodes/util/graph/generation/addSeamless'; +import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker'; +import type { GraphType } from 'features/nodes/util/graph/generation/Graph'; +import { Graph } from 'features/nodes/util/graph/generation/Graph'; +import { getBoardField, getSizes } from 'features/nodes/util/graph/graphBuilderUtils'; +import { isEqual, pick } from 'lodash-es'; +import type { Invocation } from 'services/api/types'; +import { isNonRefinerMainModelConfig } from 'services/api/types'; +import { assert } from 'tsafe'; + +import { addRegions } from './addRegions'; + +export const buildSD1Graph = async (state: RootState, manager: KonvaNodeManager): Promise => { + const generationMode = manager.util.getGenerationMode(); + + const { bbox, params } = state.canvasV2; + + const { + model, + cfgScale: cfg_scale, + cfgRescaleMultiplier: cfg_rescale_multiplier, + scheduler, + steps, + clipSkip: skipped_layers, + shouldUseCpuNoise, + vaePrecision, + seed, + vae, + positivePrompt, + negativePrompt, + img2imgStrength, + } = params; + + assert(model, 'No model found in state'); + + const { originalSize, scaledSize } = getSizes(bbox); + + const g = new Graph(CONTROL_LAYERS_GRAPH); + const modelLoader = g.addNode({ + type: 'main_model_loader', + id: MAIN_MODEL_LOADER, + model, + }); + const clipSkip = g.addNode({ + type: 'clip_skip', + id: CLIP_SKIP, + skipped_layers, + }); + const posCond = g.addNode({ + type: 'compel', + id: POSITIVE_CONDITIONING, + prompt: positivePrompt, + }); + const posCondCollect = g.addNode({ + type: 'collect', + id: POSITIVE_CONDITIONING_COLLECT, + }); + const negCond = g.addNode({ + type: 'compel', + id: NEGATIVE_CONDITIONING, + prompt: negativePrompt, + }); + const negCondCollect = g.addNode({ + type: 'collect', + id: NEGATIVE_CONDITIONING_COLLECT, + }); + const noise = g.addNode({ + type: 'noise', + id: NOISE, + seed, + width: scaledSize.width, + height: scaledSize.height, + use_cpu: shouldUseCpuNoise, + }); + const denoise = g.addNode({ + type: 'denoise_latents', + id: DENOISE_LATENTS, + cfg_scale, + cfg_rescale_multiplier, + scheduler, + steps, + denoising_start: 0, + denoising_end: 1, + }); + const l2i = g.addNode({ + type: 'l2i', + id: LATENTS_TO_IMAGE, + fp32: vaePrecision === 'fp32', + board: getBoardField(state), + // This is the terminal node and must always save to gallery. + is_intermediate: false, + use_cache: false, + }); + const vaeLoader = + vae?.base === model.base + ? g.addNode({ + type: 'vae_loader', + id: VAE_LOADER, + vae_model: vae, + }) + : null; + + let imageOutput: Invocation<'l2i'> | Invocation<'img_nsfw'> | Invocation<'img_watermark'> | Invocation<'img_resize'> = + l2i; + + g.addEdge(modelLoader, 'unet', denoise, 'unet'); + g.addEdge(modelLoader, 'clip', clipSkip, 'clip'); + g.addEdge(clipSkip, 'clip', posCond, 'clip'); + g.addEdge(clipSkip, 'clip', negCond, 'clip'); + g.addEdge(posCond, 'conditioning', posCondCollect, 'item'); + g.addEdge(negCond, 'conditioning', negCondCollect, 'item'); + g.addEdge(posCondCollect, 'collection', denoise, 'positive_conditioning'); + g.addEdge(negCondCollect, 'collection', denoise, 'negative_conditioning'); + g.addEdge(noise, 'noise', denoise, 'noise'); + g.addEdge(denoise, 'latents', l2i, 'latents'); + + const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig); + assert(modelConfig.base === 'sd-1' || modelConfig.base === 'sd-2'); + + g.upsertMetadata({ + generation_mode: 'txt2img', + cfg_scale, + cfg_rescale_multiplier, + width: scaledSize.width, + height: scaledSize.height, + positive_prompt: positivePrompt, + negative_prompt: negativePrompt, + model: Graph.getModelMetadataField(modelConfig), + seed, + steps, + rand_device: shouldUseCpuNoise ? 'cpu' : 'cuda', + scheduler, + clip_skip: skipped_layers, + vae: vae ?? undefined, + }); + + const seamless = addSeamless(state, g, denoise, modelLoader, vaeLoader); + + addLoRAs(state, g, denoise, modelLoader, seamless, clipSkip, posCond, negCond); + + // We might get the VAE from the main model, custom VAE, or seamless node. + const vaeSource = seamless ?? vaeLoader ?? modelLoader; + g.addEdge(vaeSource, 'vae', l2i, 'vae'); + + if (generationMode === 'txt2img') { + if (!isEqual(scaledSize, originalSize)) { + // We are using scaled bbox and need to resize the output image back to the original size. + imageOutput = g.addNode({ + id: 'img_resize', + type: 'img_resize', + ...originalSize, + is_intermediate: false, + use_cache: false, + }); + g.addEdge(l2i, 'image', imageOutput, 'image'); + } + } else if (generationMode === 'img2img') { + const { image_name } = await manager.util.getImageSourceImage({ + bbox: pick(bbox, ['x', 'y', 'width', 'height']), + preview: true, + }); + + denoise.denoising_start = 1 - img2imgStrength; + + if (!isEqual(scaledSize, originalSize)) { + // We are using scaled bbox and need to resize the output image back to the original size. + const initialImageResize = g.addNode({ + id: 'initial_image_resize', + type: 'img_resize', + ...scaledSize, + image: { image_name }, + }); + const i2l = g.addNode({ id: 'i2l', type: 'i2l' }); + + g.addEdge(vaeSource, 'vae', i2l, 'vae'); + g.addEdge(initialImageResize, 'image', i2l, 'image'); + g.addEdge(i2l, 'latents', denoise, 'latents'); + + imageOutput = g.addNode({ + id: 'img_resize', + type: 'img_resize', + ...originalSize, + is_intermediate: false, + use_cache: false, + }); + g.addEdge(l2i, 'image', imageOutput, 'image'); + } else { + const i2l = g.addNode({ id: 'i2l', type: 'i2l', image: { image_name } }); + g.addEdge(vaeSource, 'vae', i2l, 'vae'); + g.addEdge(i2l, 'latents', denoise, 'latents'); + } + } + + const _addedCAs = addControlAdapters(state.canvasV2.controlAdapters.entities, g, denoise, modelConfig.base); + const _addedIPAs = addIPAdapters(state.canvasV2.ipAdapters.entities, g, denoise, modelConfig.base); + const _addedRegions = await addRegions( + manager, + state.canvasV2.regions.entities, + g, + state.canvasV2.document, + state.canvasV2.bbox, + modelConfig.base, + denoise, + posCond, + negCond, + posCondCollect, + negCondCollect + ); + + // const isHRFAllowed = !addedLayers.some((l) => isInitialImageLayer(l) || isRegionalGuidanceLayer(l)); + // if (isHRFAllowed && state.hrf.hrfEnabled) { + // imageOutput = addHRF(state, g, denoise, noise, l2i, vaeSource); + // } + + if (state.system.shouldUseNSFWChecker) { + imageOutput = addNSFWChecker(g, imageOutput); + } + + if (state.system.shouldUseWatermarker) { + imageOutput = addWatermarker(g, imageOutput); + } + + g.setMetadataReceivingNode(imageOutput); + return g.getGraph(); +}; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts new file mode 100644 index 0000000000..5523f77795 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts @@ -0,0 +1,246 @@ +import type { RootState } from 'app/store/store'; +import type { KonvaNodeManager } from 'features/controlLayers/konva/nodeManager'; +import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; +import { + LATENTS_TO_IMAGE, + NEGATIVE_CONDITIONING, + NEGATIVE_CONDITIONING_COLLECT, + NOISE, + POSITIVE_CONDITIONING, + POSITIVE_CONDITIONING_COLLECT, + SDXL_CONTROL_LAYERS_GRAPH, + SDXL_DENOISE_LATENTS, + SDXL_MODEL_LOADER, + VAE_LOADER, +} from 'features/nodes/util/graph/constants'; +import { addControlAdapters } from 'features/nodes/util/graph/generation/addControlAdapters'; +import { addIPAdapters } from 'features/nodes/util/graph/generation/addIPAdapters'; +import { addNSFWChecker } from 'features/nodes/util/graph/generation/addNSFWChecker'; +import { addSDXLLoRAs } from 'features/nodes/util/graph/generation/addSDXLLoRAs'; +import { addSDXLRefiner } from 'features/nodes/util/graph/generation/addSDXLRefiner'; +import { addSeamless } from 'features/nodes/util/graph/generation/addSeamless'; +import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker'; +import { Graph } from 'features/nodes/util/graph/generation/Graph'; +import { getBoardField, getSDXLStylePrompts, getSizes } from 'features/nodes/util/graph/graphBuilderUtils'; +import { isEqual, pick } from 'lodash-es'; +import type { Invocation, NonNullableGraph } from 'services/api/types'; +import { isNonRefinerMainModelConfig } from 'services/api/types'; +import { assert } from 'tsafe'; + +import { addRegions } from './addRegions'; + +export const buildSDXLGraph = async (state: RootState, manager: KonvaNodeManager): Promise => { + const generationMode = manager.util.getGenerationMode(); + + const { bbox, params } = state.canvasV2; + + const { + model, + cfgScale: cfg_scale, + cfgRescaleMultiplier: cfg_rescale_multiplier, + scheduler, + seed, + steps, + shouldUseCpuNoise, + vaePrecision, + vae, + positivePrompt, + negativePrompt, + refinerModel, + refinerStart, + img2imgStrength, + } = params; + + assert(model, 'No model found in state'); + + const { originalSize, scaledSize } = getSizes(bbox); + + const { positiveStylePrompt, negativeStylePrompt } = getSDXLStylePrompts(state); + + const g = new Graph(SDXL_CONTROL_LAYERS_GRAPH); + const modelLoader = g.addNode({ + type: 'sdxl_model_loader', + id: SDXL_MODEL_LOADER, + model, + }); + const posCond = g.addNode({ + type: 'sdxl_compel_prompt', + id: POSITIVE_CONDITIONING, + prompt: positivePrompt, + style: positiveStylePrompt, + }); + const posCondCollect = g.addNode({ + type: 'collect', + id: POSITIVE_CONDITIONING_COLLECT, + }); + const negCond = g.addNode({ + type: 'sdxl_compel_prompt', + id: NEGATIVE_CONDITIONING, + prompt: negativePrompt, + style: negativeStylePrompt, + }); + const negCondCollect = g.addNode({ + type: 'collect', + id: NEGATIVE_CONDITIONING_COLLECT, + }); + const noise = g.addNode({ + type: 'noise', + id: NOISE, + seed, + width: scaledSize.width, + height: scaledSize.height, + use_cpu: shouldUseCpuNoise, + }); + const denoise = g.addNode({ + type: 'denoise_latents', + id: SDXL_DENOISE_LATENTS, + cfg_scale, + cfg_rescale_multiplier, + scheduler, + steps, + denoising_start: 0, + denoising_end: refinerModel ? refinerStart : 1, + }); + const l2i = g.addNode({ + type: 'l2i', + id: LATENTS_TO_IMAGE, + fp32: vaePrecision === 'fp32', + board: getBoardField(state), + // This is the terminal node and must always save to gallery. + is_intermediate: false, + use_cache: false, + }); + const vaeLoader = + vae?.base === model.base + ? g.addNode({ + type: 'vae_loader', + id: VAE_LOADER, + vae_model: vae, + }) + : null; + + let imageOutput: Invocation<'l2i'> | Invocation<'img_nsfw'> | Invocation<'img_watermark'> | Invocation<'img_resize'> = + l2i; + + g.addEdge(modelLoader, 'unet', denoise, 'unet'); + g.addEdge(modelLoader, 'clip', posCond, 'clip'); + g.addEdge(modelLoader, 'clip', negCond, 'clip'); + g.addEdge(modelLoader, 'clip2', posCond, 'clip2'); + g.addEdge(modelLoader, 'clip2', negCond, 'clip2'); + g.addEdge(posCond, 'conditioning', posCondCollect, 'item'); + g.addEdge(negCond, 'conditioning', negCondCollect, 'item'); + g.addEdge(posCondCollect, 'collection', denoise, 'positive_conditioning'); + g.addEdge(negCondCollect, 'collection', denoise, 'negative_conditioning'); + g.addEdge(noise, 'noise', denoise, 'noise'); + g.addEdge(denoise, 'latents', l2i, 'latents'); + + const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig); + assert(modelConfig.base === 'sdxl'); + + g.upsertMetadata({ + generation_mode: 'sdxl_txt2img', + cfg_scale, + cfg_rescale_multiplier, + width: scaledSize.width, + height: scaledSize.height, + positive_prompt: positivePrompt, + negative_prompt: negativePrompt, + model: Graph.getModelMetadataField(modelConfig), + seed, + steps, + rand_device: shouldUseCpuNoise ? 'cpu' : 'cuda', + scheduler, + positive_style_prompt: positiveStylePrompt, + negative_style_prompt: negativeStylePrompt, + vae: vae ?? undefined, + }); + + const seamless = addSeamless(state, g, denoise, modelLoader, vaeLoader); + + addSDXLLoRAs(state, g, denoise, modelLoader, seamless, posCond, negCond); + + // We might get the VAE from the main model, custom VAE, or seamless node. + const vaeSource = seamless ?? vaeLoader ?? modelLoader; + g.addEdge(vaeSource, 'vae', l2i, 'vae'); + + // Add Refiner if enabled + if (refinerModel) { + await addSDXLRefiner(state, g, denoise, seamless, posCond, negCond, l2i); + } + + if (generationMode === 'txt2img') { + if (!isEqual(scaledSize, originalSize)) { + // We are using scaled bbox and need to resize the output image back to the original size. + imageOutput = g.addNode({ + id: 'img_resize', + type: 'img_resize', + ...originalSize, + is_intermediate: false, + use_cache: false, + }); + g.addEdge(l2i, 'image', imageOutput, 'image'); + } + } else if (generationMode === 'img2img') { + denoise.denoising_start = refinerModel ? Math.min(refinerStart, 1 - img2imgStrength) : 1 - img2imgStrength; + + const { image_name } = await manager.util.getImageSourceImage({ + bbox: pick(bbox, ['x', 'y', 'width', 'height']), + preview: true, + }); + + if (!isEqual(scaledSize, originalSize)) { + // We are using scaled bbox and need to resize the output image back to the original size. + const initialImageResize = g.addNode({ + id: 'initial_image_resize', + type: 'img_resize', + ...scaledSize, + image: { image_name }, + }); + const i2l = g.addNode({ id: 'i2l', type: 'i2l' }); + + g.addEdge(vaeSource, 'vae', i2l, 'vae'); + g.addEdge(initialImageResize, 'image', i2l, 'image'); + g.addEdge(i2l, 'latents', denoise, 'latents'); + + imageOutput = g.addNode({ + id: 'img_resize', + type: 'img_resize', + ...originalSize, + is_intermediate: false, + use_cache: false, + }); + g.addEdge(l2i, 'image', imageOutput, 'image'); + } else { + const i2l = g.addNode({ id: 'i2l', type: 'i2l', image: { image_name } }); + g.addEdge(vaeSource, 'vae', i2l, 'vae'); + g.addEdge(i2l, 'latents', denoise, 'latents'); + } + } + + const _addedCAs = addControlAdapters(state.canvasV2.controlAdapters.entities, g, denoise, modelConfig.base); + const _addedIPAs = addIPAdapters(state.canvasV2.ipAdapters.entities, g, denoise, modelConfig.base); + const _addedRegions = await addRegions( + manager, + state.canvasV2.regions.entities, + g, + state.canvasV2.document, + state.canvasV2.bbox, + modelConfig.base, + denoise, + posCond, + negCond, + posCondCollect, + negCondCollect + ); + + if (state.system.shouldUseNSFWChecker) { + imageOutput = addNSFWChecker(g, imageOutput); + } + + if (state.system.shouldUseWatermarker) { + imageOutput = addWatermarker(g, imageOutput); + } + + g.setMetadataReceivingNode(imageOutput); + return g.getGraph(); +}; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildGenerationTabGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildTextToImageSD1SD2Graph.ts similarity index 85% rename from invokeai/frontend/web/src/features/nodes/util/graph/generation/buildGenerationTabGraph.ts rename to invokeai/frontend/web/src/features/nodes/util/graph/generation/buildTextToImageSD1SD2Graph.ts index cc859780f2..044cca05ce 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildGenerationTabGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildTextToImageSD1SD2Graph.ts @@ -23,14 +23,17 @@ import { addSeamless } from 'features/nodes/util/graph/generation/addSeamless'; import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker'; import type { GraphType } from 'features/nodes/util/graph/generation/Graph'; import { Graph } from 'features/nodes/util/graph/generation/Graph'; -import { getBoardField, getPresetModifiedPrompts } from 'features/nodes/util/graph/graphBuilderUtils'; +import { getBoardField, getPresetModifiedPrompts , getSizes } from 'features/nodes/util/graph/graphBuilderUtils'; +import { isEqual } from 'lodash-es'; import type { Invocation } from 'services/api/types'; import { isNonRefinerMainModelConfig } from 'services/api/types'; import { assert } from 'tsafe'; import { addRegions } from './addRegions'; -export const buildGenerationTabGraph = async (state: RootState, manager: KonvaNodeManager): Promise => { +export const buildTextToImageSD1SD2Graph = async (state: RootState, manager: KonvaNodeManager): Promise => { + const { bbox, params } = state.canvasV2; + const { model, cfgScale: cfg_scale, @@ -42,14 +45,12 @@ export const buildGenerationTabGraph = async (state: RootState, manager: KonvaNo vaePrecision, seed, vae, - positivePrompt, - negativePrompt, - } = state.canvasV2.params; - const { width, height } = state.canvasV2.document; + } = params; assert(model, 'No model found in state'); const { positivePrompt, negativePrompt } = getPresetModifiedPrompts(state); + const { originalSize, scaledSize } = getSizes(bbox); const g = new Graph(CONTROL_LAYERS_GRAPH); const modelLoader = g.addNode({ @@ -84,8 +85,8 @@ export const buildGenerationTabGraph = async (state: RootState, manager: KonvaNo type: 'noise', id: NOISE, seed, - width, - height, + width: scaledSize.width, + height: scaledSize.height, use_cpu: shouldUseCpuNoise, }); const denoise = g.addNode({ @@ -116,7 +117,8 @@ export const buildGenerationTabGraph = async (state: RootState, manager: KonvaNo }) : null; - let imageOutput: Invocation<'l2i'> | Invocation<'img_nsfw'> | Invocation<'img_watermark'> = l2i; + let imageOutput: Invocation<'l2i'> | Invocation<'img_nsfw'> | Invocation<'img_watermark'> | Invocation<'img_resize'> = + l2i; g.addEdge(modelLoader, 'unet', denoise, 'unet'); g.addEdge(modelLoader, 'clip', clipSkip, 'clip'); @@ -136,8 +138,8 @@ export const buildGenerationTabGraph = async (state: RootState, manager: KonvaNo generation_mode: 'txt2img', cfg_scale, cfg_rescale_multiplier, - height, - width, + width: scaledSize.width, + height: scaledSize.height, positive_prompt: positivePrompt, negative_prompt: negativePrompt, model: Graph.getModelMetadataField(modelConfig), @@ -157,6 +159,18 @@ export const buildGenerationTabGraph = async (state: RootState, manager: KonvaNo const vaeSource = seamless ?? vaeLoader ?? modelLoader; g.addEdge(vaeSource, 'vae', l2i, 'vae'); + if (!isEqual(scaledSize, originalSize)) { + // We are using scaled bbox and need to resize the output image back to the original size. + imageOutput = g.addNode({ + id: 'img_resize', + type: 'img_resize', + ...originalSize, + is_intermediate: false, + use_cache: false, + }); + g.addEdge(l2i, 'image', imageOutput, 'image'); + } + const _addedCAs = addControlAdapters(state.canvasV2.controlAdapters.entities, g, denoise, modelConfig.base); const _addedIPAs = addIPAdapters(state.canvasV2.ipAdapters.entities, g, denoise, modelConfig.base); const _addedRegions = await addRegions( diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/graphBuilderUtils.ts b/invokeai/frontend/web/src/features/nodes/util/graph/graphBuilderUtils.ts index 90151797ea..419c7aac28 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/graphBuilderUtils.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/graphBuilderUtils.ts @@ -1,7 +1,9 @@ import type { RootState } from 'app/store/store'; +import type { CanvasV2State } from 'features/controlLayers/store/types'; import type { BoardField } from 'features/nodes/types/common'; import { buildPresetModifiedPrompt } from 'features/stylePresets/hooks/usePresetModifiedPrompts'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; +import { pick } from 'lodash-es'; import { stylePresetsApi } from 'services/api/endpoints/stylePresets'; /** @@ -22,7 +24,7 @@ export const getPresetModifiedPrompts = ( state: RootState ): { positivePrompt: string; negativePrompt: string; positiveStylePrompt?: string; negativeStylePrompt?: string } => { const { positivePrompt, negativePrompt, positivePrompt2, negativePrompt2, shouldConcatPrompts } = - state.generation; + state.canvasV2.params; const { activeStylePresetId } = state.stylePreset; if (activeStylePresetId) { @@ -68,3 +70,9 @@ export const getIsIntermediate = (state: RootState) => { } return false; }; + +export const getSizes = (bboxState: CanvasV2State['bbox']) => { + const originalSize = pick(bboxState, 'width', 'height'); + const scaledSize = ['auto', 'manual'].includes(bboxState.scaleMethod) ? bboxState.scaledSize : originalSize; + return { originalSize, scaledSize }; +};