mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): use graph builder for generation tab sdxl
This commit is contained in:
parent
5a4b050e66
commit
5425526d50
@ -2,7 +2,7 @@ import { enqueueRequested } from 'app/store/actions';
|
|||||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||||
import { isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
|
import { isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
|
||||||
import { buildGenerationTabGraph2 } from 'features/nodes/util/graph/buildGenerationTabGraph2';
|
import { buildGenerationTabGraph2 } from 'features/nodes/util/graph/buildGenerationTabGraph2';
|
||||||
import { buildGenerationTabSDXLGraph } from 'features/nodes/util/graph/buildGenerationTabSDXLGraph';
|
import { buildGenerationTabSDXLGraph2 } from 'features/nodes/util/graph/buildGenerationTabSDXLGraph2';
|
||||||
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
|
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
|
||||||
import { queueApi } from 'services/api/endpoints/queue';
|
import { queueApi } from 'services/api/endpoints/queue';
|
||||||
|
|
||||||
@ -19,7 +19,7 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
|
|||||||
let graph;
|
let graph;
|
||||||
|
|
||||||
if (model && model.base === 'sdxl') {
|
if (model && model.base === 'sdxl') {
|
||||||
graph = await buildGenerationTabSDXLGraph(state);
|
graph = await buildGenerationTabSDXLGraph2(state);
|
||||||
} else {
|
} else {
|
||||||
graph = await buildGenerationTabGraph2(state);
|
graph = await buildGenerationTabGraph2(state);
|
||||||
}
|
}
|
||||||
|
@ -118,11 +118,20 @@ export const addGenerationTabControlLayers = async (
|
|||||||
// Connect the conditioning to the collector
|
// Connect the conditioning to the collector
|
||||||
g.addEdge(regionalPosCond, 'conditioning', posCondCollect, 'item');
|
g.addEdge(regionalPosCond, 'conditioning', posCondCollect, 'item');
|
||||||
// Copy the connections to the "global" positive conditioning node to the regional cond
|
// Copy the connections to the "global" positive conditioning node to the regional cond
|
||||||
for (const edge of g.getEdgesTo(posCond, ['clip', 'mask'])) {
|
if (posCond.type === 'compel') {
|
||||||
// Clone the edge, but change the destination node to the regional conditioning node
|
for (const edge of g.getEdgesTo(posCond, ['clip', 'mask'])) {
|
||||||
const clone = deepClone(edge);
|
// Clone the edge, but change the destination node to the regional conditioning node
|
||||||
clone.destination.node_id = regionalPosCond.id;
|
const clone = deepClone(edge);
|
||||||
g.addEdgeFromObj(clone);
|
clone.destination.node_id = regionalPosCond.id;
|
||||||
|
g.addEdgeFromObj(clone);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (const edge of g.getEdgesTo(posCond, ['clip', 'clip2', 'mask'])) {
|
||||||
|
// Clone the edge, but change the destination node to the regional conditioning node
|
||||||
|
const clone = deepClone(edge);
|
||||||
|
clone.destination.node_id = regionalPosCond.id;
|
||||||
|
g.addEdgeFromObj(clone);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -147,11 +156,18 @@ export const addGenerationTabControlLayers = async (
|
|||||||
// Connect the conditioning to the collector
|
// Connect the conditioning to the collector
|
||||||
g.addEdge(regionalNegCond, 'conditioning', negCondCollect, 'item');
|
g.addEdge(regionalNegCond, 'conditioning', negCondCollect, 'item');
|
||||||
// Copy the connections to the "global" negative conditioning node to the regional cond
|
// Copy the connections to the "global" negative conditioning node to the regional cond
|
||||||
for (const edge of g.getEdgesTo(negCond, ['clip', 'mask'])) {
|
if (negCond.type === 'compel') {
|
||||||
// Clone the edge, but change the destination node to the regional conditioning node
|
for (const edge of g.getEdgesTo(negCond, ['clip', 'mask'])) {
|
||||||
const clone = deepClone(edge);
|
const clone = deepClone(edge);
|
||||||
clone.destination.node_id = regionalNegCond.id;
|
clone.destination.node_id = regionalNegCond.id;
|
||||||
g.addEdgeFromObj(clone);
|
g.addEdgeFromObj(clone);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (const edge of g.getEdgesTo(negCond, ['clip', 'clip2', 'mask'])) {
|
||||||
|
const clone = deepClone(edge);
|
||||||
|
clone.destination.node_id = regionalNegCond.id;
|
||||||
|
g.addEdgeFromObj(clone);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -184,11 +200,18 @@ export const addGenerationTabControlLayers = async (
|
|||||||
// Connect the conditioning to the negative collector
|
// Connect the conditioning to the negative collector
|
||||||
g.addEdge(regionalPosCondInverted, 'conditioning', negCondCollect, 'item');
|
g.addEdge(regionalPosCondInverted, 'conditioning', negCondCollect, 'item');
|
||||||
// Copy the connections to the "global" positive conditioning node to our regional node
|
// Copy the connections to the "global" positive conditioning node to our regional node
|
||||||
for (const edge of g.getEdgesTo(posCond, ['clip', 'mask'])) {
|
if (posCond.type === 'compel') {
|
||||||
// Clone the edge, but change the destination node to the regional conditioning node
|
for (const edge of g.getEdgesTo(posCond, ['clip', 'mask'])) {
|
||||||
const clone = deepClone(edge);
|
const clone = deepClone(edge);
|
||||||
clone.destination.node_id = regionalPosCondInverted.id;
|
clone.destination.node_id = regionalPosCondInverted.id;
|
||||||
g.addEdgeFromObj(clone);
|
g.addEdgeFromObj(clone);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (const edge of g.getEdgesTo(posCond, ['clip', 'clip2', 'mask'])) {
|
||||||
|
const clone = deepClone(edge);
|
||||||
|
clone.destination.node_id = regionalPosCondInverted.id;
|
||||||
|
g.addEdgeFromObj(clone);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -0,0 +1,75 @@
|
|||||||
|
import type { RootState } from 'app/store/store';
|
||||||
|
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||||
|
import type { Graph } from 'features/nodes/util/graph/Graph';
|
||||||
|
import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil';
|
||||||
|
import { filter, size } from 'lodash-es';
|
||||||
|
import type { Invocation, S } from 'services/api/types';
|
||||||
|
|
||||||
|
import { LORA_LOADER } from './constants';
|
||||||
|
|
||||||
|
export const addGenerationTabSDXLLoRAs = (
|
||||||
|
state: RootState,
|
||||||
|
g: Graph,
|
||||||
|
denoise: Invocation<'denoise_latents'>,
|
||||||
|
modelLoader: Invocation<'sdxl_model_loader'>,
|
||||||
|
seamless: Invocation<'seamless'> | null,
|
||||||
|
posCond: Invocation<'sdxl_compel_prompt'>,
|
||||||
|
negCond: Invocation<'sdxl_compel_prompt'>
|
||||||
|
): void => {
|
||||||
|
const enabledLoRAs = filter(state.lora.loras, (l) => l.isEnabled ?? false);
|
||||||
|
const loraCount = size(enabledLoRAs);
|
||||||
|
|
||||||
|
if (loraCount === 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const loraMetadata: S['LoRAMetadataField'][] = [];
|
||||||
|
|
||||||
|
// We will collect LoRAs into a single collection node, then pass them to the LoRA collection loader, which applies
|
||||||
|
// each LoRA to the UNet and CLIP.
|
||||||
|
const loraCollector = g.addNode({
|
||||||
|
id: `${LORA_LOADER}_collect`,
|
||||||
|
type: 'collect',
|
||||||
|
});
|
||||||
|
const loraCollectionLoader = g.addNode({
|
||||||
|
id: LORA_LOADER,
|
||||||
|
type: 'sdxl_lora_collection_loader',
|
||||||
|
});
|
||||||
|
|
||||||
|
g.addEdge(loraCollector, 'collection', loraCollectionLoader, 'loras');
|
||||||
|
// Use seamless as UNet input if it exists, otherwise use the model loader
|
||||||
|
g.addEdge(seamless ?? modelLoader, 'unet', loraCollectionLoader, 'unet');
|
||||||
|
g.addEdge(modelLoader, 'clip', loraCollectionLoader, 'clip');
|
||||||
|
g.addEdge(modelLoader, 'clip2', loraCollectionLoader, 'clip2');
|
||||||
|
// Reroute UNet & CLIP connections through the LoRA collection loader
|
||||||
|
g.deleteEdgesTo(denoise, ['unet']);
|
||||||
|
g.deleteEdgesTo(posCond, ['clip', 'clip2']);
|
||||||
|
g.deleteEdgesTo(negCond, ['clip', 'clip2']);
|
||||||
|
g.addEdge(loraCollectionLoader, 'unet', denoise, 'unet');
|
||||||
|
g.addEdge(loraCollectionLoader, 'clip', posCond, 'clip');
|
||||||
|
g.addEdge(loraCollectionLoader, 'clip', negCond, 'clip');
|
||||||
|
g.addEdge(loraCollectionLoader, 'clip2', posCond, 'clip2');
|
||||||
|
g.addEdge(loraCollectionLoader, 'clip2', negCond, 'clip2');
|
||||||
|
|
||||||
|
for (const lora of enabledLoRAs) {
|
||||||
|
const { weight } = lora;
|
||||||
|
const { key } = lora.model;
|
||||||
|
const parsedModel = zModelIdentifierField.parse(lora.model);
|
||||||
|
|
||||||
|
const loraSelector = g.addNode({
|
||||||
|
type: 'lora_selector',
|
||||||
|
id: `${LORA_LOADER}_${key}`,
|
||||||
|
lora: parsedModel,
|
||||||
|
weight,
|
||||||
|
});
|
||||||
|
|
||||||
|
loraMetadata.push({
|
||||||
|
model: parsedModel,
|
||||||
|
weight,
|
||||||
|
});
|
||||||
|
|
||||||
|
g.addEdge(loraSelector, 'lora', loraCollector, 'item');
|
||||||
|
}
|
||||||
|
|
||||||
|
MetadataUtil.add(g, { loras: loraMetadata });
|
||||||
|
};
|
@ -0,0 +1,104 @@
|
|||||||
|
import type { RootState } from 'app/store/store';
|
||||||
|
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||||
|
import type { Graph } from 'features/nodes/util/graph/Graph';
|
||||||
|
import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil';
|
||||||
|
import type { Invocation } from 'services/api/types';
|
||||||
|
import { isRefinerMainModelModelConfig } from 'services/api/types';
|
||||||
|
import { assert } from 'tsafe';
|
||||||
|
|
||||||
|
import {
|
||||||
|
SDXL_REFINER_DENOISE_LATENTS,
|
||||||
|
SDXL_REFINER_MODEL_LOADER,
|
||||||
|
SDXL_REFINER_NEGATIVE_CONDITIONING,
|
||||||
|
SDXL_REFINER_POSITIVE_CONDITIONING,
|
||||||
|
SDXL_REFINER_SEAMLESS,
|
||||||
|
} from './constants';
|
||||||
|
import { getModelMetadataField } from './metadata';
|
||||||
|
|
||||||
|
export const addGenerationTabSDXLRefiner = async (
|
||||||
|
state: RootState,
|
||||||
|
g: Graph,
|
||||||
|
denoise: Invocation<'denoise_latents'>,
|
||||||
|
modelLoader: Invocation<'sdxl_model_loader'>,
|
||||||
|
seamless: Invocation<'seamless'> | null,
|
||||||
|
posCond: Invocation<'sdxl_compel_prompt'>,
|
||||||
|
negCond: Invocation<'sdxl_compel_prompt'>,
|
||||||
|
l2i: Invocation<'l2i'>
|
||||||
|
): Promise<void> => {
|
||||||
|
const {
|
||||||
|
refinerModel,
|
||||||
|
refinerPositiveAestheticScore,
|
||||||
|
refinerNegativeAestheticScore,
|
||||||
|
refinerSteps,
|
||||||
|
refinerScheduler,
|
||||||
|
refinerCFGScale,
|
||||||
|
refinerStart,
|
||||||
|
} = state.sdxl;
|
||||||
|
|
||||||
|
assert(refinerModel, 'No refiner model found in state');
|
||||||
|
|
||||||
|
const modelConfig = await fetchModelConfigWithTypeGuard(refinerModel.key, isRefinerMainModelModelConfig);
|
||||||
|
|
||||||
|
// We need to re-route latents to the refiner
|
||||||
|
g.deleteEdgesFrom(denoise, ['latents']);
|
||||||
|
// Latents will now come from refiner - delete edges to the l2i VAE decode
|
||||||
|
g.deleteEdgesTo(l2i, ['latents']);
|
||||||
|
|
||||||
|
const refinerModelLoader = g.addNode({
|
||||||
|
type: 'sdxl_refiner_model_loader',
|
||||||
|
id: SDXL_REFINER_MODEL_LOADER,
|
||||||
|
model: refinerModel,
|
||||||
|
});
|
||||||
|
const refinerPosCond = g.addNode({
|
||||||
|
type: 'sdxl_refiner_compel_prompt',
|
||||||
|
id: SDXL_REFINER_POSITIVE_CONDITIONING,
|
||||||
|
style: posCond.style,
|
||||||
|
aesthetic_score: refinerPositiveAestheticScore,
|
||||||
|
});
|
||||||
|
const refinerNegCond = g.addNode({
|
||||||
|
type: 'sdxl_refiner_compel_prompt',
|
||||||
|
id: SDXL_REFINER_NEGATIVE_CONDITIONING,
|
||||||
|
style: negCond.style,
|
||||||
|
aesthetic_score: refinerNegativeAestheticScore,
|
||||||
|
});
|
||||||
|
const refinerDenoise = g.addNode({
|
||||||
|
type: 'denoise_latents',
|
||||||
|
id: SDXL_REFINER_DENOISE_LATENTS,
|
||||||
|
cfg_scale: refinerCFGScale,
|
||||||
|
steps: refinerSteps,
|
||||||
|
scheduler: refinerScheduler,
|
||||||
|
denoising_start: refinerStart,
|
||||||
|
denoising_end: 1,
|
||||||
|
});
|
||||||
|
|
||||||
|
if (seamless) {
|
||||||
|
const refinerSeamless = g.addNode({
|
||||||
|
id: SDXL_REFINER_SEAMLESS,
|
||||||
|
type: 'seamless',
|
||||||
|
seamless_x: seamless.seamless_x,
|
||||||
|
seamless_y: seamless.seamless_y,
|
||||||
|
});
|
||||||
|
g.addEdge(refinerModelLoader, 'unet', refinerSeamless, 'unet');
|
||||||
|
g.addEdge(refinerModelLoader, 'vae', refinerSeamless, 'vae');
|
||||||
|
g.addEdge(refinerSeamless, 'unet', refinerDenoise, 'unet');
|
||||||
|
} else {
|
||||||
|
g.addEdge(refinerModelLoader, 'unet', refinerDenoise, 'unet');
|
||||||
|
}
|
||||||
|
|
||||||
|
g.addEdge(refinerModelLoader, 'clip2', refinerPosCond, 'clip2');
|
||||||
|
g.addEdge(refinerModelLoader, 'clip2', refinerNegCond, 'clip2');
|
||||||
|
g.addEdge(refinerPosCond, 'conditioning', refinerDenoise, 'positive_conditioning');
|
||||||
|
g.addEdge(refinerNegCond, 'conditioning', refinerDenoise, 'negative_conditioning');
|
||||||
|
g.addEdge(denoise, 'latents', refinerDenoise, 'latents');
|
||||||
|
g.addEdge(refinerDenoise, 'latents', l2i, 'latents');
|
||||||
|
|
||||||
|
MetadataUtil.add(g, {
|
||||||
|
refiner_model: getModelMetadataField(modelConfig),
|
||||||
|
refiner_positive_aesthetic_score: refinerPositiveAestheticScore,
|
||||||
|
refiner_negative_aesthetic_score: refinerNegativeAestheticScore,
|
||||||
|
refiner_cfg_scale: refinerCFGScale,
|
||||||
|
refiner_scheduler: refinerScheduler,
|
||||||
|
refiner_start: refinerStart,
|
||||||
|
refiner_steps: refinerSteps,
|
||||||
|
});
|
||||||
|
};
|
@ -0,0 +1,178 @@
|
|||||||
|
import type { RootState } from 'app/store/store';
|
||||||
|
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||||
|
import { addGenerationTabControlLayers } from 'features/nodes/util/graph/addGenerationTabControlLayers';
|
||||||
|
import { addGenerationTabNSFWChecker } from 'features/nodes/util/graph/addGenerationTabNSFWChecker';
|
||||||
|
import { addGenerationTabSDXLLoRAs } from 'features/nodes/util/graph/addGenerationTabSDXLLoRAs';
|
||||||
|
import { addGenerationTabSDXLRefiner } from 'features/nodes/util/graph/addGenerationTabSDXLRefiner';
|
||||||
|
import { addGenerationTabSeamless } from 'features/nodes/util/graph/addGenerationTabSeamless';
|
||||||
|
import { addGenerationTabWatermarker } from 'features/nodes/util/graph/addGenerationTabWatermarker';
|
||||||
|
import { Graph } from 'features/nodes/util/graph/Graph';
|
||||||
|
import { MetadataUtil } from 'features/nodes/util/graph/MetadataUtil';
|
||||||
|
import type { Invocation, NonNullableGraph } from 'services/api/types';
|
||||||
|
import { isNonRefinerMainModelConfig } from 'services/api/types';
|
||||||
|
import { assert } from 'tsafe';
|
||||||
|
|
||||||
|
import {
|
||||||
|
LATENTS_TO_IMAGE,
|
||||||
|
NEGATIVE_CONDITIONING,
|
||||||
|
NEGATIVE_CONDITIONING_COLLECT,
|
||||||
|
NOISE,
|
||||||
|
POSITIVE_CONDITIONING,
|
||||||
|
POSITIVE_CONDITIONING_COLLECT,
|
||||||
|
SDXL_CONTROL_LAYERS_GRAPH,
|
||||||
|
SDXL_DENOISE_LATENTS,
|
||||||
|
SDXL_MODEL_LOADER,
|
||||||
|
VAE_LOADER,
|
||||||
|
} from './constants';
|
||||||
|
import { getBoardField, getSDXLStylePrompts } from './graphBuilderUtils';
|
||||||
|
import { getModelMetadataField } from './metadata';
|
||||||
|
|
||||||
|
export const buildGenerationTabSDXLGraph2 = async (state: RootState): Promise<NonNullableGraph> => {
|
||||||
|
const {
|
||||||
|
model,
|
||||||
|
cfgScale: cfg_scale,
|
||||||
|
cfgRescaleMultiplier: cfg_rescale_multiplier,
|
||||||
|
scheduler,
|
||||||
|
seed,
|
||||||
|
steps,
|
||||||
|
shouldUseCpuNoise,
|
||||||
|
vaePrecision,
|
||||||
|
vae,
|
||||||
|
} = state.generation;
|
||||||
|
const { positivePrompt, negativePrompt } = state.controlLayers.present;
|
||||||
|
const { width, height } = state.controlLayers.present.size;
|
||||||
|
|
||||||
|
const { refinerModel, refinerStart } = state.sdxl;
|
||||||
|
|
||||||
|
assert(model, 'No model found in state');
|
||||||
|
|
||||||
|
const { positiveStylePrompt, negativeStylePrompt } = getSDXLStylePrompts(state);
|
||||||
|
|
||||||
|
const g = new Graph(SDXL_CONTROL_LAYERS_GRAPH);
|
||||||
|
const modelLoader = g.addNode({
|
||||||
|
type: 'sdxl_model_loader',
|
||||||
|
id: SDXL_MODEL_LOADER,
|
||||||
|
model,
|
||||||
|
});
|
||||||
|
const posCond = g.addNode({
|
||||||
|
type: 'sdxl_compel_prompt',
|
||||||
|
id: POSITIVE_CONDITIONING,
|
||||||
|
prompt: positivePrompt,
|
||||||
|
style: positiveStylePrompt,
|
||||||
|
});
|
||||||
|
const posCondCollect = g.addNode({
|
||||||
|
type: 'collect',
|
||||||
|
id: POSITIVE_CONDITIONING_COLLECT,
|
||||||
|
});
|
||||||
|
const negCond = g.addNode({
|
||||||
|
type: 'sdxl_compel_prompt',
|
||||||
|
id: NEGATIVE_CONDITIONING,
|
||||||
|
prompt: negativePrompt,
|
||||||
|
style: negativeStylePrompt,
|
||||||
|
});
|
||||||
|
const negCondCollect = g.addNode({
|
||||||
|
type: 'collect',
|
||||||
|
id: NEGATIVE_CONDITIONING_COLLECT,
|
||||||
|
});
|
||||||
|
const noise = g.addNode({ type: 'noise', id: NOISE, seed, width, height, use_cpu: shouldUseCpuNoise });
|
||||||
|
const denoise = g.addNode({
|
||||||
|
type: 'denoise_latents',
|
||||||
|
id: SDXL_DENOISE_LATENTS,
|
||||||
|
cfg_scale,
|
||||||
|
cfg_rescale_multiplier,
|
||||||
|
scheduler,
|
||||||
|
steps,
|
||||||
|
denoising_start: 0,
|
||||||
|
denoising_end: refinerModel ? refinerStart : 1,
|
||||||
|
});
|
||||||
|
const l2i = g.addNode({
|
||||||
|
type: 'l2i',
|
||||||
|
id: LATENTS_TO_IMAGE,
|
||||||
|
fp32: vaePrecision === 'fp32',
|
||||||
|
board: getBoardField(state),
|
||||||
|
// This is the terminal node and must always save to gallery.
|
||||||
|
is_intermediate: false,
|
||||||
|
use_cache: false,
|
||||||
|
});
|
||||||
|
const vaeLoader =
|
||||||
|
vae?.base === model.base
|
||||||
|
? g.addNode({
|
||||||
|
type: 'vae_loader',
|
||||||
|
id: VAE_LOADER,
|
||||||
|
vae_model: vae,
|
||||||
|
})
|
||||||
|
: null;
|
||||||
|
|
||||||
|
let imageOutput: Invocation<'l2i'> | Invocation<'img_nsfw'> | Invocation<'img_watermark'> = l2i;
|
||||||
|
|
||||||
|
g.addEdge(modelLoader, 'unet', denoise, 'unet');
|
||||||
|
g.addEdge(modelLoader, 'clip', posCond, 'clip');
|
||||||
|
g.addEdge(modelLoader, 'clip', negCond, 'clip');
|
||||||
|
g.addEdge(modelLoader, 'clip2', posCond, 'clip2');
|
||||||
|
g.addEdge(modelLoader, 'clip2', negCond, 'clip2');
|
||||||
|
g.addEdge(posCond, 'conditioning', posCondCollect, 'item');
|
||||||
|
g.addEdge(negCond, 'conditioning', negCondCollect, 'item');
|
||||||
|
g.addEdge(posCondCollect, 'collection', denoise, 'positive_conditioning');
|
||||||
|
g.addEdge(negCondCollect, 'collection', denoise, 'negative_conditioning');
|
||||||
|
g.addEdge(noise, 'noise', denoise, 'noise');
|
||||||
|
g.addEdge(denoise, 'latents', l2i, 'latents');
|
||||||
|
|
||||||
|
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
|
||||||
|
|
||||||
|
MetadataUtil.add(g, {
|
||||||
|
generation_mode: 'txt2img',
|
||||||
|
cfg_scale,
|
||||||
|
cfg_rescale_multiplier,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
positive_prompt: positivePrompt,
|
||||||
|
negative_prompt: negativePrompt,
|
||||||
|
model: getModelMetadataField(modelConfig),
|
||||||
|
seed,
|
||||||
|
steps,
|
||||||
|
rand_device: shouldUseCpuNoise ? 'cpu' : 'cuda',
|
||||||
|
scheduler,
|
||||||
|
positive_style_prompt: positiveStylePrompt,
|
||||||
|
negative_style_prompt: negativeStylePrompt,
|
||||||
|
vae: vae ?? undefined,
|
||||||
|
});
|
||||||
|
g.validate();
|
||||||
|
|
||||||
|
const seamless = addGenerationTabSeamless(state, g, denoise, modelLoader, vaeLoader);
|
||||||
|
g.validate();
|
||||||
|
|
||||||
|
addGenerationTabSDXLLoRAs(state, g, denoise, modelLoader, seamless, posCond, negCond);
|
||||||
|
g.validate();
|
||||||
|
|
||||||
|
// We might get the VAE from the main model, custom VAE, or seamless node.
|
||||||
|
const vaeSource = seamless ?? vaeLoader ?? modelLoader;
|
||||||
|
g.addEdge(vaeSource, 'vae', l2i, 'vae');
|
||||||
|
|
||||||
|
// Add Refiner if enabled
|
||||||
|
if (refinerModel) {
|
||||||
|
await addGenerationTabSDXLRefiner(state, g, denoise, modelLoader, seamless, posCond, negCond, l2i);
|
||||||
|
}
|
||||||
|
|
||||||
|
await addGenerationTabControlLayers(
|
||||||
|
state,
|
||||||
|
g,
|
||||||
|
denoise,
|
||||||
|
posCond,
|
||||||
|
negCond,
|
||||||
|
posCondCollect,
|
||||||
|
negCondCollect,
|
||||||
|
noise,
|
||||||
|
vaeSource
|
||||||
|
);
|
||||||
|
|
||||||
|
if (state.system.shouldUseNSFWChecker) {
|
||||||
|
imageOutput = addGenerationTabNSFWChecker(g, imageOutput);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (state.system.shouldUseWatermarker) {
|
||||||
|
imageOutput = addGenerationTabWatermarker(g, imageOutput);
|
||||||
|
}
|
||||||
|
|
||||||
|
MetadataUtil.setMetadataReceivingNode(g, imageOutput);
|
||||||
|
return g.getGraph();
|
||||||
|
};
|
Loading…
Reference in New Issue
Block a user