diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildMultidiffusionUpscaleGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildMultidiffusionUpscaleGraph.ts index c18d31fd8d..d80b20c770 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildMultidiffusionUpscaleGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildMultidiffusionUpscaleGraph.ts @@ -1,11 +1,12 @@ import { Graph, GraphType } from 'features/nodes/util/graph/generation/Graph'; import { RootState } from '../../../../app/store/store'; import { assert } from 'tsafe'; -import { ControlNetModelConfig, Invocation, NonNullableGraph } from '../../../../services/api/types'; -import { ESRGAN, NEGATIVE_CONDITIONING, POSITIVE_CONDITIONING } from './constants'; +import { CLIP_SKIP, CONTROL_NET_COLLECT, ESRGAN, IMAGE_TO_LATENTS, LATENTS_TO_IMAGE, MAIN_MODEL_LOADER, NEGATIVE_CONDITIONING, NOISE, POSITIVE_CONDITIONING, RESIZE, SDXL_MODEL_LOADER, TILED_MULTI_DIFFUSION_DENOISE_LATENTS, UNSHARP_MASK, VAE_LOADER } from './constants'; import { isParamESRGANModelName } from '../../../parameters/store/postprocessingSlice'; -import { ControlNetConfig } from '../../../controlAdapters/store/types'; -import { MODEL_TYPES } from '../../types/constants'; +import { getSDXLStylePrompts } from './graphBuilderUtils'; +import { addLoRAs } from './generation/addLoRAs'; +import { addSDXLLoRas } from './generation/addSDXLLoRAs'; +import { modelsApi } from '../../../../services/api/endpoints/models'; export const buildMultidiffusionUpscsaleGraph = async (state: RootState): Promise => { @@ -24,63 +25,75 @@ export const buildMultidiffusionUpscsaleGraph = async (state: RootState): Promis assert(model, 'No model found in state'); assert(upscaleModel, 'No upscale model found in state'); assert(upscaleInitialImage, 'No initial image found in state'); - - if (!isParamESRGANModelName(upscaleModel.name)) { - throw new Error() - } + assert(isParamESRGANModelName(upscaleModel.name), "") const g = new Graph() - const unsharp_mask_1 = g.addNode({ - id: 'unsharp_mask_1', + const unsharpMaskNode1 = g.addNode({ + id: `${UNSHARP_MASK}_1`, type: 'unsharp_mask', image: upscaleInitialImage, radius: 2, strength: ((sharpness + 10) * 3.75) + 25 }) - const esrgan = g.addNode({ + const upscaleNode = g.addNode({ id: ESRGAN, type: 'esrgan', model_name: upscaleModel.name, tile_size: 500 }) - g.addEdge(unsharp_mask_1, 'image', esrgan, 'image') + g.addEdge(unsharpMaskNode1, 'image', upscaleNode, 'image') - const unsharp_mask_2 = g.addNode({ - id: 'unsharp_mask_2', + const unsharpMaskNode2 = g.addNode({ + id: `${UNSHARP_MASK}_2`, type: 'unsharp_mask', radius: 2, strength: 50 }) - g.addEdge(esrgan, 'image', unsharp_mask_2, 'image',) + g.addEdge(upscaleNode, 'image', unsharpMaskNode2, 'image',) - const SCALE = 2 + const SCALE = 4 const resizeNode = g.addNode({ - id: 'img_resize', + id: RESIZE, type: 'img_resize', width: upscaleInitialImage.width * SCALE, // TODO: handle floats height: upscaleInitialImage.height * SCALE, // TODO: handle floats - resample_mode: "lanczos" + resample_mode: "lanczos", + is_intermediate: false }) - g.addEdge(unsharp_mask_2, 'image', resizeNode, "image") + g.addEdge(unsharpMaskNode2, 'image', resizeNode, "image") + const noiseNode = g.addNode({ + id: NOISE, + type: "noise", + seed, + }) + g.addEdge(resizeNode, 'width', noiseNode, "width") + g.addEdge(resizeNode, 'height', noiseNode, "height") - const sharpnessNode: Invocation<'unsharp_mask'> = { //before and after esrgan - id: 'unsharp_mask', - type: 'unsharp_mask', - image: upscaleInitialImage, - radius: 2, - strength: ((sharpness + 10) * 3.75) + 25 - }; + const i2lNode = g.addNode({ + id: IMAGE_TO_LATENTS, + type: "i2l", + is_intermediate: false, + fp32: vaePrecision === "fp32" + }) - const creativityNode: Invocation<'tiled_multi_diffusion_denoise_latents'> = { //before and after esrgan - id: 'tiled_multi_diffusion_denoise_latents', + g.addEdge(resizeNode, 'image', i2lNode, "image") + + const l2iNode = g.addNode({ + type: "l2i", + id: LATENTS_TO_IMAGE, + fp32: vaePrecision === "fp32" + }) + + const tiledMultidiffusionNode = g.addNode({ + id: TILED_MULTI_DIFFUSION_DENOISE_LATENTS, type: 'tiled_multi_diffusion_denoise_latents', tile_height: 1024, tile_width: 1024, @@ -90,45 +103,124 @@ export const buildMultidiffusionUpscsaleGraph = async (state: RootState): Promis scheduler, denoising_start: (((creativity * -1) + 10) * 4.99) / 100, denoising_end: 1 - }; + }); - const controlnetModel = { - key: "placeholder", - hash: "placeholder", + const clipSkipNode = g.addNode({ + type: 'clip_skip', + id: CLIP_SKIP, + }); + + + let posCondNode, negCondNode, modelNode; + + if (model.base === "sdxl") { + const { positiveStylePrompt, negativeStylePrompt } = getSDXLStylePrompts(state); + + posCondNode = g.addNode({ + type: 'sdxl_compel_prompt', + id: POSITIVE_CONDITIONING, + prompt: positivePrompt, + style: positiveStylePrompt + }); + negCondNode = g.addNode({ + type: 'sdxl_compel_prompt', + id: NEGATIVE_CONDITIONING, + prompt: negativePrompt, + style: negativeStylePrompt + }); + modelNode = g.addNode({ + type: 'sdxl_model_loader', + id: SDXL_MODEL_LOADER, + model, + }); + addSDXLLoRas(state, g, tiledMultidiffusionNode, modelNode, null, posCondNode, negCondNode); + } else { + posCondNode = g.addNode({ + type: 'compel', + id: POSITIVE_CONDITIONING, + prompt: positivePrompt, + }); + negCondNode = g.addNode({ + type: 'compel', + id: NEGATIVE_CONDITIONING, + prompt: negativePrompt, + }); + modelNode = g.addNode({ + type: 'main_model_loader', + id: MAIN_MODEL_LOADER, + model, + }); + addLoRAs(state, g, tiledMultidiffusionNode, modelNode, null, clipSkipNode, posCondNode, negCondNode); + } + + g.addEdge(modelNode, 'clip', clipSkipNode, 'clip'); + g.addEdge(clipSkipNode, 'clip', posCondNode, 'clip'); + g.addEdge(clipSkipNode, 'clip', negCondNode, 'clip'); + + let vaeNode; + if (vae) { + vaeNode = g.addNode({ + id: VAE_LOADER, + type: "vae_loader", + vae_model: vae + }) + } + + g.addEdge(vaeNode || modelNode, "vae", i2lNode, "vae") + g.addEdge(vaeNode || modelNode, "vae", l2iNode, "vae") + + + g.addEdge(noiseNode, "noise", tiledMultidiffusionNode, "noise") + g.addEdge(i2lNode, "latents", tiledMultidiffusionNode, "latents") + g.addEdge(posCondNode, 'conditioning', tiledMultidiffusionNode, 'positive_conditioning'); + g.addEdge(negCondNode, 'conditioning', tiledMultidiffusionNode, 'negative_conditioning'); + g.addEdge(modelNode, "unet", tiledMultidiffusionNode, "unet") + g.addEdge(tiledMultidiffusionNode, "latents", l2iNode, "latents") + + + const controlnetTileModel = { // TODO: figure out how to handle this, can't assume name is `tile` or that they have it installed + key: "", + hash: "", type: "controlnet" as any, name: "tile", base: model.base } - const controlnet: Invocation<"controlnet"> = { - id: "controlnet", + const controlnetNode1 = g.addNode({ + id: 'controlnet_1', type: "controlnet", - control_model: controlnetModel, + control_model: controlnetTileModel, control_mode: "balanced", resize_mode: "just_resize", - control_weight: ((((structure + 10) * 0.025) + 0.3) * 0.013) + 0.35 - } + control_weight: ((((structure + 10) * 0.025) + 0.3) * 0.013) + 0.35, + begin_step_percent: 0, + end_step_percent: ((structure + 10) * 0.025) + 0.3 + }) + g.addEdge(resizeNode, "image", controlnetNode1, "image") - const noiseNode: Invocation<'noise'> = { - id: "noise", - type: "noise", - seed, - // width: resized output width - // height: resized output height - } + const controlnetNode2 = g.addNode({ + id: "controlnet_2", + type: "controlnet", + control_model: controlnetTileModel, + control_mode: "balanced", + resize_mode: "just_resize", + control_weight: (((structure + 10) * 0.025) + 0.3) * 0.013, + begin_step_percent: ((structure + 10) * 0.025) + 0.3, + end_step_percent: 0.8 + }) - const posPrompt: Invocation<"compel"> = { - type: 'compel', - id: POSITIVE_CONDITIONING, - prompt: positivePrompt, - } + g.addEdge(resizeNode, "image", controlnetNode2, "image") + + const collectNode = g.addNode({ + id: CONTROL_NET_COLLECT, + type: "collect", + }) + g.addEdge(controlnetNode1, "control", collectNode, "item") + g.addEdge(controlnetNode2, "control", collectNode, "item") + + g.addEdge(collectNode, "collection", tiledMultidiffusionNode, "control") - const negPrompt: Invocation<"compel"> = { - type: 'compel', - id: NEGATIVE_CONDITIONING, - prompt: negativePrompt, - } return g.getGraph(); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/constants.ts b/invokeai/frontend/web/src/features/nodes/util/graph/constants.ts index 53d7d742ab..e7d62897ef 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/constants.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/constants.ts @@ -53,6 +53,8 @@ export const PROMPT_REGION_NEGATIVE_COND_PREFIX = 'prompt_region_negative_cond'; export const PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX = 'prompt_region_positive_cond_inverted'; export const POSITIVE_CONDITIONING_COLLECT = 'positive_conditioning_collect'; export const NEGATIVE_CONDITIONING_COLLECT = 'negative_conditioning_collect'; +export const UNSHARP_MASK = 'unsharp_mask' +export const TILED_MULTI_DIFFUSION_DENOISE_LATENTS = "tiled_multi_diffusion_denoise_latents" // friendly graph ids export const CONTROL_LAYERS_GRAPH = 'control_layers_graph'; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addLoRAs.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addLoRAs.ts index 3623343367..3335e0f80d 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addLoRAs.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addLoRAs.ts @@ -8,7 +8,7 @@ import type { Invocation, S } from 'services/api/types'; export const addLoRAs = ( state: RootState, g: Graph, - denoise: Invocation<'denoise_latents'>, + denoise: Invocation<'denoise_latents'> | Invocation<'tiled_multi_diffusion_denoise_latents'>, modelLoader: Invocation<'main_model_loader'>, seamless: Invocation<'seamless'> | null, clipSkip: Invocation<'clip_skip'>, diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addSDXLLoRAs.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addSDXLLoRAs.ts index f38e8de570..3125ab5ac3 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addSDXLLoRAs.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addSDXLLoRAs.ts @@ -8,7 +8,7 @@ import type { Invocation, S } from 'services/api/types'; export const addSDXLLoRas = ( state: RootState, g: Graph, - denoise: Invocation<'denoise_latents'>, + denoise: Invocation<'denoise_latents'> | Invocation<'tiled_multi_diffusion_denoise_latents'>, modelLoader: Invocation<'sdxl_model_loader'>, seamless: Invocation<'seamless'> | null, posCond: Invocation<'sdxl_compel_prompt'>,