From eb320df41d338b6de05b6148d2dd07b16e19e567 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 9 May 2024 19:35:16 +1000 Subject: [PATCH] feat(ui): use new lora loaders, simplify VAE loader, seamless --- .../util/graph/addControlLayersToGraph2.ts | 18 ++--- .../graph/addGenerationTabControlLayers.ts | 15 +++- .../nodes/util/graph/addGenerationTabLoRAs.ts | 81 +++++++------------ .../util/graph/addGenerationTabSeamless.ts | 17 +--- .../util/graph/buildGenerationTabGraph2.ts | 25 ++++-- 5 files changed, 74 insertions(+), 82 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addControlLayersToGraph2.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addControlLayersToGraph2.ts index 16d7d74c27..5721f31313 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addControlLayersToGraph2.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addControlLayersToGraph2.ts @@ -490,29 +490,27 @@ const isValidIPAdapter = (ipa: IPAdapterConfigV2, base: BaseModelType): boolean }; const isValidLayer = (layer: Layer, base: BaseModelType) => { + if (!layer.isEnabled) { + return false; + } if (isControlAdapterLayer(layer)) { - if (!layer.isEnabled) { - return false; - } return isValidControlAdapter(layer.controlAdapter, base); } if (isIPAdapterLayer(layer)) { - if (!layer.isEnabled) { - return false; - } return isValidIPAdapter(layer.ipAdapter, base); } if (isInitialImageLayer(layer)) { - if (!layer.isEnabled) { - return false; - } if (!layer.image) { return false; } return true; } if (isRegionalGuidanceLayer(layer)) { - const hasTextPrompt = Boolean(layer.positivePrompt || layer.negativePrompt); + if (layer.maskObjects.length === 0) { + // Layer has no mask, meaning any guidance would be applied to an empty region. + return false; + } + const hasTextPrompt = Boolean(layer.positivePrompt) || Boolean(layer.negativePrompt); const hasIPAdapter = layer.ipAdapters.filter((ipa) => isValidIPAdapter(ipa, base)).length > 0; return hasTextPrompt || hasIPAdapter; } diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabControlLayers.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabControlLayers.ts index d3b2788329..7851d5c19d 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabControlLayers.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabControlLayers.ts @@ -45,7 +45,12 @@ export const addGenerationTabControlLayers = async ( negCond: Invocation<'compel'> | Invocation<'sdxl_compel_prompt'>, posCondCollect: Invocation<'collect'>, negCondCollect: Invocation<'collect'>, - noise: Invocation<'noise'> + noise: Invocation<'noise'>, + vaeSource: + | Invocation<'seamless'> + | Invocation<'vae_loader'> + | Invocation<'main_model_loader'> + | Invocation<'sdxl_model_loader'> ): Promise => { const mainModel = state.generation.model; assert(mainModel, 'Missing main model when building graph'); @@ -67,7 +72,7 @@ export const addGenerationTabControlLayers = async ( const initialImageLayers = validLayers.filter(isInitialImageLayer); assert(initialImageLayers.length <= 1, 'Only one initial image layer allowed'); if (initialImageLayers[0]) { - addInitialImageLayerToGraph(state, g, denoise, noise, initialImageLayers[0]); + addInitialImageLayerToGraph(state, g, denoise, noise, vaeSource, initialImageLayers[0]); } // TODO: We should probably just use conditioning collectors by default, and skip all this fanagling with re-routing // the existing conditioning nodes. @@ -414,6 +419,11 @@ const addInitialImageLayerToGraph = ( g: Graph, denoise: Invocation<'denoise_latents'>, noise: Invocation<'noise'>, + vaeSource: + | Invocation<'seamless'> + | Invocation<'vae_loader'> + | Invocation<'main_model_loader'> + | Invocation<'sdxl_model_loader'>, layer: InitialImageLayer ) => { const { vaePrecision, model } = state.generation; @@ -438,6 +448,7 @@ const addInitialImageLayerToGraph = ( }); g.addEdge(i2l, 'latents', denoise, 'latents'); + g.addEdge(vaeSource, 'vae', i2l, 'vae'); if (layer.image.width !== width || layer.image.height !== height) { // The init image needs to be resized to the specified width and height before being passed to `IMAGE_TO_LATENTS` diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabLoRAs.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabLoRAs.ts index 3cb43fd48d..5a7173f2d5 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabLoRAs.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabLoRAs.ts @@ -1,11 +1,9 @@ import type { RootState } from 'app/store/store'; -import { deepClone } from 'common/util/deepClone'; import { zModelIdentifierField } from 'features/nodes/types/common'; -import { Graph } from 'features/nodes/util/graph/Graph'; +import type { Graph } from 'features/nodes/util/graph/Graph'; import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil'; import { filter, size } from 'lodash-es'; import type { Invocation, S } from 'services/api/types'; -import { assert } from 'tsafe'; import { LORA_LOADER } from './constants'; @@ -13,19 +11,12 @@ export const addGenerationTabLoRAs = ( state: RootState, g: Graph, denoise: Invocation<'denoise_latents'>, - unetSource: Invocation<'main_model_loader'> | Invocation<'sdxl_model_loader'> | Invocation<'seamless'>, + modelLoader: Invocation<'main_model_loader'>, + seamless: Invocation<'seamless'> | null, clipSkip: Invocation<'clip_skip'>, posCond: Invocation<'compel'>, negCond: Invocation<'compel'> ): void => { - /** - * LoRA nodes get the UNet and CLIP models from the main model loader and apply the LoRA to them. - * They then output the UNet and CLIP models references on to either the next LoRA in the chain, - * or to the inference/conditioning nodes. - * - * So we need to inject a LoRA chain into the graph. - */ - const enabledLoRAs = filter(state.lora.loras, (l) => l.isEnabled ?? false); const loraCount = size(enabledLoRAs); @@ -33,30 +24,39 @@ export const addGenerationTabLoRAs = ( return; } - // Remove modelLoaderNodeId unet connection to feed it to LoRAs - console.log(deepClone(g)._graph.edges.map((e) => Graph.edgeToString(e))); - g.deleteEdgesFrom(unetSource, 'unet'); - console.log(deepClone(g)._graph.edges.map((e) => Graph.edgeToString(e))); - if (clipSkip) { - // Remove CLIP_SKIP connections to conditionings to feed it through LoRAs - g.deleteEdgesFrom(clipSkip, 'clip'); - } - console.log(deepClone(g)._graph.edges.map((e) => Graph.edgeToString(e))); - - // we need to remember the last lora so we can chain from it - let lastLoRALoader: Invocation<'lora_loader'> | null = null; - let currentLoraIndex = 0; const loraMetadata: S['LoRAMetadataField'][] = []; + // We will collect LoRAs into a single collection node, then pass them to the LoRA collection loader, which applies + // each LoRA to the UNet and CLIP. + const loraCollector = g.addNode({ + id: `${LORA_LOADER}_collect`, + type: 'collect', + }); + const loraCollectionLoader = g.addNode({ + id: LORA_LOADER, + type: 'lora_collection_loader', + }); + + g.addEdge(loraCollector, 'collection', loraCollectionLoader, 'loras'); + // Use seamless as UNet input if it exists, otherwise use the model loader + g.addEdge(seamless ?? modelLoader, 'unet', loraCollectionLoader, 'unet'); + g.addEdge(clipSkip, 'clip', loraCollectionLoader, 'clip'); + // Reroute UNet & CLIP connections through the LoRA collection loader + g.deleteEdgesTo(denoise, 'unet'); + g.deleteEdgesTo(posCond, 'clip'); + g.deleteEdgesTo(negCond, 'clip'); + g.addEdge(loraCollectionLoader, 'unet', denoise, 'unet'); + g.addEdge(loraCollectionLoader, 'clip', posCond, 'clip'); + g.addEdge(loraCollectionLoader, 'clip', negCond, 'clip'); + for (const lora of enabledLoRAs) { const { weight } = lora; const { key } = lora.model; - const currentLoraNodeId = `${LORA_LOADER}_${key}`; const parsedModel = zModelIdentifierField.parse(lora.model); - const currentLoRALoader = g.addNode({ - type: 'lora_loader', - id: currentLoraNodeId, + const loraSelector = g.addNode({ + type: 'lora_selector', + id: `${LORA_LOADER}_${key}`, lora: parsedModel, weight, }); @@ -66,28 +66,7 @@ export const addGenerationTabLoRAs = ( weight, }); - // add to graph - if (currentLoraIndex === 0) { - // first lora = start the lora chain, attach directly to model loader - g.addEdge(unetSource, 'unet', currentLoRALoader, 'unet'); - g.addEdge(clipSkip, 'clip', currentLoRALoader, 'clip'); - } else { - assert(lastLoRALoader !== null); - // we are in the middle of the lora chain, instead connect to the previous lora - g.addEdge(lastLoRALoader, 'unet', currentLoRALoader, 'unet'); - g.addEdge(lastLoRALoader, 'clip', currentLoRALoader, 'clip'); - } - - if (currentLoraIndex === loraCount - 1) { - // final lora, end the lora chain - we need to connect up to inference and conditioning nodes - g.addEdge(currentLoRALoader, 'unet', denoise, 'unet'); - g.addEdge(currentLoRALoader, 'clip', posCond, 'clip'); - g.addEdge(currentLoRALoader, 'clip', negCond, 'clip'); - } - - // increment the lora for the next one in the chain - lastLoRALoader = currentLoRALoader; - currentLoraIndex += 1; + g.addEdge(loraSelector, 'lora', loraCollector, 'item'); } MetadataUtil.add(g, { loras: loraMetadata }); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabSeamless.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabSeamless.ts index e56f37916c..ef0b38291f 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabSeamless.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addGenerationTabSeamless.ts @@ -3,7 +3,7 @@ import type { Graph } from 'features/nodes/util/graph/Graph'; import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil'; import type { Invocation } from 'services/api/types'; -import { SEAMLESS, VAE_LOADER } from './constants'; +import { SEAMLESS } from './constants'; /** * Adds the seamless node to the graph and connects it to the model loader and denoise node. @@ -19,9 +19,10 @@ export const addGenerationTabSeamless = ( state: RootState, g: Graph, denoise: Invocation<'denoise_latents'>, - modelLoader: Invocation<'main_model_loader'> | Invocation<'sdxl_model_loader'> + modelLoader: Invocation<'main_model_loader'> | Invocation<'sdxl_model_loader'>, + vaeLoader: Invocation<'vae_loader'> | null ): Invocation<'seamless'> | null => { - const { seamlessXAxis: seamless_x, seamlessYAxis: seamless_y, vae } = state.generation; + const { seamlessXAxis: seamless_x, seamlessYAxis: seamless_y } = state.generation; if (!seamless_x && !seamless_y) { return null; @@ -34,16 +35,6 @@ export const addGenerationTabSeamless = ( seamless_y, }); - // The VAE helper also adds the VAE loader - so we need to check if it's already there - const shouldAddVAELoader = !g.hasNode(VAE_LOADER) && vae; - const vaeLoader = shouldAddVAELoader - ? g.addNode({ - type: 'vae_loader', - id: VAE_LOADER, - vae_model: vae, - }) - : null; - MetadataUtil.add(g, { seamless_x: seamless_x || undefined, seamless_y: seamless_y || undefined, diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildGenerationTabGraph2.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildGenerationTabGraph2.ts index 38c7ba18d1..2735103215 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildGenerationTabGraph2.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildGenerationTabGraph2.ts @@ -5,7 +5,6 @@ import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetch import { addGenerationTabControlLayers } from 'features/nodes/util/graph/addGenerationTabControlLayers'; import { addGenerationTabLoRAs } from 'features/nodes/util/graph/addGenerationTabLoRAs'; import { addGenerationTabSeamless } from 'features/nodes/util/graph/addGenerationTabSeamless'; -import { addGenerationTabVAE } from 'features/nodes/util/graph/addGenerationTabVAE'; import type { GraphType } from 'features/nodes/util/graph/Graph'; import { Graph } from 'features/nodes/util/graph/Graph'; import { getBoardField } from 'features/nodes/util/graph/graphBuilderUtils'; @@ -26,6 +25,7 @@ import { NOISE, POSITIVE_CONDITIONING, POSITIVE_CONDITIONING_COLLECT, + VAE_LOADER, } from './constants'; import { getModelMetadataField } from './metadata'; @@ -41,6 +41,7 @@ export const buildGenerationTabGraph2 = async (state: RootState): Promise