mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): use new lora loaders, simplify VAE loader, seamless
This commit is contained in:
parent
de1869773f
commit
eb320df41d
@ -490,29 +490,27 @@ const isValidIPAdapter = (ipa: IPAdapterConfigV2, base: BaseModelType): boolean
|
||||
};
|
||||
|
||||
const isValidLayer = (layer: Layer, base: BaseModelType) => {
|
||||
if (isControlAdapterLayer(layer)) {
|
||||
if (!layer.isEnabled) {
|
||||
return false;
|
||||
}
|
||||
if (isControlAdapterLayer(layer)) {
|
||||
return isValidControlAdapter(layer.controlAdapter, base);
|
||||
}
|
||||
if (isIPAdapterLayer(layer)) {
|
||||
if (!layer.isEnabled) {
|
||||
return false;
|
||||
}
|
||||
return isValidIPAdapter(layer.ipAdapter, base);
|
||||
}
|
||||
if (isInitialImageLayer(layer)) {
|
||||
if (!layer.isEnabled) {
|
||||
return false;
|
||||
}
|
||||
if (!layer.image) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
if (isRegionalGuidanceLayer(layer)) {
|
||||
const hasTextPrompt = Boolean(layer.positivePrompt || layer.negativePrompt);
|
||||
if (layer.maskObjects.length === 0) {
|
||||
// Layer has no mask, meaning any guidance would be applied to an empty region.
|
||||
return false;
|
||||
}
|
||||
const hasTextPrompt = Boolean(layer.positivePrompt) || Boolean(layer.negativePrompt);
|
||||
const hasIPAdapter = layer.ipAdapters.filter((ipa) => isValidIPAdapter(ipa, base)).length > 0;
|
||||
return hasTextPrompt || hasIPAdapter;
|
||||
}
|
||||
|
@ -45,7 +45,12 @@ export const addGenerationTabControlLayers = async (
|
||||
negCond: Invocation<'compel'> | Invocation<'sdxl_compel_prompt'>,
|
||||
posCondCollect: Invocation<'collect'>,
|
||||
negCondCollect: Invocation<'collect'>,
|
||||
noise: Invocation<'noise'>
|
||||
noise: Invocation<'noise'>,
|
||||
vaeSource:
|
||||
| Invocation<'seamless'>
|
||||
| Invocation<'vae_loader'>
|
||||
| Invocation<'main_model_loader'>
|
||||
| Invocation<'sdxl_model_loader'>
|
||||
): Promise<Layer[]> => {
|
||||
const mainModel = state.generation.model;
|
||||
assert(mainModel, 'Missing main model when building graph');
|
||||
@ -67,7 +72,7 @@ export const addGenerationTabControlLayers = async (
|
||||
const initialImageLayers = validLayers.filter(isInitialImageLayer);
|
||||
assert(initialImageLayers.length <= 1, 'Only one initial image layer allowed');
|
||||
if (initialImageLayers[0]) {
|
||||
addInitialImageLayerToGraph(state, g, denoise, noise, initialImageLayers[0]);
|
||||
addInitialImageLayerToGraph(state, g, denoise, noise, vaeSource, initialImageLayers[0]);
|
||||
}
|
||||
// TODO: We should probably just use conditioning collectors by default, and skip all this fanagling with re-routing
|
||||
// the existing conditioning nodes.
|
||||
@ -414,6 +419,11 @@ const addInitialImageLayerToGraph = (
|
||||
g: Graph,
|
||||
denoise: Invocation<'denoise_latents'>,
|
||||
noise: Invocation<'noise'>,
|
||||
vaeSource:
|
||||
| Invocation<'seamless'>
|
||||
| Invocation<'vae_loader'>
|
||||
| Invocation<'main_model_loader'>
|
||||
| Invocation<'sdxl_model_loader'>,
|
||||
layer: InitialImageLayer
|
||||
) => {
|
||||
const { vaePrecision, model } = state.generation;
|
||||
@ -438,6 +448,7 @@ const addInitialImageLayerToGraph = (
|
||||
});
|
||||
|
||||
g.addEdge(i2l, 'latents', denoise, 'latents');
|
||||
g.addEdge(vaeSource, 'vae', i2l, 'vae');
|
||||
|
||||
if (layer.image.width !== width || layer.image.height !== height) {
|
||||
// The init image needs to be resized to the specified width and height before being passed to `IMAGE_TO_LATENTS`
|
||||
|
@ -1,11 +1,9 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { Graph } from 'features/nodes/util/graph/Graph';
|
||||
import type { Graph } from 'features/nodes/util/graph/Graph';
|
||||
import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil';
|
||||
import { filter, size } from 'lodash-es';
|
||||
import type { Invocation, S } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
import { LORA_LOADER } from './constants';
|
||||
|
||||
@ -13,19 +11,12 @@ export const addGenerationTabLoRAs = (
|
||||
state: RootState,
|
||||
g: Graph,
|
||||
denoise: Invocation<'denoise_latents'>,
|
||||
unetSource: Invocation<'main_model_loader'> | Invocation<'sdxl_model_loader'> | Invocation<'seamless'>,
|
||||
modelLoader: Invocation<'main_model_loader'>,
|
||||
seamless: Invocation<'seamless'> | null,
|
||||
clipSkip: Invocation<'clip_skip'>,
|
||||
posCond: Invocation<'compel'>,
|
||||
negCond: Invocation<'compel'>
|
||||
): void => {
|
||||
/**
|
||||
* LoRA nodes get the UNet and CLIP models from the main model loader and apply the LoRA to them.
|
||||
* They then output the UNet and CLIP models references on to either the next LoRA in the chain,
|
||||
* or to the inference/conditioning nodes.
|
||||
*
|
||||
* So we need to inject a LoRA chain into the graph.
|
||||
*/
|
||||
|
||||
const enabledLoRAs = filter(state.lora.loras, (l) => l.isEnabled ?? false);
|
||||
const loraCount = size(enabledLoRAs);
|
||||
|
||||
@ -33,30 +24,39 @@ export const addGenerationTabLoRAs = (
|
||||
return;
|
||||
}
|
||||
|
||||
// Remove modelLoaderNodeId unet connection to feed it to LoRAs
|
||||
console.log(deepClone(g)._graph.edges.map((e) => Graph.edgeToString(e)));
|
||||
g.deleteEdgesFrom(unetSource, 'unet');
|
||||
console.log(deepClone(g)._graph.edges.map((e) => Graph.edgeToString(e)));
|
||||
if (clipSkip) {
|
||||
// Remove CLIP_SKIP connections to conditionings to feed it through LoRAs
|
||||
g.deleteEdgesFrom(clipSkip, 'clip');
|
||||
}
|
||||
console.log(deepClone(g)._graph.edges.map((e) => Graph.edgeToString(e)));
|
||||
|
||||
// we need to remember the last lora so we can chain from it
|
||||
let lastLoRALoader: Invocation<'lora_loader'> | null = null;
|
||||
let currentLoraIndex = 0;
|
||||
const loraMetadata: S['LoRAMetadataField'][] = [];
|
||||
|
||||
// We will collect LoRAs into a single collection node, then pass them to the LoRA collection loader, which applies
|
||||
// each LoRA to the UNet and CLIP.
|
||||
const loraCollector = g.addNode({
|
||||
id: `${LORA_LOADER}_collect`,
|
||||
type: 'collect',
|
||||
});
|
||||
const loraCollectionLoader = g.addNode({
|
||||
id: LORA_LOADER,
|
||||
type: 'lora_collection_loader',
|
||||
});
|
||||
|
||||
g.addEdge(loraCollector, 'collection', loraCollectionLoader, 'loras');
|
||||
// Use seamless as UNet input if it exists, otherwise use the model loader
|
||||
g.addEdge(seamless ?? modelLoader, 'unet', loraCollectionLoader, 'unet');
|
||||
g.addEdge(clipSkip, 'clip', loraCollectionLoader, 'clip');
|
||||
// Reroute UNet & CLIP connections through the LoRA collection loader
|
||||
g.deleteEdgesTo(denoise, 'unet');
|
||||
g.deleteEdgesTo(posCond, 'clip');
|
||||
g.deleteEdgesTo(negCond, 'clip');
|
||||
g.addEdge(loraCollectionLoader, 'unet', denoise, 'unet');
|
||||
g.addEdge(loraCollectionLoader, 'clip', posCond, 'clip');
|
||||
g.addEdge(loraCollectionLoader, 'clip', negCond, 'clip');
|
||||
|
||||
for (const lora of enabledLoRAs) {
|
||||
const { weight } = lora;
|
||||
const { key } = lora.model;
|
||||
const currentLoraNodeId = `${LORA_LOADER}_${key}`;
|
||||
const parsedModel = zModelIdentifierField.parse(lora.model);
|
||||
|
||||
const currentLoRALoader = g.addNode({
|
||||
type: 'lora_loader',
|
||||
id: currentLoraNodeId,
|
||||
const loraSelector = g.addNode({
|
||||
type: 'lora_selector',
|
||||
id: `${LORA_LOADER}_${key}`,
|
||||
lora: parsedModel,
|
||||
weight,
|
||||
});
|
||||
@ -66,28 +66,7 @@ export const addGenerationTabLoRAs = (
|
||||
weight,
|
||||
});
|
||||
|
||||
// add to graph
|
||||
if (currentLoraIndex === 0) {
|
||||
// first lora = start the lora chain, attach directly to model loader
|
||||
g.addEdge(unetSource, 'unet', currentLoRALoader, 'unet');
|
||||
g.addEdge(clipSkip, 'clip', currentLoRALoader, 'clip');
|
||||
} else {
|
||||
assert(lastLoRALoader !== null);
|
||||
// we are in the middle of the lora chain, instead connect to the previous lora
|
||||
g.addEdge(lastLoRALoader, 'unet', currentLoRALoader, 'unet');
|
||||
g.addEdge(lastLoRALoader, 'clip', currentLoRALoader, 'clip');
|
||||
}
|
||||
|
||||
if (currentLoraIndex === loraCount - 1) {
|
||||
// final lora, end the lora chain - we need to connect up to inference and conditioning nodes
|
||||
g.addEdge(currentLoRALoader, 'unet', denoise, 'unet');
|
||||
g.addEdge(currentLoRALoader, 'clip', posCond, 'clip');
|
||||
g.addEdge(currentLoRALoader, 'clip', negCond, 'clip');
|
||||
}
|
||||
|
||||
// increment the lora for the next one in the chain
|
||||
lastLoRALoader = currentLoRALoader;
|
||||
currentLoraIndex += 1;
|
||||
g.addEdge(loraSelector, 'lora', loraCollector, 'item');
|
||||
}
|
||||
|
||||
MetadataUtil.add(g, { loras: loraMetadata });
|
||||
|
@ -3,7 +3,7 @@ import type { Graph } from 'features/nodes/util/graph/Graph';
|
||||
import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil';
|
||||
import type { Invocation } from 'services/api/types';
|
||||
|
||||
import { SEAMLESS, VAE_LOADER } from './constants';
|
||||
import { SEAMLESS } from './constants';
|
||||
|
||||
/**
|
||||
* Adds the seamless node to the graph and connects it to the model loader and denoise node.
|
||||
@ -19,9 +19,10 @@ export const addGenerationTabSeamless = (
|
||||
state: RootState,
|
||||
g: Graph,
|
||||
denoise: Invocation<'denoise_latents'>,
|
||||
modelLoader: Invocation<'main_model_loader'> | Invocation<'sdxl_model_loader'>
|
||||
modelLoader: Invocation<'main_model_loader'> | Invocation<'sdxl_model_loader'>,
|
||||
vaeLoader: Invocation<'vae_loader'> | null
|
||||
): Invocation<'seamless'> | null => {
|
||||
const { seamlessXAxis: seamless_x, seamlessYAxis: seamless_y, vae } = state.generation;
|
||||
const { seamlessXAxis: seamless_x, seamlessYAxis: seamless_y } = state.generation;
|
||||
|
||||
if (!seamless_x && !seamless_y) {
|
||||
return null;
|
||||
@ -34,16 +35,6 @@ export const addGenerationTabSeamless = (
|
||||
seamless_y,
|
||||
});
|
||||
|
||||
// The VAE helper also adds the VAE loader - so we need to check if it's already there
|
||||
const shouldAddVAELoader = !g.hasNode(VAE_LOADER) && vae;
|
||||
const vaeLoader = shouldAddVAELoader
|
||||
? g.addNode({
|
||||
type: 'vae_loader',
|
||||
id: VAE_LOADER,
|
||||
vae_model: vae,
|
||||
})
|
||||
: null;
|
||||
|
||||
MetadataUtil.add(g, {
|
||||
seamless_x: seamless_x || undefined,
|
||||
seamless_y: seamless_y || undefined,
|
||||
|
@ -5,7 +5,6 @@ import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetch
|
||||
import { addGenerationTabControlLayers } from 'features/nodes/util/graph/addGenerationTabControlLayers';
|
||||
import { addGenerationTabLoRAs } from 'features/nodes/util/graph/addGenerationTabLoRAs';
|
||||
import { addGenerationTabSeamless } from 'features/nodes/util/graph/addGenerationTabSeamless';
|
||||
import { addGenerationTabVAE } from 'features/nodes/util/graph/addGenerationTabVAE';
|
||||
import type { GraphType } from 'features/nodes/util/graph/Graph';
|
||||
import { Graph } from 'features/nodes/util/graph/Graph';
|
||||
import { getBoardField } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
@ -26,6 +25,7 @@ import {
|
||||
NOISE,
|
||||
POSITIVE_CONDITIONING,
|
||||
POSITIVE_CONDITIONING_COLLECT,
|
||||
VAE_LOADER,
|
||||
} from './constants';
|
||||
import { getModelMetadataField } from './metadata';
|
||||
|
||||
@ -41,6 +41,7 @@ export const buildGenerationTabGraph2 = async (state: RootState): Promise<GraphT
|
||||
shouldUseCpuNoise,
|
||||
vaePrecision,
|
||||
seed,
|
||||
vae,
|
||||
} = state.generation;
|
||||
const { positivePrompt, negativePrompt } = state.controlLayers.present;
|
||||
const { width, height } = state.controlLayers.present.size;
|
||||
@ -106,6 +107,14 @@ export const buildGenerationTabGraph2 = async (state: RootState): Promise<GraphT
|
||||
is_intermediate: false,
|
||||
use_cache: false,
|
||||
});
|
||||
const vaeLoader =
|
||||
vae?.base === model.base
|
||||
? g.addNode({
|
||||
type: 'vae_loader',
|
||||
id: VAE_LOADER,
|
||||
vae_model: vae,
|
||||
})
|
||||
: null;
|
||||
|
||||
g.addEdge(modelLoader, 'unet', denoise, 'unet');
|
||||
g.addEdge(modelLoader, 'clip', clipSkip, 'clip');
|
||||
@ -134,17 +143,20 @@ export const buildGenerationTabGraph2 = async (state: RootState): Promise<GraphT
|
||||
rand_device: shouldUseCpuNoise ? 'cpu' : 'cuda',
|
||||
scheduler,
|
||||
clip_skip: skipped_layers,
|
||||
vae: vae ?? undefined,
|
||||
});
|
||||
MetadataUtil.setMetadataReceivingNode(g, l2i);
|
||||
g.validate();
|
||||
|
||||
const seamless = addGenerationTabSeamless(state, g, denoise, modelLoader);
|
||||
const seamless = addGenerationTabSeamless(state, g, denoise, modelLoader, vaeLoader);
|
||||
g.validate();
|
||||
addGenerationTabVAE(state, g, modelLoader, l2i, i2l, seamless);
|
||||
g.validate();
|
||||
addGenerationTabLoRAs(state, g, denoise, seamless ?? modelLoader, clipSkip, posCond, negCond);
|
||||
addGenerationTabLoRAs(state, g, denoise, modelLoader, seamless, clipSkip, posCond, negCond);
|
||||
g.validate();
|
||||
|
||||
// We might get the VAE from the main model, custom VAE, or seamless node.
|
||||
const vaeSource = seamless ?? vaeLoader ?? modelLoader;
|
||||
g.addEdge(vaeSource, 'vae', l2i, 'vae');
|
||||
|
||||
const addedLayers = await addGenerationTabControlLayers(
|
||||
state,
|
||||
g,
|
||||
@ -153,7 +165,8 @@ export const buildGenerationTabGraph2 = async (state: RootState): Promise<GraphT
|
||||
negCond,
|
||||
posCondCollect,
|
||||
negCondCollect,
|
||||
noise
|
||||
noise,
|
||||
vaeSource
|
||||
);
|
||||
g.validate();
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user