diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasManager.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasManager.ts index 22d5413c8a..7f42b6d566 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasManager.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasManager.ts @@ -2,6 +2,7 @@ import type { Store } from '@reduxjs/toolkit'; import { logger } from 'app/logging/logger'; import type { RootState } from 'app/store/store'; import { + getControlAdapterImage, getGenerationMode, getImageSourceImage, getInpaintMaskImage, @@ -369,6 +370,10 @@ export class CanvasManager { return getGenerationMode({ manager: this }); } + getControlAdapterImage(arg: Omit[0], 'manager'>) { + return getControlAdapterImage({ ...arg, manager: this }); + } + getRegionMaskImage(arg: Omit[0], 'manager'>) { return getRegionMaskImage({ ...arg, manager: this }); } diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/util.ts b/invokeai/frontend/web/src/features/controlLayers/konva/util.ts index 912cc168c2..236a5e864a 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/util.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/util.ts @@ -319,6 +319,24 @@ export function getRegionMaskLayerClone(arg: { manager: CanvasManager; id: strin return layerClone; } +export function getControlAdapterLayerClone(arg: { manager: CanvasManager; id: string }): Konva.Layer { + const { id, manager } = arg; + + const controlAdapter = manager.controlAdapters.get(id); + assert(controlAdapter, `Canvas region with id ${id} not found`); + + const controlAdapterClone = controlAdapter.layer.clone(); + const objectGroupClone = controlAdapter.group.clone(); + + controlAdapterClone.destroyChildren(); + controlAdapterClone.add(objectGroupClone); + + objectGroupClone.opacity(1); + objectGroupClone.cache(); + + return controlAdapterClone; +} + export function getCompositeLayerStageClone(arg: { manager: CanvasManager }): Konva.Stage { const { manager } = arg; @@ -406,6 +424,37 @@ export async function getRegionMaskImage(arg: { return imageDTO; } +export async function getControlAdapterImage(arg: { + manager: CanvasManager; + id: string; + bbox?: Rect; + preview?: boolean; +}): Promise { + const { manager, id, bbox, preview = false } = arg; + const ca = manager.stateApi.getControlAdaptersState().entities.find((entity) => entity.id === id); + assert(ca, `Control adapter entity state with id ${id} not found`); + + // if (region.imageCache) { + // const imageDTO = await this.util.getImageDTO(region.imageCache.name); + // if (imageDTO) { + // return imageDTO; + // } + // } + + const layerClone = getControlAdapterLayerClone({ id, manager }); + const blob = await konvaNodeToBlob(layerClone, bbox); + + if (preview) { + previewBlob(blob, `region ${ca.id} mask`); + } + + layerClone.destroy(); + + const imageDTO = await manager.util.uploadImage(blob, `${ca.id}_control_image.png`, 'control', true); + // manager.stateApi.onRegionMaskImageCached(ca.id, imageDTO); + return imageDTO; +} + export async function getInpaintMaskImage(arg: { manager: CanvasManager; bbox?: Rect; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addControlAdapters.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addControlAdapters.ts index 3759a0822b..7e55beca0a 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addControlAdapters.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addControlAdapters.ts @@ -1,8 +1,10 @@ +import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; import type { ControlAdapterEntity, ControlNetData, ImageWithDims, ProcessorConfig, + Rect, T2IAdapterData, } from 'features/controlLayers/store/types'; import type { ImageField } from 'features/nodes/types/common'; @@ -11,18 +13,20 @@ import type { Graph } from 'features/nodes/util/graph/generation/Graph'; import type { BaseModelType, Invocation } from 'services/api/types'; import { assert } from 'tsafe'; -export const addControlAdapters = ( +export const addControlAdapters = async ( + manager: CanvasManager, controlAdapters: ControlAdapterEntity[], g: Graph, + bbox: Rect, denoise: Invocation<'denoise_latents'>, base: BaseModelType -): ControlAdapterEntity[] => { +): Promise => { const validControlAdapters = controlAdapters.filter((ca) => isValidControlAdapter(ca, base)); for (const ca of validControlAdapters) { if (ca.adapterType === 'controlnet') { - addControlNetToGraph(ca, g, denoise); + await addControlNetToGraph(manager, ca, g, bbox, denoise); } else { - addT2IAdapterToGraph(ca, g, denoise); + await addT2IAdapterToGraph(manager, ca, g, bbox, denoise); } } return validControlAdapters; @@ -45,14 +49,17 @@ const addControlNetCollectorSafe = (g: Graph, denoise: Invocation<'denoise_laten } }; -const addControlNetToGraph = (ca: ControlNetData, g: Graph, denoise: Invocation<'denoise_latents'>) => { - const { id, beginEndStepPct, controlMode, imageObject, model, processedImageObject, processorConfig, weight } = ca; +const addControlNetToGraph = async ( + manager: CanvasManager, + ca: ControlNetData, + g: Graph, + bbox: Rect, + denoise: Invocation<'denoise_latents'> +) => { + const { id, beginEndStepPct, controlMode, model, weight } = ca; assert(model, 'ControlNet model is required'); - const controlImage = buildControlImage( - imageObject?.image ?? null, - processedImageObject?.image ?? null, - processorConfig - ); + const { image_name } = await manager.getControlAdapterImage({ id: ca.id, bbox, preview: true }); + const controlNetCollect = addControlNetCollectorSafe(g, denoise); const controlNet = g.addNode({ @@ -64,7 +71,7 @@ const addControlNetToGraph = (ca: ControlNetData, g: Graph, denoise: Invocation< resize_mode: 'just_resize', control_model: model, control_weight: weight, - image: controlImage, + image: { image_name }, }); g.addEdge(controlNet, 'control', controlNetCollect, 'item'); }; @@ -87,14 +94,17 @@ const addT2IAdapterCollectorSafe = (g: Graph, denoise: Invocation<'denoise_laten } }; -const addT2IAdapterToGraph = (ca: T2IAdapterData, g: Graph, denoise: Invocation<'denoise_latents'>) => { - const { id, beginEndStepPct, imageObject, model, processedImageObject, processorConfig, weight } = ca; +const addT2IAdapterToGraph = async ( + manager: CanvasManager, + ca: T2IAdapterData, + g: Graph, + bbox: Rect, + denoise: Invocation<'denoise_latents'> +) => { + const { id, beginEndStepPct, model, weight } = ca; assert(model, 'T2I Adapter model is required'); - const controlImage = buildControlImage( - imageObject?.image ?? null, - processedImageObject?.image ?? null, - processorConfig - ); + const { image_name } = await manager.getControlAdapterImage({ id: ca.id, bbox, preview: true }); + const t2iAdapterCollect = addT2IAdapterCollectorSafe(g, denoise); const t2iAdapter = g.addNode({ @@ -105,7 +115,7 @@ const addT2IAdapterToGraph = (ca: T2IAdapterData, g: Graph, denoise: Invocation< resize_mode: 'just_resize', t2i_adapter_model: model, weight: weight, - image: controlImage, + image: { image_name }, }); g.addEdge(t2iAdapter, 't2i_adapter', t2iAdapterCollect, 'item'); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts index 6966feef9e..30cbd48f9e 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts @@ -210,7 +210,14 @@ export const buildSD1Graph = async (state: RootState, manager: CanvasManager): P ); } - const _addedCAs = addControlAdapters(state.canvasV2.controlAdapters.entities, g, denoise, modelConfig.base); + const _addedCAs = await addControlAdapters( + manager, + state.canvasV2.controlAdapters.entities, + g, + state.canvasV2.bbox, + denoise, + modelConfig.base + ); const _addedIPAs = addIPAdapters(state.canvasV2.ipAdapters.entities, g, denoise, modelConfig.base); const _addedRegions = await addRegions( manager, diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts index 9177e9e745..2233445a25 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts @@ -214,7 +214,14 @@ export const buildSDXLGraph = async (state: RootState, manager: CanvasManager): ); } - const _addedCAs = addControlAdapters(state.canvasV2.controlAdapters.entities, g, denoise, modelConfig.base); + const _addedCAs = await addControlAdapters( + manager, + state.canvasV2.controlAdapters.entities, + g, + state.canvasV2.bbox, + denoise, + modelConfig.base + ); const _addedIPAs = addIPAdapters(state.canvasV2.ipAdapters.entities, g, denoise, modelConfig.base); const _addedRegions = await addRegions( manager,