feat(ui): use new lora loaders, simplify VAE loader, seamless

This commit is contained in:
psychedelicious 2024-05-09 19:35:16 +10:00
parent de1869773f
commit eb320df41d
5 changed files with 74 additions and 82 deletions

View File

@ -490,29 +490,27 @@ const isValidIPAdapter = (ipa: IPAdapterConfigV2, base: BaseModelType): boolean
};
const isValidLayer = (layer: Layer, base: BaseModelType) => {
if (!layer.isEnabled) {
return false;
}
if (isControlAdapterLayer(layer)) {
if (!layer.isEnabled) {
return false;
}
return isValidControlAdapter(layer.controlAdapter, base);
}
if (isIPAdapterLayer(layer)) {
if (!layer.isEnabled) {
return false;
}
return isValidIPAdapter(layer.ipAdapter, base);
}
if (isInitialImageLayer(layer)) {
if (!layer.isEnabled) {
return false;
}
if (!layer.image) {
return false;
}
return true;
}
if (isRegionalGuidanceLayer(layer)) {
const hasTextPrompt = Boolean(layer.positivePrompt || layer.negativePrompt);
if (layer.maskObjects.length === 0) {
// Layer has no mask, meaning any guidance would be applied to an empty region.
return false;
}
const hasTextPrompt = Boolean(layer.positivePrompt) || Boolean(layer.negativePrompt);
const hasIPAdapter = layer.ipAdapters.filter((ipa) => isValidIPAdapter(ipa, base)).length > 0;
return hasTextPrompt || hasIPAdapter;
}

View File

@ -45,7 +45,12 @@ export const addGenerationTabControlLayers = async (
negCond: Invocation<'compel'> | Invocation<'sdxl_compel_prompt'>,
posCondCollect: Invocation<'collect'>,
negCondCollect: Invocation<'collect'>,
noise: Invocation<'noise'>
noise: Invocation<'noise'>,
vaeSource:
| Invocation<'seamless'>
| Invocation<'vae_loader'>
| Invocation<'main_model_loader'>
| Invocation<'sdxl_model_loader'>
): Promise<Layer[]> => {
const mainModel = state.generation.model;
assert(mainModel, 'Missing main model when building graph');
@ -67,7 +72,7 @@ export const addGenerationTabControlLayers = async (
const initialImageLayers = validLayers.filter(isInitialImageLayer);
assert(initialImageLayers.length <= 1, 'Only one initial image layer allowed');
if (initialImageLayers[0]) {
addInitialImageLayerToGraph(state, g, denoise, noise, initialImageLayers[0]);
addInitialImageLayerToGraph(state, g, denoise, noise, vaeSource, initialImageLayers[0]);
}
// TODO: We should probably just use conditioning collectors by default, and skip all this fanagling with re-routing
// the existing conditioning nodes.
@ -414,6 +419,11 @@ const addInitialImageLayerToGraph = (
g: Graph,
denoise: Invocation<'denoise_latents'>,
noise: Invocation<'noise'>,
vaeSource:
| Invocation<'seamless'>
| Invocation<'vae_loader'>
| Invocation<'main_model_loader'>
| Invocation<'sdxl_model_loader'>,
layer: InitialImageLayer
) => {
const { vaePrecision, model } = state.generation;
@ -438,6 +448,7 @@ const addInitialImageLayerToGraph = (
});
g.addEdge(i2l, 'latents', denoise, 'latents');
g.addEdge(vaeSource, 'vae', i2l, 'vae');
if (layer.image.width !== width || layer.image.height !== height) {
// The init image needs to be resized to the specified width and height before being passed to `IMAGE_TO_LATENTS`

View File

@ -1,11 +1,9 @@
import type { RootState } from 'app/store/store';
import { deepClone } from 'common/util/deepClone';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { Graph } from 'features/nodes/util/graph/Graph';
import type { Graph } from 'features/nodes/util/graph/Graph';
import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil';
import { filter, size } from 'lodash-es';
import type { Invocation, S } from 'services/api/types';
import { assert } from 'tsafe';
import { LORA_LOADER } from './constants';
@ -13,19 +11,12 @@ export const addGenerationTabLoRAs = (
state: RootState,
g: Graph,
denoise: Invocation<'denoise_latents'>,
unetSource: Invocation<'main_model_loader'> | Invocation<'sdxl_model_loader'> | Invocation<'seamless'>,
modelLoader: Invocation<'main_model_loader'>,
seamless: Invocation<'seamless'> | null,
clipSkip: Invocation<'clip_skip'>,
posCond: Invocation<'compel'>,
negCond: Invocation<'compel'>
): void => {
/**
* LoRA nodes get the UNet and CLIP models from the main model loader and apply the LoRA to them.
* They then output the UNet and CLIP models references on to either the next LoRA in the chain,
* or to the inference/conditioning nodes.
*
* So we need to inject a LoRA chain into the graph.
*/
const enabledLoRAs = filter(state.lora.loras, (l) => l.isEnabled ?? false);
const loraCount = size(enabledLoRAs);
@ -33,30 +24,39 @@ export const addGenerationTabLoRAs = (
return;
}
// Remove modelLoaderNodeId unet connection to feed it to LoRAs
console.log(deepClone(g)._graph.edges.map((e) => Graph.edgeToString(e)));
g.deleteEdgesFrom(unetSource, 'unet');
console.log(deepClone(g)._graph.edges.map((e) => Graph.edgeToString(e)));
if (clipSkip) {
// Remove CLIP_SKIP connections to conditionings to feed it through LoRAs
g.deleteEdgesFrom(clipSkip, 'clip');
}
console.log(deepClone(g)._graph.edges.map((e) => Graph.edgeToString(e)));
// we need to remember the last lora so we can chain from it
let lastLoRALoader: Invocation<'lora_loader'> | null = null;
let currentLoraIndex = 0;
const loraMetadata: S['LoRAMetadataField'][] = [];
// We will collect LoRAs into a single collection node, then pass them to the LoRA collection loader, which applies
// each LoRA to the UNet and CLIP.
const loraCollector = g.addNode({
id: `${LORA_LOADER}_collect`,
type: 'collect',
});
const loraCollectionLoader = g.addNode({
id: LORA_LOADER,
type: 'lora_collection_loader',
});
g.addEdge(loraCollector, 'collection', loraCollectionLoader, 'loras');
// Use seamless as UNet input if it exists, otherwise use the model loader
g.addEdge(seamless ?? modelLoader, 'unet', loraCollectionLoader, 'unet');
g.addEdge(clipSkip, 'clip', loraCollectionLoader, 'clip');
// Reroute UNet & CLIP connections through the LoRA collection loader
g.deleteEdgesTo(denoise, 'unet');
g.deleteEdgesTo(posCond, 'clip');
g.deleteEdgesTo(negCond, 'clip');
g.addEdge(loraCollectionLoader, 'unet', denoise, 'unet');
g.addEdge(loraCollectionLoader, 'clip', posCond, 'clip');
g.addEdge(loraCollectionLoader, 'clip', negCond, 'clip');
for (const lora of enabledLoRAs) {
const { weight } = lora;
const { key } = lora.model;
const currentLoraNodeId = `${LORA_LOADER}_${key}`;
const parsedModel = zModelIdentifierField.parse(lora.model);
const currentLoRALoader = g.addNode({
type: 'lora_loader',
id: currentLoraNodeId,
const loraSelector = g.addNode({
type: 'lora_selector',
id: `${LORA_LOADER}_${key}`,
lora: parsedModel,
weight,
});
@ -66,28 +66,7 @@ export const addGenerationTabLoRAs = (
weight,
});
// add to graph
if (currentLoraIndex === 0) {
// first lora = start the lora chain, attach directly to model loader
g.addEdge(unetSource, 'unet', currentLoRALoader, 'unet');
g.addEdge(clipSkip, 'clip', currentLoRALoader, 'clip');
} else {
assert(lastLoRALoader !== null);
// we are in the middle of the lora chain, instead connect to the previous lora
g.addEdge(lastLoRALoader, 'unet', currentLoRALoader, 'unet');
g.addEdge(lastLoRALoader, 'clip', currentLoRALoader, 'clip');
}
if (currentLoraIndex === loraCount - 1) {
// final lora, end the lora chain - we need to connect up to inference and conditioning nodes
g.addEdge(currentLoRALoader, 'unet', denoise, 'unet');
g.addEdge(currentLoRALoader, 'clip', posCond, 'clip');
g.addEdge(currentLoRALoader, 'clip', negCond, 'clip');
}
// increment the lora for the next one in the chain
lastLoRALoader = currentLoRALoader;
currentLoraIndex += 1;
g.addEdge(loraSelector, 'lora', loraCollector, 'item');
}
MetadataUtil.add(g, { loras: loraMetadata });

View File

@ -3,7 +3,7 @@ import type { Graph } from 'features/nodes/util/graph/Graph';
import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil';
import type { Invocation } from 'services/api/types';
import { SEAMLESS, VAE_LOADER } from './constants';
import { SEAMLESS } from './constants';
/**
* Adds the seamless node to the graph and connects it to the model loader and denoise node.
@ -19,9 +19,10 @@ export const addGenerationTabSeamless = (
state: RootState,
g: Graph,
denoise: Invocation<'denoise_latents'>,
modelLoader: Invocation<'main_model_loader'> | Invocation<'sdxl_model_loader'>
modelLoader: Invocation<'main_model_loader'> | Invocation<'sdxl_model_loader'>,
vaeLoader: Invocation<'vae_loader'> | null
): Invocation<'seamless'> | null => {
const { seamlessXAxis: seamless_x, seamlessYAxis: seamless_y, vae } = state.generation;
const { seamlessXAxis: seamless_x, seamlessYAxis: seamless_y } = state.generation;
if (!seamless_x && !seamless_y) {
return null;
@ -34,16 +35,6 @@ export const addGenerationTabSeamless = (
seamless_y,
});
// The VAE helper also adds the VAE loader - so we need to check if it's already there
const shouldAddVAELoader = !g.hasNode(VAE_LOADER) && vae;
const vaeLoader = shouldAddVAELoader
? g.addNode({
type: 'vae_loader',
id: VAE_LOADER,
vae_model: vae,
})
: null;
MetadataUtil.add(g, {
seamless_x: seamless_x || undefined,
seamless_y: seamless_y || undefined,

View File

@ -5,7 +5,6 @@ import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetch
import { addGenerationTabControlLayers } from 'features/nodes/util/graph/addGenerationTabControlLayers';
import { addGenerationTabLoRAs } from 'features/nodes/util/graph/addGenerationTabLoRAs';
import { addGenerationTabSeamless } from 'features/nodes/util/graph/addGenerationTabSeamless';
import { addGenerationTabVAE } from 'features/nodes/util/graph/addGenerationTabVAE';
import type { GraphType } from 'features/nodes/util/graph/Graph';
import { Graph } from 'features/nodes/util/graph/Graph';
import { getBoardField } from 'features/nodes/util/graph/graphBuilderUtils';
@ -26,6 +25,7 @@ import {
NOISE,
POSITIVE_CONDITIONING,
POSITIVE_CONDITIONING_COLLECT,
VAE_LOADER,
} from './constants';
import { getModelMetadataField } from './metadata';
@ -41,6 +41,7 @@ export const buildGenerationTabGraph2 = async (state: RootState): Promise<GraphT
shouldUseCpuNoise,
vaePrecision,
seed,
vae,
} = state.generation;
const { positivePrompt, negativePrompt } = state.controlLayers.present;
const { width, height } = state.controlLayers.present.size;
@ -106,6 +107,14 @@ export const buildGenerationTabGraph2 = async (state: RootState): Promise<GraphT
is_intermediate: false,
use_cache: false,
});
const vaeLoader =
vae?.base === model.base
? g.addNode({
type: 'vae_loader',
id: VAE_LOADER,
vae_model: vae,
})
: null;
g.addEdge(modelLoader, 'unet', denoise, 'unet');
g.addEdge(modelLoader, 'clip', clipSkip, 'clip');
@ -134,17 +143,20 @@ export const buildGenerationTabGraph2 = async (state: RootState): Promise<GraphT
rand_device: shouldUseCpuNoise ? 'cpu' : 'cuda',
scheduler,
clip_skip: skipped_layers,
vae: vae ?? undefined,
});
MetadataUtil.setMetadataReceivingNode(g, l2i);
g.validate();
const seamless = addGenerationTabSeamless(state, g, denoise, modelLoader);
const seamless = addGenerationTabSeamless(state, g, denoise, modelLoader, vaeLoader);
g.validate();
addGenerationTabVAE(state, g, modelLoader, l2i, i2l, seamless);
g.validate();
addGenerationTabLoRAs(state, g, denoise, seamless ?? modelLoader, clipSkip, posCond, negCond);
addGenerationTabLoRAs(state, g, denoise, modelLoader, seamless, clipSkip, posCond, negCond);
g.validate();
// We might get the VAE from the main model, custom VAE, or seamless node.
const vaeSource = seamless ?? vaeLoader ?? modelLoader;
g.addEdge(vaeSource, 'vae', l2i, 'vae');
const addedLayers = await addGenerationTabControlLayers(
state,
g,
@ -153,7 +165,8 @@ export const buildGenerationTabGraph2 = async (state: RootState): Promise<GraphT
negCond,
posCondCollect,
negCondCollect,
noise
noise,
vaeSource
);
g.validate();