From 5425526d5027f1d59768633eeaacb12d47cc6a1f Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 13 May 2024 20:58:51 +1000 Subject: [PATCH] feat(ui): use graph builder for generation tab sdxl --- .../listeners/enqueueRequestedLinear.ts | 4 +- .../graph/addGenerationTabControlLayers.ts | 53 ++++-- .../util/graph/addGenerationTabSDXLLoRAs.ts | 75 ++++++++ .../util/graph/addGenerationTabSDXLRefiner.ts | 104 ++++++++++ .../graph/buildGenerationTabSDXLGraph2.ts | 178 ++++++++++++++++++ 5 files changed, 397 insertions(+), 17 deletions(-) create mode 100644 invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabSDXLLoRAs.ts create mode 100644 invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabSDXLRefiner.ts create mode 100644 invokeai/frontend/web/src/features/nodes/util/graph/buildGenerationTabSDXLGraph2.ts diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts index bbb77c9ac5..a2d9f253a1 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts @@ -2,7 +2,7 @@ import { enqueueRequested } from 'app/store/actions'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import { isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice'; import { buildGenerationTabGraph2 } from 'features/nodes/util/graph/buildGenerationTabGraph2'; -import { buildGenerationTabSDXLGraph } from 'features/nodes/util/graph/buildGenerationTabSDXLGraph'; +import { buildGenerationTabSDXLGraph2 } from 'features/nodes/util/graph/buildGenerationTabSDXLGraph2'; import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig'; import { queueApi } from 'services/api/endpoints/queue'; @@ -19,7 +19,7 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening) let graph; if (model && model.base === 'sdxl') { - graph = await buildGenerationTabSDXLGraph(state); + graph = await buildGenerationTabSDXLGraph2(state); } else { graph = await buildGenerationTabGraph2(state); } diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabControlLayers.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabControlLayers.ts index 88a9c0859b..3c7c0c9c66 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabControlLayers.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabControlLayers.ts @@ -118,11 +118,20 @@ export const addGenerationTabControlLayers = async ( // Connect the conditioning to the collector g.addEdge(regionalPosCond, 'conditioning', posCondCollect, 'item'); // Copy the connections to the "global" positive conditioning node to the regional cond - for (const edge of g.getEdgesTo(posCond, ['clip', 'mask'])) { - // Clone the edge, but change the destination node to the regional conditioning node - const clone = deepClone(edge); - clone.destination.node_id = regionalPosCond.id; - g.addEdgeFromObj(clone); + if (posCond.type === 'compel') { + for (const edge of g.getEdgesTo(posCond, ['clip', 'mask'])) { + // Clone the edge, but change the destination node to the regional conditioning node + const clone = deepClone(edge); + clone.destination.node_id = regionalPosCond.id; + g.addEdgeFromObj(clone); + } + } else { + for (const edge of g.getEdgesTo(posCond, ['clip', 'clip2', 'mask'])) { + // Clone the edge, but change the destination node to the regional conditioning node + const clone = deepClone(edge); + clone.destination.node_id = regionalPosCond.id; + g.addEdgeFromObj(clone); + } } } @@ -147,11 +156,18 @@ export const addGenerationTabControlLayers = async ( // Connect the conditioning to the collector g.addEdge(regionalNegCond, 'conditioning', negCondCollect, 'item'); // Copy the connections to the "global" negative conditioning node to the regional cond - for (const edge of g.getEdgesTo(negCond, ['clip', 'mask'])) { - // Clone the edge, but change the destination node to the regional conditioning node - const clone = deepClone(edge); - clone.destination.node_id = regionalNegCond.id; - g.addEdgeFromObj(clone); + if (negCond.type === 'compel') { + for (const edge of g.getEdgesTo(negCond, ['clip', 'mask'])) { + const clone = deepClone(edge); + clone.destination.node_id = regionalNegCond.id; + g.addEdgeFromObj(clone); + } + } else { + for (const edge of g.getEdgesTo(negCond, ['clip', 'clip2', 'mask'])) { + const clone = deepClone(edge); + clone.destination.node_id = regionalNegCond.id; + g.addEdgeFromObj(clone); + } } } @@ -184,11 +200,18 @@ export const addGenerationTabControlLayers = async ( // Connect the conditioning to the negative collector g.addEdge(regionalPosCondInverted, 'conditioning', negCondCollect, 'item'); // Copy the connections to the "global" positive conditioning node to our regional node - for (const edge of g.getEdgesTo(posCond, ['clip', 'mask'])) { - // Clone the edge, but change the destination node to the regional conditioning node - const clone = deepClone(edge); - clone.destination.node_id = regionalPosCondInverted.id; - g.addEdgeFromObj(clone); + if (posCond.type === 'compel') { + for (const edge of g.getEdgesTo(posCond, ['clip', 'mask'])) { + const clone = deepClone(edge); + clone.destination.node_id = regionalPosCondInverted.id; + g.addEdgeFromObj(clone); + } + } else { + for (const edge of g.getEdgesTo(posCond, ['clip', 'clip2', 'mask'])) { + const clone = deepClone(edge); + clone.destination.node_id = regionalPosCondInverted.id; + g.addEdgeFromObj(clone); + } } } diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabSDXLLoRAs.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabSDXLLoRAs.ts new file mode 100644 index 0000000000..89f1f8f18e --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabSDXLLoRAs.ts @@ -0,0 +1,75 @@ +import type { RootState } from 'app/store/store'; +import { zModelIdentifierField } from 'features/nodes/types/common'; +import type { Graph } from 'features/nodes/util/graph/Graph'; +import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil'; +import { filter, size } from 'lodash-es'; +import type { Invocation, S } from 'services/api/types'; + +import { LORA_LOADER } from './constants'; + +export const addGenerationTabSDXLLoRAs = ( + state: RootState, + g: Graph, + denoise: Invocation<'denoise_latents'>, + modelLoader: Invocation<'sdxl_model_loader'>, + seamless: Invocation<'seamless'> | null, + posCond: Invocation<'sdxl_compel_prompt'>, + negCond: Invocation<'sdxl_compel_prompt'> +): void => { + const enabledLoRAs = filter(state.lora.loras, (l) => l.isEnabled ?? false); + const loraCount = size(enabledLoRAs); + + if (loraCount === 0) { + return; + } + + const loraMetadata: S['LoRAMetadataField'][] = []; + + // We will collect LoRAs into a single collection node, then pass them to the LoRA collection loader, which applies + // each LoRA to the UNet and CLIP. + const loraCollector = g.addNode({ + id: `${LORA_LOADER}_collect`, + type: 'collect', + }); + const loraCollectionLoader = g.addNode({ + id: LORA_LOADER, + type: 'sdxl_lora_collection_loader', + }); + + g.addEdge(loraCollector, 'collection', loraCollectionLoader, 'loras'); + // Use seamless as UNet input if it exists, otherwise use the model loader + g.addEdge(seamless ?? modelLoader, 'unet', loraCollectionLoader, 'unet'); + g.addEdge(modelLoader, 'clip', loraCollectionLoader, 'clip'); + g.addEdge(modelLoader, 'clip2', loraCollectionLoader, 'clip2'); + // Reroute UNet & CLIP connections through the LoRA collection loader + g.deleteEdgesTo(denoise, ['unet']); + g.deleteEdgesTo(posCond, ['clip', 'clip2']); + g.deleteEdgesTo(negCond, ['clip', 'clip2']); + g.addEdge(loraCollectionLoader, 'unet', denoise, 'unet'); + g.addEdge(loraCollectionLoader, 'clip', posCond, 'clip'); + g.addEdge(loraCollectionLoader, 'clip', negCond, 'clip'); + g.addEdge(loraCollectionLoader, 'clip2', posCond, 'clip2'); + g.addEdge(loraCollectionLoader, 'clip2', negCond, 'clip2'); + + for (const lora of enabledLoRAs) { + const { weight } = lora; + const { key } = lora.model; + const parsedModel = zModelIdentifierField.parse(lora.model); + + const loraSelector = g.addNode({ + type: 'lora_selector', + id: `${LORA_LOADER}_${key}`, + lora: parsedModel, + weight, + }); + + loraMetadata.push({ + model: parsedModel, + weight, + }); + + g.addEdge(loraSelector, 'lora', loraCollector, 'item'); + } + + MetadataUtil.add(g, { loras: loraMetadata }); +}; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabSDXLRefiner.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabSDXLRefiner.ts new file mode 100644 index 0000000000..7c207b75bb --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabSDXLRefiner.ts @@ -0,0 +1,104 @@ +import type { RootState } from 'app/store/store'; +import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; +import type { Graph } from 'features/nodes/util/graph/Graph'; +import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil'; +import type { Invocation } from 'services/api/types'; +import { isRefinerMainModelModelConfig } from 'services/api/types'; +import { assert } from 'tsafe'; + +import { + SDXL_REFINER_DENOISE_LATENTS, + SDXL_REFINER_MODEL_LOADER, + SDXL_REFINER_NEGATIVE_CONDITIONING, + SDXL_REFINER_POSITIVE_CONDITIONING, + SDXL_REFINER_SEAMLESS, +} from './constants'; +import { getModelMetadataField } from './metadata'; + +export const addGenerationTabSDXLRefiner = async ( + state: RootState, + g: Graph, + denoise: Invocation<'denoise_latents'>, + modelLoader: Invocation<'sdxl_model_loader'>, + seamless: Invocation<'seamless'> | null, + posCond: Invocation<'sdxl_compel_prompt'>, + negCond: Invocation<'sdxl_compel_prompt'>, + l2i: Invocation<'l2i'> +): Promise => { + const { + refinerModel, + refinerPositiveAestheticScore, + refinerNegativeAestheticScore, + refinerSteps, + refinerScheduler, + refinerCFGScale, + refinerStart, + } = state.sdxl; + + assert(refinerModel, 'No refiner model found in state'); + + const modelConfig = await fetchModelConfigWithTypeGuard(refinerModel.key, isRefinerMainModelModelConfig); + + // We need to re-route latents to the refiner + g.deleteEdgesFrom(denoise, ['latents']); + // Latents will now come from refiner - delete edges to the l2i VAE decode + g.deleteEdgesTo(l2i, ['latents']); + + const refinerModelLoader = g.addNode({ + type: 'sdxl_refiner_model_loader', + id: SDXL_REFINER_MODEL_LOADER, + model: refinerModel, + }); + const refinerPosCond = g.addNode({ + type: 'sdxl_refiner_compel_prompt', + id: SDXL_REFINER_POSITIVE_CONDITIONING, + style: posCond.style, + aesthetic_score: refinerPositiveAestheticScore, + }); + const refinerNegCond = g.addNode({ + type: 'sdxl_refiner_compel_prompt', + id: SDXL_REFINER_NEGATIVE_CONDITIONING, + style: negCond.style, + aesthetic_score: refinerNegativeAestheticScore, + }); + const refinerDenoise = g.addNode({ + type: 'denoise_latents', + id: SDXL_REFINER_DENOISE_LATENTS, + cfg_scale: refinerCFGScale, + steps: refinerSteps, + scheduler: refinerScheduler, + denoising_start: refinerStart, + denoising_end: 1, + }); + + if (seamless) { + const refinerSeamless = g.addNode({ + id: SDXL_REFINER_SEAMLESS, + type: 'seamless', + seamless_x: seamless.seamless_x, + seamless_y: seamless.seamless_y, + }); + g.addEdge(refinerModelLoader, 'unet', refinerSeamless, 'unet'); + g.addEdge(refinerModelLoader, 'vae', refinerSeamless, 'vae'); + g.addEdge(refinerSeamless, 'unet', refinerDenoise, 'unet'); + } else { + g.addEdge(refinerModelLoader, 'unet', refinerDenoise, 'unet'); + } + + g.addEdge(refinerModelLoader, 'clip2', refinerPosCond, 'clip2'); + g.addEdge(refinerModelLoader, 'clip2', refinerNegCond, 'clip2'); + g.addEdge(refinerPosCond, 'conditioning', refinerDenoise, 'positive_conditioning'); + g.addEdge(refinerNegCond, 'conditioning', refinerDenoise, 'negative_conditioning'); + g.addEdge(denoise, 'latents', refinerDenoise, 'latents'); + g.addEdge(refinerDenoise, 'latents', l2i, 'latents'); + + MetadataUtil.add(g, { + refiner_model: getModelMetadataField(modelConfig), + refiner_positive_aesthetic_score: refinerPositiveAestheticScore, + refiner_negative_aesthetic_score: refinerNegativeAestheticScore, + refiner_cfg_scale: refinerCFGScale, + refiner_scheduler: refinerScheduler, + refiner_start: refinerStart, + refiner_steps: refinerSteps, + }); +}; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildGenerationTabSDXLGraph2.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildGenerationTabSDXLGraph2.ts new file mode 100644 index 0000000000..05fe3d1565 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildGenerationTabSDXLGraph2.ts @@ -0,0 +1,178 @@ +import type { RootState } from 'app/store/store'; +import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; +import { addGenerationTabControlLayers } from 'features/nodes/util/graph/addGenerationTabControlLayers'; +import { addGenerationTabNSFWChecker } from 'features/nodes/util/graph/addGenerationTabNSFWChecker'; +import { addGenerationTabSDXLLoRAs } from 'features/nodes/util/graph/addGenerationTabSDXLLoRAs'; +import { addGenerationTabSDXLRefiner } from 'features/nodes/util/graph/addGenerationTabSDXLRefiner'; +import { addGenerationTabSeamless } from 'features/nodes/util/graph/addGenerationTabSeamless'; +import { addGenerationTabWatermarker } from 'features/nodes/util/graph/addGenerationTabWatermarker'; +import { Graph } from 'features/nodes/util/graph/Graph'; +import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil'; +import type { Invocation, NonNullableGraph } from 'services/api/types'; +import { isNonRefinerMainModelConfig } from 'services/api/types'; +import { assert } from 'tsafe'; + +import { + LATENTS_TO_IMAGE, + NEGATIVE_CONDITIONING, + NEGATIVE_CONDITIONING_COLLECT, + NOISE, + POSITIVE_CONDITIONING, + POSITIVE_CONDITIONING_COLLECT, + SDXL_CONTROL_LAYERS_GRAPH, + SDXL_DENOISE_LATENTS, + SDXL_MODEL_LOADER, + VAE_LOADER, +} from './constants'; +import { getBoardField, getSDXLStylePrompts } from './graphBuilderUtils'; +import { getModelMetadataField } from './metadata'; + +export const buildGenerationTabSDXLGraph2 = async (state: RootState): Promise => { + const { + model, + cfgScale: cfg_scale, + cfgRescaleMultiplier: cfg_rescale_multiplier, + scheduler, + seed, + steps, + shouldUseCpuNoise, + vaePrecision, + vae, + } = state.generation; + const { positivePrompt, negativePrompt } = state.controlLayers.present; + const { width, height } = state.controlLayers.present.size; + + const { refinerModel, refinerStart } = state.sdxl; + + assert(model, 'No model found in state'); + + const { positiveStylePrompt, negativeStylePrompt } = getSDXLStylePrompts(state); + + const g = new Graph(SDXL_CONTROL_LAYERS_GRAPH); + const modelLoader = g.addNode({ + type: 'sdxl_model_loader', + id: SDXL_MODEL_LOADER, + model, + }); + const posCond = g.addNode({ + type: 'sdxl_compel_prompt', + id: POSITIVE_CONDITIONING, + prompt: positivePrompt, + style: positiveStylePrompt, + }); + const posCondCollect = g.addNode({ + type: 'collect', + id: POSITIVE_CONDITIONING_COLLECT, + }); + const negCond = g.addNode({ + type: 'sdxl_compel_prompt', + id: NEGATIVE_CONDITIONING, + prompt: negativePrompt, + style: negativeStylePrompt, + }); + const negCondCollect = g.addNode({ + type: 'collect', + id: NEGATIVE_CONDITIONING_COLLECT, + }); + const noise = g.addNode({ type: 'noise', id: NOISE, seed, width, height, use_cpu: shouldUseCpuNoise }); + const denoise = g.addNode({ + type: 'denoise_latents', + id: SDXL_DENOISE_LATENTS, + cfg_scale, + cfg_rescale_multiplier, + scheduler, + steps, + denoising_start: 0, + denoising_end: refinerModel ? refinerStart : 1, + }); + const l2i = g.addNode({ + type: 'l2i', + id: LATENTS_TO_IMAGE, + fp32: vaePrecision === 'fp32', + board: getBoardField(state), + // This is the terminal node and must always save to gallery. + is_intermediate: false, + use_cache: false, + }); + const vaeLoader = + vae?.base === model.base + ? g.addNode({ + type: 'vae_loader', + id: VAE_LOADER, + vae_model: vae, + }) + : null; + + let imageOutput: Invocation<'l2i'> | Invocation<'img_nsfw'> | Invocation<'img_watermark'> = l2i; + + g.addEdge(modelLoader, 'unet', denoise, 'unet'); + g.addEdge(modelLoader, 'clip', posCond, 'clip'); + g.addEdge(modelLoader, 'clip', negCond, 'clip'); + g.addEdge(modelLoader, 'clip2', posCond, 'clip2'); + g.addEdge(modelLoader, 'clip2', negCond, 'clip2'); + g.addEdge(posCond, 'conditioning', posCondCollect, 'item'); + g.addEdge(negCond, 'conditioning', negCondCollect, 'item'); + g.addEdge(posCondCollect, 'collection', denoise, 'positive_conditioning'); + g.addEdge(negCondCollect, 'collection', denoise, 'negative_conditioning'); + g.addEdge(noise, 'noise', denoise, 'noise'); + g.addEdge(denoise, 'latents', l2i, 'latents'); + + const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig); + + MetadataUtil.add(g, { + generation_mode: 'txt2img', + cfg_scale, + cfg_rescale_multiplier, + height, + width, + positive_prompt: positivePrompt, + negative_prompt: negativePrompt, + model: getModelMetadataField(modelConfig), + seed, + steps, + rand_device: shouldUseCpuNoise ? 'cpu' : 'cuda', + scheduler, + positive_style_prompt: positiveStylePrompt, + negative_style_prompt: negativeStylePrompt, + vae: vae ?? undefined, + }); + g.validate(); + + const seamless = addGenerationTabSeamless(state, g, denoise, modelLoader, vaeLoader); + g.validate(); + + addGenerationTabSDXLLoRAs(state, g, denoise, modelLoader, seamless, posCond, negCond); + g.validate(); + + // We might get the VAE from the main model, custom VAE, or seamless node. + const vaeSource = seamless ?? vaeLoader ?? modelLoader; + g.addEdge(vaeSource, 'vae', l2i, 'vae'); + + // Add Refiner if enabled + if (refinerModel) { + await addGenerationTabSDXLRefiner(state, g, denoise, modelLoader, seamless, posCond, negCond, l2i); + } + + await addGenerationTabControlLayers( + state, + g, + denoise, + posCond, + negCond, + posCondCollect, + negCondCollect, + noise, + vaeSource + ); + + if (state.system.shouldUseNSFWChecker) { + imageOutput = addGenerationTabNSFWChecker(g, imageOutput); + } + + if (state.system.shouldUseWatermarker) { + imageOutput = addGenerationTabWatermarker(g, imageOutput); + } + + MetadataUtil.setMetadataReceivingNode(g, imageOutput); + return g.getGraph(); +};