mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): use graph builder for generation tab sdxl
This commit is contained in:
parent
5a4b050e66
commit
5425526d50
@ -2,7 +2,7 @@ import { enqueueRequested } from 'app/store/actions';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { buildGenerationTabGraph2 } from 'features/nodes/util/graph/buildGenerationTabGraph2';
|
||||
import { buildGenerationTabSDXLGraph } from 'features/nodes/util/graph/buildGenerationTabSDXLGraph';
|
||||
import { buildGenerationTabSDXLGraph2 } from 'features/nodes/util/graph/buildGenerationTabSDXLGraph2';
|
||||
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
|
||||
@ -19,7 +19,7 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
|
||||
let graph;
|
||||
|
||||
if (model && model.base === 'sdxl') {
|
||||
graph = await buildGenerationTabSDXLGraph(state);
|
||||
graph = await buildGenerationTabSDXLGraph2(state);
|
||||
} else {
|
||||
graph = await buildGenerationTabGraph2(state);
|
||||
}
|
||||
|
@ -118,11 +118,20 @@ export const addGenerationTabControlLayers = async (
|
||||
// Connect the conditioning to the collector
|
||||
g.addEdge(regionalPosCond, 'conditioning', posCondCollect, 'item');
|
||||
// Copy the connections to the "global" positive conditioning node to the regional cond
|
||||
for (const edge of g.getEdgesTo(posCond, ['clip', 'mask'])) {
|
||||
// Clone the edge, but change the destination node to the regional conditioning node
|
||||
const clone = deepClone(edge);
|
||||
clone.destination.node_id = regionalPosCond.id;
|
||||
g.addEdgeFromObj(clone);
|
||||
if (posCond.type === 'compel') {
|
||||
for (const edge of g.getEdgesTo(posCond, ['clip', 'mask'])) {
|
||||
// Clone the edge, but change the destination node to the regional conditioning node
|
||||
const clone = deepClone(edge);
|
||||
clone.destination.node_id = regionalPosCond.id;
|
||||
g.addEdgeFromObj(clone);
|
||||
}
|
||||
} else {
|
||||
for (const edge of g.getEdgesTo(posCond, ['clip', 'clip2', 'mask'])) {
|
||||
// Clone the edge, but change the destination node to the regional conditioning node
|
||||
const clone = deepClone(edge);
|
||||
clone.destination.node_id = regionalPosCond.id;
|
||||
g.addEdgeFromObj(clone);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -147,11 +156,18 @@ export const addGenerationTabControlLayers = async (
|
||||
// Connect the conditioning to the collector
|
||||
g.addEdge(regionalNegCond, 'conditioning', negCondCollect, 'item');
|
||||
// Copy the connections to the "global" negative conditioning node to the regional cond
|
||||
for (const edge of g.getEdgesTo(negCond, ['clip', 'mask'])) {
|
||||
// Clone the edge, but change the destination node to the regional conditioning node
|
||||
const clone = deepClone(edge);
|
||||
clone.destination.node_id = regionalNegCond.id;
|
||||
g.addEdgeFromObj(clone);
|
||||
if (negCond.type === 'compel') {
|
||||
for (const edge of g.getEdgesTo(negCond, ['clip', 'mask'])) {
|
||||
const clone = deepClone(edge);
|
||||
clone.destination.node_id = regionalNegCond.id;
|
||||
g.addEdgeFromObj(clone);
|
||||
}
|
||||
} else {
|
||||
for (const edge of g.getEdgesTo(negCond, ['clip', 'clip2', 'mask'])) {
|
||||
const clone = deepClone(edge);
|
||||
clone.destination.node_id = regionalNegCond.id;
|
||||
g.addEdgeFromObj(clone);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -184,11 +200,18 @@ export const addGenerationTabControlLayers = async (
|
||||
// Connect the conditioning to the negative collector
|
||||
g.addEdge(regionalPosCondInverted, 'conditioning', negCondCollect, 'item');
|
||||
// Copy the connections to the "global" positive conditioning node to our regional node
|
||||
for (const edge of g.getEdgesTo(posCond, ['clip', 'mask'])) {
|
||||
// Clone the edge, but change the destination node to the regional conditioning node
|
||||
const clone = deepClone(edge);
|
||||
clone.destination.node_id = regionalPosCondInverted.id;
|
||||
g.addEdgeFromObj(clone);
|
||||
if (posCond.type === 'compel') {
|
||||
for (const edge of g.getEdgesTo(posCond, ['clip', 'mask'])) {
|
||||
const clone = deepClone(edge);
|
||||
clone.destination.node_id = regionalPosCondInverted.id;
|
||||
g.addEdgeFromObj(clone);
|
||||
}
|
||||
} else {
|
||||
for (const edge of g.getEdgesTo(posCond, ['clip', 'clip2', 'mask'])) {
|
||||
const clone = deepClone(edge);
|
||||
clone.destination.node_id = regionalPosCondInverted.id;
|
||||
g.addEdgeFromObj(clone);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -0,0 +1,75 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import type { Graph } from 'features/nodes/util/graph/Graph';
|
||||
import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil';
|
||||
import { filter, size } from 'lodash-es';
|
||||
import type { Invocation, S } from 'services/api/types';
|
||||
|
||||
import { LORA_LOADER } from './constants';
|
||||
|
||||
export const addGenerationTabSDXLLoRAs = (
|
||||
state: RootState,
|
||||
g: Graph,
|
||||
denoise: Invocation<'denoise_latents'>,
|
||||
modelLoader: Invocation<'sdxl_model_loader'>,
|
||||
seamless: Invocation<'seamless'> | null,
|
||||
posCond: Invocation<'sdxl_compel_prompt'>,
|
||||
negCond: Invocation<'sdxl_compel_prompt'>
|
||||
): void => {
|
||||
const enabledLoRAs = filter(state.lora.loras, (l) => l.isEnabled ?? false);
|
||||
const loraCount = size(enabledLoRAs);
|
||||
|
||||
if (loraCount === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const loraMetadata: S['LoRAMetadataField'][] = [];
|
||||
|
||||
// We will collect LoRAs into a single collection node, then pass them to the LoRA collection loader, which applies
|
||||
// each LoRA to the UNet and CLIP.
|
||||
const loraCollector = g.addNode({
|
||||
id: `${LORA_LOADER}_collect`,
|
||||
type: 'collect',
|
||||
});
|
||||
const loraCollectionLoader = g.addNode({
|
||||
id: LORA_LOADER,
|
||||
type: 'sdxl_lora_collection_loader',
|
||||
});
|
||||
|
||||
g.addEdge(loraCollector, 'collection', loraCollectionLoader, 'loras');
|
||||
// Use seamless as UNet input if it exists, otherwise use the model loader
|
||||
g.addEdge(seamless ?? modelLoader, 'unet', loraCollectionLoader, 'unet');
|
||||
g.addEdge(modelLoader, 'clip', loraCollectionLoader, 'clip');
|
||||
g.addEdge(modelLoader, 'clip2', loraCollectionLoader, 'clip2');
|
||||
// Reroute UNet & CLIP connections through the LoRA collection loader
|
||||
g.deleteEdgesTo(denoise, ['unet']);
|
||||
g.deleteEdgesTo(posCond, ['clip', 'clip2']);
|
||||
g.deleteEdgesTo(negCond, ['clip', 'clip2']);
|
||||
g.addEdge(loraCollectionLoader, 'unet', denoise, 'unet');
|
||||
g.addEdge(loraCollectionLoader, 'clip', posCond, 'clip');
|
||||
g.addEdge(loraCollectionLoader, 'clip', negCond, 'clip');
|
||||
g.addEdge(loraCollectionLoader, 'clip2', posCond, 'clip2');
|
||||
g.addEdge(loraCollectionLoader, 'clip2', negCond, 'clip2');
|
||||
|
||||
for (const lora of enabledLoRAs) {
|
||||
const { weight } = lora;
|
||||
const { key } = lora.model;
|
||||
const parsedModel = zModelIdentifierField.parse(lora.model);
|
||||
|
||||
const loraSelector = g.addNode({
|
||||
type: 'lora_selector',
|
||||
id: `${LORA_LOADER}_${key}`,
|
||||
lora: parsedModel,
|
||||
weight,
|
||||
});
|
||||
|
||||
loraMetadata.push({
|
||||
model: parsedModel,
|
||||
weight,
|
||||
});
|
||||
|
||||
g.addEdge(loraSelector, 'lora', loraCollector, 'item');
|
||||
}
|
||||
|
||||
MetadataUtil.add(g, { loras: loraMetadata });
|
||||
};
|
@ -0,0 +1,104 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import type { Graph } from 'features/nodes/util/graph/Graph';
|
||||
import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil';
|
||||
import type { Invocation } from 'services/api/types';
|
||||
import { isRefinerMainModelModelConfig } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
import {
|
||||
SDXL_REFINER_DENOISE_LATENTS,
|
||||
SDXL_REFINER_MODEL_LOADER,
|
||||
SDXL_REFINER_NEGATIVE_CONDITIONING,
|
||||
SDXL_REFINER_POSITIVE_CONDITIONING,
|
||||
SDXL_REFINER_SEAMLESS,
|
||||
} from './constants';
|
||||
import { getModelMetadataField } from './metadata';
|
||||
|
||||
export const addGenerationTabSDXLRefiner = async (
|
||||
state: RootState,
|
||||
g: Graph,
|
||||
denoise: Invocation<'denoise_latents'>,
|
||||
modelLoader: Invocation<'sdxl_model_loader'>,
|
||||
seamless: Invocation<'seamless'> | null,
|
||||
posCond: Invocation<'sdxl_compel_prompt'>,
|
||||
negCond: Invocation<'sdxl_compel_prompt'>,
|
||||
l2i: Invocation<'l2i'>
|
||||
): Promise<void> => {
|
||||
const {
|
||||
refinerModel,
|
||||
refinerPositiveAestheticScore,
|
||||
refinerNegativeAestheticScore,
|
||||
refinerSteps,
|
||||
refinerScheduler,
|
||||
refinerCFGScale,
|
||||
refinerStart,
|
||||
} = state.sdxl;
|
||||
|
||||
assert(refinerModel, 'No refiner model found in state');
|
||||
|
||||
const modelConfig = await fetchModelConfigWithTypeGuard(refinerModel.key, isRefinerMainModelModelConfig);
|
||||
|
||||
// We need to re-route latents to the refiner
|
||||
g.deleteEdgesFrom(denoise, ['latents']);
|
||||
// Latents will now come from refiner - delete edges to the l2i VAE decode
|
||||
g.deleteEdgesTo(l2i, ['latents']);
|
||||
|
||||
const refinerModelLoader = g.addNode({
|
||||
type: 'sdxl_refiner_model_loader',
|
||||
id: SDXL_REFINER_MODEL_LOADER,
|
||||
model: refinerModel,
|
||||
});
|
||||
const refinerPosCond = g.addNode({
|
||||
type: 'sdxl_refiner_compel_prompt',
|
||||
id: SDXL_REFINER_POSITIVE_CONDITIONING,
|
||||
style: posCond.style,
|
||||
aesthetic_score: refinerPositiveAestheticScore,
|
||||
});
|
||||
const refinerNegCond = g.addNode({
|
||||
type: 'sdxl_refiner_compel_prompt',
|
||||
id: SDXL_REFINER_NEGATIVE_CONDITIONING,
|
||||
style: negCond.style,
|
||||
aesthetic_score: refinerNegativeAestheticScore,
|
||||
});
|
||||
const refinerDenoise = g.addNode({
|
||||
type: 'denoise_latents',
|
||||
id: SDXL_REFINER_DENOISE_LATENTS,
|
||||
cfg_scale: refinerCFGScale,
|
||||
steps: refinerSteps,
|
||||
scheduler: refinerScheduler,
|
||||
denoising_start: refinerStart,
|
||||
denoising_end: 1,
|
||||
});
|
||||
|
||||
if (seamless) {
|
||||
const refinerSeamless = g.addNode({
|
||||
id: SDXL_REFINER_SEAMLESS,
|
||||
type: 'seamless',
|
||||
seamless_x: seamless.seamless_x,
|
||||
seamless_y: seamless.seamless_y,
|
||||
});
|
||||
g.addEdge(refinerModelLoader, 'unet', refinerSeamless, 'unet');
|
||||
g.addEdge(refinerModelLoader, 'vae', refinerSeamless, 'vae');
|
||||
g.addEdge(refinerSeamless, 'unet', refinerDenoise, 'unet');
|
||||
} else {
|
||||
g.addEdge(refinerModelLoader, 'unet', refinerDenoise, 'unet');
|
||||
}
|
||||
|
||||
g.addEdge(refinerModelLoader, 'clip2', refinerPosCond, 'clip2');
|
||||
g.addEdge(refinerModelLoader, 'clip2', refinerNegCond, 'clip2');
|
||||
g.addEdge(refinerPosCond, 'conditioning', refinerDenoise, 'positive_conditioning');
|
||||
g.addEdge(refinerNegCond, 'conditioning', refinerDenoise, 'negative_conditioning');
|
||||
g.addEdge(denoise, 'latents', refinerDenoise, 'latents');
|
||||
g.addEdge(refinerDenoise, 'latents', l2i, 'latents');
|
||||
|
||||
MetadataUtil.add(g, {
|
||||
refiner_model: getModelMetadataField(modelConfig),
|
||||
refiner_positive_aesthetic_score: refinerPositiveAestheticScore,
|
||||
refiner_negative_aesthetic_score: refinerNegativeAestheticScore,
|
||||
refiner_cfg_scale: refinerCFGScale,
|
||||
refiner_scheduler: refinerScheduler,
|
||||
refiner_start: refinerStart,
|
||||
refiner_steps: refinerSteps,
|
||||
});
|
||||
};
|
@ -0,0 +1,178 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import { addGenerationTabControlLayers } from 'features/nodes/util/graph/addGenerationTabControlLayers';
|
||||
import { addGenerationTabNSFWChecker } from 'features/nodes/util/graph/addGenerationTabNSFWChecker';
|
||||
import { addGenerationTabSDXLLoRAs } from 'features/nodes/util/graph/addGenerationTabSDXLLoRAs';
|
||||
import { addGenerationTabSDXLRefiner } from 'features/nodes/util/graph/addGenerationTabSDXLRefiner';
|
||||
import { addGenerationTabSeamless } from 'features/nodes/util/graph/addGenerationTabSeamless';
|
||||
import { addGenerationTabWatermarker } from 'features/nodes/util/graph/addGenerationTabWatermarker';
|
||||
import { Graph } from 'features/nodes/util/graph/Graph';
|
||||
import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil';
|
||||
import type { Invocation, NonNullableGraph } from 'services/api/types';
|
||||
import { isNonRefinerMainModelConfig } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
import {
|
||||
LATENTS_TO_IMAGE,
|
||||
NEGATIVE_CONDITIONING,
|
||||
NEGATIVE_CONDITIONING_COLLECT,
|
||||
NOISE,
|
||||
POSITIVE_CONDITIONING,
|
||||
POSITIVE_CONDITIONING_COLLECT,
|
||||
SDXL_CONTROL_LAYERS_GRAPH,
|
||||
SDXL_DENOISE_LATENTS,
|
||||
SDXL_MODEL_LOADER,
|
||||
VAE_LOADER,
|
||||
} from './constants';
|
||||
import { getBoardField, getSDXLStylePrompts } from './graphBuilderUtils';
|
||||
import { getModelMetadataField } from './metadata';
|
||||
|
||||
export const buildGenerationTabSDXLGraph2 = async (state: RootState): Promise<NonNullableGraph> => {
|
||||
const {
|
||||
model,
|
||||
cfgScale: cfg_scale,
|
||||
cfgRescaleMultiplier: cfg_rescale_multiplier,
|
||||
scheduler,
|
||||
seed,
|
||||
steps,
|
||||
shouldUseCpuNoise,
|
||||
vaePrecision,
|
||||
vae,
|
||||
} = state.generation;
|
||||
const { positivePrompt, negativePrompt } = state.controlLayers.present;
|
||||
const { width, height } = state.controlLayers.present.size;
|
||||
|
||||
const { refinerModel, refinerStart } = state.sdxl;
|
||||
|
||||
assert(model, 'No model found in state');
|
||||
|
||||
const { positiveStylePrompt, negativeStylePrompt } = getSDXLStylePrompts(state);
|
||||
|
||||
const g = new Graph(SDXL_CONTROL_LAYERS_GRAPH);
|
||||
const modelLoader = g.addNode({
|
||||
type: 'sdxl_model_loader',
|
||||
id: SDXL_MODEL_LOADER,
|
||||
model,
|
||||
});
|
||||
const posCond = g.addNode({
|
||||
type: 'sdxl_compel_prompt',
|
||||
id: POSITIVE_CONDITIONING,
|
||||
prompt: positivePrompt,
|
||||
style: positiveStylePrompt,
|
||||
});
|
||||
const posCondCollect = g.addNode({
|
||||
type: 'collect',
|
||||
id: POSITIVE_CONDITIONING_COLLECT,
|
||||
});
|
||||
const negCond = g.addNode({
|
||||
type: 'sdxl_compel_prompt',
|
||||
id: NEGATIVE_CONDITIONING,
|
||||
prompt: negativePrompt,
|
||||
style: negativeStylePrompt,
|
||||
});
|
||||
const negCondCollect = g.addNode({
|
||||
type: 'collect',
|
||||
id: NEGATIVE_CONDITIONING_COLLECT,
|
||||
});
|
||||
const noise = g.addNode({ type: 'noise', id: NOISE, seed, width, height, use_cpu: shouldUseCpuNoise });
|
||||
const denoise = g.addNode({
|
||||
type: 'denoise_latents',
|
||||
id: SDXL_DENOISE_LATENTS,
|
||||
cfg_scale,
|
||||
cfg_rescale_multiplier,
|
||||
scheduler,
|
||||
steps,
|
||||
denoising_start: 0,
|
||||
denoising_end: refinerModel ? refinerStart : 1,
|
||||
});
|
||||
const l2i = g.addNode({
|
||||
type: 'l2i',
|
||||
id: LATENTS_TO_IMAGE,
|
||||
fp32: vaePrecision === 'fp32',
|
||||
board: getBoardField(state),
|
||||
// This is the terminal node and must always save to gallery.
|
||||
is_intermediate: false,
|
||||
use_cache: false,
|
||||
});
|
||||
const vaeLoader =
|
||||
vae?.base === model.base
|
||||
? g.addNode({
|
||||
type: 'vae_loader',
|
||||
id: VAE_LOADER,
|
||||
vae_model: vae,
|
||||
})
|
||||
: null;
|
||||
|
||||
let imageOutput: Invocation<'l2i'> | Invocation<'img_nsfw'> | Invocation<'img_watermark'> = l2i;
|
||||
|
||||
g.addEdge(modelLoader, 'unet', denoise, 'unet');
|
||||
g.addEdge(modelLoader, 'clip', posCond, 'clip');
|
||||
g.addEdge(modelLoader, 'clip', negCond, 'clip');
|
||||
g.addEdge(modelLoader, 'clip2', posCond, 'clip2');
|
||||
g.addEdge(modelLoader, 'clip2', negCond, 'clip2');
|
||||
g.addEdge(posCond, 'conditioning', posCondCollect, 'item');
|
||||
g.addEdge(negCond, 'conditioning', negCondCollect, 'item');
|
||||
g.addEdge(posCondCollect, 'collection', denoise, 'positive_conditioning');
|
||||
g.addEdge(negCondCollect, 'collection', denoise, 'negative_conditioning');
|
||||
g.addEdge(noise, 'noise', denoise, 'noise');
|
||||
g.addEdge(denoise, 'latents', l2i, 'latents');
|
||||
|
||||
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
|
||||
|
||||
MetadataUtil.add(g, {
|
||||
generation_mode: 'txt2img',
|
||||
cfg_scale,
|
||||
cfg_rescale_multiplier,
|
||||
height,
|
||||
width,
|
||||
positive_prompt: positivePrompt,
|
||||
negative_prompt: negativePrompt,
|
||||
model: getModelMetadataField(modelConfig),
|
||||
seed,
|
||||
steps,
|
||||
rand_device: shouldUseCpuNoise ? 'cpu' : 'cuda',
|
||||
scheduler,
|
||||
positive_style_prompt: positiveStylePrompt,
|
||||
negative_style_prompt: negativeStylePrompt,
|
||||
vae: vae ?? undefined,
|
||||
});
|
||||
g.validate();
|
||||
|
||||
const seamless = addGenerationTabSeamless(state, g, denoise, modelLoader, vaeLoader);
|
||||
g.validate();
|
||||
|
||||
addGenerationTabSDXLLoRAs(state, g, denoise, modelLoader, seamless, posCond, negCond);
|
||||
g.validate();
|
||||
|
||||
// We might get the VAE from the main model, custom VAE, or seamless node.
|
||||
const vaeSource = seamless ?? vaeLoader ?? modelLoader;
|
||||
g.addEdge(vaeSource, 'vae', l2i, 'vae');
|
||||
|
||||
// Add Refiner if enabled
|
||||
if (refinerModel) {
|
||||
await addGenerationTabSDXLRefiner(state, g, denoise, modelLoader, seamless, posCond, negCond, l2i);
|
||||
}
|
||||
|
||||
await addGenerationTabControlLayers(
|
||||
state,
|
||||
g,
|
||||
denoise,
|
||||
posCond,
|
||||
negCond,
|
||||
posCondCollect,
|
||||
negCondCollect,
|
||||
noise,
|
||||
vaeSource
|
||||
);
|
||||
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
imageOutput = addGenerationTabNSFWChecker(g, imageOutput);
|
||||
}
|
||||
|
||||
if (state.system.shouldUseWatermarker) {
|
||||
imageOutput = addGenerationTabWatermarker(g, imageOutput);
|
||||
}
|
||||
|
||||
MetadataUtil.setMetadataReceivingNode(g, imageOutput);
|
||||
return g.getGraph();
|
||||
};
|
Loading…
Reference in New Issue
Block a user