feat(ui): txt2img & img2img graphs

This commit is contained in:
psychedelicious 2024-06-24 15:19:51 +10:00
parent 6a4a5ece74
commit a6ca17b19d
9 changed files with 562 additions and 40 deletions

View File

@ -3,9 +3,10 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'
import { getNodeManager } from 'features/controlLayers/konva/nodeManager'; import { getNodeManager } from 'features/controlLayers/konva/nodeManager';
import { isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice'; import { isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig'; import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
import { buildGenerationTabGraph } from 'features/nodes/util/graph/generation/buildGenerationTabGraph'; import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
import { buildGenerationTabSDXLGraph } from 'features/nodes/util/graph/generation/buildGenerationTabSDXLGraph'; import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph';
import { queueApi } from 'services/api/endpoints/queue'; import { queueApi } from 'services/api/endpoints/queue';
import { assert } from 'tsafe';
export const addEnqueueRequestedLinear = (startAppListening: AppStartListening) => { export const addEnqueueRequestedLinear = (startAppListening: AppStartListening) => {
startAppListening({ startAppListening({
@ -20,13 +21,15 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
let graph; let graph;
const manager = getNodeManager(); const manager = getNodeManager();
assert(model, 'No model found in state');
const base = model.base;
console.log('generation mode', manager.util.getGenerationMode()); if (base === 'sdxl') {
graph = await buildSDXLGraph(state, manager);
if (model?.base === 'sdxl') { } else if (base === 'sd-1' || base === 'sd-2') {
graph = await buildGenerationTabSDXLGraph(state, manager); graph = await buildSD1Graph(state, manager);
} else { } else {
graph = await buildGenerationTabGraph(state, manager); assert(false, `No graph builders for base ${base}`);
} }
const batchConfig = prepareLinearUIBatch(state, graph, prepend); const batchConfig = prepareLinearUIBatch(state, graph, prepend);

View File

@ -10,7 +10,7 @@ import type { Invocation } from 'services/api/types';
*/ */
export const addNSFWChecker = ( export const addNSFWChecker = (
g: Graph, g: Graph,
imageOutput: Invocation<'l2i'> | Invocation<'img_nsfw'> | Invocation<'img_watermark'> imageOutput: Invocation<'l2i'> | Invocation<'img_nsfw'> | Invocation<'img_watermark'> | Invocation<'img_resize'>
): Invocation<'img_nsfw'> => { ): Invocation<'img_nsfw'> => {
const nsfw = g.addNode({ const nsfw = g.addNode({
id: NSFW_CHECKER, id: NSFW_CHECKER,

View File

@ -4,7 +4,7 @@ import { LORA_LOADER } from 'features/nodes/util/graph/constants';
import type { Graph } from 'features/nodes/util/graph/generation/Graph'; import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { Invocation, S } from 'services/api/types'; import type { Invocation, S } from 'services/api/types';
export const addSDXLLoRas = ( export const addSDXLLoRAs = (
state: RootState, state: RootState,
g: Graph, g: Graph,
denoise: Invocation<'denoise_latents'> | Invocation<'tiled_multi_diffusion_denoise_latents'>, denoise: Invocation<'denoise_latents'> | Invocation<'tiled_multi_diffusion_denoise_latents'>,

View File

@ -10,7 +10,7 @@ import type { Invocation } from 'services/api/types';
*/ */
export const addWatermarker = ( export const addWatermarker = (
g: Graph, g: Graph,
imageOutput: Invocation<'l2i'> | Invocation<'img_nsfw'> | Invocation<'img_watermark'> imageOutput: Invocation<'l2i'> | Invocation<'img_nsfw'> | Invocation<'img_watermark'> | Invocation<'img_resize'>
): Invocation<'img_watermark'> => { ): Invocation<'img_watermark'> => {
const watermark = g.addNode({ const watermark = g.addNode({
id: WATERMARKER, id: WATERMARKER,

View File

@ -16,22 +16,23 @@ import {
import { addControlAdapters } from 'features/nodes/util/graph/generation/addControlAdapters'; import { addControlAdapters } from 'features/nodes/util/graph/generation/addControlAdapters';
import { addIPAdapters } from 'features/nodes/util/graph/generation/addIPAdapters'; import { addIPAdapters } from 'features/nodes/util/graph/generation/addIPAdapters';
import { addNSFWChecker } from 'features/nodes/util/graph/generation/addNSFWChecker'; import { addNSFWChecker } from 'features/nodes/util/graph/generation/addNSFWChecker';
import { addSDXLLoRas } from 'features/nodes/util/graph/generation/addSDXLLoRAs'; import { addSDXLLoRAs } from 'features/nodes/util/graph/generation/addSDXLLoRAs';
import { addSDXLRefiner } from 'features/nodes/util/graph/generation/addSDXLRefiner'; import { addSDXLRefiner } from 'features/nodes/util/graph/generation/addSDXLRefiner';
import { addSeamless } from 'features/nodes/util/graph/generation/addSeamless'; import { addSeamless } from 'features/nodes/util/graph/generation/addSeamless';
import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker'; import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker';
import { Graph } from 'features/nodes/util/graph/generation/Graph'; import { Graph } from 'features/nodes/util/graph/generation/Graph';
import { getBoardField, getPresetModifiedPrompts } from 'features/nodes/util/graph/graphBuilderUtils'; import { getBoardField, getPresetModifiedPrompts , getSizes } from 'features/nodes/util/graph/graphBuilderUtils';
import type { Invocation, NonNullableGraph } from 'services/api/types'; import type { Invocation, NonNullableGraph } from 'services/api/types';
import { isNonRefinerMainModelConfig } from 'services/api/types'; import { isNonRefinerMainModelConfig } from 'services/api/types';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
import { addRegions } from './addRegions'; import { addRegions } from './addRegions';
export const buildGenerationTabSDXLGraph = async ( export const buildImageToImageSDXLGraph = async (
state: RootState, state: RootState,
manager: KonvaNodeManager manager: KonvaNodeManager
): Promise<NonNullableGraph> => { ): Promise<NonNullableGraph> => {
const { bbox, params } = state.canvasV2;
const { const {
model, model,
cfgScale: cfg_scale, cfgScale: cfg_scale,
@ -42,17 +43,17 @@ export const buildGenerationTabSDXLGraph = async (
shouldUseCpuNoise, shouldUseCpuNoise,
vaePrecision, vaePrecision,
vae, vae,
positivePrompt,
negativePrompt,
refinerModel, refinerModel,
refinerStart, refinerStart,
img2imgStrength, img2imgStrength,
} = state.canvasV2.params; } = params;
const { width, height } = state.canvasV2.bbox;
assert(model, 'No model found in state'); assert(model, 'No model found in state');
const { positivePrompt, negativePrompt, positiveStylePrompt, negativeStylePrompt } = getPresetModifiedPrompts(state); const { positivePrompt, negativePrompt, positiveStylePrompt, negativeStylePrompt } = getPresetModifiedPrompts(state);
const { originalSize, scaledSize } = getSizes(bbox);
const g = new Graph(SDXL_CONTROL_LAYERS_GRAPH); const g = new Graph(SDXL_CONTROL_LAYERS_GRAPH);
const modelLoader = g.addNode({ const modelLoader = g.addNode({
@ -80,8 +81,14 @@ export const buildGenerationTabSDXLGraph = async (
type: 'collect', type: 'collect',
id: NEGATIVE_CONDITIONING_COLLECT, id: NEGATIVE_CONDITIONING_COLLECT,
}); });
const noise = g.addNode({ type: 'noise', id: NOISE, seed, width, height, use_cpu: shouldUseCpuNoise }); const noise = g.addNode({
const i2l = g.addNode({ type: 'i2l', id: 'i2l' }); type: 'noise',
id: NOISE,
seed,
width: scaledSize.width,
height: scaledSize.height,
use_cpu: shouldUseCpuNoise,
});
const denoise = g.addNode({ const denoise = g.addNode({
type: 'denoise_latents', type: 'denoise_latents',
id: SDXL_DENOISE_LATENTS, id: SDXL_DENOISE_LATENTS,
@ -110,7 +117,8 @@ export const buildGenerationTabSDXLGraph = async (
}) })
: null; : null;
let imageOutput: Invocation<'l2i'> | Invocation<'img_nsfw'> | Invocation<'img_watermark'> = l2i; let imageOutput: Invocation<'l2i'> | Invocation<'img_nsfw'> | Invocation<'img_watermark'> | Invocation<'img_resize'> =
l2i;
g.addEdge(modelLoader, 'unet', denoise, 'unet'); g.addEdge(modelLoader, 'unet', denoise, 'unet');
g.addEdge(modelLoader, 'clip', posCond, 'clip'); g.addEdge(modelLoader, 'clip', posCond, 'clip');
@ -122,7 +130,6 @@ export const buildGenerationTabSDXLGraph = async (
g.addEdge(posCondCollect, 'collection', denoise, 'positive_conditioning'); g.addEdge(posCondCollect, 'collection', denoise, 'positive_conditioning');
g.addEdge(negCondCollect, 'collection', denoise, 'negative_conditioning'); g.addEdge(negCondCollect, 'collection', denoise, 'negative_conditioning');
g.addEdge(noise, 'noise', denoise, 'noise'); g.addEdge(noise, 'noise', denoise, 'noise');
g.addEdge(i2l, 'latents', denoise, 'latents');
g.addEdge(denoise, 'latents', l2i, 'latents'); g.addEdge(denoise, 'latents', l2i, 'latents');
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig); const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
@ -132,8 +139,8 @@ export const buildGenerationTabSDXLGraph = async (
generation_mode: 'sdxl_txt2img', generation_mode: 'sdxl_txt2img',
cfg_scale, cfg_scale,
cfg_rescale_multiplier, cfg_rescale_multiplier,
height, width: scaledSize.width,
width, height: scaledSize.height,
positive_prompt: positivePrompt, positive_prompt: positivePrompt,
negative_prompt: negativePrompt, negative_prompt: negativePrompt,
model: Graph.getModelMetadataField(modelConfig), model: Graph.getModelMetadataField(modelConfig),
@ -148,18 +155,19 @@ export const buildGenerationTabSDXLGraph = async (
const seamless = addSeamless(state, g, denoise, modelLoader, vaeLoader); const seamless = addSeamless(state, g, denoise, modelLoader, vaeLoader);
addSDXLLoRas(state, g, denoise, modelLoader, seamless, posCond, negCond); addSDXLLoRAs(state, g, denoise, modelLoader, seamless, posCond, negCond);
// We might get the VAE from the main model, custom VAE, or seamless node. // We might get the VAE from the main model, custom VAE, or seamless node.
const vaeSource = seamless ?? vaeLoader ?? modelLoader; const vaeSource = seamless ?? vaeLoader ?? modelLoader;
g.addEdge(vaeSource, 'vae', l2i, 'vae'); g.addEdge(vaeSource, 'vae', l2i, 'vae');
g.addEdge(vaeSource, 'vae', i2l, 'vae');
// Add Refiner if enabled // Add Refiner if enabled
if (refinerModel) { if (refinerModel) {
await addSDXLRefiner(state, g, denoise, seamless, posCond, negCond, l2i); await addSDXLRefiner(state, g, denoise, seamless, posCond, negCond, l2i);
} }
const _addedCAs = addControlAdapters(state.canvasV2.controlAdapters.entities, g, denoise, modelConfig.base); const _addedCAs = addControlAdapters(state.canvasV2.controlAdapters.entities, g, denoise, modelConfig.base);
const _addedIPAs = addIPAdapters(state.canvasV2.ipAdapters.entities, g, denoise, modelConfig.base); const _addedIPAs = addIPAdapters(state.canvasV2.ipAdapters.entities, g, denoise, modelConfig.base);
const _addedRegions = await addRegions( const _addedRegions = await addRegions(
@ -175,9 +183,6 @@ export const buildGenerationTabSDXLGraph = async (
posCondCollect, posCondCollect,
negCondCollect negCondCollect
); );
const { image_name } = await manager.util.getImageSourceImage({ bbox: state.canvasV2.bbox, preview: true });
await manager.util.getInpaintMaskImage({ bbox: state.canvasV2.bbox, preview: true });
i2l.image = { image_name };
if (state.system.shouldUseNSFWChecker) { if (state.system.shouldUseNSFWChecker) {
imageOutput = addNSFWChecker(g, imageOutput); imageOutput = addNSFWChecker(g, imageOutput);

View File

@ -0,0 +1,246 @@
import type { RootState } from 'app/store/store';
import type { KonvaNodeManager } from 'features/controlLayers/konva/nodeManager';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import {
CLIP_SKIP,
CONTROL_LAYERS_GRAPH,
DENOISE_LATENTS,
LATENTS_TO_IMAGE,
MAIN_MODEL_LOADER,
NEGATIVE_CONDITIONING,
NEGATIVE_CONDITIONING_COLLECT,
NOISE,
POSITIVE_CONDITIONING,
POSITIVE_CONDITIONING_COLLECT,
VAE_LOADER,
} from 'features/nodes/util/graph/constants';
import { addControlAdapters } from 'features/nodes/util/graph/generation/addControlAdapters';
// import { addHRF } from 'features/nodes/util/graph/generation/addHRF';
import { addIPAdapters } from 'features/nodes/util/graph/generation/addIPAdapters';
import { addLoRAs } from 'features/nodes/util/graph/generation/addLoRAs';
import { addNSFWChecker } from 'features/nodes/util/graph/generation/addNSFWChecker';
import { addSeamless } from 'features/nodes/util/graph/generation/addSeamless';
import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker';
import type { GraphType } from 'features/nodes/util/graph/generation/Graph';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
import { getBoardField, getSizes } from 'features/nodes/util/graph/graphBuilderUtils';
import { isEqual, pick } from 'lodash-es';
import type { Invocation } from 'services/api/types';
import { isNonRefinerMainModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
import { addRegions } from './addRegions';
export const buildSD1Graph = async (state: RootState, manager: KonvaNodeManager): Promise<GraphType> => {
const generationMode = manager.util.getGenerationMode();
const { bbox, params } = state.canvasV2;
const {
model,
cfgScale: cfg_scale,
cfgRescaleMultiplier: cfg_rescale_multiplier,
scheduler,
steps,
clipSkip: skipped_layers,
shouldUseCpuNoise,
vaePrecision,
seed,
vae,
positivePrompt,
negativePrompt,
img2imgStrength,
} = params;
assert(model, 'No model found in state');
const { originalSize, scaledSize } = getSizes(bbox);
const g = new Graph(CONTROL_LAYERS_GRAPH);
const modelLoader = g.addNode({
type: 'main_model_loader',
id: MAIN_MODEL_LOADER,
model,
});
const clipSkip = g.addNode({
type: 'clip_skip',
id: CLIP_SKIP,
skipped_layers,
});
const posCond = g.addNode({
type: 'compel',
id: POSITIVE_CONDITIONING,
prompt: positivePrompt,
});
const posCondCollect = g.addNode({
type: 'collect',
id: POSITIVE_CONDITIONING_COLLECT,
});
const negCond = g.addNode({
type: 'compel',
id: NEGATIVE_CONDITIONING,
prompt: negativePrompt,
});
const negCondCollect = g.addNode({
type: 'collect',
id: NEGATIVE_CONDITIONING_COLLECT,
});
const noise = g.addNode({
type: 'noise',
id: NOISE,
seed,
width: scaledSize.width,
height: scaledSize.height,
use_cpu: shouldUseCpuNoise,
});
const denoise = g.addNode({
type: 'denoise_latents',
id: DENOISE_LATENTS,
cfg_scale,
cfg_rescale_multiplier,
scheduler,
steps,
denoising_start: 0,
denoising_end: 1,
});
const l2i = g.addNode({
type: 'l2i',
id: LATENTS_TO_IMAGE,
fp32: vaePrecision === 'fp32',
board: getBoardField(state),
// This is the terminal node and must always save to gallery.
is_intermediate: false,
use_cache: false,
});
const vaeLoader =
vae?.base === model.base
? g.addNode({
type: 'vae_loader',
id: VAE_LOADER,
vae_model: vae,
})
: null;
let imageOutput: Invocation<'l2i'> | Invocation<'img_nsfw'> | Invocation<'img_watermark'> | Invocation<'img_resize'> =
l2i;
g.addEdge(modelLoader, 'unet', denoise, 'unet');
g.addEdge(modelLoader, 'clip', clipSkip, 'clip');
g.addEdge(clipSkip, 'clip', posCond, 'clip');
g.addEdge(clipSkip, 'clip', negCond, 'clip');
g.addEdge(posCond, 'conditioning', posCondCollect, 'item');
g.addEdge(negCond, 'conditioning', negCondCollect, 'item');
g.addEdge(posCondCollect, 'collection', denoise, 'positive_conditioning');
g.addEdge(negCondCollect, 'collection', denoise, 'negative_conditioning');
g.addEdge(noise, 'noise', denoise, 'noise');
g.addEdge(denoise, 'latents', l2i, 'latents');
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
assert(modelConfig.base === 'sd-1' || modelConfig.base === 'sd-2');
g.upsertMetadata({
generation_mode: 'txt2img',
cfg_scale,
cfg_rescale_multiplier,
width: scaledSize.width,
height: scaledSize.height,
positive_prompt: positivePrompt,
negative_prompt: negativePrompt,
model: Graph.getModelMetadataField(modelConfig),
seed,
steps,
rand_device: shouldUseCpuNoise ? 'cpu' : 'cuda',
scheduler,
clip_skip: skipped_layers,
vae: vae ?? undefined,
});
const seamless = addSeamless(state, g, denoise, modelLoader, vaeLoader);
addLoRAs(state, g, denoise, modelLoader, seamless, clipSkip, posCond, negCond);
// We might get the VAE from the main model, custom VAE, or seamless node.
const vaeSource = seamless ?? vaeLoader ?? modelLoader;
g.addEdge(vaeSource, 'vae', l2i, 'vae');
if (generationMode === 'txt2img') {
if (!isEqual(scaledSize, originalSize)) {
// We are using scaled bbox and need to resize the output image back to the original size.
imageOutput = g.addNode({
id: 'img_resize',
type: 'img_resize',
...originalSize,
is_intermediate: false,
use_cache: false,
});
g.addEdge(l2i, 'image', imageOutput, 'image');
}
} else if (generationMode === 'img2img') {
const { image_name } = await manager.util.getImageSourceImage({
bbox: pick(bbox, ['x', 'y', 'width', 'height']),
preview: true,
});
denoise.denoising_start = 1 - img2imgStrength;
if (!isEqual(scaledSize, originalSize)) {
// We are using scaled bbox and need to resize the output image back to the original size.
const initialImageResize = g.addNode({
id: 'initial_image_resize',
type: 'img_resize',
...scaledSize,
image: { image_name },
});
const i2l = g.addNode({ id: 'i2l', type: 'i2l' });
g.addEdge(vaeSource, 'vae', i2l, 'vae');
g.addEdge(initialImageResize, 'image', i2l, 'image');
g.addEdge(i2l, 'latents', denoise, 'latents');
imageOutput = g.addNode({
id: 'img_resize',
type: 'img_resize',
...originalSize,
is_intermediate: false,
use_cache: false,
});
g.addEdge(l2i, 'image', imageOutput, 'image');
} else {
const i2l = g.addNode({ id: 'i2l', type: 'i2l', image: { image_name } });
g.addEdge(vaeSource, 'vae', i2l, 'vae');
g.addEdge(i2l, 'latents', denoise, 'latents');
}
}
const _addedCAs = addControlAdapters(state.canvasV2.controlAdapters.entities, g, denoise, modelConfig.base);
const _addedIPAs = addIPAdapters(state.canvasV2.ipAdapters.entities, g, denoise, modelConfig.base);
const _addedRegions = await addRegions(
manager,
state.canvasV2.regions.entities,
g,
state.canvasV2.document,
state.canvasV2.bbox,
modelConfig.base,
denoise,
posCond,
negCond,
posCondCollect,
negCondCollect
);
// const isHRFAllowed = !addedLayers.some((l) => isInitialImageLayer(l) || isRegionalGuidanceLayer(l));
// if (isHRFAllowed && state.hrf.hrfEnabled) {
// imageOutput = addHRF(state, g, denoise, noise, l2i, vaeSource);
// }
if (state.system.shouldUseNSFWChecker) {
imageOutput = addNSFWChecker(g, imageOutput);
}
if (state.system.shouldUseWatermarker) {
imageOutput = addWatermarker(g, imageOutput);
}
g.setMetadataReceivingNode(imageOutput);
return g.getGraph();
};

View File

@ -0,0 +1,246 @@
import type { RootState } from 'app/store/store';
import type { KonvaNodeManager } from 'features/controlLayers/konva/nodeManager';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import {
LATENTS_TO_IMAGE,
NEGATIVE_CONDITIONING,
NEGATIVE_CONDITIONING_COLLECT,
NOISE,
POSITIVE_CONDITIONING,
POSITIVE_CONDITIONING_COLLECT,
SDXL_CONTROL_LAYERS_GRAPH,
SDXL_DENOISE_LATENTS,
SDXL_MODEL_LOADER,
VAE_LOADER,
} from 'features/nodes/util/graph/constants';
import { addControlAdapters } from 'features/nodes/util/graph/generation/addControlAdapters';
import { addIPAdapters } from 'features/nodes/util/graph/generation/addIPAdapters';
import { addNSFWChecker } from 'features/nodes/util/graph/generation/addNSFWChecker';
import { addSDXLLoRAs } from 'features/nodes/util/graph/generation/addSDXLLoRAs';
import { addSDXLRefiner } from 'features/nodes/util/graph/generation/addSDXLRefiner';
import { addSeamless } from 'features/nodes/util/graph/generation/addSeamless';
import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
import { getBoardField, getSDXLStylePrompts, getSizes } from 'features/nodes/util/graph/graphBuilderUtils';
import { isEqual, pick } from 'lodash-es';
import type { Invocation, NonNullableGraph } from 'services/api/types';
import { isNonRefinerMainModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
import { addRegions } from './addRegions';
export const buildSDXLGraph = async (state: RootState, manager: KonvaNodeManager): Promise<NonNullableGraph> => {
const generationMode = manager.util.getGenerationMode();
const { bbox, params } = state.canvasV2;
const {
model,
cfgScale: cfg_scale,
cfgRescaleMultiplier: cfg_rescale_multiplier,
scheduler,
seed,
steps,
shouldUseCpuNoise,
vaePrecision,
vae,
positivePrompt,
negativePrompt,
refinerModel,
refinerStart,
img2imgStrength,
} = params;
assert(model, 'No model found in state');
const { originalSize, scaledSize } = getSizes(bbox);
const { positiveStylePrompt, negativeStylePrompt } = getSDXLStylePrompts(state);
const g = new Graph(SDXL_CONTROL_LAYERS_GRAPH);
const modelLoader = g.addNode({
type: 'sdxl_model_loader',
id: SDXL_MODEL_LOADER,
model,
});
const posCond = g.addNode({
type: 'sdxl_compel_prompt',
id: POSITIVE_CONDITIONING,
prompt: positivePrompt,
style: positiveStylePrompt,
});
const posCondCollect = g.addNode({
type: 'collect',
id: POSITIVE_CONDITIONING_COLLECT,
});
const negCond = g.addNode({
type: 'sdxl_compel_prompt',
id: NEGATIVE_CONDITIONING,
prompt: negativePrompt,
style: negativeStylePrompt,
});
const negCondCollect = g.addNode({
type: 'collect',
id: NEGATIVE_CONDITIONING_COLLECT,
});
const noise = g.addNode({
type: 'noise',
id: NOISE,
seed,
width: scaledSize.width,
height: scaledSize.height,
use_cpu: shouldUseCpuNoise,
});
const denoise = g.addNode({
type: 'denoise_latents',
id: SDXL_DENOISE_LATENTS,
cfg_scale,
cfg_rescale_multiplier,
scheduler,
steps,
denoising_start: 0,
denoising_end: refinerModel ? refinerStart : 1,
});
const l2i = g.addNode({
type: 'l2i',
id: LATENTS_TO_IMAGE,
fp32: vaePrecision === 'fp32',
board: getBoardField(state),
// This is the terminal node and must always save to gallery.
is_intermediate: false,
use_cache: false,
});
const vaeLoader =
vae?.base === model.base
? g.addNode({
type: 'vae_loader',
id: VAE_LOADER,
vae_model: vae,
})
: null;
let imageOutput: Invocation<'l2i'> | Invocation<'img_nsfw'> | Invocation<'img_watermark'> | Invocation<'img_resize'> =
l2i;
g.addEdge(modelLoader, 'unet', denoise, 'unet');
g.addEdge(modelLoader, 'clip', posCond, 'clip');
g.addEdge(modelLoader, 'clip', negCond, 'clip');
g.addEdge(modelLoader, 'clip2', posCond, 'clip2');
g.addEdge(modelLoader, 'clip2', negCond, 'clip2');
g.addEdge(posCond, 'conditioning', posCondCollect, 'item');
g.addEdge(negCond, 'conditioning', negCondCollect, 'item');
g.addEdge(posCondCollect, 'collection', denoise, 'positive_conditioning');
g.addEdge(negCondCollect, 'collection', denoise, 'negative_conditioning');
g.addEdge(noise, 'noise', denoise, 'noise');
g.addEdge(denoise, 'latents', l2i, 'latents');
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
assert(modelConfig.base === 'sdxl');
g.upsertMetadata({
generation_mode: 'sdxl_txt2img',
cfg_scale,
cfg_rescale_multiplier,
width: scaledSize.width,
height: scaledSize.height,
positive_prompt: positivePrompt,
negative_prompt: negativePrompt,
model: Graph.getModelMetadataField(modelConfig),
seed,
steps,
rand_device: shouldUseCpuNoise ? 'cpu' : 'cuda',
scheduler,
positive_style_prompt: positiveStylePrompt,
negative_style_prompt: negativeStylePrompt,
vae: vae ?? undefined,
});
const seamless = addSeamless(state, g, denoise, modelLoader, vaeLoader);
addSDXLLoRAs(state, g, denoise, modelLoader, seamless, posCond, negCond);
// We might get the VAE from the main model, custom VAE, or seamless node.
const vaeSource = seamless ?? vaeLoader ?? modelLoader;
g.addEdge(vaeSource, 'vae', l2i, 'vae');
// Add Refiner if enabled
if (refinerModel) {
await addSDXLRefiner(state, g, denoise, seamless, posCond, negCond, l2i);
}
if (generationMode === 'txt2img') {
if (!isEqual(scaledSize, originalSize)) {
// We are using scaled bbox and need to resize the output image back to the original size.
imageOutput = g.addNode({
id: 'img_resize',
type: 'img_resize',
...originalSize,
is_intermediate: false,
use_cache: false,
});
g.addEdge(l2i, 'image', imageOutput, 'image');
}
} else if (generationMode === 'img2img') {
denoise.denoising_start = refinerModel ? Math.min(refinerStart, 1 - img2imgStrength) : 1 - img2imgStrength;
const { image_name } = await manager.util.getImageSourceImage({
bbox: pick(bbox, ['x', 'y', 'width', 'height']),
preview: true,
});
if (!isEqual(scaledSize, originalSize)) {
// We are using scaled bbox and need to resize the output image back to the original size.
const initialImageResize = g.addNode({
id: 'initial_image_resize',
type: 'img_resize',
...scaledSize,
image: { image_name },
});
const i2l = g.addNode({ id: 'i2l', type: 'i2l' });
g.addEdge(vaeSource, 'vae', i2l, 'vae');
g.addEdge(initialImageResize, 'image', i2l, 'image');
g.addEdge(i2l, 'latents', denoise, 'latents');
imageOutput = g.addNode({
id: 'img_resize',
type: 'img_resize',
...originalSize,
is_intermediate: false,
use_cache: false,
});
g.addEdge(l2i, 'image', imageOutput, 'image');
} else {
const i2l = g.addNode({ id: 'i2l', type: 'i2l', image: { image_name } });
g.addEdge(vaeSource, 'vae', i2l, 'vae');
g.addEdge(i2l, 'latents', denoise, 'latents');
}
}
const _addedCAs = addControlAdapters(state.canvasV2.controlAdapters.entities, g, denoise, modelConfig.base);
const _addedIPAs = addIPAdapters(state.canvasV2.ipAdapters.entities, g, denoise, modelConfig.base);
const _addedRegions = await addRegions(
manager,
state.canvasV2.regions.entities,
g,
state.canvasV2.document,
state.canvasV2.bbox,
modelConfig.base,
denoise,
posCond,
negCond,
posCondCollect,
negCondCollect
);
if (state.system.shouldUseNSFWChecker) {
imageOutput = addNSFWChecker(g, imageOutput);
}
if (state.system.shouldUseWatermarker) {
imageOutput = addWatermarker(g, imageOutput);
}
g.setMetadataReceivingNode(imageOutput);
return g.getGraph();
};

View File

@ -23,14 +23,17 @@ import { addSeamless } from 'features/nodes/util/graph/generation/addSeamless';
import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker'; import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker';
import type { GraphType } from 'features/nodes/util/graph/generation/Graph'; import type { GraphType } from 'features/nodes/util/graph/generation/Graph';
import { Graph } from 'features/nodes/util/graph/generation/Graph'; import { Graph } from 'features/nodes/util/graph/generation/Graph';
import { getBoardField, getPresetModifiedPrompts } from 'features/nodes/util/graph/graphBuilderUtils'; import { getBoardField, getPresetModifiedPrompts , getSizes } from 'features/nodes/util/graph/graphBuilderUtils';
import { isEqual } from 'lodash-es';
import type { Invocation } from 'services/api/types'; import type { Invocation } from 'services/api/types';
import { isNonRefinerMainModelConfig } from 'services/api/types'; import { isNonRefinerMainModelConfig } from 'services/api/types';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
import { addRegions } from './addRegions'; import { addRegions } from './addRegions';
export const buildGenerationTabGraph = async (state: RootState, manager: KonvaNodeManager): Promise<GraphType> => { export const buildTextToImageSD1SD2Graph = async (state: RootState, manager: KonvaNodeManager): Promise<GraphType> => {
const { bbox, params } = state.canvasV2;
const { const {
model, model,
cfgScale: cfg_scale, cfgScale: cfg_scale,
@ -42,14 +45,12 @@ export const buildGenerationTabGraph = async (state: RootState, manager: KonvaNo
vaePrecision, vaePrecision,
seed, seed,
vae, vae,
positivePrompt, } = params;
negativePrompt,
} = state.canvasV2.params;
const { width, height } = state.canvasV2.document;
assert(model, 'No model found in state'); assert(model, 'No model found in state');
const { positivePrompt, negativePrompt } = getPresetModifiedPrompts(state); const { positivePrompt, negativePrompt } = getPresetModifiedPrompts(state);
const { originalSize, scaledSize } = getSizes(bbox);
const g = new Graph(CONTROL_LAYERS_GRAPH); const g = new Graph(CONTROL_LAYERS_GRAPH);
const modelLoader = g.addNode({ const modelLoader = g.addNode({
@ -84,8 +85,8 @@ export const buildGenerationTabGraph = async (state: RootState, manager: KonvaNo
type: 'noise', type: 'noise',
id: NOISE, id: NOISE,
seed, seed,
width, width: scaledSize.width,
height, height: scaledSize.height,
use_cpu: shouldUseCpuNoise, use_cpu: shouldUseCpuNoise,
}); });
const denoise = g.addNode({ const denoise = g.addNode({
@ -116,7 +117,8 @@ export const buildGenerationTabGraph = async (state: RootState, manager: KonvaNo
}) })
: null; : null;
let imageOutput: Invocation<'l2i'> | Invocation<'img_nsfw'> | Invocation<'img_watermark'> = l2i; let imageOutput: Invocation<'l2i'> | Invocation<'img_nsfw'> | Invocation<'img_watermark'> | Invocation<'img_resize'> =
l2i;
g.addEdge(modelLoader, 'unet', denoise, 'unet'); g.addEdge(modelLoader, 'unet', denoise, 'unet');
g.addEdge(modelLoader, 'clip', clipSkip, 'clip'); g.addEdge(modelLoader, 'clip', clipSkip, 'clip');
@ -136,8 +138,8 @@ export const buildGenerationTabGraph = async (state: RootState, manager: KonvaNo
generation_mode: 'txt2img', generation_mode: 'txt2img',
cfg_scale, cfg_scale,
cfg_rescale_multiplier, cfg_rescale_multiplier,
height, width: scaledSize.width,
width, height: scaledSize.height,
positive_prompt: positivePrompt, positive_prompt: positivePrompt,
negative_prompt: negativePrompt, negative_prompt: negativePrompt,
model: Graph.getModelMetadataField(modelConfig), model: Graph.getModelMetadataField(modelConfig),
@ -157,6 +159,18 @@ export const buildGenerationTabGraph = async (state: RootState, manager: KonvaNo
const vaeSource = seamless ?? vaeLoader ?? modelLoader; const vaeSource = seamless ?? vaeLoader ?? modelLoader;
g.addEdge(vaeSource, 'vae', l2i, 'vae'); g.addEdge(vaeSource, 'vae', l2i, 'vae');
if (!isEqual(scaledSize, originalSize)) {
// We are using scaled bbox and need to resize the output image back to the original size.
imageOutput = g.addNode({
id: 'img_resize',
type: 'img_resize',
...originalSize,
is_intermediate: false,
use_cache: false,
});
g.addEdge(l2i, 'image', imageOutput, 'image');
}
const _addedCAs = addControlAdapters(state.canvasV2.controlAdapters.entities, g, denoise, modelConfig.base); const _addedCAs = addControlAdapters(state.canvasV2.controlAdapters.entities, g, denoise, modelConfig.base);
const _addedIPAs = addIPAdapters(state.canvasV2.ipAdapters.entities, g, denoise, modelConfig.base); const _addedIPAs = addIPAdapters(state.canvasV2.ipAdapters.entities, g, denoise, modelConfig.base);
const _addedRegions = await addRegions( const _addedRegions = await addRegions(

View File

@ -1,7 +1,9 @@
import type { RootState } from 'app/store/store'; import type { RootState } from 'app/store/store';
import type { CanvasV2State } from 'features/controlLayers/store/types';
import type { BoardField } from 'features/nodes/types/common'; import type { BoardField } from 'features/nodes/types/common';
import { buildPresetModifiedPrompt } from 'features/stylePresets/hooks/usePresetModifiedPrompts'; import { buildPresetModifiedPrompt } from 'features/stylePresets/hooks/usePresetModifiedPrompts';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { pick } from 'lodash-es';
import { stylePresetsApi } from 'services/api/endpoints/stylePresets'; import { stylePresetsApi } from 'services/api/endpoints/stylePresets';
/** /**
@ -22,7 +24,7 @@ export const getPresetModifiedPrompts = (
state: RootState state: RootState
): { positivePrompt: string; negativePrompt: string; positiveStylePrompt?: string; negativeStylePrompt?: string } => { ): { positivePrompt: string; negativePrompt: string; positiveStylePrompt?: string; negativeStylePrompt?: string } => {
const { positivePrompt, negativePrompt, positivePrompt2, negativePrompt2, shouldConcatPrompts } = const { positivePrompt, negativePrompt, positivePrompt2, negativePrompt2, shouldConcatPrompts } =
state.generation; state.canvasV2.params;
const { activeStylePresetId } = state.stylePreset; const { activeStylePresetId } = state.stylePreset;
if (activeStylePresetId) { if (activeStylePresetId) {
@ -68,3 +70,9 @@ export const getIsIntermediate = (state: RootState) => {
} }
return false; return false;
}; };
export const getSizes = (bboxState: CanvasV2State['bbox']) => {
const originalSize = pick(bboxState, 'width', 'height');
const scaledSize = ['auto', 'manual'].includes(bboxState.scaleMethod) ? bboxState.scaledSize : originalSize;
return { originalSize, scaledSize };
};