From 43672a53abd93d14be641dc465a3528052a2241e Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 23 Aug 2024 14:49:16 +1000 Subject: [PATCH] feat(ui): revise graph building for control layers, fix issues w/ invocation complete events --- .../listeners/enqueueRequestedLinear.ts | 23 ++- .../src/features/controlLayers/konva/util.ts | 28 +--- .../graph/buildAdHocPostProcessingGraph.ts | 5 +- .../util/graph/buildLinearBatchConfig.ts | 60 ++++---- .../graph/buildMultidiffusionUpscaleGraph.ts | 46 +++--- .../features/nodes/util/graph/constants.ts | 69 --------- .../graph/generation/addControlAdapters.ts | 103 ++++++------- .../util/graph/generation/addIPAdapters.ts | 42 +++--- .../util/graph/generation/addImageToImage.ts | 5 +- .../nodes/util/graph/generation/addInpaint.ts | 25 ++-- .../nodes/util/graph/generation/addLoRAs.ts | 9 +- .../util/graph/generation/addNSFWChecker.ts | 4 +- .../util/graph/generation/addOutpaint.ts | 33 ++--- .../nodes/util/graph/generation/addRegions.ts | 61 +++++--- .../util/graph/generation/addSDXLLoRAs.ts | 9 +- .../util/graph/generation/addSDXLRefiner.ts | 18 +-- .../util/graph/generation/addSeamless.ts | 4 +- .../util/graph/generation/addTextToImage.ts | 3 +- .../util/graph/generation/addWatermarker.ts | 4 +- .../util/graph/generation/buildSD1Graph.ts | 103 ++++++++----- .../util/graph/generation/buildSDXLGraph.ts | 94 +++++++----- .../services/events/onInvocationComplete.ts | 136 +++++++++++------- .../src/services/events/setEventListeners.tsx | 7 +- .../frontend/web/src/services/events/types.ts | 86 ++++------- 24 files changed, 469 insertions(+), 508 deletions(-) delete mode 100644 invokeai/frontend/web/src/features/nodes/util/graph/constants.ts diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts index 9111eba123..5412ccf28f 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts @@ -1,3 +1,4 @@ +import { logger } from 'app/logging/logger'; import { enqueueRequested } from 'app/store/actions'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import { $canvasManager } from 'features/controlLayers/konva/CanvasManager'; @@ -5,9 +6,14 @@ import { sessionStagingAreaReset, sessionStartedStaging } from 'features/control import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig'; import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph'; import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph'; +import type { Graph } from 'features/nodes/util/graph/generation/Graph'; +import { serializeError } from 'serialize-error'; import { queueApi } from 'services/api/endpoints/queue'; +import type { Invocation } from 'services/api/types'; import { assert } from 'tsafe'; +const log = logger('generation'); + export const addEnqueueRequestedLinear = (startAppListening: AppStartListening) => { startAppListening({ predicate: (action): action is ReturnType => @@ -27,20 +33,28 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening) } try { - let g; + let g: Graph; + let noise: Invocation<'noise'>; + let posCond: Invocation<'compel' | 'sdxl_compel_prompt'>; assert(model, 'No model found in state'); const base = model.base; if (base === 'sdxl') { - g = await buildSDXLGraph(state, manager); + const result = await buildSDXLGraph(state, manager); + g = result.g; + noise = result.noise; + posCond = result.posCond; } else if (base === 'sd-1' || base === 'sd-2') { - g = await buildSD1Graph(state, manager); + const result = await buildSD1Graph(state, manager); + g = result.g; + noise = result.noise; + posCond = result.posCond; } else { assert(false, `No graph builders for base ${base}`); } - const batchConfig = prepareLinearUIBatch(state, g, prepend); + const batchConfig = prepareLinearUIBatch(state, g, prepend, noise, posCond); const req = dispatch( queueApi.endpoints.enqueueBatch.initiate(batchConfig, { @@ -50,6 +64,7 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening) req.reset(); await req.unwrap(); } catch (error) { + log.error({ error: serializeError(error) }, 'Failed to enqueue batch'); if (didStartStaging && getState().canvasV2.session.isStaging) { dispatch(sessionStagingAreaReset()); } diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/util.ts b/invokeai/frontend/web/src/features/controlLayers/konva/util.ts index 5c917dddb7..adc8504bf1 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/util.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/util.ts @@ -1,5 +1,5 @@ -import type { Coordinate, Rect, RgbaColor } from 'features/controlLayers/store/types'; -import Konva from 'konva'; +import type { Coordinate, Rect } from 'features/controlLayers/store/types'; +import type Konva from 'konva'; import type { KonvaEventObject } from 'konva/lib/Node'; import type { Vector2d } from 'konva/lib/types'; import { customAlphabet } from 'nanoid'; @@ -279,30 +279,6 @@ export const konvaNodeToBlob = (node: Konva.Node, bbox?: Rect): Promise => return canvasToBlob(canvas); }; -/** - * Gets the pixel under the cursor on the stage, or null if the cursor is not over the stage. - * @param stage The konva stage - */ -export const getPixelUnderCursor = (stage: Konva.Stage): RgbaColor | null => { - const cursorPos = stage.getPointerPosition(); - const pixelRatio = Konva.pixelRatio; - if (!cursorPos) { - return null; - } - const ctx = stage.toCanvas().getContext('2d'); - - if (!ctx) { - return null; - } - const [r, g, b, a] = ctx.getImageData(cursorPos.x * pixelRatio, cursorPos.y * pixelRatio, 1, 1).data; - - if (r === undefined || g === undefined || b === undefined || a === undefined) { - return null; - } - - return { r, g, b, a }; -}; - export const previewBlob = (blob: Blob, label?: string) => { const url = URL.createObjectURL(blob); const w = window.open(''); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildAdHocPostProcessingGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildAdHocPostProcessingGraph.ts index 30ca457cda..35bf63e87a 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildAdHocPostProcessingGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildAdHocPostProcessingGraph.ts @@ -1,4 +1,5 @@ import type { RootState } from 'app/store/store'; +import { getPrefixedId } from 'features/controlLayers/konva/util'; import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; import type { GraphType } from 'features/nodes/util/graph/generation/Graph'; import { Graph } from 'features/nodes/util/graph/generation/Graph'; @@ -7,8 +8,6 @@ import type { ImageDTO } from 'services/api/types'; import { isSpandrelImageToImageModelConfig } from 'services/api/types'; import { assert } from 'tsafe'; -import { SPANDREL } from './constants'; - type Arg = { image: ImageDTO; state: RootState; @@ -21,8 +20,8 @@ export const buildAdHocPostProcessingGraph = async ({ image, state }: Arg): Prom const g = new Graph('adhoc-post-processing-graph'); g.addNode({ - id: SPANDREL, type: 'spandrel_image_to_image', + id: getPrefixedId('spandrel'), image_to_image_model: postProcessingModel, image, board: getBoardField(state), diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearBatchConfig.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearBatchConfig.ts index 3cd80862ab..c38d9e252a 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearBatchConfig.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearBatchConfig.ts @@ -3,11 +3,15 @@ import { generateSeeds } from 'common/util/generateSeeds'; import type { Graph } from 'features/nodes/util/graph/generation/Graph'; import { range } from 'lodash-es'; import type { components } from 'services/api/schema'; -import type { Batch, BatchConfig } from 'services/api/types'; +import type { Batch, BatchConfig, Invocation } from 'services/api/types'; -import { NOISE, POSITIVE_CONDITIONING } from './constants'; - -export const prepareLinearUIBatch = (state: RootState, g: Graph, prepend: boolean): BatchConfig => { +export const prepareLinearUIBatch = ( + state: RootState, + g: Graph, + prepend: boolean, + noise: Invocation<'noise'>, + posCond: Invocation<'compel' | 'sdxl_compel_prompt'> +): BatchConfig => { const { iterations, model, shouldRandomizeSeed, seed, shouldConcatPrompts } = state.canvasV2.params; const { prompts, seedBehaviour } = state.dynamicPrompts; @@ -22,13 +26,11 @@ export const prepareLinearUIBatch = (state: RootState, g: Graph, prepend: boolea start: shouldRandomizeSeed ? undefined : seed, }); - if (g.hasNode(NOISE)) { - firstBatchDatumList.push({ - node_path: NOISE, - field_name: 'seed', - items: seeds, - }); - } + firstBatchDatumList.push({ + node_path: noise.id, + field_name: 'seed', + items: seeds, + }); // add to metadata g.removeMetadata(['seed']); @@ -44,13 +46,11 @@ export const prepareLinearUIBatch = (state: RootState, g: Graph, prepend: boolea start: shouldRandomizeSeed ? undefined : seed, }); - if (g.hasNode(NOISE)) { - secondBatchDatumList.push({ - node_path: NOISE, - field_name: 'seed', - items: seeds, - }); - } + secondBatchDatumList.push({ + node_path: noise.id, + field_name: 'seed', + items: seeds, + }); // add to metadata g.removeMetadata(['seed']); @@ -65,13 +65,11 @@ export const prepareLinearUIBatch = (state: RootState, g: Graph, prepend: boolea const extendedPrompts = seedBehaviour === 'PER_PROMPT' ? range(iterations).flatMap(() => prompts) : prompts; // zipped batch of prompts - if (g.hasNode(POSITIVE_CONDITIONING)) { - firstBatchDatumList.push({ - node_path: POSITIVE_CONDITIONING, - field_name: 'prompt', - items: extendedPrompts, - }); - } + firstBatchDatumList.push({ + node_path: posCond.id, + field_name: 'prompt', + items: extendedPrompts, + }); // add to metadata g.removeMetadata(['positive_prompt']); @@ -82,13 +80,11 @@ export const prepareLinearUIBatch = (state: RootState, g: Graph, prepend: boolea }); if (shouldConcatPrompts && model?.base === 'sdxl') { - if (g.hasNode(POSITIVE_CONDITIONING)) { - firstBatchDatumList.push({ - node_path: POSITIVE_CONDITIONING, - field_name: 'style', - items: extendedPrompts, - }); - } + firstBatchDatumList.push({ + node_path: posCond.id, + field_name: 'style', + items: extendedPrompts, + }); // add to metadata g.removeMetadata(['positive_style_prompt']); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildMultidiffusionUpscaleGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildMultidiffusionUpscaleGraph.ts index 31a22fa6fb..f07d4cbc79 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildMultidiffusionUpscaleGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildMultidiffusionUpscaleGraph.ts @@ -1,25 +1,11 @@ import type { RootState } from 'app/store/store'; +import { getPrefixedId } from 'features/controlLayers/konva/util'; import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; import { addSDXLLoRAs } from 'features/nodes/util/graph/generation/addSDXLLoRAs'; import { Graph } from 'features/nodes/util/graph/generation/Graph'; import { isNonRefinerMainModelConfig, isSpandrelImageToImageModelConfig } from 'services/api/types'; import { assert } from 'tsafe'; -import { - CLIP_SKIP, - CONTROL_NET_COLLECT, - IMAGE_TO_LATENTS, - LATENTS_TO_IMAGE, - MAIN_MODEL_LOADER, - NEGATIVE_CONDITIONING, - NOISE, - POSITIVE_CONDITIONING, - SDXL_MODEL_LOADER, - SPANDREL, - TILED_MULTI_DIFFUSION_DENOISE_LATENTS, - UNSHARP_MASK, - VAE_LOADER, -} from './constants'; import { addLoRAs } from './generation/addLoRAs'; import { getBoardField, getPresetModifiedPrompts } from './graphBuilderUtils'; @@ -35,8 +21,8 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise const g = new Graph(); const upscaleNode = g.addNode({ - id: SPANDREL, type: 'spandrel_image_to_image_autoscale', + id: getPrefixedId('spandrel_autoscale'), image: upscaleInitialImage, image_to_image_model: upscaleModel, fit_to_multiple_of_8: true, @@ -44,8 +30,8 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise }); const unsharpMaskNode2 = g.addNode({ - id: `${UNSHARP_MASK}_2`, type: 'unsharp_mask', + id: getPrefixedId('unsharp_2'), radius: 2, strength: 60, }); @@ -53,8 +39,8 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise g.addEdge(upscaleNode, 'image', unsharpMaskNode2, 'image'); const noiseNode = g.addNode({ - id: NOISE, type: 'noise', + id: getPrefixedId('noise'), seed, }); @@ -62,8 +48,8 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise g.addEdge(unsharpMaskNode2, 'height', noiseNode, 'height'); const i2lNode = g.addNode({ - id: IMAGE_TO_LATENTS, type: 'i2l', + id: getPrefixedId('i2l'), fp32: vaePrecision === 'fp32', tiled: true, }); @@ -72,7 +58,7 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise const l2iNode = g.addNode({ type: 'l2i', - id: LATENTS_TO_IMAGE, + id: getPrefixedId('l2i'), fp32: vaePrecision === 'fp32', tiled: true, board: getBoardField(state), @@ -80,8 +66,8 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise }); const tiledMultidiffusionNode = g.addNode({ - id: TILED_MULTI_DIFFUSION_DENOISE_LATENTS, type: 'tiled_multi_diffusion_denoise_latents', + id: getPrefixedId('tiled_multidiffusion_denoise_latents'), tile_height: 1024, // is this dependent on base model tile_width: 1024, // is this dependent on base model tile_overlap: 128, @@ -102,19 +88,19 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise posCondNode = g.addNode({ type: 'sdxl_compel_prompt', - id: POSITIVE_CONDITIONING, + id: getPrefixedId('pos_cond'), prompt: positivePrompt, style: positiveStylePrompt, }); negCondNode = g.addNode({ type: 'sdxl_compel_prompt', - id: NEGATIVE_CONDITIONING, + id: getPrefixedId('neg_cond'), prompt: negativePrompt, style: negativeStylePrompt, }); modelNode = g.addNode({ type: 'sdxl_model_loader', - id: SDXL_MODEL_LOADER, + id: getPrefixedId('sdxl_model_loader'), model, }); g.addEdge(modelNode, 'clip', posCondNode, 'clip'); @@ -135,22 +121,22 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise posCondNode = g.addNode({ type: 'compel', - id: POSITIVE_CONDITIONING, + id: getPrefixedId('pos_cond'), prompt: positivePrompt, }); negCondNode = g.addNode({ type: 'compel', - id: NEGATIVE_CONDITIONING, + id: getPrefixedId('neg_cond'), prompt: negativePrompt, }); modelNode = g.addNode({ type: 'main_model_loader', - id: MAIN_MODEL_LOADER, + id: getPrefixedId('sd1_model_loader'), model, }); const clipSkipNode = g.addNode({ type: 'clip_skip', - id: CLIP_SKIP, + id: getPrefixedId('clip_skip'), }); g.addEdge(modelNode, 'clip', clipSkipNode, 'clip'); @@ -193,8 +179,8 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise let vaeNode; if (vae) { vaeNode = g.addNode({ - id: VAE_LOADER, type: 'vae_loader', + id: getPrefixedId('vae'), vae_model: vae, }); } @@ -236,8 +222,8 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise g.addEdge(unsharpMaskNode2, 'image', controlnetNode2, 'image'); const collectNode = g.addNode({ - id: CONTROL_NET_COLLECT, type: 'collect', + id: getPrefixedId('controlnet_collector'), }); g.addEdge(controlnetNode1, 'control', collectNode, 'item'); g.addEdge(controlnetNode2, 'control', collectNode, 'item'); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/constants.ts b/invokeai/frontend/web/src/features/nodes/util/graph/constants.ts deleted file mode 100644 index 5bc2192f53..0000000000 --- a/invokeai/frontend/web/src/features/nodes/util/graph/constants.ts +++ /dev/null @@ -1,69 +0,0 @@ -// friendly node ids -export const POSITIVE_CONDITIONING = 'positive_conditioning'; -export const NEGATIVE_CONDITIONING = 'negative_conditioning'; -export const DENOISE_LATENTS = 'denoise_latents'; -export const DENOISE_LATENTS_HRF = 'denoise_latents_hrf'; -export const LATENTS_TO_IMAGE = 'latents_to_image'; -export const LATENTS_TO_IMAGE_HRF_HR = 'latents_to_image_hrf_hr'; -export const LATENTS_TO_IMAGE_HRF_LR = 'latents_to_image_hrf_lr'; -export const IMAGE_TO_LATENTS_HRF = 'image_to_latents_hrf'; -export const RESIZE_HRF = 'resize_hrf'; -export const ESRGAN_HRF = 'esrgan_hrf'; -export const NSFW_CHECKER = 'nsfw_checker'; -export const WATERMARKER = 'invisible_watermark'; -export const NOISE = 'noise'; -export const NOISE_HRF = 'noise_hrf'; -export const MAIN_MODEL_LOADER = 'main_model_loader'; -export const VAE_LOADER = 'vae_loader'; -export const LORA_LOADER = 'lora_loader'; -export const CLIP_SKIP = 'clip_skip'; -export const IMAGE_TO_LATENTS = 'image_to_latents'; -export const RESIZE = 'resize_image'; -export const IMG2IMG_RESIZE = 'img2img_resize'; -export const CANVAS_OUTPUT = 'canvas_output'; -export const INPAINT_IMAGE = 'inpaint_image'; -export const INPAINT_IMAGE_RESIZE_UP = 'inpaint_image_resize_up'; -export const INPAINT_IMAGE_RESIZE_DOWN = 'inpaint_image_resize_down'; -export const INPAINT_INFILL = 'inpaint_infill'; -export const INPAINT_INFILL_RESIZE_DOWN = 'inpaint_infill_resize_down'; -export const INPAINT_CREATE_MASK = 'inpaint_create_mask'; -export const CANVAS_COHERENCE_NOISE = 'canvas_coherence_noise'; -export const MASK_FROM_ALPHA = 'tomask'; -export const MASK_COMBINE = 'mask_combine'; -export const MASK_RESIZE_UP = 'mask_resize_up'; -export const MASK_RESIZE_DOWN = 'mask_resize_down'; -export const CONTROL_NET_COLLECT = 'control_net_collect'; -export const IP_ADAPTER_COLLECT = 'ip_adapter_collect'; -export const T2I_ADAPTER_COLLECT = 't2i_adapter_collect'; -export const METADATA = 'core_metadata'; -export const SPANDREL = 'spandrel'; -export const SDXL_MODEL_LOADER = 'sdxl_model_loader'; -export const SDXL_DENOISE_LATENTS = 'sdxl_denoise_latents'; -export const SDXL_REFINER_MODEL_LOADER = 'sdxl_refiner_model_loader'; -export const SDXL_REFINER_POSITIVE_CONDITIONING = 'sdxl_refiner_positive_conditioning'; -export const SDXL_REFINER_NEGATIVE_CONDITIONING = 'sdxl_refiner_negative_conditioning'; -export const SDXL_REFINER_DENOISE_LATENTS = 'sdxl_refiner_denoise_latents'; -export const SDXL_REFINER_INPAINT_CREATE_MASK = 'refiner_inpaint_create_mask'; -export const SEAMLESS = 'seamless'; -export const SDXL_REFINER_SEAMLESS = 'refiner_seamless'; -export const PROMPT_REGION_MASK_TO_TENSOR_PREFIX = 'prompt_region_mask_to_tensor'; -export const PROMPT_REGION_INVERT_TENSOR_MASK_PREFIX = 'prompt_region_invert_tensor_mask'; -export const PROMPT_REGION_POSITIVE_COND_PREFIX = 'prompt_region_positive_cond'; -export const PROMPT_REGION_NEGATIVE_COND_PREFIX = 'prompt_region_negative_cond'; -export const PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX = 'prompt_region_positive_cond_inverted'; -export const POSITIVE_CONDITIONING_COLLECT = 'positive_conditioning_collect'; -export const NEGATIVE_CONDITIONING_COLLECT = 'negative_conditioning_collect'; -export const UNSHARP_MASK = 'unsharp_mask'; -export const TILED_MULTI_DIFFUSION_DENOISE_LATENTS = 'tiled_multi_diffusion_denoise_latents'; - -// friendly graph ids -export const CONTROL_LAYERS_GRAPH = 'control_layers_graph'; -export const SDXL_CONTROL_LAYERS_GRAPH = 'sdxl_control_layers_graph'; -export const CANVAS_TEXT_TO_IMAGE_GRAPH = 'canvas_text_to_image_graph'; -export const CANVAS_IMAGE_TO_IMAGE_GRAPH = 'canvas_image_to_image_graph'; -export const CANVAS_INPAINT_GRAPH = 'canvas_inpaint_graph'; -export const CANVAS_OUTPAINT_GRAPH = 'canvas_outpaint_graph'; -export const SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH = 'sdxl_canvas_text_to_image_graph'; -export const SDXL_CANVAS_IMAGE_TO_IMAGE_GRAPH = 'sdxl_canvas_image_to_image_graph'; -export const SDXL_CANVAS_INPAINT_GRAPH = 'sdxl_canvas_inpaint_graph'; -export const SDXL_CANVAS_OUTPAINT_GRAPH = 'sdxl_canvas_outpaint_graph'; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addControlAdapters.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addControlAdapters.ts index ba0be20c46..8b3630574b 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addControlAdapters.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addControlAdapters.ts @@ -5,58 +5,81 @@ import type { Rect, T2IAdapterConfig, } from 'features/controlLayers/store/types'; -import { CONTROL_NET_COLLECT, T2I_ADAPTER_COLLECT } from 'features/nodes/util/graph/constants'; import type { Graph } from 'features/nodes/util/graph/generation/Graph'; import type { BaseModelType, ImageDTO, Invocation } from 'services/api/types'; import { assert } from 'tsafe'; -export const addControlAdapters = async ( +type AddControlNetsResult = { + addedControlNets: number; +}; + +export const addControlNets = async ( manager: CanvasManager, layers: CanvasControlLayerState[], g: Graph, bbox: Rect, - denoise: Invocation<'denoise_latents'>, + collector: Invocation<'collect'>, base: BaseModelType -): Promise => { +): Promise => { const validControlLayers = layers .filter((layer) => layer.isEnabled) - .filter((layer) => isValidControlAdapter(layer.controlAdapter, base)); + .filter((layer) => isValidControlAdapter(layer.controlAdapter, base)) + .filter((layer) => layer.controlAdapter.type === 'controlnet'); + + const result: AddControlNetsResult = { + addedControlNets: 0, + }; for (const layer of validControlLayers) { + result.addedControlNets++; + const adapter = manager.adapters.controlLayers.get(layer.id); assert(adapter, 'Adapter not found'); const imageDTO = await adapter.renderer.rasterize({ rect: bbox, attrs: { opacity: 1, filters: [] } }); - if (layer.controlAdapter.type === 'controlnet') { - await addControlNetToGraph(g, layer, imageDTO, denoise); - } else { - await addT2IAdapterToGraph(g, layer, imageDTO, denoise); - } + await addControlNetToGraph(g, layer, imageDTO, collector); } - return validControlLayers; + + return result; }; -const addControlNetCollectorSafe = (g: Graph, denoise: Invocation<'denoise_latents'>): Invocation<'collect'> => { - try { - // Attempt to retrieve the collector - const controlNetCollect = g.getNode(CONTROL_NET_COLLECT); - assert(controlNetCollect.type === 'collect'); - return controlNetCollect; - } catch { - // Add the ControlNet collector - const controlNetCollect = g.addNode({ - id: CONTROL_NET_COLLECT, - type: 'collect', - }); - g.addEdge(controlNetCollect, 'collection', denoise, 'control'); - return controlNetCollect; +type AddT2IAdaptersResult = { + addedT2IAdapters: number; +}; + +export const addT2IAdapters = async ( + manager: CanvasManager, + layers: CanvasControlLayerState[], + g: Graph, + bbox: Rect, + collector: Invocation<'collect'>, + base: BaseModelType +): Promise => { + const validControlLayers = layers + .filter((layer) => layer.isEnabled) + .filter((layer) => isValidControlAdapter(layer.controlAdapter, base)) + .filter((layer) => layer.controlAdapter.type === 't2i_adapter'); + + const result: AddT2IAdaptersResult = { + addedT2IAdapters: 0, + }; + + for (const layer of validControlLayers) { + result.addedT2IAdapters++; + + const adapter = manager.adapters.controlLayers.get(layer.id); + assert(adapter, 'Adapter not found'); + const imageDTO = await adapter.renderer.rasterize({ rect: bbox, attrs: { opacity: 1, filters: [] } }); + await addT2IAdapterToGraph(g, layer, imageDTO, collector); } + + return result; }; const addControlNetToGraph = ( g: Graph, layer: CanvasControlLayerState, imageDTO: ImageDTO, - denoise: Invocation<'denoise_latents'> + collector: Invocation<'collect'> ) => { const { id, controlAdapter } = layer; assert(controlAdapter.type === 'controlnet'); @@ -64,8 +87,6 @@ const addControlNetToGraph = ( assert(model !== null); const { image_name } = imageDTO; - const controlNetCollect = addControlNetCollectorSafe(g, denoise); - const controlNet = g.addNode({ id: `control_net_${id}`, type: 'controlnet', @@ -77,32 +98,14 @@ const addControlNetToGraph = ( control_weight: weight, image: { image_name }, }); - g.addEdge(controlNet, 'control', controlNetCollect, 'item'); -}; - -const addT2IAdapterCollectorSafe = (g: Graph, denoise: Invocation<'denoise_latents'>): Invocation<'collect'> => { - try { - // You see, we've already got one! - const t2iAdapterCollect = g.getNode(T2I_ADAPTER_COLLECT); - assert(t2iAdapterCollect.type === 'collect'); - return t2iAdapterCollect; - } catch { - const t2iAdapterCollect = g.addNode({ - id: T2I_ADAPTER_COLLECT, - type: 'collect', - }); - - g.addEdge(t2iAdapterCollect, 'collection', denoise, 't2i_adapter'); - - return t2iAdapterCollect; - } + g.addEdge(controlNet, 'control', collector, 'item'); }; const addT2IAdapterToGraph = ( g: Graph, layer: CanvasControlLayerState, imageDTO: ImageDTO, - denoise: Invocation<'denoise_latents'> + collector: Invocation<'collect'> ) => { const { id, controlAdapter } = layer; assert(controlAdapter.type === 't2i_adapter'); @@ -110,8 +113,6 @@ const addT2IAdapterToGraph = ( assert(model !== null); const { image_name } = imageDTO; - const t2iAdapterCollect = addT2IAdapterCollectorSafe(g, denoise); - const t2iAdapter = g.addNode({ id: `t2i_adapter_${id}`, type: 't2i_adapter', @@ -123,7 +124,7 @@ const addT2IAdapterToGraph = ( image: { image_name }, }); - g.addEdge(t2iAdapter, 't2i_adapter', t2iAdapterCollect, 'item'); + g.addEdge(t2iAdapter, 't2i_adapter', collector, 'item'); }; const isValidControlAdapter = (controlAdapter: ControlNetConfig | T2IAdapterConfig, base: BaseModelType): boolean => { diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addIPAdapters.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addIPAdapters.ts index 99d5009220..90a314c11b 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addIPAdapters.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addIPAdapters.ts @@ -1,44 +1,38 @@ import type { CanvasIPAdapterState, IPAdapterConfig } from 'features/controlLayers/store/types'; -import { IP_ADAPTER_COLLECT } from 'features/nodes/util/graph/constants'; import type { Graph } from 'features/nodes/util/graph/generation/Graph'; import type { BaseModelType, Invocation } from 'services/api/types'; import { assert } from 'tsafe'; +type AddIPAdaptersResult = { + addedIPAdapters: number; +}; + export const addIPAdapters = ( ipAdapters: CanvasIPAdapterState[], g: Graph, - denoise: Invocation<'denoise_latents'>, + collector: Invocation<'collect'>, base: BaseModelType -): CanvasIPAdapterState[] => { +): AddIPAdaptersResult => { const validIPAdapters = ipAdapters.filter((entity) => isValidIPAdapter(entity.ipAdapter, base)); + + const result: AddIPAdaptersResult = { + addedIPAdapters: 0, + }; + for (const ipa of validIPAdapters) { - addIPAdapter(ipa, g, denoise); + result.addedIPAdapters++; + + addIPAdapter(ipa, g, collector); } - return validIPAdapters; + + return result; }; -export const addIPAdapterCollectorSafe = (g: Graph, denoise: Invocation<'denoise_latents'>): Invocation<'collect'> => { - try { - // You see, we've already got one! - const ipAdapterCollect = g.getNode(IP_ADAPTER_COLLECT); - assert(ipAdapterCollect.type === 'collect'); - return ipAdapterCollect; - } catch { - const ipAdapterCollect = g.addNode({ - id: IP_ADAPTER_COLLECT, - type: 'collect', - }); - g.addEdge(ipAdapterCollect, 'collection', denoise, 'ip_adapter'); - return ipAdapterCollect; - } -}; - -const addIPAdapter = (entity: CanvasIPAdapterState, g: Graph, denoise: Invocation<'denoise_latents'>) => { +const addIPAdapter = (entity: CanvasIPAdapterState, g: Graph, collector: Invocation<'collect'>) => { const { id, ipAdapter } = entity; const { weight, model, clipVisionModel, method, beginEndStepPct, image } = ipAdapter; assert(image, 'IP Adapter image is required'); assert(model, 'IP Adapter model is required'); - const ipAdapterCollect = addIPAdapterCollectorSafe(g, denoise); const ipAdapterNode = g.addNode({ id: `ip_adapter_${id}`, @@ -53,7 +47,7 @@ const addIPAdapter = (entity: CanvasIPAdapterState, g: Graph, denoise: Invocatio image_name: image.image_name, }, }); - g.addEdge(ipAdapterNode, 'ip_adapter', ipAdapterCollect, 'item'); + g.addEdge(ipAdapterNode, 'ip_adapter', collector, 'item'); }; export const isValidIPAdapter = (ipAdapter: IPAdapterConfig, base: BaseModelType): boolean => { diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addImageToImage.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addImageToImage.ts index 6314ef9df4..9bcd148acf 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addImageToImage.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addImageToImage.ts @@ -1,4 +1,5 @@ import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; +import { getPrefixedId } from 'features/controlLayers/konva/util'; import type { CanvasV2State, Dimensions } from 'features/controlLayers/store/types'; import type { Graph } from 'features/nodes/util/graph/generation/Graph'; import { isEqual } from 'lodash-es'; @@ -22,15 +23,15 @@ export const addImageToImage = async ( if (!isEqual(scaledSize, originalSize)) { // Resize the initial image to the scaled size, denoise, then resize back to the original size const resizeImageToScaledSize = g.addNode({ - id: 'initial_image_resize_in', type: 'img_resize', + id: getPrefixedId('initial_image_resize_in'), image: { image_name }, ...scaledSize, }); const i2l = g.addNode({ id: 'i2l', type: 'i2l' }); const resizeImageToOriginalSize = g.addNode({ - id: 'initial_image_resize_out', type: 'img_resize', + id: getPrefixedId('initial_image_resize_out'), ...originalSize, }); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addInpaint.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addInpaint.ts index b15e55ce25..ef0aed835b 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addInpaint.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addInpaint.ts @@ -1,4 +1,5 @@ import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; +import { getPrefixedId } from 'features/controlLayers/konva/util'; import type { CanvasV2State, Dimensions } from 'features/controlLayers/store/types'; import type { Graph } from 'features/nodes/util/graph/generation/Graph'; import type { ParameterPrecision } from 'features/parameters/types/parameterSchemas'; @@ -26,36 +27,36 @@ export const addInpaint = async ( if (!isEqual(scaledSize, originalSize)) { // Scale before processing requires some resizing - const i2l = g.addNode({ id: 'i2l', type: 'i2l' }); + const i2l = g.addNode({ id: getPrefixedId('i2l'), type: 'i2l' }); const resizeImageToScaledSize = g.addNode({ - id: 'resize_image_to_scaled_size', type: 'img_resize', + id: getPrefixedId('resize_image_to_scaled_size'), image: { image_name: initialImage.image_name }, ...scaledSize, }); const alphaToMask = g.addNode({ - id: 'alpha_to_mask', + id: getPrefixedId('alpha_to_mask'), type: 'tomask', image: { image_name: maskImage.image_name }, invert: true, }); const resizeMaskToScaledSize = g.addNode({ - id: 'resize_mask_to_scaled_size', + id: getPrefixedId('resize_mask_to_scaled_size'), type: 'img_resize', ...scaledSize, }); const resizeImageToOriginalSize = g.addNode({ - id: 'resize_image_to_original_size', + id: getPrefixedId('resize_image_to_original_size'), type: 'img_resize', ...originalSize, }); const resizeMaskToOriginalSize = g.addNode({ - id: 'resize_mask_to_original_size', + id: getPrefixedId('resize_mask_to_original_size'), type: 'img_resize', ...originalSize, }); const createGradientMask = g.addNode({ - id: 'create_gradient_mask', + id: getPrefixedId('create_gradient_mask'), type: 'create_gradient_mask', coherence_mode: compositing.canvasCoherenceMode, minimum_denoise: compositing.canvasCoherenceMinDenoise, @@ -63,7 +64,7 @@ export const addInpaint = async ( fp32: vaePrecision === 'fp32', }); const canvasPasteBack = g.addNode({ - id: 'canvas_v2_mask_and_crop', + id: getPrefixedId('canvas_v2_mask_and_crop'), type: 'canvas_v2_mask_and_crop', mask_blur: compositing.maskBlur, }); @@ -92,15 +93,15 @@ export const addInpaint = async ( return canvasPasteBack; } else { // No scale before processing, much simpler - const i2l = g.addNode({ id: 'i2l', type: 'i2l', image: { image_name: initialImage.image_name } }); + const i2l = g.addNode({ id: getPrefixedId('i2l'), type: 'i2l', image: { image_name: initialImage.image_name } }); const alphaToMask = g.addNode({ - id: 'alpha_to_mask', + id: getPrefixedId('alpha_to_mask'), type: 'tomask', image: { image_name: maskImage.image_name }, invert: true, }); const createGradientMask = g.addNode({ - id: 'create_gradient_mask', + id: getPrefixedId('create_gradient_mask'), type: 'create_gradient_mask', coherence_mode: compositing.canvasCoherenceMode, minimum_denoise: compositing.canvasCoherenceMinDenoise, @@ -109,7 +110,7 @@ export const addInpaint = async ( image: { image_name: initialImage.image_name }, }); const canvasPasteBack = g.addNode({ - id: 'canvas_v2_mask_and_crop', + id: getPrefixedId('canvas_v2_mask_and_crop'), type: 'canvas_v2_mask_and_crop', mask_blur: compositing.maskBlur, }); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addLoRAs.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addLoRAs.ts index b078dfcdfc..92bf0cbeaa 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addLoRAs.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addLoRAs.ts @@ -1,6 +1,6 @@ import type { RootState } from 'app/store/store'; +import { getPrefixedId } from 'features/controlLayers/konva/util'; import { zModelIdentifierField } from 'features/nodes/types/common'; -import { LORA_LOADER } from 'features/nodes/util/graph/constants'; import type { Graph } from 'features/nodes/util/graph/generation/Graph'; import type { Invocation, S } from 'services/api/types'; @@ -28,12 +28,12 @@ export const addLoRAs = ( // We will collect LoRAs into a single collection node, then pass them to the LoRA collection loader, which applies // each LoRA to the UNet and CLIP. const loraCollector = g.addNode({ - id: `${LORA_LOADER}_collect`, type: 'collect', + id: getPrefixedId('lora_collector'), }); const loraCollectionLoader = g.addNode({ - id: LORA_LOADER, type: 'lora_collection_loader', + id: getPrefixedId('lora_collection_loader'), }); g.addEdge(loraCollector, 'collection', loraCollectionLoader, 'loras'); @@ -50,12 +50,11 @@ export const addLoRAs = ( for (const lora of enabledLoRAs) { const { weight } = lora; - const { key } = lora.model; const parsedModel = zModelIdentifierField.parse(lora.model); const loraSelector = g.addNode({ type: 'lora_selector', - id: `${LORA_LOADER}_${key}`, + id: getPrefixedId('lora_selector'), lora: parsedModel, weight, }); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addNSFWChecker.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addNSFWChecker.ts index ec9a809a9f..5c0242aa16 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addNSFWChecker.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addNSFWChecker.ts @@ -1,4 +1,4 @@ -import { NSFW_CHECKER } from 'features/nodes/util/graph/constants'; +import { getPrefixedId } from 'features/controlLayers/konva/util'; import type { Graph } from 'features/nodes/util/graph/generation/Graph'; import type { Invocation } from 'services/api/types'; @@ -13,8 +13,8 @@ export const addNSFWChecker = ( imageOutput: Invocation<'l2i' | 'img_nsfw' | 'img_watermark' | 'img_resize' | 'canvas_v2_mask_and_crop'> ): Invocation<'img_nsfw'> => { const nsfw = g.addNode({ - id: NSFW_CHECKER, type: 'img_nsfw', + id: getPrefixedId('nsfw_checker'), }); g.addEdge(imageOutput, 'image', nsfw, 'image'); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addOutpaint.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addOutpaint.ts index 0798f82916..b2d5b4a18e 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addOutpaint.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addOutpaint.ts @@ -1,4 +1,5 @@ import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; +import { getPrefixedId } from 'features/controlLayers/konva/util'; import type { CanvasV2State, Dimensions } from 'features/controlLayers/store/types'; import type { Graph } from 'features/nodes/util/graph/generation/Graph'; import { getInfill } from 'features/nodes/util/graph/graphBuilderUtils'; @@ -31,18 +32,18 @@ export const addOutpaint = async ( // Combine the inpaint mask and the initial image's alpha channel into a single mask const maskAlphaToMask = g.addNode({ - id: 'alpha_to_mask', + id: getPrefixedId('alpha_to_mask'), type: 'tomask', image: { image_name: maskImage.image_name }, invert: true, }); const initialImageAlphaToMask = g.addNode({ - id: 'image_alpha_to_mask', + id: getPrefixedId('image_alpha_to_mask'), type: 'tomask', image: { image_name: initialImage.image_name }, }); const maskCombine = g.addNode({ - id: 'mask_combine', + id: getPrefixedId('mask_combine'), type: 'mask_combine', }); g.addEdge(maskAlphaToMask, 'image', maskCombine, 'mask1'); @@ -50,7 +51,7 @@ export const addOutpaint = async ( // Resize the combined and initial image to the scaled size const resizeInputMaskToScaledSize = g.addNode({ - id: 'resize_mask_to_scaled_size', + id: getPrefixedId('resize_mask_to_scaled_size'), type: 'img_resize', ...scaledSize, }); @@ -58,7 +59,7 @@ export const addOutpaint = async ( // Resize the initial image to the scaled size and infill const resizeInputImageToScaledSize = g.addNode({ - id: 'resize_image_to_scaled_size', + id: getPrefixedId('resize_image_to_scaled_size'), type: 'img_resize', image: { image_name: initialImage.image_name }, ...scaledSize, @@ -67,7 +68,7 @@ export const addOutpaint = async ( // Create the gradient denoising mask from the combined mask const createGradientMask = g.addNode({ - id: 'create_gradient_mask', + id: getPrefixedId('create_gradient_mask'), type: 'create_gradient_mask', coherence_mode: compositing.canvasCoherenceMode, minimum_denoise: compositing.canvasCoherenceMinDenoise, @@ -81,24 +82,24 @@ export const addOutpaint = async ( g.addEdge(createGradientMask, 'denoise_mask', denoise, 'denoise_mask'); // Decode infilled image and connect to denoise - const i2l = g.addNode({ id: 'i2l', type: 'i2l' }); + const i2l = g.addNode({ id: getPrefixedId('i2l'), type: 'i2l' }); g.addEdge(infill, 'image', i2l, 'image'); g.addEdge(vaeSource, 'vae', i2l, 'vae'); g.addEdge(i2l, 'latents', denoise, 'latents'); // Resize the output image back to the original size const resizeOutputImageToOriginalSize = g.addNode({ - id: 'resize_image_to_original_size', + id: getPrefixedId('resize_image_to_original_size'), type: 'img_resize', ...originalSize, }); const resizeOutputMaskToOriginalSize = g.addNode({ - id: 'resize_mask_to_original_size', + id: getPrefixedId('resize_mask_to_original_size'), type: 'img_resize', ...originalSize, }); const canvasPasteBack = g.addNode({ - id: 'canvas_v2_mask_and_crop', + id: getPrefixedId('canvas_v2_mask_and_crop'), type: 'canvas_v2_mask_and_crop', mask_blur: compositing.maskBlur, }); @@ -117,24 +118,24 @@ export const addOutpaint = async ( } else { infill.image = { image_name: initialImage.image_name }; // No scale before processing, much simpler - const i2l = g.addNode({ id: 'i2l', type: 'i2l' }); + const i2l = g.addNode({ id: getPrefixedId('i2l'), type: 'i2l' }); const maskAlphaToMask = g.addNode({ - id: 'mask_alpha_to_mask', + id: getPrefixedId('mask_alpha_to_mask'), type: 'tomask', image: { image_name: maskImage.image_name }, invert: true, }); const initialImageAlphaToMask = g.addNode({ - id: 'image_alpha_to_mask', + id: getPrefixedId('image_alpha_to_mask'), type: 'tomask', image: { image_name: initialImage.image_name }, }); const maskCombine = g.addNode({ - id: 'mask_combine', + id: getPrefixedId('mask_combine'), type: 'mask_combine', }); const createGradientMask = g.addNode({ - id: 'create_gradient_mask', + id: getPrefixedId('create_gradient_mask'), type: 'create_gradient_mask', coherence_mode: compositing.canvasCoherenceMode, minimum_denoise: compositing.canvasCoherenceMinDenoise, @@ -143,7 +144,7 @@ export const addOutpaint = async ( image: { image_name: initialImage.image_name }, }); const canvasPasteBack = g.addNode({ - id: 'canvas_v2_mask_and_crop', + id: getPrefixedId('canvas_v2_mask_and_crop'), type: 'canvas_v2_mask_and_crop', mask_blur: compositing.maskBlur, }); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addRegions.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addRegions.ts index ec4091a447..2c42d2c077 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addRegions.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addRegions.ts @@ -1,22 +1,23 @@ import { deepClone } from 'common/util/deepClone'; import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; +import { getPrefixedId } from 'features/controlLayers/konva/util'; import type { CanvasRegionalGuidanceState, Rect, RegionalGuidanceIPAdapterConfig, } from 'features/controlLayers/store/types'; -import { - PROMPT_REGION_INVERT_TENSOR_MASK_PREFIX, - PROMPT_REGION_MASK_TO_TENSOR_PREFIX, - PROMPT_REGION_NEGATIVE_COND_PREFIX, - PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX, - PROMPT_REGION_POSITIVE_COND_PREFIX, -} from 'features/nodes/util/graph/constants'; -import { addIPAdapterCollectorSafe, isValidIPAdapter } from 'features/nodes/util/graph/generation/addIPAdapters'; +import { isValidIPAdapter } from 'features/nodes/util/graph/generation/addIPAdapters'; import type { Graph } from 'features/nodes/util/graph/generation/Graph'; import type { BaseModelType, Invocation } from 'services/api/types'; import { assert } from 'tsafe'; +type AddedRegionResult = { + addedPositivePrompt: boolean; + addedNegativePrompt: boolean; + addedAutoNegativePositivePrompt: boolean; + addedIPAdapters: number; +}; + /** * Adds regional guidance to the graph * @param regions Array of regions to add @@ -27,6 +28,7 @@ import { assert } from 'tsafe'; * @param negCond The negative conditioning node * @param posCondCollect The positive conditioning collector * @param negCondCollect The negative conditioning collector + * @param ipAdapterCollect The IP adapter collector * @returns A promise that resolves to the regions that were successfully added to the graph */ @@ -40,21 +42,29 @@ export const addRegions = async ( posCond: Invocation<'compel'> | Invocation<'sdxl_compel_prompt'>, negCond: Invocation<'compel'> | Invocation<'sdxl_compel_prompt'>, posCondCollect: Invocation<'collect'>, - negCondCollect: Invocation<'collect'> -): Promise => { + negCondCollect: Invocation<'collect'>, + ipAdapterCollect: Invocation<'collect'> +): Promise => { const isSDXL = base === 'sdxl'; const validRegions = regions.filter((rg) => isValidRegion(rg, base)); + const results: AddedRegionResult[] = []; for (const region of validRegions) { + const result: AddedRegionResult = { + addedPositivePrompt: false, + addedNegativePrompt: false, + addedAutoNegativePositivePrompt: false, + addedIPAdapters: 0, + }; const adapter = manager.adapters.regionMasks.get(region.id); assert(adapter, 'Adapter not found'); const imageDTO = await adapter.renderer.rasterize({ rect: bbox }); // The main mask-to-tensor node const maskToTensor = g.addNode({ - id: `${PROMPT_REGION_MASK_TO_TENSOR_PREFIX}_${region.id}`, type: 'alpha_mask_to_tensor', + id: getPrefixedId('prompt_region_mask_to_tensor'), image: { image_name: imageDTO.image_name, }, @@ -62,17 +72,18 @@ export const addRegions = async ( if (region.positivePrompt) { // The main positive conditioning node + result.addedPositivePrompt = true; const regionalPosCond = g.addNode( isSDXL ? { type: 'sdxl_compel_prompt', - id: `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${region.id}`, + id: getPrefixedId('prompt_region_positive_cond'), prompt: region.positivePrompt, style: region.positivePrompt, // TODO: Should we put the positive prompt in both fields? } : { type: 'compel', - id: `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${region.id}`, + id: getPrefixedId('prompt_region_positive_cond'), prompt: region.positivePrompt, } ); @@ -99,18 +110,19 @@ export const addRegions = async ( } if (region.negativePrompt) { + result.addedNegativePrompt = true; // The main negative conditioning node const regionalNegCond = g.addNode( isSDXL ? { type: 'sdxl_compel_prompt', - id: `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${region.id}`, + id: getPrefixedId('prompt_region_negative_cond'), prompt: region.negativePrompt, style: region.negativePrompt, } : { type: 'compel', - id: `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${region.id}`, + id: getPrefixedId('prompt_region_negative_cond'), prompt: region.negativePrompt, } ); @@ -135,10 +147,11 @@ export const addRegions = async ( } // If we are using the "invert" auto-negative setting, we need to add an additional negative conditioning node - if (region.autoNegative === 'invert' && region.positivePrompt) { + if (region.autoNegative && region.positivePrompt) { + result.addedAutoNegativePositivePrompt = true; // We re-use the mask image, but invert it when converting to tensor const invertTensorMask = g.addNode({ - id: `${PROMPT_REGION_INVERT_TENSOR_MASK_PREFIX}_${region.id}`, + id: getPrefixedId('prompt_region_invert_tensor_mask'), type: 'invert_tensor_mask', }); // Connect the OG mask image to the inverted mask-to-tensor node @@ -148,13 +161,13 @@ export const addRegions = async ( isSDXL ? { type: 'sdxl_compel_prompt', - id: `${PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX}_${region.id}`, + id: getPrefixedId('prompt_region_positive_cond_inverted'), prompt: region.positivePrompt, style: region.positivePrompt, } : { type: 'compel', - id: `${PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX}_${region.id}`, + id: getPrefixedId('prompt_region_positive_cond_inverted'), prompt: region.positivePrompt, } ); @@ -183,7 +196,7 @@ export const addRegions = async ( ); for (const ipa of validRGIPAdapters) { - const ipAdapterCollect = addIPAdapterCollectorSafe(g, denoise); + result.addedIPAdapters++; const { id, weight, model, clipVisionModel, method, beginEndStepPct, image } = ipa; assert(model, 'IP Adapter model is required'); assert(image, 'IP Adapter image is required'); @@ -206,14 +219,18 @@ export const addRegions = async ( g.addEdge(maskToTensor, 'mask', ipAdapter, 'mask'); g.addEdge(ipAdapter, 'ip_adapter', ipAdapterCollect, 'item'); } + + results.push(result); } g.upsertMetadata({ regions: validRegions }); - return validRegions; + + return results; }; export const isValidRegion = (rg: CanvasRegionalGuidanceState, base: BaseModelType) => { + const isEnabled = rg.isEnabled; const hasTextPrompt = Boolean(rg.positivePrompt || rg.negativePrompt); const hasIPAdapter = rg.ipAdapters.filter((ipa) => isValidIPAdapter(ipa, base)).length > 0; - return hasTextPrompt || hasIPAdapter; + return isEnabled && (hasTextPrompt || hasIPAdapter); }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addSDXLLoRAs.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addSDXLLoRAs.ts index f274ec9a09..ffb5268520 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addSDXLLoRAs.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addSDXLLoRAs.ts @@ -1,6 +1,6 @@ import type { RootState } from 'app/store/store'; +import { getPrefixedId } from 'features/controlLayers/konva/util'; import { zModelIdentifierField } from 'features/nodes/types/common'; -import { LORA_LOADER } from 'features/nodes/util/graph/constants'; import type { Graph } from 'features/nodes/util/graph/generation/Graph'; import type { Invocation, S } from 'services/api/types'; @@ -25,12 +25,12 @@ export const addSDXLLoRAs = ( // We will collect LoRAs into a single collection node, then pass them to the LoRA collection loader, which applies // each LoRA to the UNet and CLIP. const loraCollector = g.addNode({ - id: `${LORA_LOADER}_collect`, + id: getPrefixedId('lora_collector'), type: 'collect', }); const loraCollectionLoader = g.addNode({ - id: LORA_LOADER, type: 'sdxl_lora_collection_loader', + id: getPrefixedId('sdxl_lora_collection_loader'), }); g.addEdge(loraCollector, 'collection', loraCollectionLoader, 'loras'); @@ -50,12 +50,11 @@ export const addSDXLLoRAs = ( for (const lora of enabledLoRAs) { const { weight } = lora; - const { key } = lora.model; const parsedModel = zModelIdentifierField.parse(lora.model); const loraSelector = g.addNode({ type: 'lora_selector', - id: `${LORA_LOADER}_${key}`, + id: getPrefixedId('lora_selector'), lora: parsedModel, weight, }); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addSDXLRefiner.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addSDXLRefiner.ts index 7e79ffe4ff..09152d6659 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addSDXLRefiner.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addSDXLRefiner.ts @@ -1,12 +1,6 @@ import type { RootState } from 'app/store/store'; +import { getPrefixedId } from 'features/controlLayers/konva/util'; import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; -import { - SDXL_REFINER_DENOISE_LATENTS, - SDXL_REFINER_MODEL_LOADER, - SDXL_REFINER_NEGATIVE_CONDITIONING, - SDXL_REFINER_POSITIVE_CONDITIONING, - SDXL_REFINER_SEAMLESS, -} from 'features/nodes/util/graph/constants'; import { Graph } from 'features/nodes/util/graph/generation/Graph'; import type { Invocation } from 'services/api/types'; import { isRefinerMainModelModelConfig } from 'services/api/types'; @@ -42,24 +36,24 @@ export const addSDXLRefiner = async ( const refinerModelLoader = g.addNode({ type: 'sdxl_refiner_model_loader', - id: SDXL_REFINER_MODEL_LOADER, + id: getPrefixedId('refiner_model_loader'), model: refinerModel, }); const refinerPosCond = g.addNode({ type: 'sdxl_refiner_compel_prompt', - id: SDXL_REFINER_POSITIVE_CONDITIONING, + id: getPrefixedId('refiner_pos_cond'), style: posCond.style, aesthetic_score: refinerPositiveAestheticScore, }); const refinerNegCond = g.addNode({ type: 'sdxl_refiner_compel_prompt', - id: SDXL_REFINER_NEGATIVE_CONDITIONING, + id: getPrefixedId('refiner_neg_cond'), style: negCond.style, aesthetic_score: refinerNegativeAestheticScore, }); const refinerDenoise = g.addNode({ type: 'denoise_latents', - id: SDXL_REFINER_DENOISE_LATENTS, + id: getPrefixedId('refiner_denoise_latents'), cfg_scale: refinerCFGScale, steps: refinerSteps, scheduler: refinerScheduler, @@ -69,8 +63,8 @@ export const addSDXLRefiner = async ( if (seamless) { const refinerSeamless = g.addNode({ - id: SDXL_REFINER_SEAMLESS, type: 'seamless', + id: getPrefixedId('refiner_seamless'), seamless_x: seamless.seamless_x, seamless_y: seamless.seamless_y, }); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addSeamless.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addSeamless.ts index 3fdcfbe28e..8a48a6e9fd 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addSeamless.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addSeamless.ts @@ -1,5 +1,5 @@ import type { RootState } from 'app/store/store'; -import { SEAMLESS } from 'features/nodes/util/graph/constants'; +import { getPrefixedId } from 'features/controlLayers/konva/util'; import type { Graph } from 'features/nodes/util/graph/generation/Graph'; import type { Invocation } from 'services/api/types'; @@ -28,8 +28,8 @@ export const addSeamless = ( } const seamless = g.addNode({ - id: SEAMLESS, type: 'seamless', + id: getPrefixedId('seamless'), seamless_x, seamless_y, }); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addTextToImage.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addTextToImage.ts index bc11f76be2..e98da9bb9f 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addTextToImage.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addTextToImage.ts @@ -1,3 +1,4 @@ +import { getPrefixedId } from 'features/controlLayers/konva/util'; import type { Dimensions } from 'features/controlLayers/store/types'; import type { Graph } from 'features/nodes/util/graph/generation/Graph'; import { isEqual } from 'lodash-es'; @@ -12,7 +13,7 @@ export const addTextToImage = ( if (!isEqual(scaledSize, originalSize)) { // We need to resize the output image back to the original size const resizeImageToOriginalSize = g.addNode({ - id: 'resize_image_to_original_size', + id: getPrefixedId('resize_image_to_original_size'), type: 'img_resize', ...originalSize, }); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addWatermarker.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addWatermarker.ts index b0f0f14008..e5cedb0040 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addWatermarker.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addWatermarker.ts @@ -1,4 +1,4 @@ -import { WATERMARKER } from 'features/nodes/util/graph/constants'; +import { getPrefixedId } from 'features/controlLayers/konva/util'; import type { Graph } from 'features/nodes/util/graph/generation/Graph'; import type { Invocation } from 'services/api/types'; @@ -13,8 +13,8 @@ export const addWatermarker = ( imageOutput: Invocation<'l2i' | 'img_nsfw' | 'img_watermark' | 'img_resize' | 'canvas_v2_mask_and_crop'> ): Invocation<'img_watermark'> => { const watermark = g.addNode({ - id: WATERMARKER, type: 'img_watermark', + id: getPrefixedId('watermarker'), }); g.addEdge(imageOutput, 'image', watermark, 'image'); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts index 53f8875152..b0aad315ca 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts @@ -1,22 +1,9 @@ import { logger } from 'app/logging/logger'; import type { RootState } from 'app/store/store'; import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; +import { getPrefixedId } from 'features/controlLayers/konva/util'; import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; -import { - CANVAS_OUTPUT, - CLIP_SKIP, - CONTROL_LAYERS_GRAPH, - DENOISE_LATENTS, - LATENTS_TO_IMAGE, - MAIN_MODEL_LOADER, - NEGATIVE_CONDITIONING, - NEGATIVE_CONDITIONING_COLLECT, - NOISE, - POSITIVE_CONDITIONING, - POSITIVE_CONDITIONING_COLLECT, - VAE_LOADER, -} from 'features/nodes/util/graph/constants'; -import { addControlAdapters } from 'features/nodes/util/graph/generation/addControlAdapters'; +import { addControlNets, addT2IAdapters } from 'features/nodes/util/graph/generation/addControlAdapters'; import { addImageToImage } from 'features/nodes/util/graph/generation/addImageToImage'; import { addInpaint } from 'features/nodes/util/graph/generation/addInpaint'; // import { addHRF } from 'features/nodes/util/graph/generation/addHRF'; @@ -37,7 +24,10 @@ import { addRegions } from './addRegions'; const log = logger('system'); -export const buildSD1Graph = async (state: RootState, manager: CanvasManager): Promise => { +export const buildSD1Graph = async ( + state: RootState, + manager: CanvasManager +): Promise<{ g: Graph; noise: Invocation<'noise'>; posCond: Invocation<'compel'> }> => { const generationMode = manager.compositor.getGenerationMode(); log.debug({ generationMode }, 'Building SD1/SD2 graph'); @@ -62,38 +52,38 @@ export const buildSD1Graph = async (state: RootState, manager: CanvasManager): P const { originalSize, scaledSize } = getSizes(bbox); - const g = new Graph(CONTROL_LAYERS_GRAPH); + const g = new Graph(getPrefixedId('sd1_graph')); const modelLoader = g.addNode({ type: 'main_model_loader', - id: MAIN_MODEL_LOADER, + id: getPrefixedId('sd1_model_loader'), model, }); const clipSkip = g.addNode({ type: 'clip_skip', - id: CLIP_SKIP, + id: getPrefixedId('clip_skip'), skipped_layers, }); const posCond = g.addNode({ type: 'compel', - id: POSITIVE_CONDITIONING, + id: getPrefixedId('pos_cond'), prompt: positivePrompt, }); const posCondCollect = g.addNode({ type: 'collect', - id: POSITIVE_CONDITIONING_COLLECT, + id: getPrefixedId('pos_cond_collect'), }); const negCond = g.addNode({ type: 'compel', - id: NEGATIVE_CONDITIONING, + id: getPrefixedId('neg_cond'), prompt: negativePrompt, }); const negCondCollect = g.addNode({ type: 'collect', - id: NEGATIVE_CONDITIONING_COLLECT, + id: getPrefixedId('neg_cond_collect'), }); const noise = g.addNode({ type: 'noise', - id: NOISE, + id: getPrefixedId('noise'), seed, width: scaledSize.width, height: scaledSize.height, @@ -101,7 +91,7 @@ export const buildSD1Graph = async (state: RootState, manager: CanvasManager): P }); const denoise = g.addNode({ type: 'denoise_latents', - id: DENOISE_LATENTS, + id: getPrefixedId('denoise_latents'), cfg_scale, cfg_rescale_multiplier, scheduler, @@ -111,14 +101,14 @@ export const buildSD1Graph = async (state: RootState, manager: CanvasManager): P }); const l2i = g.addNode({ type: 'l2i', - id: LATENTS_TO_IMAGE, + id: getPrefixedId('l2i'), fp32: vaePrecision === 'fp32', }); const vaeLoader = vae?.base === model.base ? g.addNode({ type: 'vae_loader', - id: VAE_LOADER, + id: getPrefixedId('vae'), vae_model: vae, }) : null; @@ -214,16 +204,49 @@ export const buildSD1Graph = async (state: RootState, manager: CanvasManager): P ); } - const _addedCAs = await addControlAdapters( + const controlNetCollector = g.addNode({ + type: 'collect', + id: getPrefixedId('control_net_collector'), + }); + const controlNetResult = await addControlNets( manager, state.canvasV2.controlLayers.entities, g, state.canvasV2.bbox.rect, - denoise, + controlNetCollector, modelConfig.base ); - const _addedIPAs = addIPAdapters(state.canvasV2.ipAdapters.entities, g, denoise, modelConfig.base); - const _addedRegions = await addRegions( + if (controlNetResult.addedControlNets > 0) { + g.addEdge(controlNetCollector, 'collection', denoise, 'control'); + } else { + g.deleteNode(controlNetCollector.id); + } + + const t2iAdapterCollector = g.addNode({ + type: 'collect', + id: getPrefixedId('t2i_adapter_collector'), + }); + const t2iAdapterResult = await addT2IAdapters( + manager, + state.canvasV2.controlLayers.entities, + g, + state.canvasV2.bbox.rect, + controlNetCollector, + modelConfig.base + ); + if (t2iAdapterResult.addedT2IAdapters > 0) { + g.addEdge(t2iAdapterCollector, 'collection', denoise, 't2i_adapter'); + } else { + g.deleteNode(t2iAdapterCollector.id); + } + + const ipAdapterCollector = g.addNode({ + type: 'collect', + id: getPrefixedId('ip_adapter_collector'), + }); + const ipAdapterResult = addIPAdapters(state.canvasV2.ipAdapters.entities, g, ipAdapterCollector, modelConfig.base); + + const regionsResult = await addRegions( manager, state.canvasV2.regions.entities, g, @@ -233,13 +256,17 @@ export const buildSD1Graph = async (state: RootState, manager: CanvasManager): P posCond, negCond, posCondCollect, - negCondCollect + negCondCollect, + ipAdapterCollector ); - // const isHRFAllowed = !addedLayers.some((l) => isInitialImageLayer(l) || isRegionalGuidanceLayer(l)); - // if (isHRFAllowed && state.hrf.hrfEnabled) { - // imageOutput = addHRF(state, g, denoise, noise, l2i, vaeSource); - // } + const totalIPAdaptersAdded = + ipAdapterResult.addedIPAdapters + regionsResult.reduce((acc, r) => acc + r.addedIPAdapters, 0); + if (totalIPAdaptersAdded > 0) { + g.addEdge(ipAdapterCollector, 'collection', denoise, 'ip_adapter'); + } else { + g.deleteNode(ipAdapterCollector.id); + } if (state.system.shouldUseNSFWChecker) { canvasOutput = addNSFWChecker(g, canvasOutput); @@ -252,12 +279,12 @@ export const buildSD1Graph = async (state: RootState, manager: CanvasManager): P const shouldSaveToGallery = session.mode === 'generate' || settings.autoSave; g.updateNode(canvasOutput, { - id: CANVAS_OUTPUT, + id: getPrefixedId('canvas_output'), is_intermediate: !shouldSaveToGallery, use_cache: false, board: getBoardField(state), }); g.setMetadataReceivingNode(canvasOutput); - return g; + return { g, noise, posCond }; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts index d55d632c6a..e1384a8826 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts @@ -1,21 +1,9 @@ import { logger } from 'app/logging/logger'; import type { RootState } from 'app/store/store'; import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; +import { getPrefixedId } from 'features/controlLayers/konva/util'; import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; -import { - CANVAS_OUTPUT, - LATENTS_TO_IMAGE, - NEGATIVE_CONDITIONING, - NEGATIVE_CONDITIONING_COLLECT, - NOISE, - POSITIVE_CONDITIONING, - POSITIVE_CONDITIONING_COLLECT, - SDXL_CONTROL_LAYERS_GRAPH, - SDXL_DENOISE_LATENTS, - SDXL_MODEL_LOADER, - VAE_LOADER, -} from 'features/nodes/util/graph/constants'; -import { addControlAdapters } from 'features/nodes/util/graph/generation/addControlAdapters'; +import { addControlNets, addT2IAdapters } from 'features/nodes/util/graph/generation/addControlAdapters'; import { addImageToImage } from 'features/nodes/util/graph/generation/addImageToImage'; import { addInpaint } from 'features/nodes/util/graph/generation/addInpaint'; import { addIPAdapters } from 'features/nodes/util/graph/generation/addIPAdapters'; @@ -36,7 +24,10 @@ import { addRegions } from './addRegions'; const log = logger('system'); -export const buildSDXLGraph = async (state: RootState, manager: CanvasManager): Promise => { +export const buildSDXLGraph = async ( + state: RootState, + manager: CanvasManager +): Promise<{ g: Graph; noise: Invocation<'noise'>; posCond: Invocation<'sdxl_compel_prompt'> }> => { const generationMode = manager.compositor.getGenerationMode(); log.debug({ generationMode }, 'Building SDXL graph'); @@ -62,35 +53,35 @@ export const buildSDXLGraph = async (state: RootState, manager: CanvasManager): const { positivePrompt, negativePrompt, positiveStylePrompt, negativeStylePrompt } = getPresetModifiedPrompts(state); - const g = new Graph(SDXL_CONTROL_LAYERS_GRAPH); + const g = new Graph(getPrefixedId('sdxl_graph')); const modelLoader = g.addNode({ type: 'sdxl_model_loader', - id: SDXL_MODEL_LOADER, + id: getPrefixedId('sdxl_model_loader'), model, }); const posCond = g.addNode({ type: 'sdxl_compel_prompt', - id: POSITIVE_CONDITIONING, + id: getPrefixedId('pos_cond'), prompt: positivePrompt, style: positiveStylePrompt, }); const posCondCollect = g.addNode({ type: 'collect', - id: POSITIVE_CONDITIONING_COLLECT, + id: getPrefixedId('pos_cond_collect'), }); const negCond = g.addNode({ type: 'sdxl_compel_prompt', - id: NEGATIVE_CONDITIONING, + id: getPrefixedId('neg_cond'), prompt: negativePrompt, style: negativeStylePrompt, }); const negCondCollect = g.addNode({ type: 'collect', - id: NEGATIVE_CONDITIONING_COLLECT, + id: getPrefixedId('neg_cond_collect'), }); const noise = g.addNode({ type: 'noise', - id: NOISE, + id: getPrefixedId('noise'), seed, width: scaledSize.width, height: scaledSize.height, @@ -98,7 +89,7 @@ export const buildSDXLGraph = async (state: RootState, manager: CanvasManager): }); const denoise = g.addNode({ type: 'denoise_latents', - id: SDXL_DENOISE_LATENTS, + id: getPrefixedId('denoise_latents'), cfg_scale, cfg_rescale_multiplier, scheduler, @@ -108,14 +99,14 @@ export const buildSDXLGraph = async (state: RootState, manager: CanvasManager): }); const l2i = g.addNode({ type: 'l2i', - id: LATENTS_TO_IMAGE, + id: getPrefixedId('l2i'), fp32: vaePrecision === 'fp32', }); const vaeLoader = vae?.base === model.base ? g.addNode({ type: 'vae_loader', - id: VAE_LOADER, + id: getPrefixedId('vae'), vae_model: vae, }) : null; @@ -216,16 +207,47 @@ export const buildSDXLGraph = async (state: RootState, manager: CanvasManager): ); } - const _addedCAs = await addControlAdapters( + const controlNetCollector = g.createNode({ + type: 'collect', + id: getPrefixedId('control_net_collector'), + }); + const controlNetResult = await addControlNets( manager, state.canvasV2.controlLayers.entities, g, state.canvasV2.bbox.rect, - denoise, + controlNetCollector, modelConfig.base ); - const _addedIPAs = addIPAdapters(state.canvasV2.ipAdapters.entities, g, denoise, modelConfig.base); - const _addedRegions = await addRegions( + if (controlNetResult.addedControlNets > 0) { + g.addNode(controlNetCollector); + g.addEdge(controlNetCollector, 'collection', denoise, 'control'); + } + + const t2iAdapterCollector = g.createNode({ + type: 'collect', + id: getPrefixedId('t2i_adapter_collector'), + }); + const t2iAdapterResult = await addT2IAdapters( + manager, + state.canvasV2.controlLayers.entities, + g, + state.canvasV2.bbox.rect, + controlNetCollector, + modelConfig.base + ); + if (t2iAdapterResult.addedT2IAdapters > 0) { + g.addNode(t2iAdapterCollector); + g.addEdge(t2iAdapterCollector, 'collection', denoise, 't2i_adapter'); + } + + const ipAdapterCollector = g.createNode({ + type: 'collect', + id: getPrefixedId('ip_adapter_collector'), + }); + const ipAdapterResult = addIPAdapters(state.canvasV2.ipAdapters.entities, g, ipAdapterCollector, modelConfig.base); + + const regionsResult = await addRegions( manager, state.canvasV2.regions.entities, g, @@ -235,9 +257,17 @@ export const buildSDXLGraph = async (state: RootState, manager: CanvasManager): posCond, negCond, posCondCollect, - negCondCollect + negCondCollect, + ipAdapterCollector ); + const totalIPAdaptersAdded = + ipAdapterResult.addedIPAdapters + regionsResult.reduce((acc, r) => acc + r.addedIPAdapters, 0); + if (totalIPAdaptersAdded > 0) { + g.addNode(ipAdapterCollector); + g.addEdge(ipAdapterCollector, 'collection', denoise, 'ip_adapter'); + } + if (state.system.shouldUseNSFWChecker) { canvasOutput = addNSFWChecker(g, canvasOutput); } @@ -249,12 +279,12 @@ export const buildSDXLGraph = async (state: RootState, manager: CanvasManager): const shouldSaveToGallery = session.mode === 'generate' || settings.autoSave; g.updateNode(canvasOutput, { - id: CANVAS_OUTPUT, + id: getPrefixedId('canvas_output'), is_intermediate: !shouldSaveToGallery, use_cache: false, board: getBoardField(state), }); g.setMetadataReceivingNode(canvasOutput); - return g; + return { g, noise, posCond }; }; diff --git a/invokeai/frontend/web/src/services/events/onInvocationComplete.ts b/invokeai/frontend/web/src/services/events/onInvocationComplete.ts index b87713379e..5b737b41ff 100644 --- a/invokeai/frontend/web/src/services/events/onInvocationComplete.ts +++ b/invokeai/frontend/web/src/services/events/onInvocationComplete.ts @@ -8,18 +8,21 @@ import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks import { zNodeStatus } from 'features/nodes/types/invocation'; import { boardsApi } from 'services/api/endpoints/boards'; import { getImageDTO, imagesApi } from 'services/api/endpoints/images'; -import type { ImageDTO } from 'services/api/types'; +import type { ImageDTO, S } from 'services/api/types'; import { getCategories, getListImagesUrl } from 'services/api/util'; -import type { InvocationCompleteEvent, InvocationDenoiseProgressEvent } from 'services/events/types'; const log = logger('events'); +const isCanvasOutput = (data: S['InvocationCompleteEvent']) => { + return data.invocation_source_id.split(':')[0] === 'canvas_output'; +}; + export const buildOnInvocationComplete = ( getState: () => RootState, dispatch: AppDispatch, nodeTypeDenylist: string[], - setLastProgressEvent: (event: InvocationDenoiseProgressEvent | null) => void, - setLastCanvasProgressEvent: (event: InvocationDenoiseProgressEvent | null) => void + setLastProgressEvent: (event: S['InvocationDenoiseProgressEvent'] | null) => void, + setLastCanvasProgressEvent: (event: S['InvocationDenoiseProgressEvent'] | null) => void ) => { const addImageToGallery = (imageDTO: ImageDTO) => { if (imageDTO.is_intermediate) { @@ -80,64 +83,87 @@ export const buildOnInvocationComplete = ( } }; - return async (data: InvocationCompleteEvent) => { + const getResultImageDTO = (data: S['InvocationCompleteEvent']) => { + const { result } = data; + if (result.type === 'image_output') { + return getImageDTO(result.image.image_name); + } else if (result.type === 'canvas_v2_mask_and_crop_output') { + return getImageDTO(result.image.image_name); + } + return null; + }; + + const handleOriginWorkflows = async (data: S['InvocationCompleteEvent']) => { + const { result, invocation_source_id } = data; + + const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]); + if (nes) { + nes.status = zNodeStatus.enum.COMPLETED; + if (nes.progress !== null) { + nes.progress = 1; + } + nes.outputs.push(result); + upsertExecutionState(nes.nodeId, nes); + } + + const imageDTO = await getResultImageDTO(data); + + if (imageDTO) { + addImageToGallery(imageDTO); + } + }; + + const handleOriginCanvas = async (data: S['InvocationCompleteEvent']) => { + const session = getState().canvasV2.session; + + const imageDTO = await getResultImageDTO(data); + + if (!imageDTO) { + return; + } + + if (session.mode === 'compose') { + if (session.isStaging && isCanvasOutput(data)) { + if (data.result.type === 'canvas_v2_mask_and_crop_output') { + const { offset_x, offset_y } = data.result; + if (session.isStaging) { + dispatch(sessionImageStaged({ stagingAreaImage: { imageDTO, offsetX: offset_x, offsetY: offset_y } })); + } + } else if (data.result.type === 'image_output') { + if (session.isStaging) { + dispatch(sessionImageStaged({ stagingAreaImage: { imageDTO, offsetX: 0, offsetY: 0 } })); + } + } + } + } else { + // session.mode === 'generate' + setLastCanvasProgressEvent(null); + } + + addImageToGallery(imageDTO); + }; + + const handleOriginOther = async (data: S['InvocationCompleteEvent']) => { + const imageDTO = await getResultImageDTO(data); + + if (imageDTO) { + addImageToGallery(imageDTO); + } + }; + + return async (data: S['InvocationCompleteEvent']) => { log.debug( { data } as SerializableObject, `Invocation complete (${data.invocation.type}, ${data.invocation_source_id})` ); - const { result, invocation_source_id } = data; - // Update the node execution states - the image output is handled below if (data.origin === 'workflows') { - const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]); - if (nes) { - nes.status = zNodeStatus.enum.COMPLETED; - if (nes.progress !== null) { - nes.progress = 1; - } - nes.outputs.push(result); - upsertExecutionState(nes.nodeId, nes); - } - } - - // This complete event has an associated image output - if ( - (data.result.type === 'image_output' || data.result.type === 'canvas_v2_mask_and_crop_output') && - !nodeTypeDenylist.includes(data.invocation.type) - ) { - const { image_name } = data.result.image; - const { session } = getState().canvasV2; - - const imageDTO = await getImageDTO(image_name); - - if (!imageDTO) { - log.error({ data } as SerializableObject, 'Failed to fetch image DTO after generation'); - return; - } - - if (data.origin === 'canvas') { - if (data.invocation_source_id !== 'canvas_output') { - // Not a canvas output image - ignore - return; - } - if (session.mode === 'compose' && session.isStaging) { - if (data.result.type === 'canvas_v2_mask_and_crop_output') { - const { offset_x, offset_y } = data.result; - if (session.isStaging) { - dispatch(sessionImageStaged({ stagingAreaImage: { imageDTO, offsetX: offset_x, offsetY: offset_y } })); - } - } else if (data.result.type === 'image_output') { - if (session.isStaging) { - dispatch(sessionImageStaged({ stagingAreaImage: { imageDTO, offsetX: 0, offsetY: 0 } })); - } - } - addImageToGallery(imageDTO); - } else { - addImageToGallery(imageDTO); - setLastCanvasProgressEvent(null); - } - } + await handleOriginWorkflows(data); + } else if (data.origin === 'canvas') { + await handleOriginCanvas(data); + } else { + await handleOriginOther(data); } setLastProgressEvent(null); diff --git a/invokeai/frontend/web/src/services/events/setEventListeners.tsx b/invokeai/frontend/web/src/services/events/setEventListeners.tsx index 9468dd707f..9af12891c7 100644 --- a/invokeai/frontend/web/src/services/events/setEventListeners.tsx +++ b/invokeai/frontend/web/src/services/events/setEventListeners.tsx @@ -17,8 +17,9 @@ import { atom, computed } from 'nanostores'; import { api, LIST_TAG } from 'services/api'; import { modelsApi } from 'services/api/endpoints/models'; import { queueApi, queueItemsAdapter } from 'services/api/endpoints/queue'; +import type { S } from 'services/api/types'; import { buildOnInvocationComplete } from 'services/events/onInvocationComplete'; -import type { ClientToServerEvents, InvocationDenoiseProgressEvent, ServerToClientEvents } from 'services/events/types'; +import type { ClientToServerEvents, ServerToClientEvents } from 'services/events/types'; import type { Socket } from 'socket.io-client'; export const socketConnected = createAction('socket/connected'); @@ -34,8 +35,8 @@ type SetEventListenersArg = { const selectModelInstalls = modelsApi.endpoints.listModelInstalls.select(); const nodeTypeDenylist = ['load_image', 'image']; -export const $lastProgressEvent = atom(null); -export const $lastCanvasProgressEvent = atom(null); +export const $lastProgressEvent = atom(null); +export const $lastCanvasProgressEvent = atom(null); export const $hasProgress = computed($lastProgressEvent, (val) => Boolean(val)); export const $progressImage = computed($lastProgressEvent, (val) => val?.progress_image ?? null); const cancellations = new Set(); diff --git a/invokeai/frontend/web/src/services/events/types.ts b/invokeai/frontend/web/src/services/events/types.ts index 77631a3606..4d11490a8a 100644 --- a/invokeai/frontend/web/src/services/events/types.ts +++ b/invokeai/frontend/web/src/services/events/types.ts @@ -1,69 +1,35 @@ import type { S } from 'services/api/types'; -type ModelLoadStartedEvent = S['ModelLoadStartedEvent']; -type ModelLoadCompleteEvent = S['ModelLoadCompleteEvent']; - -type InvocationStartedEvent = S['InvocationStartedEvent']; -type InvocationDenoiseProgressEvent = S['InvocationDenoiseProgressEvent']; -type InvocationCompleteEvent = S['InvocationCompleteEvent']; -type InvocationErrorEvent = S['InvocationErrorEvent']; - -type ModelInstallDownloadStartedEvent = S['ModelInstallDownloadStartedEvent']; -type ModelInstallDownloadProgressEvent = S['ModelInstallDownloadProgressEvent']; -type ModelInstallDownloadsCompleteEvent = S['ModelInstallDownloadsCompleteEvent']; -type ModelInstallCompleteEvent = S['ModelInstallCompleteEvent']; -type ModelInstallErrorEvent = S['ModelInstallErrorEvent']; -type ModelInstallStartedEvent = S['ModelInstallStartedEvent']; -type ModelInstallCancelledEvent = S['ModelInstallCancelledEvent']; - -type DownloadStartedEvent = S['DownloadStartedEvent']; -type DownloadProgressEvent = S['DownloadProgressEvent']; -type DownloadCompleteEvent = S['DownloadCompleteEvent']; -type DownloadCancelledEvent = S['DownloadCancelledEvent']; -type DownloadErrorEvent = S['DownloadErrorEvent']; - -type QueueItemStatusChangedEvent = S['QueueItemStatusChangedEvent']; -type QueueClearedEvent = S['QueueClearedEvent']; -type BatchEnqueuedEvent = S['BatchEnqueuedEvent']; - -type BulkDownloadStartedEvent = S['BulkDownloadStartedEvent']; -type BulkDownloadCompleteEvent = S['BulkDownloadCompleteEvent']; -type BulkDownloadFailedEvent = S['BulkDownloadErrorEvent']; - -type ClientEmitSubscribeQueue = { - queue_id: string; -}; +type ClientEmitSubscribeQueue = { queue_id: string }; type ClientEmitUnsubscribeQueue = ClientEmitSubscribeQueue; -type ClientEmitSubscribeBulkDownload = { - bulk_download_id: string; -}; +type ClientEmitSubscribeBulkDownload = { bulk_download_id: string }; type ClientEmitUnsubscribeBulkDownload = ClientEmitSubscribeBulkDownload; export type ServerToClientEvents = { - invocation_denoise_progress: (payload: InvocationDenoiseProgressEvent) => void; - invocation_complete: (payload: InvocationCompleteEvent) => void; - invocation_error: (payload: InvocationErrorEvent) => void; - invocation_started: (payload: InvocationStartedEvent) => void; - download_started: (payload: DownloadStartedEvent) => void; - download_progress: (payload: DownloadProgressEvent) => void; - download_complete: (payload: DownloadCompleteEvent) => void; - download_cancelled: (payload: DownloadCancelledEvent) => void; - download_error: (payload: DownloadErrorEvent) => void; - model_load_started: (payload: ModelLoadStartedEvent) => void; - model_install_started: (payload: ModelInstallStartedEvent) => void; - model_install_download_started: (payload: ModelInstallDownloadStartedEvent) => void; - model_install_download_progress: (payload: ModelInstallDownloadProgressEvent) => void; - model_install_downloads_complete: (payload: ModelInstallDownloadsCompleteEvent) => void; - model_install_complete: (payload: ModelInstallCompleteEvent) => void; - model_install_error: (payload: ModelInstallErrorEvent) => void; - model_install_cancelled: (payload: ModelInstallCancelledEvent) => void; - model_load_complete: (payload: ModelLoadCompleteEvent) => void; - queue_item_status_changed: (payload: QueueItemStatusChangedEvent) => void; - queue_cleared: (payload: QueueClearedEvent) => void; - batch_enqueued: (payload: BatchEnqueuedEvent) => void; - bulk_download_started: (payload: BulkDownloadStartedEvent) => void; - bulk_download_complete: (payload: BulkDownloadCompleteEvent) => void; - bulk_download_error: (payload: BulkDownloadFailedEvent) => void; + invocation_denoise_progress: (payload: S['InvocationDenoiseProgressEvent']) => void; + invocation_complete: (payload: S['InvocationCompleteEvent']) => void; + invocation_error: (payload: S['InvocationErrorEvent']) => void; + invocation_started: (payload: S['InvocationStartedEvent']) => void; + download_started: (payload: S['DownloadStartedEvent']) => void; + download_progress: (payload: S['DownloadProgressEvent']) => void; + download_complete: (payload: S['DownloadCompleteEvent']) => void; + download_cancelled: (payload: S['DownloadCancelledEvent']) => void; + download_error: (payload: S['DownloadErrorEvent']) => void; + model_load_started: (payload: S['ModelLoadStartedEvent']) => void; + model_install_started: (payload: S['ModelInstallStartedEvent']) => void; + model_install_download_started: (payload: S['ModelInstallDownloadStartedEvent']) => void; + model_install_download_progress: (payload: S['ModelInstallDownloadProgressEvent']) => void; + model_install_downloads_complete: (payload: S['ModelInstallDownloadsCompleteEvent']) => void; + model_install_complete: (payload: S['ModelInstallCompleteEvent']) => void; + model_install_error: (payload: S['ModelInstallErrorEvent']) => void; + model_install_cancelled: (payload: S['ModelInstallCancelledEvent']) => void; + model_load_complete: (payload: S['ModelLoadCompleteEvent']) => void; + queue_item_status_changed: (payload: S['QueueItemStatusChangedEvent']) => void; + queue_cleared: (payload: S['QueueClearedEvent']) => void; + batch_enqueued: (payload: S['BatchEnqueuedEvent']) => void; + bulk_download_started: (payload: S['BulkDownloadStartedEvent']) => void; + bulk_download_complete: (payload: S['BulkDownloadCompleteEvent']) => void; + bulk_download_error: (payload: S['BulkDownloadErrorEvent']) => void; }; export type ClientToServerEvents = {