feat(wip): Add SDXL To Canvas

This commit is contained in:
blessedcoolant 2023-08-12 08:16:05 +12:00
parent f343ab0302
commit 7293a6036a
8 changed files with 1236 additions and 11 deletions

View File

@ -12,7 +12,10 @@ export const addTabChangedListener = () => {
if (activeTabName === 'unifiedCanvas') { if (activeTabName === 'unifiedCanvas') {
const currentBaseModel = getState().generation.model?.base_model; const currentBaseModel = getState().generation.model?.base_model;
if (currentBaseModel && ['sd-1', 'sd-2'].includes(currentBaseModel)) { if (
currentBaseModel &&
['sd-1', 'sd-2', 'sdxl'].includes(currentBaseModel)
) {
// if we're already on a valid model, no change needed // if we're already on a valid model, no change needed
return; return;
} }
@ -36,7 +39,9 @@ export const addTabChangedListener = () => {
const validCanvasModels = mainModelsAdapter const validCanvasModels = mainModelsAdapter
.getSelectors() .getSelectors()
.selectAll(models) .selectAll(models)
.filter((model) => ['sd-1', 'sd-2'].includes(model.base_model)); .filter((model) =>
['sd-1', 'sd-2', 'sxdl'].includes(model.base_model)
);
const firstValidCanvasModel = validCanvasModels[0]; const firstValidCanvasModel = validCanvasModels[0];

View File

@ -3,6 +3,9 @@ import { NonNullableGraph } from 'features/nodes/types/types';
import { ImageDTO } from 'services/api/types'; import { ImageDTO } from 'services/api/types';
import { buildCanvasImageToImageGraph } from './buildCanvasImageToImageGraph'; import { buildCanvasImageToImageGraph } from './buildCanvasImageToImageGraph';
import { buildCanvasInpaintGraph } from './buildCanvasInpaintGraph'; import { buildCanvasInpaintGraph } from './buildCanvasInpaintGraph';
import { buildCanvasSDXLImageToImageGraph } from './buildCanvasSDXLImageToImageGraph';
import { buildCanvasSDXLInpaintGraph } from './buildCanvasSDXLInpaintGraph';
import { buildCanvasSDXLTextToImageGraph } from './buildCanvasSDXLTextToImageGraph';
import { buildCanvasTextToImageGraph } from './buildCanvasTextToImageGraph'; import { buildCanvasTextToImageGraph } from './buildCanvasTextToImageGraph';
export const buildCanvasGraph = ( export const buildCanvasGraph = (
@ -14,18 +17,44 @@ export const buildCanvasGraph = (
let graph: NonNullableGraph; let graph: NonNullableGraph;
if (generationMode === 'txt2img') { if (generationMode === 'txt2img') {
if (
state.generation.model &&
state.generation.model.base_model === 'sdxl'
) {
graph = buildCanvasSDXLTextToImageGraph(state);
} else {
graph = buildCanvasTextToImageGraph(state); graph = buildCanvasTextToImageGraph(state);
}
} else if (generationMode === 'img2img') { } else if (generationMode === 'img2img') {
if (!canvasInitImage) { if (!canvasInitImage) {
throw new Error('Missing canvas init image'); throw new Error('Missing canvas init image');
} }
if (
state.generation.model &&
state.generation.model.base_model === 'sdxl'
) {
graph = buildCanvasSDXLImageToImageGraph(state, canvasInitImage);
} else {
graph = buildCanvasImageToImageGraph(state, canvasInitImage); graph = buildCanvasImageToImageGraph(state, canvasInitImage);
}
} else { } else {
if (!canvasInitImage || !canvasMaskImage) { if (!canvasInitImage || !canvasMaskImage) {
throw new Error('Missing canvas init and mask images'); throw new Error('Missing canvas init and mask images');
} }
if (
state.generation.model &&
state.generation.model.base_model === 'sdxl'
) {
graph = buildCanvasSDXLInpaintGraph(
state,
canvasInitImage,
canvasMaskImage
);
} else {
graph = buildCanvasInpaintGraph(state, canvasInitImage, canvasMaskImage); graph = buildCanvasInpaintGraph(state, canvasInitImage, canvasMaskImage);
} }
}
return graph; return graph;
}; };

View File

@ -0,0 +1,373 @@
import { logger } from 'app/logging/logger';
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import { initialGenerationState } from 'features/parameters/store/generationSlice';
import {
ImageDTO,
ImageResizeInvocation,
ImageToLatentsInvocation,
} from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import {
DENOISE_LATENTS,
IMAGE_TO_IMAGE_GRAPH,
IMAGE_TO_LATENTS,
LATENTS_TO_IMAGE,
METADATA_ACCUMULATOR,
NEGATIVE_CONDITIONING,
NOISE,
POSITIVE_CONDITIONING,
RESIZE,
SDXL_MODEL_LOADER,
} from './constants';
/**
* Builds the Canvas tab's Image to Image graph.
*/
export const buildCanvasSDXLImageToImageGraph = (
state: RootState,
initialImage: ImageDTO
): NonNullableGraph => {
const log = logger('nodes');
const {
positivePrompt,
negativePrompt,
model,
cfgScale: cfg_scale,
scheduler,
steps,
clipSkip,
shouldUseCpuNoise,
shouldUseNoiseSettings,
} = state.generation;
const {
positiveStylePrompt,
negativeStylePrompt,
shouldConcatSDXLStylePrompt,
shouldUseSDXLRefiner,
refinerStart,
sdxlImg2ImgDenoisingStrength: strength,
} = state.sdxl;
// The bounding box determines width and height, not the width and height params
const { width, height } = state.canvas.boundingBoxDimensions;
const { shouldAutoSave } = state.canvas;
if (!model) {
log.error('No model found in state');
throw new Error('No model found in state');
}
const use_cpu = shouldUseNoiseSettings
? shouldUseCpuNoise
: initialGenerationState.shouldUseCpuNoise;
/**
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
* full graph here as a template. Then use the parameters from app state and set friendlier node
* ids.
*
* The only thing we need extra logic for is handling randomized seed, control net, and for img2img,
* the `fit` param. These are added to the graph at the end.
*/
// copy-pasted graph from node editor, filled in with state values & friendly node ids
const graph: NonNullableGraph = {
id: IMAGE_TO_IMAGE_GRAPH,
nodes: {
[SDXL_MODEL_LOADER]: {
type: 'sdxl_model_loader',
id: SDXL_MODEL_LOADER,
model,
},
[POSITIVE_CONDITIONING]: {
type: 'sdxl_compel_prompt',
id: POSITIVE_CONDITIONING,
prompt: positivePrompt,
style: shouldConcatSDXLStylePrompt
? `${positivePrompt} ${positiveStylePrompt}`
: positiveStylePrompt,
},
[NEGATIVE_CONDITIONING]: {
type: 'sdxl_compel_prompt',
id: NEGATIVE_CONDITIONING,
prompt: negativePrompt,
style: shouldConcatSDXLStylePrompt
? `${negativePrompt} ${negativeStylePrompt}`
: negativeStylePrompt,
},
[NOISE]: {
type: 'noise',
id: NOISE,
is_intermediate: true,
use_cpu,
},
[DENOISE_LATENTS]: {
type: 'denoise_latents',
id: DENOISE_LATENTS,
is_intermediate: true,
cfg_scale,
scheduler,
steps,
denoising_start: shouldUseSDXLRefiner
? Math.min(refinerStart, 1 - strength)
: 1 - strength,
denoising_end: shouldUseSDXLRefiner ? refinerStart : 1,
},
[IMAGE_TO_LATENTS]: {
type: 'i2l',
id: IMAGE_TO_LATENTS,
is_intermediate: true,
// must be set manually later, bc `fit` parameter may require a resize node inserted
// image: {
// image_name: initialImage.image_name,
// },
},
[LATENTS_TO_IMAGE]: {
type: 'l2i',
id: LATENTS_TO_IMAGE,
is_intermediate: !shouldAutoSave,
},
},
edges: [
{
source: {
node_id: DENOISE_LATENTS,
field: 'latents',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'latents',
},
},
{
source: {
node_id: IMAGE_TO_LATENTS,
field: 'latents',
},
destination: {
node_id: DENOISE_LATENTS,
field: 'latents',
},
},
{
source: {
node_id: NOISE,
field: 'noise',
},
destination: {
node_id: DENOISE_LATENTS,
field: 'noise',
},
},
{
source: {
node_id: SDXL_MODEL_LOADER,
field: 'unet',
},
destination: {
node_id: DENOISE_LATENTS,
field: 'unet',
},
},
{
source: {
node_id: SDXL_MODEL_LOADER,
field: 'clip',
},
destination: {
node_id: POSITIVE_CONDITIONING,
field: 'clip',
},
},
{
source: {
node_id: SDXL_MODEL_LOADER,
field: 'clip2',
},
destination: {
node_id: POSITIVE_CONDITIONING,
field: 'clip2',
},
},
{
source: {
node_id: SDXL_MODEL_LOADER,
field: 'clip',
},
destination: {
node_id: NEGATIVE_CONDITIONING,
field: 'clip',
},
},
{
source: {
node_id: SDXL_MODEL_LOADER,
field: 'clip2',
},
destination: {
node_id: NEGATIVE_CONDITIONING,
field: 'clip2',
},
},
{
source: {
node_id: NEGATIVE_CONDITIONING,
field: 'conditioning',
},
destination: {
node_id: DENOISE_LATENTS,
field: 'negative_conditioning',
},
},
{
source: {
node_id: POSITIVE_CONDITIONING,
field: 'conditioning',
},
destination: {
node_id: DENOISE_LATENTS,
field: 'positive_conditioning',
},
},
],
};
// handle `fit`
if (initialImage.width !== width || initialImage.height !== height) {
// The init image needs to be resized to the specified width and height before being passed to `IMAGE_TO_LATENTS`
// Create a resize node, explicitly setting its image
const resizeNode: ImageResizeInvocation = {
id: RESIZE,
type: 'img_resize',
image: {
image_name: initialImage.image_name,
},
is_intermediate: true,
width,
height,
};
graph.nodes[RESIZE] = resizeNode;
// The `RESIZE` node then passes its image to `IMAGE_TO_LATENTS`
graph.edges.push({
source: { node_id: RESIZE, field: 'image' },
destination: {
node_id: IMAGE_TO_LATENTS,
field: 'image',
},
});
// The `RESIZE` node also passes its width and height to `NOISE`
graph.edges.push({
source: { node_id: RESIZE, field: 'width' },
destination: {
node_id: NOISE,
field: 'width',
},
});
graph.edges.push({
source: { node_id: RESIZE, field: 'height' },
destination: {
node_id: NOISE,
field: 'height',
},
});
} else {
// We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly
(graph.nodes[IMAGE_TO_LATENTS] as ImageToLatentsInvocation).image = {
image_name: initialImage.image_name,
};
// Pass the image's dimensions to the `NOISE` node
graph.edges.push({
source: { node_id: IMAGE_TO_LATENTS, field: 'width' },
destination: {
node_id: NOISE,
field: 'width',
},
});
graph.edges.push({
source: { node_id: IMAGE_TO_LATENTS, field: 'height' },
destination: {
node_id: NOISE,
field: 'height',
},
});
}
// add metadata accumulator, which is only mostly populated - some fields are added later
graph.nodes[METADATA_ACCUMULATOR] = {
id: METADATA_ACCUMULATOR,
type: 'metadata_accumulator',
generation_mode: 'img2img',
cfg_scale,
height,
width,
positive_prompt: '', // set in addDynamicPromptsToGraph
negative_prompt: negativePrompt,
model,
seed: 0, // set in addDynamicPromptsToGraph
steps,
rand_device: use_cpu ? 'cpu' : 'cuda',
scheduler,
vae: undefined, // option; set in addVAEToGraph
controlnets: [], // populated in addControlNetToLinearGraph
loras: [], // populated in addLoRAsToGraph
clip_skip: clipSkip,
strength,
init_image: initialImage.image_name,
};
graph.edges.push({
source: {
node_id: METADATA_ACCUMULATOR,
field: 'metadata',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'metadata',
},
});
// add LoRA support
addLoRAsToGraph(state, graph, DENOISE_LATENTS);
// Add Refiner if enabled
if (shouldUseSDXLRefiner) {
addSDXLRefinerToGraph(state, graph, DENOISE_LATENTS);
}
// optionally add custom VAE
addVAEToGraph(state, graph, SDXL_MODEL_LOADER);
// add dynamic prompts - also sets up core iteration and seed
addDynamicPromptsToGraph(state, graph);
// add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!
addNSFWCheckerToGraph(state, graph);
}
if (state.system.shouldUseWatermarker) {
// must add after nsfw checker!
addWatermarkerToGraph(state, graph);
}
return graph;
};

View File

@ -0,0 +1,480 @@
import { logger } from 'app/logging/logger';
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import {
ImageDTO,
InfillPatchmatchInvocation,
InfillTileInvocation,
RandomIntInvocation,
RangeOfSizeInvocation,
} from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import {
COLOR_CORRECT,
INPAINT,
INPAINT_FINAL_IMAGE,
INPAINT_GRAPH,
INPAINT_IMAGE,
INPAINT_INFILL,
ITERATE,
LATENTS_TO_IMAGE,
MASK_BLUR,
MASK_COMBINE,
MASK_FROM_ALPHA,
NEGATIVE_CONDITIONING,
NOISE,
POSITIVE_CONDITIONING,
RANDOM_INT,
RANGE_OF_SIZE,
SDXL_MODEL_LOADER,
} from './constants';
/**
* Builds the Canvas tab's Inpaint graph.
*/
export const buildCanvasSDXLInpaintGraph = (
state: RootState,
canvasInitImage: ImageDTO,
canvasMaskImage: ImageDTO
): NonNullableGraph => {
const log = logger('nodes');
const {
positivePrompt,
negativePrompt,
model,
cfgScale: cfg_scale,
scheduler,
steps,
img2imgStrength: strength,
shouldFitToWidthHeight,
iterations,
seed,
shouldRandomizeSeed,
vaePrecision,
shouldUseNoiseSettings,
shouldUseCpuNoise,
maskBlur,
maskBlurMethod,
tileSize,
infillMethod,
} = state.generation;
const {
positiveStylePrompt,
negativeStylePrompt,
shouldConcatSDXLStylePrompt,
shouldUseSDXLRefiner,
refinerStart,
} = state.sdxl;
if (!model) {
log.error('No model found in state');
throw new Error('No model found in state');
}
// The bounding box determines width and height, not the width and height params
const { width, height } = state.canvas.boundingBoxDimensions;
// We may need to set the inpaint width and height to scale the image
const {
scaledBoundingBoxDimensions,
boundingBoxScaleMethod,
shouldAutoSave,
} = state.canvas;
const use_cpu = shouldUseNoiseSettings
? shouldUseCpuNoise
: shouldUseCpuNoise;
let infillNode: InfillTileInvocation | InfillPatchmatchInvocation = {
type: 'infill_tile',
id: INPAINT_INFILL,
is_intermediate: true,
image: canvasInitImage,
tile_size: tileSize,
};
if (infillMethod === 'patchmatch') {
infillNode = {
type: 'infill_patchmatch',
id: INPAINT_INFILL,
is_intermediate: true,
image: canvasInitImage,
};
}
const graph: NonNullableGraph = {
id: INPAINT_GRAPH,
nodes: {
[INPAINT]: {
type: 'denoise_latents',
id: INPAINT,
is_intermediate: true,
steps: steps,
cfg_scale: cfg_scale,
scheduler: scheduler,
denoising_start: 1 - strength,
denoising_end: shouldUseSDXLRefiner ? refinerStart : 1,
},
[infillNode.id]: infillNode,
[MASK_FROM_ALPHA]: {
type: 'tomask',
id: MASK_FROM_ALPHA,
is_intermediate: true,
image: canvasInitImage,
},
[MASK_COMBINE]: {
type: 'mask_combine',
id: MASK_COMBINE,
is_intermediate: true,
mask2: canvasMaskImage,
},
[MASK_BLUR]: {
type: 'img_blur',
id: MASK_BLUR,
is_intermediate: true,
radius: maskBlur,
blur_type: maskBlurMethod,
},
[INPAINT_IMAGE]: {
type: 'i2l',
id: INPAINT_IMAGE,
is_intermediate: true,
fp32: vaePrecision === 'fp32' ? true : false,
},
[NOISE]: {
type: 'noise',
id: NOISE,
width,
height,
use_cpu,
is_intermediate: true,
},
[POSITIVE_CONDITIONING]: {
type: 'sdxl_compel_prompt',
id: POSITIVE_CONDITIONING,
prompt: positivePrompt,
style: shouldConcatSDXLStylePrompt
? `${positivePrompt} ${positiveStylePrompt}`
: positiveStylePrompt,
},
[NEGATIVE_CONDITIONING]: {
type: 'sdxl_compel_prompt',
id: NEGATIVE_CONDITIONING,
prompt: negativePrompt,
style: shouldConcatSDXLStylePrompt
? `${negativePrompt} ${negativeStylePrompt}`
: negativeStylePrompt,
},
[SDXL_MODEL_LOADER]: {
type: 'sdxl_model_loader',
id: SDXL_MODEL_LOADER,
model,
},
[LATENTS_TO_IMAGE]: {
type: 'l2i',
id: LATENTS_TO_IMAGE,
is_intermediate: true,
fp32: vaePrecision === 'fp32' ? true : false,
},
[COLOR_CORRECT]: {
type: 'color_correct',
id: COLOR_CORRECT,
is_intermediate: true,
},
[INPAINT_FINAL_IMAGE]: {
type: 'img_paste',
id: INPAINT_FINAL_IMAGE,
is_intermediate: true,
},
[RANGE_OF_SIZE]: {
type: 'range_of_size',
id: RANGE_OF_SIZE,
is_intermediate: true,
// seed - must be connected manually
// start: 0,
size: iterations,
step: 1,
},
[ITERATE]: {
type: 'iterate',
id: ITERATE,
is_intermediate: true,
},
},
edges: [
{
source: {
node_id: SDXL_MODEL_LOADER,
field: 'unet',
},
destination: {
node_id: INPAINT,
field: 'unet',
},
},
{
source: {
node_id: SDXL_MODEL_LOADER,
field: 'clip',
},
destination: {
node_id: POSITIVE_CONDITIONING,
field: 'clip',
},
},
{
source: {
node_id: SDXL_MODEL_LOADER,
field: 'clip2',
},
destination: {
node_id: POSITIVE_CONDITIONING,
field: 'clip2',
},
},
{
source: {
node_id: SDXL_MODEL_LOADER,
field: 'clip',
},
destination: {
node_id: NEGATIVE_CONDITIONING,
field: 'clip',
},
},
{
source: {
node_id: SDXL_MODEL_LOADER,
field: 'clip2',
},
destination: {
node_id: NEGATIVE_CONDITIONING,
field: 'clip2',
},
},
{
source: {
node_id: NEGATIVE_CONDITIONING,
field: 'conditioning',
},
destination: {
node_id: INPAINT,
field: 'negative_conditioning',
},
},
{
source: {
node_id: POSITIVE_CONDITIONING,
field: 'conditioning',
},
destination: {
node_id: INPAINT,
field: 'positive_conditioning',
},
},
{
source: {
node_id: NOISE,
field: 'noise',
},
destination: {
node_id: INPAINT,
field: 'noise',
},
},
{
source: {
node_id: INPAINT_INFILL,
field: 'image',
},
destination: {
node_id: INPAINT_IMAGE,
field: 'image',
},
},
{
source: {
node_id: INPAINT_IMAGE,
field: 'latents',
},
destination: {
node_id: INPAINT,
field: 'latents',
},
},
{
source: {
node_id: MASK_FROM_ALPHA,
field: 'mask',
},
destination: {
node_id: MASK_COMBINE,
field: 'mask1',
},
},
{
source: {
node_id: MASK_COMBINE,
field: 'image',
},
destination: {
node_id: MASK_BLUR,
field: 'image',
},
},
{
source: {
node_id: MASK_BLUR,
field: 'image',
},
destination: {
node_id: INPAINT,
field: 'mask',
},
},
{
source: {
node_id: RANGE_OF_SIZE,
field: 'collection',
},
destination: {
node_id: ITERATE,
field: 'collection',
},
},
{
source: {
node_id: ITERATE,
field: 'item',
},
destination: {
node_id: NOISE,
field: 'seed',
},
},
{
source: {
node_id: INPAINT,
field: 'latents',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'latents',
},
},
{
source: {
node_id: INPAINT_INFILL,
field: 'image',
},
destination: {
node_id: COLOR_CORRECT,
field: 'reference',
},
},
{
source: {
node_id: MASK_BLUR,
field: 'image',
},
destination: {
node_id: COLOR_CORRECT,
field: 'mask',
},
},
{
source: {
node_id: LATENTS_TO_IMAGE,
field: 'image',
},
destination: {
node_id: COLOR_CORRECT,
field: 'image',
},
},
{
source: {
node_id: INPAINT_INFILL,
field: 'image',
},
destination: {
node_id: INPAINT_FINAL_IMAGE,
field: 'base_image',
},
},
{
source: {
node_id: MASK_BLUR,
field: 'image',
},
destination: {
node_id: INPAINT_FINAL_IMAGE,
field: 'mask',
},
},
{
source: {
node_id: COLOR_CORRECT,
field: 'image',
},
destination: {
node_id: INPAINT_FINAL_IMAGE,
field: 'image',
},
},
],
};
// Add Refiner if enabled
if (shouldUseSDXLRefiner) {
addSDXLRefinerToGraph(state, graph, INPAINT);
}
// Add VAE
addVAEToGraph(state, graph, SDXL_MODEL_LOADER);
// handle seed
if (shouldRandomizeSeed) {
// Random int node to generate the starting seed
const randomIntNode: RandomIntInvocation = {
id: RANDOM_INT,
type: 'rand_int',
};
graph.nodes[RANDOM_INT] = randomIntNode;
// Connect random int to the start of the range of size so the range starts on the random first seed
graph.edges.push({
source: { node_id: RANDOM_INT, field: 'a' },
destination: { node_id: RANGE_OF_SIZE, field: 'start' },
});
} else {
// User specified seed, so set the start of the range of size to the seed
(graph.nodes[RANGE_OF_SIZE] as RangeOfSizeInvocation).start = seed;
}
// add LoRA support
addSDXLLoRAsToGraph(state, graph, INPAINT, SDXL_MODEL_LOADER);
// add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, INPAINT);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!
addNSFWCheckerToGraph(state, graph, INPAINT);
}
if (state.system.shouldUseWatermarker) {
// must add after nsfw checker!
addWatermarkerToGraph(state, graph, INPAINT);
}
return graph;
};

View File

@ -0,0 +1,304 @@
import { logger } from 'app/logging/logger';
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import { initialGenerationState } from 'features/parameters/store/generationSlice';
import {
DenoiseLatentsInvocation,
ONNXTextToLatentsInvocation,
} from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import {
DENOISE_LATENTS,
LATENTS_TO_IMAGE,
METADATA_ACCUMULATOR,
NEGATIVE_CONDITIONING,
NOISE,
ONNX_MODEL_LOADER,
POSITIVE_CONDITIONING,
SDXL_MODEL_LOADER,
TEXT_TO_IMAGE_GRAPH,
} from './constants';
/**
* Builds the Canvas tab's Text to Image graph.
*/
export const buildCanvasSDXLTextToImageGraph = (
state: RootState
): NonNullableGraph => {
const log = logger('nodes');
const {
positivePrompt,
negativePrompt,
model,
cfgScale: cfg_scale,
scheduler,
steps,
clipSkip,
shouldUseCpuNoise,
shouldUseNoiseSettings,
} = state.generation;
// The bounding box determines width and height, not the width and height params
const { width, height } = state.canvas.boundingBoxDimensions;
const { shouldAutoSave } = state.canvas;
const {
positiveStylePrompt,
negativeStylePrompt,
shouldConcatSDXLStylePrompt,
shouldUseSDXLRefiner,
refinerStart,
} = state.sdxl;
if (!model) {
log.error('No model found in state');
throw new Error('No model found in state');
}
const use_cpu = shouldUseNoiseSettings
? shouldUseCpuNoise
: initialGenerationState.shouldUseCpuNoise;
const isUsingOnnxModel = model.model_type === 'onnx';
const modelLoaderNodeId = isUsingOnnxModel
? ONNX_MODEL_LOADER
: SDXL_MODEL_LOADER;
const modelLoaderNodeType = isUsingOnnxModel
? 'onnx_model_loader'
: 'sdxl_model_loader';
const t2lNode: DenoiseLatentsInvocation | ONNXTextToLatentsInvocation =
isUsingOnnxModel
? {
type: 't2l_onnx',
id: DENOISE_LATENTS,
is_intermediate: true,
cfg_scale,
scheduler,
steps,
}
: {
type: 'denoise_latents',
id: DENOISE_LATENTS,
is_intermediate: true,
cfg_scale,
scheduler,
steps,
denoising_start: 0,
denoising_end: shouldUseSDXLRefiner ? refinerStart : 1,
};
/**
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
* full graph here as a template. Then use the parameters from app state and set friendlier node
* ids.
*
* The only thing we need extra logic for is handling randomized seed, control net, and for img2img,
* the `fit` param. These are added to the graph at the end.
*/
// copy-pasted graph from node editor, filled in with state values & friendly node ids
// TODO: Actually create the graph correctly for ONNX
const graph: NonNullableGraph = {
id: TEXT_TO_IMAGE_GRAPH,
nodes: {
[POSITIVE_CONDITIONING]: {
type: isUsingOnnxModel ? 'prompt_onnx' : 'sdxl_compel_prompt',
id: POSITIVE_CONDITIONING,
is_intermediate: true,
prompt: positivePrompt,
style: shouldConcatSDXLStylePrompt
? `${positivePrompt} ${positiveStylePrompt}`
: positiveStylePrompt,
},
[NEGATIVE_CONDITIONING]: {
type: isUsingOnnxModel ? 'prompt_onnx' : 'sdxl_compel_prompt',
id: NEGATIVE_CONDITIONING,
is_intermediate: true,
prompt: negativePrompt,
style: shouldConcatSDXLStylePrompt
? `${negativePrompt} ${negativeStylePrompt}`
: negativeStylePrompt,
},
[NOISE]: {
type: 'noise',
id: NOISE,
is_intermediate: true,
width,
height,
use_cpu,
},
[t2lNode.id]: t2lNode,
[modelLoaderNodeId]: {
type: modelLoaderNodeType,
id: modelLoaderNodeId,
is_intermediate: true,
model,
},
[LATENTS_TO_IMAGE]: {
type: isUsingOnnxModel ? 'l2i_onnx' : 'l2i',
id: LATENTS_TO_IMAGE,
is_intermediate: !shouldAutoSave,
},
},
edges: [
{
source: {
node_id: modelLoaderNodeId,
field: 'unet',
},
destination: {
node_id: DENOISE_LATENTS,
field: 'unet',
},
},
{
source: {
node_id: modelLoaderNodeId,
field: 'clip',
},
destination: {
node_id: POSITIVE_CONDITIONING,
field: 'clip',
},
},
{
source: {
node_id: modelLoaderNodeId,
field: 'clip2',
},
destination: {
node_id: POSITIVE_CONDITIONING,
field: 'clip2',
},
},
{
source: {
node_id: modelLoaderNodeId,
field: 'clip',
},
destination: {
node_id: NEGATIVE_CONDITIONING,
field: 'clip',
},
},
{
source: {
node_id: modelLoaderNodeId,
field: 'clip2',
},
destination: {
node_id: NEGATIVE_CONDITIONING,
field: 'clip2',
},
},
{
source: {
node_id: NEGATIVE_CONDITIONING,
field: 'conditioning',
},
destination: {
node_id: DENOISE_LATENTS,
field: 'negative_conditioning',
},
},
{
source: {
node_id: POSITIVE_CONDITIONING,
field: 'conditioning',
},
destination: {
node_id: DENOISE_LATENTS,
field: 'positive_conditioning',
},
},
{
source: {
node_id: DENOISE_LATENTS,
field: 'latents',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'latents',
},
},
{
source: {
node_id: NOISE,
field: 'noise',
},
destination: {
node_id: DENOISE_LATENTS,
field: 'noise',
},
},
],
};
// add metadata accumulator, which is only mostly populated - some fields are added later
graph.nodes[METADATA_ACCUMULATOR] = {
id: METADATA_ACCUMULATOR,
type: 'metadata_accumulator',
generation_mode: 'txt2img',
cfg_scale,
height,
width,
positive_prompt: '', // set in addDynamicPromptsToGraph
negative_prompt: negativePrompt,
model,
seed: 0, // set in addDynamicPromptsToGraph
steps,
rand_device: use_cpu ? 'cpu' : 'cuda',
scheduler,
vae: undefined, // option; set in addVAEToGraph
controlnets: [], // populated in addControlNetToLinearGraph
loras: [], // populated in addLoRAsToGraph
clip_skip: clipSkip,
};
graph.edges.push({
source: {
node_id: METADATA_ACCUMULATOR,
field: 'metadata',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'metadata',
},
});
// Add Refiner if enabled
if (shouldUseSDXLRefiner) {
addSDXLRefinerToGraph(state, graph, DENOISE_LATENTS);
}
// add LoRA support
addSDXLLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
// optionally add custom VAE
addVAEToGraph(state, graph, modelLoaderNodeId);
// add dynamic prompts - also sets up core iteration and seed
addDynamicPromptsToGraph(state, graph);
// add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!
addNSFWCheckerToGraph(state, graph);
}
if (state.system.shouldUseWatermarker) {
// must add after nsfw checker!
addWatermarkerToGraph(state, graph);
}
return graph;
};

View File

@ -15,11 +15,11 @@ import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainM
import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton'; import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { forEach } from 'lodash-es'; import { forEach } from 'lodash-es';
import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
import { import {
useGetMainModelsQuery, useGetMainModelsQuery,
useGetOnnxModelsQuery, useGetOnnxModelsQuery,
} from 'services/api/endpoints/models'; } from 'services/api/endpoints/models';
import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
import { useFeatureStatus } from '../../../../system/hooks/useFeatureStatus'; import { useFeatureStatus } from '../../../../system/hooks/useFeatureStatus';
const selector = createSelector( const selector = createSelector(
@ -52,10 +52,7 @@ const ParamMainModelSelect = () => {
const data: SelectItem[] = []; const data: SelectItem[] = [];
forEach(mainModels.entities, (model, id) => { forEach(mainModels.entities, (model, id) => {
if ( if (!model) {
!model ||
(activeTabName === 'unifiedCanvas' && model.base_model === 'sdxl')
) {
return; return;
} }

View File

@ -0,0 +1,29 @@
import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse';
import ParamLoraCollapse from 'features/lora/components/ParamLoraCollapse';
import ParamAdvancedCollapse from 'features/parameters/components/Parameters/Advanced/ParamAdvancedCollapse';
import ParamInfillAndScalingCollapse from 'features/parameters/components/Parameters/Canvas/InfillAndScaling/ParamInfillAndScalingCollapse';
import ParamMaskAdjustmentCollapse from 'features/parameters/components/Parameters/Canvas/MaskAdjustment/ParamMaskAdjustmentCollapse';
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse';
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
import UnifiedCanvasCoreParameters from 'features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasCoreParameters';
import ParamSDXLPromptArea from './ParamSDXLPromptArea';
import ParamSDXLRefinerCollapse from './ParamSDXLRefinerCollapse';
export default function SDXLUnifiedCanvasTabParameters() {
return (
<>
<ParamSDXLPromptArea />
<ProcessButtons />
<UnifiedCanvasCoreParameters />
<ParamSDXLRefinerCollapse />
<ParamControlNetCollapse />
<ParamLoraCollapse />
<ParamDynamicPromptsCollapse />
<ParamNoiseCollapse />
<ParamMaskAdjustmentCollapse />
<ParamInfillAndScalingCollapse />
<ParamAdvancedCollapse />
</>
);
}

View File

@ -1,14 +1,22 @@
import { Flex } from '@chakra-ui/react'; import { Flex } from '@chakra-ui/react';
import { RootState } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import SDXLUnifiedCanvasTabParameters from 'features/sdxl/components/SDXLUnifiedCanvasTabParameters';
import { memo } from 'react'; import { memo } from 'react';
import ParametersPinnedWrapper from '../../ParametersPinnedWrapper'; import ParametersPinnedWrapper from '../../ParametersPinnedWrapper';
import UnifiedCanvasContent from './UnifiedCanvasContent'; import UnifiedCanvasContent from './UnifiedCanvasContent';
import UnifiedCanvasParameters from './UnifiedCanvasParameters'; import UnifiedCanvasParameters from './UnifiedCanvasParameters';
const UnifiedCanvasTab = () => { const UnifiedCanvasTab = () => {
const model = useAppSelector((state: RootState) => state.generation.model);
return ( return (
<Flex sx={{ gap: 4, w: 'full', h: 'full' }}> <Flex sx={{ gap: 4, w: 'full', h: 'full' }}>
<ParametersPinnedWrapper> <ParametersPinnedWrapper>
{model && model.base_model === 'sdxl' ? (
<SDXLUnifiedCanvasTabParameters />
) : (
<UnifiedCanvasParameters /> <UnifiedCanvasParameters />
)}
</ParametersPinnedWrapper> </ParametersPinnedWrapper>
<UnifiedCanvasContent /> <UnifiedCanvasContent />
</Flex> </Flex>