diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/tabChanged.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/tabChanged.ts
index 6d3e599ae2..6791324fdd 100644
--- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/tabChanged.ts
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/tabChanged.ts
@@ -12,7 +12,10 @@ export const addTabChangedListener = () => {
if (activeTabName === 'unifiedCanvas') {
const currentBaseModel = getState().generation.model?.base_model;
- if (currentBaseModel && ['sd-1', 'sd-2'].includes(currentBaseModel)) {
+ if (
+ currentBaseModel &&
+ ['sd-1', 'sd-2', 'sdxl'].includes(currentBaseModel)
+ ) {
// if we're already on a valid model, no change needed
return;
}
@@ -36,7 +39,9 @@ export const addTabChangedListener = () => {
const validCanvasModels = mainModelsAdapter
.getSelectors()
.selectAll(models)
- .filter((model) => ['sd-1', 'sd-2'].includes(model.base_model));
+ .filter((model) =>
+ ['sd-1', 'sd-2', 'sxdl'].includes(model.base_model)
+ );
const firstValidCanvasModel = validCanvasModels[0];
diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasGraph.ts
index 8a7716071f..dd0a5e6619 100644
--- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasGraph.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasGraph.ts
@@ -3,6 +3,9 @@ import { NonNullableGraph } from 'features/nodes/types/types';
import { ImageDTO } from 'services/api/types';
import { buildCanvasImageToImageGraph } from './buildCanvasImageToImageGraph';
import { buildCanvasInpaintGraph } from './buildCanvasInpaintGraph';
+import { buildCanvasSDXLImageToImageGraph } from './buildCanvasSDXLImageToImageGraph';
+import { buildCanvasSDXLInpaintGraph } from './buildCanvasSDXLInpaintGraph';
+import { buildCanvasSDXLTextToImageGraph } from './buildCanvasSDXLTextToImageGraph';
import { buildCanvasTextToImageGraph } from './buildCanvasTextToImageGraph';
export const buildCanvasGraph = (
@@ -14,17 +17,43 @@ export const buildCanvasGraph = (
let graph: NonNullableGraph;
if (generationMode === 'txt2img') {
- graph = buildCanvasTextToImageGraph(state);
+ if (
+ state.generation.model &&
+ state.generation.model.base_model === 'sdxl'
+ ) {
+ graph = buildCanvasSDXLTextToImageGraph(state);
+ } else {
+ graph = buildCanvasTextToImageGraph(state);
+ }
} else if (generationMode === 'img2img') {
if (!canvasInitImage) {
throw new Error('Missing canvas init image');
}
- graph = buildCanvasImageToImageGraph(state, canvasInitImage);
+ if (
+ state.generation.model &&
+ state.generation.model.base_model === 'sdxl'
+ ) {
+ graph = buildCanvasSDXLImageToImageGraph(state, canvasInitImage);
+ } else {
+ graph = buildCanvasImageToImageGraph(state, canvasInitImage);
+ }
} else {
if (!canvasInitImage || !canvasMaskImage) {
throw new Error('Missing canvas init and mask images');
}
- graph = buildCanvasInpaintGraph(state, canvasInitImage, canvasMaskImage);
+
+ if (
+ state.generation.model &&
+ state.generation.model.base_model === 'sdxl'
+ ) {
+ graph = buildCanvasSDXLInpaintGraph(
+ state,
+ canvasInitImage,
+ canvasMaskImage
+ );
+ } else {
+ graph = buildCanvasInpaintGraph(state, canvasInitImage, canvasMaskImage);
+ }
}
return graph;
diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLImageToImageGraph.ts
new file mode 100644
index 0000000000..b8322fd612
--- /dev/null
+++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLImageToImageGraph.ts
@@ -0,0 +1,373 @@
+import { logger } from 'app/logging/logger';
+import { RootState } from 'app/store/store';
+import { NonNullableGraph } from 'features/nodes/types/types';
+import { initialGenerationState } from 'features/parameters/store/generationSlice';
+import {
+ ImageDTO,
+ ImageResizeInvocation,
+ ImageToLatentsInvocation,
+} from 'services/api/types';
+import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
+import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
+import { addLoRAsToGraph } from './addLoRAsToGraph';
+import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
+import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
+import { addVAEToGraph } from './addVAEToGraph';
+import { addWatermarkerToGraph } from './addWatermarkerToGraph';
+import {
+ DENOISE_LATENTS,
+ IMAGE_TO_IMAGE_GRAPH,
+ IMAGE_TO_LATENTS,
+ LATENTS_TO_IMAGE,
+ METADATA_ACCUMULATOR,
+ NEGATIVE_CONDITIONING,
+ NOISE,
+ POSITIVE_CONDITIONING,
+ RESIZE,
+ SDXL_MODEL_LOADER,
+} from './constants';
+
+/**
+ * Builds the Canvas tab's Image to Image graph.
+ */
+export const buildCanvasSDXLImageToImageGraph = (
+ state: RootState,
+ initialImage: ImageDTO
+): NonNullableGraph => {
+ const log = logger('nodes');
+ const {
+ positivePrompt,
+ negativePrompt,
+ model,
+ cfgScale: cfg_scale,
+ scheduler,
+ steps,
+ clipSkip,
+ shouldUseCpuNoise,
+ shouldUseNoiseSettings,
+ } = state.generation;
+
+ const {
+ positiveStylePrompt,
+ negativeStylePrompt,
+ shouldConcatSDXLStylePrompt,
+ shouldUseSDXLRefiner,
+ refinerStart,
+ sdxlImg2ImgDenoisingStrength: strength,
+ } = state.sdxl;
+
+ // The bounding box determines width and height, not the width and height params
+ const { width, height } = state.canvas.boundingBoxDimensions;
+
+ const { shouldAutoSave } = state.canvas;
+
+ if (!model) {
+ log.error('No model found in state');
+ throw new Error('No model found in state');
+ }
+
+ const use_cpu = shouldUseNoiseSettings
+ ? shouldUseCpuNoise
+ : initialGenerationState.shouldUseCpuNoise;
+
+ /**
+ * The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
+ * full graph here as a template. Then use the parameters from app state and set friendlier node
+ * ids.
+ *
+ * The only thing we need extra logic for is handling randomized seed, control net, and for img2img,
+ * the `fit` param. These are added to the graph at the end.
+ */
+
+ // copy-pasted graph from node editor, filled in with state values & friendly node ids
+ const graph: NonNullableGraph = {
+ id: IMAGE_TO_IMAGE_GRAPH,
+ nodes: {
+ [SDXL_MODEL_LOADER]: {
+ type: 'sdxl_model_loader',
+ id: SDXL_MODEL_LOADER,
+ model,
+ },
+ [POSITIVE_CONDITIONING]: {
+ type: 'sdxl_compel_prompt',
+ id: POSITIVE_CONDITIONING,
+ prompt: positivePrompt,
+ style: shouldConcatSDXLStylePrompt
+ ? `${positivePrompt} ${positiveStylePrompt}`
+ : positiveStylePrompt,
+ },
+ [NEGATIVE_CONDITIONING]: {
+ type: 'sdxl_compel_prompt',
+ id: NEGATIVE_CONDITIONING,
+ prompt: negativePrompt,
+ style: shouldConcatSDXLStylePrompt
+ ? `${negativePrompt} ${negativeStylePrompt}`
+ : negativeStylePrompt,
+ },
+ [NOISE]: {
+ type: 'noise',
+ id: NOISE,
+ is_intermediate: true,
+ use_cpu,
+ },
+ [DENOISE_LATENTS]: {
+ type: 'denoise_latents',
+ id: DENOISE_LATENTS,
+ is_intermediate: true,
+ cfg_scale,
+ scheduler,
+ steps,
+ denoising_start: shouldUseSDXLRefiner
+ ? Math.min(refinerStart, 1 - strength)
+ : 1 - strength,
+ denoising_end: shouldUseSDXLRefiner ? refinerStart : 1,
+ },
+ [IMAGE_TO_LATENTS]: {
+ type: 'i2l',
+ id: IMAGE_TO_LATENTS,
+ is_intermediate: true,
+ // must be set manually later, bc `fit` parameter may require a resize node inserted
+ // image: {
+ // image_name: initialImage.image_name,
+ // },
+ },
+ [LATENTS_TO_IMAGE]: {
+ type: 'l2i',
+ id: LATENTS_TO_IMAGE,
+ is_intermediate: !shouldAutoSave,
+ },
+ },
+ edges: [
+ {
+ source: {
+ node_id: DENOISE_LATENTS,
+ field: 'latents',
+ },
+ destination: {
+ node_id: LATENTS_TO_IMAGE,
+ field: 'latents',
+ },
+ },
+ {
+ source: {
+ node_id: IMAGE_TO_LATENTS,
+ field: 'latents',
+ },
+ destination: {
+ node_id: DENOISE_LATENTS,
+ field: 'latents',
+ },
+ },
+ {
+ source: {
+ node_id: NOISE,
+ field: 'noise',
+ },
+ destination: {
+ node_id: DENOISE_LATENTS,
+ field: 'noise',
+ },
+ },
+ {
+ source: {
+ node_id: SDXL_MODEL_LOADER,
+ field: 'unet',
+ },
+ destination: {
+ node_id: DENOISE_LATENTS,
+ field: 'unet',
+ },
+ },
+ {
+ source: {
+ node_id: SDXL_MODEL_LOADER,
+ field: 'clip',
+ },
+ destination: {
+ node_id: POSITIVE_CONDITIONING,
+ field: 'clip',
+ },
+ },
+ {
+ source: {
+ node_id: SDXL_MODEL_LOADER,
+ field: 'clip2',
+ },
+ destination: {
+ node_id: POSITIVE_CONDITIONING,
+ field: 'clip2',
+ },
+ },
+ {
+ source: {
+ node_id: SDXL_MODEL_LOADER,
+ field: 'clip',
+ },
+ destination: {
+ node_id: NEGATIVE_CONDITIONING,
+ field: 'clip',
+ },
+ },
+ {
+ source: {
+ node_id: SDXL_MODEL_LOADER,
+ field: 'clip2',
+ },
+ destination: {
+ node_id: NEGATIVE_CONDITIONING,
+ field: 'clip2',
+ },
+ },
+ {
+ source: {
+ node_id: NEGATIVE_CONDITIONING,
+ field: 'conditioning',
+ },
+ destination: {
+ node_id: DENOISE_LATENTS,
+ field: 'negative_conditioning',
+ },
+ },
+ {
+ source: {
+ node_id: POSITIVE_CONDITIONING,
+ field: 'conditioning',
+ },
+ destination: {
+ node_id: DENOISE_LATENTS,
+ field: 'positive_conditioning',
+ },
+ },
+ ],
+ };
+
+ // handle `fit`
+ if (initialImage.width !== width || initialImage.height !== height) {
+ // The init image needs to be resized to the specified width and height before being passed to `IMAGE_TO_LATENTS`
+
+ // Create a resize node, explicitly setting its image
+ const resizeNode: ImageResizeInvocation = {
+ id: RESIZE,
+ type: 'img_resize',
+ image: {
+ image_name: initialImage.image_name,
+ },
+ is_intermediate: true,
+ width,
+ height,
+ };
+
+ graph.nodes[RESIZE] = resizeNode;
+
+ // The `RESIZE` node then passes its image to `IMAGE_TO_LATENTS`
+ graph.edges.push({
+ source: { node_id: RESIZE, field: 'image' },
+ destination: {
+ node_id: IMAGE_TO_LATENTS,
+ field: 'image',
+ },
+ });
+
+ // The `RESIZE` node also passes its width and height to `NOISE`
+ graph.edges.push({
+ source: { node_id: RESIZE, field: 'width' },
+ destination: {
+ node_id: NOISE,
+ field: 'width',
+ },
+ });
+
+ graph.edges.push({
+ source: { node_id: RESIZE, field: 'height' },
+ destination: {
+ node_id: NOISE,
+ field: 'height',
+ },
+ });
+ } else {
+ // We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly
+ (graph.nodes[IMAGE_TO_LATENTS] as ImageToLatentsInvocation).image = {
+ image_name: initialImage.image_name,
+ };
+
+ // Pass the image's dimensions to the `NOISE` node
+ graph.edges.push({
+ source: { node_id: IMAGE_TO_LATENTS, field: 'width' },
+ destination: {
+ node_id: NOISE,
+ field: 'width',
+ },
+ });
+ graph.edges.push({
+ source: { node_id: IMAGE_TO_LATENTS, field: 'height' },
+ destination: {
+ node_id: NOISE,
+ field: 'height',
+ },
+ });
+ }
+
+ // add metadata accumulator, which is only mostly populated - some fields are added later
+ graph.nodes[METADATA_ACCUMULATOR] = {
+ id: METADATA_ACCUMULATOR,
+ type: 'metadata_accumulator',
+ generation_mode: 'img2img',
+ cfg_scale,
+ height,
+ width,
+ positive_prompt: '', // set in addDynamicPromptsToGraph
+ negative_prompt: negativePrompt,
+ model,
+ seed: 0, // set in addDynamicPromptsToGraph
+ steps,
+ rand_device: use_cpu ? 'cpu' : 'cuda',
+ scheduler,
+ vae: undefined, // option; set in addVAEToGraph
+ controlnets: [], // populated in addControlNetToLinearGraph
+ loras: [], // populated in addLoRAsToGraph
+ clip_skip: clipSkip,
+ strength,
+ init_image: initialImage.image_name,
+ };
+
+ graph.edges.push({
+ source: {
+ node_id: METADATA_ACCUMULATOR,
+ field: 'metadata',
+ },
+ destination: {
+ node_id: LATENTS_TO_IMAGE,
+ field: 'metadata',
+ },
+ });
+
+ // add LoRA support
+ addLoRAsToGraph(state, graph, DENOISE_LATENTS);
+
+ // Add Refiner if enabled
+ if (shouldUseSDXLRefiner) {
+ addSDXLRefinerToGraph(state, graph, DENOISE_LATENTS);
+ }
+
+ // optionally add custom VAE
+ addVAEToGraph(state, graph, SDXL_MODEL_LOADER);
+
+ // add dynamic prompts - also sets up core iteration and seed
+ addDynamicPromptsToGraph(state, graph);
+
+ // add controlnet, mutating `graph`
+ addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
+
+ // NSFW & watermark - must be last thing added to graph
+ if (state.system.shouldUseNSFWChecker) {
+ // must add before watermarker!
+ addNSFWCheckerToGraph(state, graph);
+ }
+
+ if (state.system.shouldUseWatermarker) {
+ // must add after nsfw checker!
+ addWatermarkerToGraph(state, graph);
+ }
+
+ return graph;
+};
diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLInpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLInpaintGraph.ts
new file mode 100644
index 0000000000..04cc120cbe
--- /dev/null
+++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLInpaintGraph.ts
@@ -0,0 +1,480 @@
+import { logger } from 'app/logging/logger';
+import { RootState } from 'app/store/store';
+import { NonNullableGraph } from 'features/nodes/types/types';
+import {
+ ImageDTO,
+ InfillPatchmatchInvocation,
+ InfillTileInvocation,
+ RandomIntInvocation,
+ RangeOfSizeInvocation,
+} from 'services/api/types';
+import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
+import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
+import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
+import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
+import { addVAEToGraph } from './addVAEToGraph';
+import { addWatermarkerToGraph } from './addWatermarkerToGraph';
+import {
+ COLOR_CORRECT,
+ INPAINT,
+ INPAINT_FINAL_IMAGE,
+ INPAINT_GRAPH,
+ INPAINT_IMAGE,
+ INPAINT_INFILL,
+ ITERATE,
+ LATENTS_TO_IMAGE,
+ MASK_BLUR,
+ MASK_COMBINE,
+ MASK_FROM_ALPHA,
+ NEGATIVE_CONDITIONING,
+ NOISE,
+ POSITIVE_CONDITIONING,
+ RANDOM_INT,
+ RANGE_OF_SIZE,
+ SDXL_MODEL_LOADER,
+} from './constants';
+
+/**
+ * Builds the Canvas tab's Inpaint graph.
+ */
+export const buildCanvasSDXLInpaintGraph = (
+ state: RootState,
+ canvasInitImage: ImageDTO,
+ canvasMaskImage: ImageDTO
+): NonNullableGraph => {
+ const log = logger('nodes');
+ const {
+ positivePrompt,
+ negativePrompt,
+ model,
+ cfgScale: cfg_scale,
+ scheduler,
+ steps,
+ img2imgStrength: strength,
+ shouldFitToWidthHeight,
+ iterations,
+ seed,
+ shouldRandomizeSeed,
+ vaePrecision,
+ shouldUseNoiseSettings,
+ shouldUseCpuNoise,
+ maskBlur,
+ maskBlurMethod,
+ tileSize,
+ infillMethod,
+ } = state.generation;
+
+ const {
+ positiveStylePrompt,
+ negativeStylePrompt,
+ shouldConcatSDXLStylePrompt,
+ shouldUseSDXLRefiner,
+ refinerStart,
+ } = state.sdxl;
+
+ if (!model) {
+ log.error('No model found in state');
+ throw new Error('No model found in state');
+ }
+
+ // The bounding box determines width and height, not the width and height params
+ const { width, height } = state.canvas.boundingBoxDimensions;
+
+ // We may need to set the inpaint width and height to scale the image
+ const {
+ scaledBoundingBoxDimensions,
+ boundingBoxScaleMethod,
+ shouldAutoSave,
+ } = state.canvas;
+
+ const use_cpu = shouldUseNoiseSettings
+ ? shouldUseCpuNoise
+ : shouldUseCpuNoise;
+
+ let infillNode: InfillTileInvocation | InfillPatchmatchInvocation = {
+ type: 'infill_tile',
+ id: INPAINT_INFILL,
+ is_intermediate: true,
+ image: canvasInitImage,
+ tile_size: tileSize,
+ };
+
+ if (infillMethod === 'patchmatch') {
+ infillNode = {
+ type: 'infill_patchmatch',
+ id: INPAINT_INFILL,
+ is_intermediate: true,
+ image: canvasInitImage,
+ };
+ }
+
+ const graph: NonNullableGraph = {
+ id: INPAINT_GRAPH,
+ nodes: {
+ [INPAINT]: {
+ type: 'denoise_latents',
+ id: INPAINT,
+ is_intermediate: true,
+ steps: steps,
+ cfg_scale: cfg_scale,
+ scheduler: scheduler,
+ denoising_start: 1 - strength,
+ denoising_end: shouldUseSDXLRefiner ? refinerStart : 1,
+ },
+ [infillNode.id]: infillNode,
+ [MASK_FROM_ALPHA]: {
+ type: 'tomask',
+ id: MASK_FROM_ALPHA,
+ is_intermediate: true,
+ image: canvasInitImage,
+ },
+ [MASK_COMBINE]: {
+ type: 'mask_combine',
+ id: MASK_COMBINE,
+ is_intermediate: true,
+ mask2: canvasMaskImage,
+ },
+ [MASK_BLUR]: {
+ type: 'img_blur',
+ id: MASK_BLUR,
+ is_intermediate: true,
+ radius: maskBlur,
+ blur_type: maskBlurMethod,
+ },
+ [INPAINT_IMAGE]: {
+ type: 'i2l',
+ id: INPAINT_IMAGE,
+ is_intermediate: true,
+ fp32: vaePrecision === 'fp32' ? true : false,
+ },
+ [NOISE]: {
+ type: 'noise',
+ id: NOISE,
+ width,
+ height,
+ use_cpu,
+ is_intermediate: true,
+ },
+ [POSITIVE_CONDITIONING]: {
+ type: 'sdxl_compel_prompt',
+ id: POSITIVE_CONDITIONING,
+ prompt: positivePrompt,
+ style: shouldConcatSDXLStylePrompt
+ ? `${positivePrompt} ${positiveStylePrompt}`
+ : positiveStylePrompt,
+ },
+ [NEGATIVE_CONDITIONING]: {
+ type: 'sdxl_compel_prompt',
+ id: NEGATIVE_CONDITIONING,
+ prompt: negativePrompt,
+ style: shouldConcatSDXLStylePrompt
+ ? `${negativePrompt} ${negativeStylePrompt}`
+ : negativeStylePrompt,
+ },
+ [SDXL_MODEL_LOADER]: {
+ type: 'sdxl_model_loader',
+ id: SDXL_MODEL_LOADER,
+ model,
+ },
+ [LATENTS_TO_IMAGE]: {
+ type: 'l2i',
+ id: LATENTS_TO_IMAGE,
+ is_intermediate: true,
+ fp32: vaePrecision === 'fp32' ? true : false,
+ },
+ [COLOR_CORRECT]: {
+ type: 'color_correct',
+ id: COLOR_CORRECT,
+ is_intermediate: true,
+ },
+ [INPAINT_FINAL_IMAGE]: {
+ type: 'img_paste',
+ id: INPAINT_FINAL_IMAGE,
+ is_intermediate: true,
+ },
+ [RANGE_OF_SIZE]: {
+ type: 'range_of_size',
+ id: RANGE_OF_SIZE,
+ is_intermediate: true,
+ // seed - must be connected manually
+ // start: 0,
+ size: iterations,
+ step: 1,
+ },
+ [ITERATE]: {
+ type: 'iterate',
+ id: ITERATE,
+ is_intermediate: true,
+ },
+ },
+ edges: [
+ {
+ source: {
+ node_id: SDXL_MODEL_LOADER,
+ field: 'unet',
+ },
+ destination: {
+ node_id: INPAINT,
+ field: 'unet',
+ },
+ },
+ {
+ source: {
+ node_id: SDXL_MODEL_LOADER,
+ field: 'clip',
+ },
+ destination: {
+ node_id: POSITIVE_CONDITIONING,
+ field: 'clip',
+ },
+ },
+ {
+ source: {
+ node_id: SDXL_MODEL_LOADER,
+ field: 'clip2',
+ },
+ destination: {
+ node_id: POSITIVE_CONDITIONING,
+ field: 'clip2',
+ },
+ },
+ {
+ source: {
+ node_id: SDXL_MODEL_LOADER,
+ field: 'clip',
+ },
+ destination: {
+ node_id: NEGATIVE_CONDITIONING,
+ field: 'clip',
+ },
+ },
+ {
+ source: {
+ node_id: SDXL_MODEL_LOADER,
+ field: 'clip2',
+ },
+ destination: {
+ node_id: NEGATIVE_CONDITIONING,
+ field: 'clip2',
+ },
+ },
+ {
+ source: {
+ node_id: NEGATIVE_CONDITIONING,
+ field: 'conditioning',
+ },
+ destination: {
+ node_id: INPAINT,
+ field: 'negative_conditioning',
+ },
+ },
+ {
+ source: {
+ node_id: POSITIVE_CONDITIONING,
+ field: 'conditioning',
+ },
+ destination: {
+ node_id: INPAINT,
+ field: 'positive_conditioning',
+ },
+ },
+ {
+ source: {
+ node_id: NOISE,
+ field: 'noise',
+ },
+ destination: {
+ node_id: INPAINT,
+ field: 'noise',
+ },
+ },
+ {
+ source: {
+ node_id: INPAINT_INFILL,
+ field: 'image',
+ },
+ destination: {
+ node_id: INPAINT_IMAGE,
+ field: 'image',
+ },
+ },
+ {
+ source: {
+ node_id: INPAINT_IMAGE,
+ field: 'latents',
+ },
+ destination: {
+ node_id: INPAINT,
+ field: 'latents',
+ },
+ },
+ {
+ source: {
+ node_id: MASK_FROM_ALPHA,
+ field: 'mask',
+ },
+ destination: {
+ node_id: MASK_COMBINE,
+ field: 'mask1',
+ },
+ },
+ {
+ source: {
+ node_id: MASK_COMBINE,
+ field: 'image',
+ },
+ destination: {
+ node_id: MASK_BLUR,
+ field: 'image',
+ },
+ },
+ {
+ source: {
+ node_id: MASK_BLUR,
+ field: 'image',
+ },
+ destination: {
+ node_id: INPAINT,
+ field: 'mask',
+ },
+ },
+ {
+ source: {
+ node_id: RANGE_OF_SIZE,
+ field: 'collection',
+ },
+ destination: {
+ node_id: ITERATE,
+ field: 'collection',
+ },
+ },
+ {
+ source: {
+ node_id: ITERATE,
+ field: 'item',
+ },
+ destination: {
+ node_id: NOISE,
+ field: 'seed',
+ },
+ },
+ {
+ source: {
+ node_id: INPAINT,
+ field: 'latents',
+ },
+ destination: {
+ node_id: LATENTS_TO_IMAGE,
+ field: 'latents',
+ },
+ },
+ {
+ source: {
+ node_id: INPAINT_INFILL,
+ field: 'image',
+ },
+ destination: {
+ node_id: COLOR_CORRECT,
+ field: 'reference',
+ },
+ },
+ {
+ source: {
+ node_id: MASK_BLUR,
+ field: 'image',
+ },
+ destination: {
+ node_id: COLOR_CORRECT,
+ field: 'mask',
+ },
+ },
+ {
+ source: {
+ node_id: LATENTS_TO_IMAGE,
+ field: 'image',
+ },
+ destination: {
+ node_id: COLOR_CORRECT,
+ field: 'image',
+ },
+ },
+ {
+ source: {
+ node_id: INPAINT_INFILL,
+ field: 'image',
+ },
+ destination: {
+ node_id: INPAINT_FINAL_IMAGE,
+ field: 'base_image',
+ },
+ },
+ {
+ source: {
+ node_id: MASK_BLUR,
+ field: 'image',
+ },
+ destination: {
+ node_id: INPAINT_FINAL_IMAGE,
+ field: 'mask',
+ },
+ },
+ {
+ source: {
+ node_id: COLOR_CORRECT,
+ field: 'image',
+ },
+ destination: {
+ node_id: INPAINT_FINAL_IMAGE,
+ field: 'image',
+ },
+ },
+ ],
+ };
+
+ // Add Refiner if enabled
+ if (shouldUseSDXLRefiner) {
+ addSDXLRefinerToGraph(state, graph, INPAINT);
+ }
+
+ // Add VAE
+ addVAEToGraph(state, graph, SDXL_MODEL_LOADER);
+
+ // handle seed
+ if (shouldRandomizeSeed) {
+ // Random int node to generate the starting seed
+ const randomIntNode: RandomIntInvocation = {
+ id: RANDOM_INT,
+ type: 'rand_int',
+ };
+
+ graph.nodes[RANDOM_INT] = randomIntNode;
+
+ // Connect random int to the start of the range of size so the range starts on the random first seed
+ graph.edges.push({
+ source: { node_id: RANDOM_INT, field: 'a' },
+ destination: { node_id: RANGE_OF_SIZE, field: 'start' },
+ });
+ } else {
+ // User specified seed, so set the start of the range of size to the seed
+ (graph.nodes[RANGE_OF_SIZE] as RangeOfSizeInvocation).start = seed;
+ }
+
+ // add LoRA support
+ addSDXLLoRAsToGraph(state, graph, INPAINT, SDXL_MODEL_LOADER);
+
+ // add controlnet, mutating `graph`
+ addControlNetToLinearGraph(state, graph, INPAINT);
+
+ // NSFW & watermark - must be last thing added to graph
+ if (state.system.shouldUseNSFWChecker) {
+ // must add before watermarker!
+ addNSFWCheckerToGraph(state, graph, INPAINT);
+ }
+
+ if (state.system.shouldUseWatermarker) {
+ // must add after nsfw checker!
+ addWatermarkerToGraph(state, graph, INPAINT);
+ }
+
+ return graph;
+};
diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLTextToImageGraph.ts
new file mode 100644
index 0000000000..ed0fb74165
--- /dev/null
+++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLTextToImageGraph.ts
@@ -0,0 +1,304 @@
+import { logger } from 'app/logging/logger';
+import { RootState } from 'app/store/store';
+import { NonNullableGraph } from 'features/nodes/types/types';
+import { initialGenerationState } from 'features/parameters/store/generationSlice';
+import {
+ DenoiseLatentsInvocation,
+ ONNXTextToLatentsInvocation,
+} from 'services/api/types';
+import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
+import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
+import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
+import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
+import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
+import { addVAEToGraph } from './addVAEToGraph';
+import { addWatermarkerToGraph } from './addWatermarkerToGraph';
+import {
+ DENOISE_LATENTS,
+ LATENTS_TO_IMAGE,
+ METADATA_ACCUMULATOR,
+ NEGATIVE_CONDITIONING,
+ NOISE,
+ ONNX_MODEL_LOADER,
+ POSITIVE_CONDITIONING,
+ SDXL_MODEL_LOADER,
+ TEXT_TO_IMAGE_GRAPH,
+} from './constants';
+
+/**
+ * Builds the Canvas tab's Text to Image graph.
+ */
+export const buildCanvasSDXLTextToImageGraph = (
+ state: RootState
+): NonNullableGraph => {
+ const log = logger('nodes');
+ const {
+ positivePrompt,
+ negativePrompt,
+ model,
+ cfgScale: cfg_scale,
+ scheduler,
+ steps,
+ clipSkip,
+ shouldUseCpuNoise,
+ shouldUseNoiseSettings,
+ } = state.generation;
+
+ // The bounding box determines width and height, not the width and height params
+ const { width, height } = state.canvas.boundingBoxDimensions;
+
+ const { shouldAutoSave } = state.canvas;
+
+ const {
+ positiveStylePrompt,
+ negativeStylePrompt,
+ shouldConcatSDXLStylePrompt,
+ shouldUseSDXLRefiner,
+ refinerStart,
+ } = state.sdxl;
+
+ if (!model) {
+ log.error('No model found in state');
+ throw new Error('No model found in state');
+ }
+
+ const use_cpu = shouldUseNoiseSettings
+ ? shouldUseCpuNoise
+ : initialGenerationState.shouldUseCpuNoise;
+ const isUsingOnnxModel = model.model_type === 'onnx';
+ const modelLoaderNodeId = isUsingOnnxModel
+ ? ONNX_MODEL_LOADER
+ : SDXL_MODEL_LOADER;
+ const modelLoaderNodeType = isUsingOnnxModel
+ ? 'onnx_model_loader'
+ : 'sdxl_model_loader';
+ const t2lNode: DenoiseLatentsInvocation | ONNXTextToLatentsInvocation =
+ isUsingOnnxModel
+ ? {
+ type: 't2l_onnx',
+ id: DENOISE_LATENTS,
+ is_intermediate: true,
+ cfg_scale,
+ scheduler,
+ steps,
+ }
+ : {
+ type: 'denoise_latents',
+ id: DENOISE_LATENTS,
+ is_intermediate: true,
+ cfg_scale,
+ scheduler,
+ steps,
+ denoising_start: 0,
+ denoising_end: shouldUseSDXLRefiner ? refinerStart : 1,
+ };
+ /**
+ * The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
+ * full graph here as a template. Then use the parameters from app state and set friendlier node
+ * ids.
+ *
+ * The only thing we need extra logic for is handling randomized seed, control net, and for img2img,
+ * the `fit` param. These are added to the graph at the end.
+ */
+
+ // copy-pasted graph from node editor, filled in with state values & friendly node ids
+ // TODO: Actually create the graph correctly for ONNX
+ const graph: NonNullableGraph = {
+ id: TEXT_TO_IMAGE_GRAPH,
+ nodes: {
+ [POSITIVE_CONDITIONING]: {
+ type: isUsingOnnxModel ? 'prompt_onnx' : 'sdxl_compel_prompt',
+ id: POSITIVE_CONDITIONING,
+ is_intermediate: true,
+ prompt: positivePrompt,
+ style: shouldConcatSDXLStylePrompt
+ ? `${positivePrompt} ${positiveStylePrompt}`
+ : positiveStylePrompt,
+ },
+ [NEGATIVE_CONDITIONING]: {
+ type: isUsingOnnxModel ? 'prompt_onnx' : 'sdxl_compel_prompt',
+ id: NEGATIVE_CONDITIONING,
+ is_intermediate: true,
+ prompt: negativePrompt,
+ style: shouldConcatSDXLStylePrompt
+ ? `${negativePrompt} ${negativeStylePrompt}`
+ : negativeStylePrompt,
+ },
+ [NOISE]: {
+ type: 'noise',
+ id: NOISE,
+ is_intermediate: true,
+ width,
+ height,
+ use_cpu,
+ },
+ [t2lNode.id]: t2lNode,
+ [modelLoaderNodeId]: {
+ type: modelLoaderNodeType,
+ id: modelLoaderNodeId,
+ is_intermediate: true,
+ model,
+ },
+
+ [LATENTS_TO_IMAGE]: {
+ type: isUsingOnnxModel ? 'l2i_onnx' : 'l2i',
+ id: LATENTS_TO_IMAGE,
+ is_intermediate: !shouldAutoSave,
+ },
+ },
+ edges: [
+ {
+ source: {
+ node_id: modelLoaderNodeId,
+ field: 'unet',
+ },
+ destination: {
+ node_id: DENOISE_LATENTS,
+ field: 'unet',
+ },
+ },
+ {
+ source: {
+ node_id: modelLoaderNodeId,
+ field: 'clip',
+ },
+ destination: {
+ node_id: POSITIVE_CONDITIONING,
+ field: 'clip',
+ },
+ },
+ {
+ source: {
+ node_id: modelLoaderNodeId,
+ field: 'clip2',
+ },
+ destination: {
+ node_id: POSITIVE_CONDITIONING,
+ field: 'clip2',
+ },
+ },
+ {
+ source: {
+ node_id: modelLoaderNodeId,
+ field: 'clip',
+ },
+ destination: {
+ node_id: NEGATIVE_CONDITIONING,
+ field: 'clip',
+ },
+ },
+ {
+ source: {
+ node_id: modelLoaderNodeId,
+ field: 'clip2',
+ },
+ destination: {
+ node_id: NEGATIVE_CONDITIONING,
+ field: 'clip2',
+ },
+ },
+ {
+ source: {
+ node_id: NEGATIVE_CONDITIONING,
+ field: 'conditioning',
+ },
+ destination: {
+ node_id: DENOISE_LATENTS,
+ field: 'negative_conditioning',
+ },
+ },
+ {
+ source: {
+ node_id: POSITIVE_CONDITIONING,
+ field: 'conditioning',
+ },
+ destination: {
+ node_id: DENOISE_LATENTS,
+ field: 'positive_conditioning',
+ },
+ },
+ {
+ source: {
+ node_id: DENOISE_LATENTS,
+ field: 'latents',
+ },
+ destination: {
+ node_id: LATENTS_TO_IMAGE,
+ field: 'latents',
+ },
+ },
+ {
+ source: {
+ node_id: NOISE,
+ field: 'noise',
+ },
+ destination: {
+ node_id: DENOISE_LATENTS,
+ field: 'noise',
+ },
+ },
+ ],
+ };
+
+ // add metadata accumulator, which is only mostly populated - some fields are added later
+ graph.nodes[METADATA_ACCUMULATOR] = {
+ id: METADATA_ACCUMULATOR,
+ type: 'metadata_accumulator',
+ generation_mode: 'txt2img',
+ cfg_scale,
+ height,
+ width,
+ positive_prompt: '', // set in addDynamicPromptsToGraph
+ negative_prompt: negativePrompt,
+ model,
+ seed: 0, // set in addDynamicPromptsToGraph
+ steps,
+ rand_device: use_cpu ? 'cpu' : 'cuda',
+ scheduler,
+ vae: undefined, // option; set in addVAEToGraph
+ controlnets: [], // populated in addControlNetToLinearGraph
+ loras: [], // populated in addLoRAsToGraph
+ clip_skip: clipSkip,
+ };
+
+ graph.edges.push({
+ source: {
+ node_id: METADATA_ACCUMULATOR,
+ field: 'metadata',
+ },
+ destination: {
+ node_id: LATENTS_TO_IMAGE,
+ field: 'metadata',
+ },
+ });
+
+ // Add Refiner if enabled
+ if (shouldUseSDXLRefiner) {
+ addSDXLRefinerToGraph(state, graph, DENOISE_LATENTS);
+ }
+
+ // add LoRA support
+ addSDXLLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
+
+ // optionally add custom VAE
+ addVAEToGraph(state, graph, modelLoaderNodeId);
+
+ // add dynamic prompts - also sets up core iteration and seed
+ addDynamicPromptsToGraph(state, graph);
+
+ // add controlnet, mutating `graph`
+ addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
+
+ // NSFW & watermark - must be last thing added to graph
+ if (state.system.shouldUseNSFWChecker) {
+ // must add before watermarker!
+ addNSFWCheckerToGraph(state, graph);
+ }
+
+ if (state.system.shouldUseWatermarker) {
+ // must add after nsfw checker!
+ addWatermarkerToGraph(state, graph);
+ }
+
+ return graph;
+};
diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/MainModel/ParamMainModelSelect.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/MainModel/ParamMainModelSelect.tsx
index 0a18d4f556..05b5b6468a 100644
--- a/invokeai/frontend/web/src/features/parameters/components/Parameters/MainModel/ParamMainModelSelect.tsx
+++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/MainModel/ParamMainModelSelect.tsx
@@ -15,11 +15,11 @@ import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainM
import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { forEach } from 'lodash-es';
+import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
import {
useGetMainModelsQuery,
useGetOnnxModelsQuery,
} from 'services/api/endpoints/models';
-import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
import { useFeatureStatus } from '../../../../system/hooks/useFeatureStatus';
const selector = createSelector(
@@ -52,10 +52,7 @@ const ParamMainModelSelect = () => {
const data: SelectItem[] = [];
forEach(mainModels.entities, (model, id) => {
- if (
- !model ||
- (activeTabName === 'unifiedCanvas' && model.base_model === 'sdxl')
- ) {
+ if (!model) {
return;
}
diff --git a/invokeai/frontend/web/src/features/sdxl/components/SDXLUnifiedCanvasTabParameters.tsx b/invokeai/frontend/web/src/features/sdxl/components/SDXLUnifiedCanvasTabParameters.tsx
new file mode 100644
index 0000000000..270e839894
--- /dev/null
+++ b/invokeai/frontend/web/src/features/sdxl/components/SDXLUnifiedCanvasTabParameters.tsx
@@ -0,0 +1,29 @@
+import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse';
+import ParamLoraCollapse from 'features/lora/components/ParamLoraCollapse';
+import ParamAdvancedCollapse from 'features/parameters/components/Parameters/Advanced/ParamAdvancedCollapse';
+import ParamInfillAndScalingCollapse from 'features/parameters/components/Parameters/Canvas/InfillAndScaling/ParamInfillAndScalingCollapse';
+import ParamMaskAdjustmentCollapse from 'features/parameters/components/Parameters/Canvas/MaskAdjustment/ParamMaskAdjustmentCollapse';
+import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
+import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse';
+import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
+import UnifiedCanvasCoreParameters from 'features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasCoreParameters';
+import ParamSDXLPromptArea from './ParamSDXLPromptArea';
+import ParamSDXLRefinerCollapse from './ParamSDXLRefinerCollapse';
+
+export default function SDXLUnifiedCanvasTabParameters() {
+ return (
+ <>
+
+
+
+
+
+
+
+
+
+
+
+ >
+ );
+}
diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasTab.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasTab.tsx
index 4c36c45e13..0a5b872e4b 100644
--- a/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasTab.tsx
+++ b/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasTab.tsx
@@ -1,14 +1,22 @@
import { Flex } from '@chakra-ui/react';
+import { RootState } from 'app/store/store';
+import { useAppSelector } from 'app/store/storeHooks';
+import SDXLUnifiedCanvasTabParameters from 'features/sdxl/components/SDXLUnifiedCanvasTabParameters';
import { memo } from 'react';
import ParametersPinnedWrapper from '../../ParametersPinnedWrapper';
import UnifiedCanvasContent from './UnifiedCanvasContent';
import UnifiedCanvasParameters from './UnifiedCanvasParameters';
const UnifiedCanvasTab = () => {
+ const model = useAppSelector((state: RootState) => state.generation.model);
return (
-
+ {model && model.base_model === 'sdxl' ? (
+
+ ) : (
+
+ )}