feat(ui): use graph builder for generation tab sdxl

This commit is contained in:
psychedelicious 2024-05-13 20:58:51 +10:00
parent 5a4b050e66
commit 5425526d50
5 changed files with 397 additions and 17 deletions

View File

@ -2,7 +2,7 @@ import { enqueueRequested } from 'app/store/actions';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice'; import { isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
import { buildGenerationTabGraph2 } from 'features/nodes/util/graph/buildGenerationTabGraph2'; import { buildGenerationTabGraph2 } from 'features/nodes/util/graph/buildGenerationTabGraph2';
import { buildGenerationTabSDXLGraph } from 'features/nodes/util/graph/buildGenerationTabSDXLGraph'; import { buildGenerationTabSDXLGraph2 } from 'features/nodes/util/graph/buildGenerationTabSDXLGraph2';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig'; import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
import { queueApi } from 'services/api/endpoints/queue'; import { queueApi } from 'services/api/endpoints/queue';
@ -19,7 +19,7 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
let graph; let graph;
if (model && model.base === 'sdxl') { if (model && model.base === 'sdxl') {
graph = await buildGenerationTabSDXLGraph(state); graph = await buildGenerationTabSDXLGraph2(state);
} else { } else {
graph = await buildGenerationTabGraph2(state); graph = await buildGenerationTabGraph2(state);
} }

View File

@ -118,11 +118,20 @@ export const addGenerationTabControlLayers = async (
// Connect the conditioning to the collector // Connect the conditioning to the collector
g.addEdge(regionalPosCond, 'conditioning', posCondCollect, 'item'); g.addEdge(regionalPosCond, 'conditioning', posCondCollect, 'item');
// Copy the connections to the "global" positive conditioning node to the regional cond // Copy the connections to the "global" positive conditioning node to the regional cond
for (const edge of g.getEdgesTo(posCond, ['clip', 'mask'])) { if (posCond.type === 'compel') {
// Clone the edge, but change the destination node to the regional conditioning node for (const edge of g.getEdgesTo(posCond, ['clip', 'mask'])) {
const clone = deepClone(edge); // Clone the edge, but change the destination node to the regional conditioning node
clone.destination.node_id = regionalPosCond.id; const clone = deepClone(edge);
g.addEdgeFromObj(clone); clone.destination.node_id = regionalPosCond.id;
g.addEdgeFromObj(clone);
}
} else {
for (const edge of g.getEdgesTo(posCond, ['clip', 'clip2', 'mask'])) {
// Clone the edge, but change the destination node to the regional conditioning node
const clone = deepClone(edge);
clone.destination.node_id = regionalPosCond.id;
g.addEdgeFromObj(clone);
}
} }
} }
@ -147,11 +156,18 @@ export const addGenerationTabControlLayers = async (
// Connect the conditioning to the collector // Connect the conditioning to the collector
g.addEdge(regionalNegCond, 'conditioning', negCondCollect, 'item'); g.addEdge(regionalNegCond, 'conditioning', negCondCollect, 'item');
// Copy the connections to the "global" negative conditioning node to the regional cond // Copy the connections to the "global" negative conditioning node to the regional cond
for (const edge of g.getEdgesTo(negCond, ['clip', 'mask'])) { if (negCond.type === 'compel') {
// Clone the edge, but change the destination node to the regional conditioning node for (const edge of g.getEdgesTo(negCond, ['clip', 'mask'])) {
const clone = deepClone(edge); const clone = deepClone(edge);
clone.destination.node_id = regionalNegCond.id; clone.destination.node_id = regionalNegCond.id;
g.addEdgeFromObj(clone); g.addEdgeFromObj(clone);
}
} else {
for (const edge of g.getEdgesTo(negCond, ['clip', 'clip2', 'mask'])) {
const clone = deepClone(edge);
clone.destination.node_id = regionalNegCond.id;
g.addEdgeFromObj(clone);
}
} }
} }
@ -184,11 +200,18 @@ export const addGenerationTabControlLayers = async (
// Connect the conditioning to the negative collector // Connect the conditioning to the negative collector
g.addEdge(regionalPosCondInverted, 'conditioning', negCondCollect, 'item'); g.addEdge(regionalPosCondInverted, 'conditioning', negCondCollect, 'item');
// Copy the connections to the "global" positive conditioning node to our regional node // Copy the connections to the "global" positive conditioning node to our regional node
for (const edge of g.getEdgesTo(posCond, ['clip', 'mask'])) { if (posCond.type === 'compel') {
// Clone the edge, but change the destination node to the regional conditioning node for (const edge of g.getEdgesTo(posCond, ['clip', 'mask'])) {
const clone = deepClone(edge); const clone = deepClone(edge);
clone.destination.node_id = regionalPosCondInverted.id; clone.destination.node_id = regionalPosCondInverted.id;
g.addEdgeFromObj(clone); g.addEdgeFromObj(clone);
}
} else {
for (const edge of g.getEdgesTo(posCond, ['clip', 'clip2', 'mask'])) {
const clone = deepClone(edge);
clone.destination.node_id = regionalPosCondInverted.id;
g.addEdgeFromObj(clone);
}
} }
} }

View File

@ -0,0 +1,75 @@
import type { RootState } from 'app/store/store';
import { zModelIdentifierField } from 'features/nodes/types/common';
import type { Graph } from 'features/nodes/util/graph/Graph';
import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil';
import { filter, size } from 'lodash-es';
import type { Invocation, S } from 'services/api/types';
import { LORA_LOADER } from './constants';
export const addGenerationTabSDXLLoRAs = (
state: RootState,
g: Graph,
denoise: Invocation<'denoise_latents'>,
modelLoader: Invocation<'sdxl_model_loader'>,
seamless: Invocation<'seamless'> | null,
posCond: Invocation<'sdxl_compel_prompt'>,
negCond: Invocation<'sdxl_compel_prompt'>
): void => {
const enabledLoRAs = filter(state.lora.loras, (l) => l.isEnabled ?? false);
const loraCount = size(enabledLoRAs);
if (loraCount === 0) {
return;
}
const loraMetadata: S['LoRAMetadataField'][] = [];
// We will collect LoRAs into a single collection node, then pass them to the LoRA collection loader, which applies
// each LoRA to the UNet and CLIP.
const loraCollector = g.addNode({
id: `${LORA_LOADER}_collect`,
type: 'collect',
});
const loraCollectionLoader = g.addNode({
id: LORA_LOADER,
type: 'sdxl_lora_collection_loader',
});
g.addEdge(loraCollector, 'collection', loraCollectionLoader, 'loras');
// Use seamless as UNet input if it exists, otherwise use the model loader
g.addEdge(seamless ?? modelLoader, 'unet', loraCollectionLoader, 'unet');
g.addEdge(modelLoader, 'clip', loraCollectionLoader, 'clip');
g.addEdge(modelLoader, 'clip2', loraCollectionLoader, 'clip2');
// Reroute UNet & CLIP connections through the LoRA collection loader
g.deleteEdgesTo(denoise, ['unet']);
g.deleteEdgesTo(posCond, ['clip', 'clip2']);
g.deleteEdgesTo(negCond, ['clip', 'clip2']);
g.addEdge(loraCollectionLoader, 'unet', denoise, 'unet');
g.addEdge(loraCollectionLoader, 'clip', posCond, 'clip');
g.addEdge(loraCollectionLoader, 'clip', negCond, 'clip');
g.addEdge(loraCollectionLoader, 'clip2', posCond, 'clip2');
g.addEdge(loraCollectionLoader, 'clip2', negCond, 'clip2');
for (const lora of enabledLoRAs) {
const { weight } = lora;
const { key } = lora.model;
const parsedModel = zModelIdentifierField.parse(lora.model);
const loraSelector = g.addNode({
type: 'lora_selector',
id: `${LORA_LOADER}_${key}`,
lora: parsedModel,
weight,
});
loraMetadata.push({
model: parsedModel,
weight,
});
g.addEdge(loraSelector, 'lora', loraCollector, 'item');
}
MetadataUtil.add(g, { loras: loraMetadata });
};

View File

@ -0,0 +1,104 @@
import type { RootState } from 'app/store/store';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import type { Graph } from 'features/nodes/util/graph/Graph';
import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil';
import type { Invocation } from 'services/api/types';
import { isRefinerMainModelModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
import {
SDXL_REFINER_DENOISE_LATENTS,
SDXL_REFINER_MODEL_LOADER,
SDXL_REFINER_NEGATIVE_CONDITIONING,
SDXL_REFINER_POSITIVE_CONDITIONING,
SDXL_REFINER_SEAMLESS,
} from './constants';
import { getModelMetadataField } from './metadata';
export const addGenerationTabSDXLRefiner = async (
state: RootState,
g: Graph,
denoise: Invocation<'denoise_latents'>,
modelLoader: Invocation<'sdxl_model_loader'>,
seamless: Invocation<'seamless'> | null,
posCond: Invocation<'sdxl_compel_prompt'>,
negCond: Invocation<'sdxl_compel_prompt'>,
l2i: Invocation<'l2i'>
): Promise<void> => {
const {
refinerModel,
refinerPositiveAestheticScore,
refinerNegativeAestheticScore,
refinerSteps,
refinerScheduler,
refinerCFGScale,
refinerStart,
} = state.sdxl;
assert(refinerModel, 'No refiner model found in state');
const modelConfig = await fetchModelConfigWithTypeGuard(refinerModel.key, isRefinerMainModelModelConfig);
// We need to re-route latents to the refiner
g.deleteEdgesFrom(denoise, ['latents']);
// Latents will now come from refiner - delete edges to the l2i VAE decode
g.deleteEdgesTo(l2i, ['latents']);
const refinerModelLoader = g.addNode({
type: 'sdxl_refiner_model_loader',
id: SDXL_REFINER_MODEL_LOADER,
model: refinerModel,
});
const refinerPosCond = g.addNode({
type: 'sdxl_refiner_compel_prompt',
id: SDXL_REFINER_POSITIVE_CONDITIONING,
style: posCond.style,
aesthetic_score: refinerPositiveAestheticScore,
});
const refinerNegCond = g.addNode({
type: 'sdxl_refiner_compel_prompt',
id: SDXL_REFINER_NEGATIVE_CONDITIONING,
style: negCond.style,
aesthetic_score: refinerNegativeAestheticScore,
});
const refinerDenoise = g.addNode({
type: 'denoise_latents',
id: SDXL_REFINER_DENOISE_LATENTS,
cfg_scale: refinerCFGScale,
steps: refinerSteps,
scheduler: refinerScheduler,
denoising_start: refinerStart,
denoising_end: 1,
});
if (seamless) {
const refinerSeamless = g.addNode({
id: SDXL_REFINER_SEAMLESS,
type: 'seamless',
seamless_x: seamless.seamless_x,
seamless_y: seamless.seamless_y,
});
g.addEdge(refinerModelLoader, 'unet', refinerSeamless, 'unet');
g.addEdge(refinerModelLoader, 'vae', refinerSeamless, 'vae');
g.addEdge(refinerSeamless, 'unet', refinerDenoise, 'unet');
} else {
g.addEdge(refinerModelLoader, 'unet', refinerDenoise, 'unet');
}
g.addEdge(refinerModelLoader, 'clip2', refinerPosCond, 'clip2');
g.addEdge(refinerModelLoader, 'clip2', refinerNegCond, 'clip2');
g.addEdge(refinerPosCond, 'conditioning', refinerDenoise, 'positive_conditioning');
g.addEdge(refinerNegCond, 'conditioning', refinerDenoise, 'negative_conditioning');
g.addEdge(denoise, 'latents', refinerDenoise, 'latents');
g.addEdge(refinerDenoise, 'latents', l2i, 'latents');
MetadataUtil.add(g, {
refiner_model: getModelMetadataField(modelConfig),
refiner_positive_aesthetic_score: refinerPositiveAestheticScore,
refiner_negative_aesthetic_score: refinerNegativeAestheticScore,
refiner_cfg_scale: refinerCFGScale,
refiner_scheduler: refinerScheduler,
refiner_start: refinerStart,
refiner_steps: refinerSteps,
});
};

View File

@ -0,0 +1,178 @@
import type { RootState } from 'app/store/store';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { addGenerationTabControlLayers } from 'features/nodes/util/graph/addGenerationTabControlLayers';
import { addGenerationTabNSFWChecker } from 'features/nodes/util/graph/addGenerationTabNSFWChecker';
import { addGenerationTabSDXLLoRAs } from 'features/nodes/util/graph/addGenerationTabSDXLLoRAs';
import { addGenerationTabSDXLRefiner } from 'features/nodes/util/graph/addGenerationTabSDXLRefiner';
import { addGenerationTabSeamless } from 'features/nodes/util/graph/addGenerationTabSeamless';
import { addGenerationTabWatermarker } from 'features/nodes/util/graph/addGenerationTabWatermarker';
import { Graph } from 'features/nodes/util/graph/Graph';
import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil';
import type { Invocation, NonNullableGraph } from 'services/api/types';
import { isNonRefinerMainModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
import {
LATENTS_TO_IMAGE,
NEGATIVE_CONDITIONING,
NEGATIVE_CONDITIONING_COLLECT,
NOISE,
POSITIVE_CONDITIONING,
POSITIVE_CONDITIONING_COLLECT,
SDXL_CONTROL_LAYERS_GRAPH,
SDXL_DENOISE_LATENTS,
SDXL_MODEL_LOADER,
VAE_LOADER,
} from './constants';
import { getBoardField, getSDXLStylePrompts } from './graphBuilderUtils';
import { getModelMetadataField } from './metadata';
export const buildGenerationTabSDXLGraph2 = async (state: RootState): Promise<NonNullableGraph> => {
const {
model,
cfgScale: cfg_scale,
cfgRescaleMultiplier: cfg_rescale_multiplier,
scheduler,
seed,
steps,
shouldUseCpuNoise,
vaePrecision,
vae,
} = state.generation;
const { positivePrompt, negativePrompt } = state.controlLayers.present;
const { width, height } = state.controlLayers.present.size;
const { refinerModel, refinerStart } = state.sdxl;
assert(model, 'No model found in state');
const { positiveStylePrompt, negativeStylePrompt } = getSDXLStylePrompts(state);
const g = new Graph(SDXL_CONTROL_LAYERS_GRAPH);
const modelLoader = g.addNode({
type: 'sdxl_model_loader',
id: SDXL_MODEL_LOADER,
model,
});
const posCond = g.addNode({
type: 'sdxl_compel_prompt',
id: POSITIVE_CONDITIONING,
prompt: positivePrompt,
style: positiveStylePrompt,
});
const posCondCollect = g.addNode({
type: 'collect',
id: POSITIVE_CONDITIONING_COLLECT,
});
const negCond = g.addNode({
type: 'sdxl_compel_prompt',
id: NEGATIVE_CONDITIONING,
prompt: negativePrompt,
style: negativeStylePrompt,
});
const negCondCollect = g.addNode({
type: 'collect',
id: NEGATIVE_CONDITIONING_COLLECT,
});
const noise = g.addNode({ type: 'noise', id: NOISE, seed, width, height, use_cpu: shouldUseCpuNoise });
const denoise = g.addNode({
type: 'denoise_latents',
id: SDXL_DENOISE_LATENTS,
cfg_scale,
cfg_rescale_multiplier,
scheduler,
steps,
denoising_start: 0,
denoising_end: refinerModel ? refinerStart : 1,
});
const l2i = g.addNode({
type: 'l2i',
id: LATENTS_TO_IMAGE,
fp32: vaePrecision === 'fp32',
board: getBoardField(state),
// This is the terminal node and must always save to gallery.
is_intermediate: false,
use_cache: false,
});
const vaeLoader =
vae?.base === model.base
? g.addNode({
type: 'vae_loader',
id: VAE_LOADER,
vae_model: vae,
})
: null;
let imageOutput: Invocation<'l2i'> | Invocation<'img_nsfw'> | Invocation<'img_watermark'> = l2i;
g.addEdge(modelLoader, 'unet', denoise, 'unet');
g.addEdge(modelLoader, 'clip', posCond, 'clip');
g.addEdge(modelLoader, 'clip', negCond, 'clip');
g.addEdge(modelLoader, 'clip2', posCond, 'clip2');
g.addEdge(modelLoader, 'clip2', negCond, 'clip2');
g.addEdge(posCond, 'conditioning', posCondCollect, 'item');
g.addEdge(negCond, 'conditioning', negCondCollect, 'item');
g.addEdge(posCondCollect, 'collection', denoise, 'positive_conditioning');
g.addEdge(negCondCollect, 'collection', denoise, 'negative_conditioning');
g.addEdge(noise, 'noise', denoise, 'noise');
g.addEdge(denoise, 'latents', l2i, 'latents');
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
MetadataUtil.add(g, {
generation_mode: 'txt2img',
cfg_scale,
cfg_rescale_multiplier,
height,
width,
positive_prompt: positivePrompt,
negative_prompt: negativePrompt,
model: getModelMetadataField(modelConfig),
seed,
steps,
rand_device: shouldUseCpuNoise ? 'cpu' : 'cuda',
scheduler,
positive_style_prompt: positiveStylePrompt,
negative_style_prompt: negativeStylePrompt,
vae: vae ?? undefined,
});
g.validate();
const seamless = addGenerationTabSeamless(state, g, denoise, modelLoader, vaeLoader);
g.validate();
addGenerationTabSDXLLoRAs(state, g, denoise, modelLoader, seamless, posCond, negCond);
g.validate();
// We might get the VAE from the main model, custom VAE, or seamless node.
const vaeSource = seamless ?? vaeLoader ?? modelLoader;
g.addEdge(vaeSource, 'vae', l2i, 'vae');
// Add Refiner if enabled
if (refinerModel) {
await addGenerationTabSDXLRefiner(state, g, denoise, modelLoader, seamless, posCond, negCond, l2i);
}
await addGenerationTabControlLayers(
state,
g,
denoise,
posCond,
negCond,
posCondCollect,
negCondCollect,
noise,
vaeSource
);
if (state.system.shouldUseNSFWChecker) {
imageOutput = addGenerationTabNSFWChecker(g, imageOutput);
}
if (state.system.shouldUseWatermarker) {
imageOutput = addGenerationTabWatermarker(g, imageOutput);
}
MetadataUtil.setMetadataReceivingNode(g, imageOutput);
return g.getGraph();
};