diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/tabChanged.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/tabChanged.ts index 6d3e599ae2..6791324fdd 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/tabChanged.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/tabChanged.ts @@ -12,7 +12,10 @@ export const addTabChangedListener = () => { if (activeTabName === 'unifiedCanvas') { const currentBaseModel = getState().generation.model?.base_model; - if (currentBaseModel && ['sd-1', 'sd-2'].includes(currentBaseModel)) { + if ( + currentBaseModel && + ['sd-1', 'sd-2', 'sdxl'].includes(currentBaseModel) + ) { // if we're already on a valid model, no change needed return; } @@ -36,7 +39,9 @@ export const addTabChangedListener = () => { const validCanvasModels = mainModelsAdapter .getSelectors() .selectAll(models) - .filter((model) => ['sd-1', 'sd-2'].includes(model.base_model)); + .filter((model) => + ['sd-1', 'sd-2', 'sxdl'].includes(model.base_model) + ); const firstValidCanvasModel = validCanvasModels[0]; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasGraph.ts index 8a7716071f..dd0a5e6619 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasGraph.ts @@ -3,6 +3,9 @@ import { NonNullableGraph } from 'features/nodes/types/types'; import { ImageDTO } from 'services/api/types'; import { buildCanvasImageToImageGraph } from './buildCanvasImageToImageGraph'; import { buildCanvasInpaintGraph } from './buildCanvasInpaintGraph'; +import { buildCanvasSDXLImageToImageGraph } from './buildCanvasSDXLImageToImageGraph'; +import { buildCanvasSDXLInpaintGraph } from './buildCanvasSDXLInpaintGraph'; +import { buildCanvasSDXLTextToImageGraph } from './buildCanvasSDXLTextToImageGraph'; import { buildCanvasTextToImageGraph } from './buildCanvasTextToImageGraph'; export const buildCanvasGraph = ( @@ -14,17 +17,43 @@ export const buildCanvasGraph = ( let graph: NonNullableGraph; if (generationMode === 'txt2img') { - graph = buildCanvasTextToImageGraph(state); + if ( + state.generation.model && + state.generation.model.base_model === 'sdxl' + ) { + graph = buildCanvasSDXLTextToImageGraph(state); + } else { + graph = buildCanvasTextToImageGraph(state); + } } else if (generationMode === 'img2img') { if (!canvasInitImage) { throw new Error('Missing canvas init image'); } - graph = buildCanvasImageToImageGraph(state, canvasInitImage); + if ( + state.generation.model && + state.generation.model.base_model === 'sdxl' + ) { + graph = buildCanvasSDXLImageToImageGraph(state, canvasInitImage); + } else { + graph = buildCanvasImageToImageGraph(state, canvasInitImage); + } } else { if (!canvasInitImage || !canvasMaskImage) { throw new Error('Missing canvas init and mask images'); } - graph = buildCanvasInpaintGraph(state, canvasInitImage, canvasMaskImage); + + if ( + state.generation.model && + state.generation.model.base_model === 'sdxl' + ) { + graph = buildCanvasSDXLInpaintGraph( + state, + canvasInitImage, + canvasMaskImage + ); + } else { + graph = buildCanvasInpaintGraph(state, canvasInitImage, canvasMaskImage); + } } return graph; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLImageToImageGraph.ts new file mode 100644 index 0000000000..b8322fd612 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLImageToImageGraph.ts @@ -0,0 +1,373 @@ +import { logger } from 'app/logging/logger'; +import { RootState } from 'app/store/store'; +import { NonNullableGraph } from 'features/nodes/types/types'; +import { initialGenerationState } from 'features/parameters/store/generationSlice'; +import { + ImageDTO, + ImageResizeInvocation, + ImageToLatentsInvocation, +} from 'services/api/types'; +import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; +import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; +import { addLoRAsToGraph } from './addLoRAsToGraph'; +import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; +import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph'; +import { addVAEToGraph } from './addVAEToGraph'; +import { addWatermarkerToGraph } from './addWatermarkerToGraph'; +import { + DENOISE_LATENTS, + IMAGE_TO_IMAGE_GRAPH, + IMAGE_TO_LATENTS, + LATENTS_TO_IMAGE, + METADATA_ACCUMULATOR, + NEGATIVE_CONDITIONING, + NOISE, + POSITIVE_CONDITIONING, + RESIZE, + SDXL_MODEL_LOADER, +} from './constants'; + +/** + * Builds the Canvas tab's Image to Image graph. + */ +export const buildCanvasSDXLImageToImageGraph = ( + state: RootState, + initialImage: ImageDTO +): NonNullableGraph => { + const log = logger('nodes'); + const { + positivePrompt, + negativePrompt, + model, + cfgScale: cfg_scale, + scheduler, + steps, + clipSkip, + shouldUseCpuNoise, + shouldUseNoiseSettings, + } = state.generation; + + const { + positiveStylePrompt, + negativeStylePrompt, + shouldConcatSDXLStylePrompt, + shouldUseSDXLRefiner, + refinerStart, + sdxlImg2ImgDenoisingStrength: strength, + } = state.sdxl; + + // The bounding box determines width and height, not the width and height params + const { width, height } = state.canvas.boundingBoxDimensions; + + const { shouldAutoSave } = state.canvas; + + if (!model) { + log.error('No model found in state'); + throw new Error('No model found in state'); + } + + const use_cpu = shouldUseNoiseSettings + ? shouldUseCpuNoise + : initialGenerationState.shouldUseCpuNoise; + + /** + * The easiest way to build linear graphs is to do it in the node editor, then copy and paste the + * full graph here as a template. Then use the parameters from app state and set friendlier node + * ids. + * + * The only thing we need extra logic for is handling randomized seed, control net, and for img2img, + * the `fit` param. These are added to the graph at the end. + */ + + // copy-pasted graph from node editor, filled in with state values & friendly node ids + const graph: NonNullableGraph = { + id: IMAGE_TO_IMAGE_GRAPH, + nodes: { + [SDXL_MODEL_LOADER]: { + type: 'sdxl_model_loader', + id: SDXL_MODEL_LOADER, + model, + }, + [POSITIVE_CONDITIONING]: { + type: 'sdxl_compel_prompt', + id: POSITIVE_CONDITIONING, + prompt: positivePrompt, + style: shouldConcatSDXLStylePrompt + ? `${positivePrompt} ${positiveStylePrompt}` + : positiveStylePrompt, + }, + [NEGATIVE_CONDITIONING]: { + type: 'sdxl_compel_prompt', + id: NEGATIVE_CONDITIONING, + prompt: negativePrompt, + style: shouldConcatSDXLStylePrompt + ? `${negativePrompt} ${negativeStylePrompt}` + : negativeStylePrompt, + }, + [NOISE]: { + type: 'noise', + id: NOISE, + is_intermediate: true, + use_cpu, + }, + [DENOISE_LATENTS]: { + type: 'denoise_latents', + id: DENOISE_LATENTS, + is_intermediate: true, + cfg_scale, + scheduler, + steps, + denoising_start: shouldUseSDXLRefiner + ? Math.min(refinerStart, 1 - strength) + : 1 - strength, + denoising_end: shouldUseSDXLRefiner ? refinerStart : 1, + }, + [IMAGE_TO_LATENTS]: { + type: 'i2l', + id: IMAGE_TO_LATENTS, + is_intermediate: true, + // must be set manually later, bc `fit` parameter may require a resize node inserted + // image: { + // image_name: initialImage.image_name, + // }, + }, + [LATENTS_TO_IMAGE]: { + type: 'l2i', + id: LATENTS_TO_IMAGE, + is_intermediate: !shouldAutoSave, + }, + }, + edges: [ + { + source: { + node_id: DENOISE_LATENTS, + field: 'latents', + }, + destination: { + node_id: LATENTS_TO_IMAGE, + field: 'latents', + }, + }, + { + source: { + node_id: IMAGE_TO_LATENTS, + field: 'latents', + }, + destination: { + node_id: DENOISE_LATENTS, + field: 'latents', + }, + }, + { + source: { + node_id: NOISE, + field: 'noise', + }, + destination: { + node_id: DENOISE_LATENTS, + field: 'noise', + }, + }, + { + source: { + node_id: SDXL_MODEL_LOADER, + field: 'unet', + }, + destination: { + node_id: DENOISE_LATENTS, + field: 'unet', + }, + }, + { + source: { + node_id: SDXL_MODEL_LOADER, + field: 'clip', + }, + destination: { + node_id: POSITIVE_CONDITIONING, + field: 'clip', + }, + }, + { + source: { + node_id: SDXL_MODEL_LOADER, + field: 'clip2', + }, + destination: { + node_id: POSITIVE_CONDITIONING, + field: 'clip2', + }, + }, + { + source: { + node_id: SDXL_MODEL_LOADER, + field: 'clip', + }, + destination: { + node_id: NEGATIVE_CONDITIONING, + field: 'clip', + }, + }, + { + source: { + node_id: SDXL_MODEL_LOADER, + field: 'clip2', + }, + destination: { + node_id: NEGATIVE_CONDITIONING, + field: 'clip2', + }, + }, + { + source: { + node_id: NEGATIVE_CONDITIONING, + field: 'conditioning', + }, + destination: { + node_id: DENOISE_LATENTS, + field: 'negative_conditioning', + }, + }, + { + source: { + node_id: POSITIVE_CONDITIONING, + field: 'conditioning', + }, + destination: { + node_id: DENOISE_LATENTS, + field: 'positive_conditioning', + }, + }, + ], + }; + + // handle `fit` + if (initialImage.width !== width || initialImage.height !== height) { + // The init image needs to be resized to the specified width and height before being passed to `IMAGE_TO_LATENTS` + + // Create a resize node, explicitly setting its image + const resizeNode: ImageResizeInvocation = { + id: RESIZE, + type: 'img_resize', + image: { + image_name: initialImage.image_name, + }, + is_intermediate: true, + width, + height, + }; + + graph.nodes[RESIZE] = resizeNode; + + // The `RESIZE` node then passes its image to `IMAGE_TO_LATENTS` + graph.edges.push({ + source: { node_id: RESIZE, field: 'image' }, + destination: { + node_id: IMAGE_TO_LATENTS, + field: 'image', + }, + }); + + // The `RESIZE` node also passes its width and height to `NOISE` + graph.edges.push({ + source: { node_id: RESIZE, field: 'width' }, + destination: { + node_id: NOISE, + field: 'width', + }, + }); + + graph.edges.push({ + source: { node_id: RESIZE, field: 'height' }, + destination: { + node_id: NOISE, + field: 'height', + }, + }); + } else { + // We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly + (graph.nodes[IMAGE_TO_LATENTS] as ImageToLatentsInvocation).image = { + image_name: initialImage.image_name, + }; + + // Pass the image's dimensions to the `NOISE` node + graph.edges.push({ + source: { node_id: IMAGE_TO_LATENTS, field: 'width' }, + destination: { + node_id: NOISE, + field: 'width', + }, + }); + graph.edges.push({ + source: { node_id: IMAGE_TO_LATENTS, field: 'height' }, + destination: { + node_id: NOISE, + field: 'height', + }, + }); + } + + // add metadata accumulator, which is only mostly populated - some fields are added later + graph.nodes[METADATA_ACCUMULATOR] = { + id: METADATA_ACCUMULATOR, + type: 'metadata_accumulator', + generation_mode: 'img2img', + cfg_scale, + height, + width, + positive_prompt: '', // set in addDynamicPromptsToGraph + negative_prompt: negativePrompt, + model, + seed: 0, // set in addDynamicPromptsToGraph + steps, + rand_device: use_cpu ? 'cpu' : 'cuda', + scheduler, + vae: undefined, // option; set in addVAEToGraph + controlnets: [], // populated in addControlNetToLinearGraph + loras: [], // populated in addLoRAsToGraph + clip_skip: clipSkip, + strength, + init_image: initialImage.image_name, + }; + + graph.edges.push({ + source: { + node_id: METADATA_ACCUMULATOR, + field: 'metadata', + }, + destination: { + node_id: LATENTS_TO_IMAGE, + field: 'metadata', + }, + }); + + // add LoRA support + addLoRAsToGraph(state, graph, DENOISE_LATENTS); + + // Add Refiner if enabled + if (shouldUseSDXLRefiner) { + addSDXLRefinerToGraph(state, graph, DENOISE_LATENTS); + } + + // optionally add custom VAE + addVAEToGraph(state, graph, SDXL_MODEL_LOADER); + + // add dynamic prompts - also sets up core iteration and seed + addDynamicPromptsToGraph(state, graph); + + // add controlnet, mutating `graph` + addControlNetToLinearGraph(state, graph, DENOISE_LATENTS); + + // NSFW & watermark - must be last thing added to graph + if (state.system.shouldUseNSFWChecker) { + // must add before watermarker! + addNSFWCheckerToGraph(state, graph); + } + + if (state.system.shouldUseWatermarker) { + // must add after nsfw checker! + addWatermarkerToGraph(state, graph); + } + + return graph; +}; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLInpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLInpaintGraph.ts new file mode 100644 index 0000000000..04cc120cbe --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLInpaintGraph.ts @@ -0,0 +1,480 @@ +import { logger } from 'app/logging/logger'; +import { RootState } from 'app/store/store'; +import { NonNullableGraph } from 'features/nodes/types/types'; +import { + ImageDTO, + InfillPatchmatchInvocation, + InfillTileInvocation, + RandomIntInvocation, + RangeOfSizeInvocation, +} from 'services/api/types'; +import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; +import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; +import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph'; +import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph'; +import { addVAEToGraph } from './addVAEToGraph'; +import { addWatermarkerToGraph } from './addWatermarkerToGraph'; +import { + COLOR_CORRECT, + INPAINT, + INPAINT_FINAL_IMAGE, + INPAINT_GRAPH, + INPAINT_IMAGE, + INPAINT_INFILL, + ITERATE, + LATENTS_TO_IMAGE, + MASK_BLUR, + MASK_COMBINE, + MASK_FROM_ALPHA, + NEGATIVE_CONDITIONING, + NOISE, + POSITIVE_CONDITIONING, + RANDOM_INT, + RANGE_OF_SIZE, + SDXL_MODEL_LOADER, +} from './constants'; + +/** + * Builds the Canvas tab's Inpaint graph. + */ +export const buildCanvasSDXLInpaintGraph = ( + state: RootState, + canvasInitImage: ImageDTO, + canvasMaskImage: ImageDTO +): NonNullableGraph => { + const log = logger('nodes'); + const { + positivePrompt, + negativePrompt, + model, + cfgScale: cfg_scale, + scheduler, + steps, + img2imgStrength: strength, + shouldFitToWidthHeight, + iterations, + seed, + shouldRandomizeSeed, + vaePrecision, + shouldUseNoiseSettings, + shouldUseCpuNoise, + maskBlur, + maskBlurMethod, + tileSize, + infillMethod, + } = state.generation; + + const { + positiveStylePrompt, + negativeStylePrompt, + shouldConcatSDXLStylePrompt, + shouldUseSDXLRefiner, + refinerStart, + } = state.sdxl; + + if (!model) { + log.error('No model found in state'); + throw new Error('No model found in state'); + } + + // The bounding box determines width and height, not the width and height params + const { width, height } = state.canvas.boundingBoxDimensions; + + // We may need to set the inpaint width and height to scale the image + const { + scaledBoundingBoxDimensions, + boundingBoxScaleMethod, + shouldAutoSave, + } = state.canvas; + + const use_cpu = shouldUseNoiseSettings + ? shouldUseCpuNoise + : shouldUseCpuNoise; + + let infillNode: InfillTileInvocation | InfillPatchmatchInvocation = { + type: 'infill_tile', + id: INPAINT_INFILL, + is_intermediate: true, + image: canvasInitImage, + tile_size: tileSize, + }; + + if (infillMethod === 'patchmatch') { + infillNode = { + type: 'infill_patchmatch', + id: INPAINT_INFILL, + is_intermediate: true, + image: canvasInitImage, + }; + } + + const graph: NonNullableGraph = { + id: INPAINT_GRAPH, + nodes: { + [INPAINT]: { + type: 'denoise_latents', + id: INPAINT, + is_intermediate: true, + steps: steps, + cfg_scale: cfg_scale, + scheduler: scheduler, + denoising_start: 1 - strength, + denoising_end: shouldUseSDXLRefiner ? refinerStart : 1, + }, + [infillNode.id]: infillNode, + [MASK_FROM_ALPHA]: { + type: 'tomask', + id: MASK_FROM_ALPHA, + is_intermediate: true, + image: canvasInitImage, + }, + [MASK_COMBINE]: { + type: 'mask_combine', + id: MASK_COMBINE, + is_intermediate: true, + mask2: canvasMaskImage, + }, + [MASK_BLUR]: { + type: 'img_blur', + id: MASK_BLUR, + is_intermediate: true, + radius: maskBlur, + blur_type: maskBlurMethod, + }, + [INPAINT_IMAGE]: { + type: 'i2l', + id: INPAINT_IMAGE, + is_intermediate: true, + fp32: vaePrecision === 'fp32' ? true : false, + }, + [NOISE]: { + type: 'noise', + id: NOISE, + width, + height, + use_cpu, + is_intermediate: true, + }, + [POSITIVE_CONDITIONING]: { + type: 'sdxl_compel_prompt', + id: POSITIVE_CONDITIONING, + prompt: positivePrompt, + style: shouldConcatSDXLStylePrompt + ? `${positivePrompt} ${positiveStylePrompt}` + : positiveStylePrompt, + }, + [NEGATIVE_CONDITIONING]: { + type: 'sdxl_compel_prompt', + id: NEGATIVE_CONDITIONING, + prompt: negativePrompt, + style: shouldConcatSDXLStylePrompt + ? `${negativePrompt} ${negativeStylePrompt}` + : negativeStylePrompt, + }, + [SDXL_MODEL_LOADER]: { + type: 'sdxl_model_loader', + id: SDXL_MODEL_LOADER, + model, + }, + [LATENTS_TO_IMAGE]: { + type: 'l2i', + id: LATENTS_TO_IMAGE, + is_intermediate: true, + fp32: vaePrecision === 'fp32' ? true : false, + }, + [COLOR_CORRECT]: { + type: 'color_correct', + id: COLOR_CORRECT, + is_intermediate: true, + }, + [INPAINT_FINAL_IMAGE]: { + type: 'img_paste', + id: INPAINT_FINAL_IMAGE, + is_intermediate: true, + }, + [RANGE_OF_SIZE]: { + type: 'range_of_size', + id: RANGE_OF_SIZE, + is_intermediate: true, + // seed - must be connected manually + // start: 0, + size: iterations, + step: 1, + }, + [ITERATE]: { + type: 'iterate', + id: ITERATE, + is_intermediate: true, + }, + }, + edges: [ + { + source: { + node_id: SDXL_MODEL_LOADER, + field: 'unet', + }, + destination: { + node_id: INPAINT, + field: 'unet', + }, + }, + { + source: { + node_id: SDXL_MODEL_LOADER, + field: 'clip', + }, + destination: { + node_id: POSITIVE_CONDITIONING, + field: 'clip', + }, + }, + { + source: { + node_id: SDXL_MODEL_LOADER, + field: 'clip2', + }, + destination: { + node_id: POSITIVE_CONDITIONING, + field: 'clip2', + }, + }, + { + source: { + node_id: SDXL_MODEL_LOADER, + field: 'clip', + }, + destination: { + node_id: NEGATIVE_CONDITIONING, + field: 'clip', + }, + }, + { + source: { + node_id: SDXL_MODEL_LOADER, + field: 'clip2', + }, + destination: { + node_id: NEGATIVE_CONDITIONING, + field: 'clip2', + }, + }, + { + source: { + node_id: NEGATIVE_CONDITIONING, + field: 'conditioning', + }, + destination: { + node_id: INPAINT, + field: 'negative_conditioning', + }, + }, + { + source: { + node_id: POSITIVE_CONDITIONING, + field: 'conditioning', + }, + destination: { + node_id: INPAINT, + field: 'positive_conditioning', + }, + }, + { + source: { + node_id: NOISE, + field: 'noise', + }, + destination: { + node_id: INPAINT, + field: 'noise', + }, + }, + { + source: { + node_id: INPAINT_INFILL, + field: 'image', + }, + destination: { + node_id: INPAINT_IMAGE, + field: 'image', + }, + }, + { + source: { + node_id: INPAINT_IMAGE, + field: 'latents', + }, + destination: { + node_id: INPAINT, + field: 'latents', + }, + }, + { + source: { + node_id: MASK_FROM_ALPHA, + field: 'mask', + }, + destination: { + node_id: MASK_COMBINE, + field: 'mask1', + }, + }, + { + source: { + node_id: MASK_COMBINE, + field: 'image', + }, + destination: { + node_id: MASK_BLUR, + field: 'image', + }, + }, + { + source: { + node_id: MASK_BLUR, + field: 'image', + }, + destination: { + node_id: INPAINT, + field: 'mask', + }, + }, + { + source: { + node_id: RANGE_OF_SIZE, + field: 'collection', + }, + destination: { + node_id: ITERATE, + field: 'collection', + }, + }, + { + source: { + node_id: ITERATE, + field: 'item', + }, + destination: { + node_id: NOISE, + field: 'seed', + }, + }, + { + source: { + node_id: INPAINT, + field: 'latents', + }, + destination: { + node_id: LATENTS_TO_IMAGE, + field: 'latents', + }, + }, + { + source: { + node_id: INPAINT_INFILL, + field: 'image', + }, + destination: { + node_id: COLOR_CORRECT, + field: 'reference', + }, + }, + { + source: { + node_id: MASK_BLUR, + field: 'image', + }, + destination: { + node_id: COLOR_CORRECT, + field: 'mask', + }, + }, + { + source: { + node_id: LATENTS_TO_IMAGE, + field: 'image', + }, + destination: { + node_id: COLOR_CORRECT, + field: 'image', + }, + }, + { + source: { + node_id: INPAINT_INFILL, + field: 'image', + }, + destination: { + node_id: INPAINT_FINAL_IMAGE, + field: 'base_image', + }, + }, + { + source: { + node_id: MASK_BLUR, + field: 'image', + }, + destination: { + node_id: INPAINT_FINAL_IMAGE, + field: 'mask', + }, + }, + { + source: { + node_id: COLOR_CORRECT, + field: 'image', + }, + destination: { + node_id: INPAINT_FINAL_IMAGE, + field: 'image', + }, + }, + ], + }; + + // Add Refiner if enabled + if (shouldUseSDXLRefiner) { + addSDXLRefinerToGraph(state, graph, INPAINT); + } + + // Add VAE + addVAEToGraph(state, graph, SDXL_MODEL_LOADER); + + // handle seed + if (shouldRandomizeSeed) { + // Random int node to generate the starting seed + const randomIntNode: RandomIntInvocation = { + id: RANDOM_INT, + type: 'rand_int', + }; + + graph.nodes[RANDOM_INT] = randomIntNode; + + // Connect random int to the start of the range of size so the range starts on the random first seed + graph.edges.push({ + source: { node_id: RANDOM_INT, field: 'a' }, + destination: { node_id: RANGE_OF_SIZE, field: 'start' }, + }); + } else { + // User specified seed, so set the start of the range of size to the seed + (graph.nodes[RANGE_OF_SIZE] as RangeOfSizeInvocation).start = seed; + } + + // add LoRA support + addSDXLLoRAsToGraph(state, graph, INPAINT, SDXL_MODEL_LOADER); + + // add controlnet, mutating `graph` + addControlNetToLinearGraph(state, graph, INPAINT); + + // NSFW & watermark - must be last thing added to graph + if (state.system.shouldUseNSFWChecker) { + // must add before watermarker! + addNSFWCheckerToGraph(state, graph, INPAINT); + } + + if (state.system.shouldUseWatermarker) { + // must add after nsfw checker! + addWatermarkerToGraph(state, graph, INPAINT); + } + + return graph; +}; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLTextToImageGraph.ts new file mode 100644 index 0000000000..ed0fb74165 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLTextToImageGraph.ts @@ -0,0 +1,304 @@ +import { logger } from 'app/logging/logger'; +import { RootState } from 'app/store/store'; +import { NonNullableGraph } from 'features/nodes/types/types'; +import { initialGenerationState } from 'features/parameters/store/generationSlice'; +import { + DenoiseLatentsInvocation, + ONNXTextToLatentsInvocation, +} from 'services/api/types'; +import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; +import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; +import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; +import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph'; +import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph'; +import { addVAEToGraph } from './addVAEToGraph'; +import { addWatermarkerToGraph } from './addWatermarkerToGraph'; +import { + DENOISE_LATENTS, + LATENTS_TO_IMAGE, + METADATA_ACCUMULATOR, + NEGATIVE_CONDITIONING, + NOISE, + ONNX_MODEL_LOADER, + POSITIVE_CONDITIONING, + SDXL_MODEL_LOADER, + TEXT_TO_IMAGE_GRAPH, +} from './constants'; + +/** + * Builds the Canvas tab's Text to Image graph. + */ +export const buildCanvasSDXLTextToImageGraph = ( + state: RootState +): NonNullableGraph => { + const log = logger('nodes'); + const { + positivePrompt, + negativePrompt, + model, + cfgScale: cfg_scale, + scheduler, + steps, + clipSkip, + shouldUseCpuNoise, + shouldUseNoiseSettings, + } = state.generation; + + // The bounding box determines width and height, not the width and height params + const { width, height } = state.canvas.boundingBoxDimensions; + + const { shouldAutoSave } = state.canvas; + + const { + positiveStylePrompt, + negativeStylePrompt, + shouldConcatSDXLStylePrompt, + shouldUseSDXLRefiner, + refinerStart, + } = state.sdxl; + + if (!model) { + log.error('No model found in state'); + throw new Error('No model found in state'); + } + + const use_cpu = shouldUseNoiseSettings + ? shouldUseCpuNoise + : initialGenerationState.shouldUseCpuNoise; + const isUsingOnnxModel = model.model_type === 'onnx'; + const modelLoaderNodeId = isUsingOnnxModel + ? ONNX_MODEL_LOADER + : SDXL_MODEL_LOADER; + const modelLoaderNodeType = isUsingOnnxModel + ? 'onnx_model_loader' + : 'sdxl_model_loader'; + const t2lNode: DenoiseLatentsInvocation | ONNXTextToLatentsInvocation = + isUsingOnnxModel + ? { + type: 't2l_onnx', + id: DENOISE_LATENTS, + is_intermediate: true, + cfg_scale, + scheduler, + steps, + } + : { + type: 'denoise_latents', + id: DENOISE_LATENTS, + is_intermediate: true, + cfg_scale, + scheduler, + steps, + denoising_start: 0, + denoising_end: shouldUseSDXLRefiner ? refinerStart : 1, + }; + /** + * The easiest way to build linear graphs is to do it in the node editor, then copy and paste the + * full graph here as a template. Then use the parameters from app state and set friendlier node + * ids. + * + * The only thing we need extra logic for is handling randomized seed, control net, and for img2img, + * the `fit` param. These are added to the graph at the end. + */ + + // copy-pasted graph from node editor, filled in with state values & friendly node ids + // TODO: Actually create the graph correctly for ONNX + const graph: NonNullableGraph = { + id: TEXT_TO_IMAGE_GRAPH, + nodes: { + [POSITIVE_CONDITIONING]: { + type: isUsingOnnxModel ? 'prompt_onnx' : 'sdxl_compel_prompt', + id: POSITIVE_CONDITIONING, + is_intermediate: true, + prompt: positivePrompt, + style: shouldConcatSDXLStylePrompt + ? `${positivePrompt} ${positiveStylePrompt}` + : positiveStylePrompt, + }, + [NEGATIVE_CONDITIONING]: { + type: isUsingOnnxModel ? 'prompt_onnx' : 'sdxl_compel_prompt', + id: NEGATIVE_CONDITIONING, + is_intermediate: true, + prompt: negativePrompt, + style: shouldConcatSDXLStylePrompt + ? `${negativePrompt} ${negativeStylePrompt}` + : negativeStylePrompt, + }, + [NOISE]: { + type: 'noise', + id: NOISE, + is_intermediate: true, + width, + height, + use_cpu, + }, + [t2lNode.id]: t2lNode, + [modelLoaderNodeId]: { + type: modelLoaderNodeType, + id: modelLoaderNodeId, + is_intermediate: true, + model, + }, + + [LATENTS_TO_IMAGE]: { + type: isUsingOnnxModel ? 'l2i_onnx' : 'l2i', + id: LATENTS_TO_IMAGE, + is_intermediate: !shouldAutoSave, + }, + }, + edges: [ + { + source: { + node_id: modelLoaderNodeId, + field: 'unet', + }, + destination: { + node_id: DENOISE_LATENTS, + field: 'unet', + }, + }, + { + source: { + node_id: modelLoaderNodeId, + field: 'clip', + }, + destination: { + node_id: POSITIVE_CONDITIONING, + field: 'clip', + }, + }, + { + source: { + node_id: modelLoaderNodeId, + field: 'clip2', + }, + destination: { + node_id: POSITIVE_CONDITIONING, + field: 'clip2', + }, + }, + { + source: { + node_id: modelLoaderNodeId, + field: 'clip', + }, + destination: { + node_id: NEGATIVE_CONDITIONING, + field: 'clip', + }, + }, + { + source: { + node_id: modelLoaderNodeId, + field: 'clip2', + }, + destination: { + node_id: NEGATIVE_CONDITIONING, + field: 'clip2', + }, + }, + { + source: { + node_id: NEGATIVE_CONDITIONING, + field: 'conditioning', + }, + destination: { + node_id: DENOISE_LATENTS, + field: 'negative_conditioning', + }, + }, + { + source: { + node_id: POSITIVE_CONDITIONING, + field: 'conditioning', + }, + destination: { + node_id: DENOISE_LATENTS, + field: 'positive_conditioning', + }, + }, + { + source: { + node_id: DENOISE_LATENTS, + field: 'latents', + }, + destination: { + node_id: LATENTS_TO_IMAGE, + field: 'latents', + }, + }, + { + source: { + node_id: NOISE, + field: 'noise', + }, + destination: { + node_id: DENOISE_LATENTS, + field: 'noise', + }, + }, + ], + }; + + // add metadata accumulator, which is only mostly populated - some fields are added later + graph.nodes[METADATA_ACCUMULATOR] = { + id: METADATA_ACCUMULATOR, + type: 'metadata_accumulator', + generation_mode: 'txt2img', + cfg_scale, + height, + width, + positive_prompt: '', // set in addDynamicPromptsToGraph + negative_prompt: negativePrompt, + model, + seed: 0, // set in addDynamicPromptsToGraph + steps, + rand_device: use_cpu ? 'cpu' : 'cuda', + scheduler, + vae: undefined, // option; set in addVAEToGraph + controlnets: [], // populated in addControlNetToLinearGraph + loras: [], // populated in addLoRAsToGraph + clip_skip: clipSkip, + }; + + graph.edges.push({ + source: { + node_id: METADATA_ACCUMULATOR, + field: 'metadata', + }, + destination: { + node_id: LATENTS_TO_IMAGE, + field: 'metadata', + }, + }); + + // Add Refiner if enabled + if (shouldUseSDXLRefiner) { + addSDXLRefinerToGraph(state, graph, DENOISE_LATENTS); + } + + // add LoRA support + addSDXLLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId); + + // optionally add custom VAE + addVAEToGraph(state, graph, modelLoaderNodeId); + + // add dynamic prompts - also sets up core iteration and seed + addDynamicPromptsToGraph(state, graph); + + // add controlnet, mutating `graph` + addControlNetToLinearGraph(state, graph, DENOISE_LATENTS); + + // NSFW & watermark - must be last thing added to graph + if (state.system.shouldUseNSFWChecker) { + // must add before watermarker! + addNSFWCheckerToGraph(state, graph); + } + + if (state.system.shouldUseWatermarker) { + // must add after nsfw checker! + addWatermarkerToGraph(state, graph); + } + + return graph; +}; diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/MainModel/ParamMainModelSelect.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/MainModel/ParamMainModelSelect.tsx index 0a18d4f556..05b5b6468a 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/MainModel/ParamMainModelSelect.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/MainModel/ParamMainModelSelect.tsx @@ -15,11 +15,11 @@ import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainM import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; import { forEach } from 'lodash-es'; +import { NON_REFINER_BASE_MODELS } from 'services/api/constants'; import { useGetMainModelsQuery, useGetOnnxModelsQuery, } from 'services/api/endpoints/models'; -import { NON_REFINER_BASE_MODELS } from 'services/api/constants'; import { useFeatureStatus } from '../../../../system/hooks/useFeatureStatus'; const selector = createSelector( @@ -52,10 +52,7 @@ const ParamMainModelSelect = () => { const data: SelectItem[] = []; forEach(mainModels.entities, (model, id) => { - if ( - !model || - (activeTabName === 'unifiedCanvas' && model.base_model === 'sdxl') - ) { + if (!model) { return; } diff --git a/invokeai/frontend/web/src/features/sdxl/components/SDXLUnifiedCanvasTabParameters.tsx b/invokeai/frontend/web/src/features/sdxl/components/SDXLUnifiedCanvasTabParameters.tsx new file mode 100644 index 0000000000..270e839894 --- /dev/null +++ b/invokeai/frontend/web/src/features/sdxl/components/SDXLUnifiedCanvasTabParameters.tsx @@ -0,0 +1,29 @@ +import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse'; +import ParamLoraCollapse from 'features/lora/components/ParamLoraCollapse'; +import ParamAdvancedCollapse from 'features/parameters/components/Parameters/Advanced/ParamAdvancedCollapse'; +import ParamInfillAndScalingCollapse from 'features/parameters/components/Parameters/Canvas/InfillAndScaling/ParamInfillAndScalingCollapse'; +import ParamMaskAdjustmentCollapse from 'features/parameters/components/Parameters/Canvas/MaskAdjustment/ParamMaskAdjustmentCollapse'; +import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse'; +import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse'; +import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons'; +import UnifiedCanvasCoreParameters from 'features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasCoreParameters'; +import ParamSDXLPromptArea from './ParamSDXLPromptArea'; +import ParamSDXLRefinerCollapse from './ParamSDXLRefinerCollapse'; + +export default function SDXLUnifiedCanvasTabParameters() { + return ( + <> + + + + + + + + + + + + + ); +} diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasTab.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasTab.tsx index 4c36c45e13..0a5b872e4b 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasTab.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasTab.tsx @@ -1,14 +1,22 @@ import { Flex } from '@chakra-ui/react'; +import { RootState } from 'app/store/store'; +import { useAppSelector } from 'app/store/storeHooks'; +import SDXLUnifiedCanvasTabParameters from 'features/sdxl/components/SDXLUnifiedCanvasTabParameters'; import { memo } from 'react'; import ParametersPinnedWrapper from '../../ParametersPinnedWrapper'; import UnifiedCanvasContent from './UnifiedCanvasContent'; import UnifiedCanvasParameters from './UnifiedCanvasParameters'; const UnifiedCanvasTab = () => { + const model = useAppSelector((state: RootState) => state.generation.model); return ( - + {model && model.base_model === 'sdxl' ? ( + + ) : ( + + )}