diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/renderers/bbox.ts b/invokeai/frontend/web/src/features/controlLayers/konva/renderers/bbox.ts index c14d643657..5c226d017a 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/renderers/bbox.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/renderers/bbox.ts @@ -216,7 +216,7 @@ export const updateBboxes = ( onBboxChanged({ id: entityState.id, bbox: getLayerBboxPixels(konvaLayer, filterLayerChildren) }, 'layer'); } } else if (entityState.type === 'control_adapter') { - if (!entityState.image && !entityState.processedImage) { + if (!entityState.imageObject && !entityState.processedImageObject) { // No objects - no bbox to calculate onBboxChanged({ id: entityState.id, bbox: null }, 'control_adapter'); } else { diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/renderers/objects.ts b/invokeai/frontend/web/src/features/controlLayers/konva/renderers/objects.ts index d1f46db6ce..5e5df4e96a 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/renderers/objects.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/renderers/objects.ts @@ -8,9 +8,9 @@ import { } from 'features/controlLayers/konva/naming'; import type { BrushLineObjectRecord, - KonvaEntityAdapter, EraserLineObjectRecord, ImageObjectRecord, + KonvaEntityAdapter, RectShapeObjectRecord, } from 'features/controlLayers/konva/nodeManager'; import type { diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/renderers/renderer.ts b/invokeai/frontend/web/src/features/controlLayers/konva/renderers/renderer.ts index a2533f78a9..dcf283984d 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/renderers/renderer.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/renderers/renderer.ts @@ -53,8 +53,11 @@ import type { import type Konva from 'konva'; import type { IRect, Vector2d } from 'konva/lib/types'; import { debounce } from 'lodash-es'; +import { atom } from 'nanostores'; import type { RgbaColor } from 'react-colorful'; +export const $nodeManager = atom(null); + /** * Initializes the canvas renderer. It subscribes to the redux store and listens for changes directly, bypassing the * react rendering cycle entirely, improving canvas performance. @@ -249,6 +252,8 @@ export const initializeRenderer = ( }; const manager = new KonvaNodeManager(stage, getBbox, onBboxTransformed, $shift.get, $ctrl.get, $meta.get, $alt.get); + console.log(manager); + $nodeManager.set(manager); const cleanupListeners = setStageEventHandlers({ manager, @@ -344,7 +349,7 @@ export const initializeRenderer = ( canvasV2.controlAdapters !== prevCanvasV2.controlAdapters || canvasV2.regions !== prevCanvasV2.regions ) { - logIfDebugging('Updating entity bboxes'); + // logIfDebugging('Updating entity bboxes'); // debouncedUpdateBboxes(stage, canvasV2.layers, canvasV2.controlAdapters, canvasV2.regions, onBboxChanged); } diff --git a/invokeai/frontend/web/src/features/controlLayers/store/canvasV2Slice.ts b/invokeai/frontend/web/src/features/controlLayers/store/canvasV2Slice.ts index 89f26e224f..3de6e777de 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/canvasV2Slice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/canvasV2Slice.ts @@ -16,9 +16,10 @@ import { toolReducers } from 'features/controlLayers/store/toolReducers'; import { initialAspectRatioState } from 'features/parameters/components/ImageSize/constants'; import type { AspectRatioState } from 'features/parameters/components/ImageSize/types'; import { atom } from 'nanostores'; +import type { ImageDTO } from 'services/api/types'; import type { CanvasEntityIdentifier, CanvasV2State, StageAttrs } from './types'; -import { DEFAULT_RGBA_COLOR } from './types'; +import { DEFAULT_RGBA_COLOR, imageDTOToImageWithDims } from './types'; const initialState: CanvasV2State = { _version: 3, @@ -119,6 +120,7 @@ const initialState: CanvasV2State = { refinerNegativeAestheticScore: 2.5, refinerStart: 0.8, }, + baseLayerImageCache: null, }; export const canvasV2Slice = createSlice({ @@ -164,6 +166,10 @@ export const canvasV2Slice = createSlice({ state.layers = []; state.ipAdapters = []; state.controlAdapters = []; + state.baseLayerImageCache = null; + }, + baseLayerImageCacheChanged: (state, action: PayloadAction) => { + state.baseLayerImageCache = action.payload ? imageDTOToImageWithDims(action.payload) : null; }, }, }); @@ -185,6 +191,7 @@ export const { scaledBboxChanged, bboxScaleMethodChanged, clipToBboxChanged, + baseLayerImageCacheChanged, // layers layerAdded, layerRecalled, diff --git a/invokeai/frontend/web/src/features/controlLayers/store/layersReducers.ts b/invokeai/frontend/web/src/features/controlLayers/store/layersReducers.ts index 980b655232..1f4dbeedc1 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/layersReducers.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/layersReducers.ts @@ -39,6 +39,7 @@ export const layersReducers = { y: 0, }); state.selectedEntityIdentifier = { type: 'layer', id }; + state.baseLayerImageCache = null; }, prepare: () => ({ payload: { id: uuidv4() } }), }, @@ -46,6 +47,7 @@ export const layersReducers = { const { data } = action.payload; state.layers.push(data); state.selectedEntityIdentifier = { type: 'layer', id: data.id }; + state.baseLayerImageCache = null; }, layerIsEnabledToggled: (state, action: PayloadAction<{ id: string }>) => { const { id } = action.payload; @@ -54,6 +56,7 @@ export const layersReducers = { return; } layer.isEnabled = !layer.isEnabled; + state.baseLayerImageCache = null; }, layerTranslated: (state, action: PayloadAction<{ id: string; x: number; y: number }>) => { const { id, x, y } = action.payload; @@ -63,6 +66,7 @@ export const layersReducers = { } layer.x = x; layer.y = y; + state.baseLayerImageCache = null; }, layerBboxChanged: (state, action: PayloadAction<{ id: string; bbox: IRect | null }>) => { const { id, bbox } = action.payload; @@ -88,13 +92,16 @@ export const layersReducers = { layer.objects = []; layer.bbox = null; layer.bboxNeedsUpdate = false; + state.baseLayerImageCache = null; }, layerDeleted: (state, action: PayloadAction<{ id: string }>) => { const { id } = action.payload; state.layers = state.layers.filter((l) => l.id !== id); + state.baseLayerImageCache = null; }, layerAllDeleted: (state) => { state.layers = []; + state.baseLayerImageCache = null; }, layerOpacityChanged: (state, action: PayloadAction<{ id: string; opacity: number }>) => { const { id, opacity } = action.payload; @@ -103,6 +110,7 @@ export const layersReducers = { return; } layer.opacity = opacity; + state.baseLayerImageCache = null; }, layerMovedForwardOne: (state, action: PayloadAction<{ id: string }>) => { const { id } = action.payload; @@ -111,6 +119,7 @@ export const layersReducers = { return; } moveOneToEnd(state.layers, layer); + state.baseLayerImageCache = null; }, layerMovedToFront: (state, action: PayloadAction<{ id: string }>) => { const { id } = action.payload; @@ -119,6 +128,7 @@ export const layersReducers = { return; } moveToEnd(state.layers, layer); + state.baseLayerImageCache = null; }, layerMovedBackwardOne: (state, action: PayloadAction<{ id: string }>) => { const { id } = action.payload; @@ -127,6 +137,7 @@ export const layersReducers = { return; } moveOneToStart(state.layers, layer); + state.baseLayerImageCache = null; }, layerMovedToBack: (state, action: PayloadAction<{ id: string }>) => { const { id } = action.payload; @@ -135,6 +146,7 @@ export const layersReducers = { return; } moveToStart(state.layers, layer); + state.baseLayerImageCache = null; }, layerBrushLineAdded: { reducer: (state, action: PayloadAction) => { @@ -153,6 +165,7 @@ export const layersReducers = { clip, }); layer.bboxNeedsUpdate = true; + state.baseLayerImageCache = null; }, prepare: (payload: BrushLineAddedArg) => ({ payload: { ...payload, lineId: uuidv4() }, @@ -174,6 +187,7 @@ export const layersReducers = { clip, }); layer.bboxNeedsUpdate = true; + state.baseLayerImageCache = null; }, prepare: (payload: EraserLineAddedArg) => ({ payload: { ...payload, lineId: uuidv4() }, @@ -191,6 +205,7 @@ export const layersReducers = { } lastObject.points.push(...point); layer.bboxNeedsUpdate = true; + state.baseLayerImageCache = null; }, layerRectAdded: { reducer: (state, action: PayloadAction) => { @@ -210,6 +225,7 @@ export const layersReducers = { color, }); layer.bboxNeedsUpdate = true; + state.baseLayerImageCache = null; }, prepare: (payload: RectShapeAddedArg) => ({ payload: { ...payload, rectId: uuidv4() } }), }, @@ -222,6 +238,7 @@ export const layersReducers = { } layer.objects.push(imageDTOToImageObject(id, objectId, imageDTO)); layer.bboxNeedsUpdate = true; + state.baseLayerImageCache = null; }, prepare: (payload: ImageObjectAddedArg) => ({ payload: { ...payload, objectId: uuidv4() } }), }, diff --git a/invokeai/frontend/web/src/features/controlLayers/store/types.ts b/invokeai/frontend/web/src/features/controlLayers/store/types.ts index 5debf0da7a..f39f097920 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/types.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/types.ts @@ -872,6 +872,7 @@ export type CanvasV2State = { refinerNegativeAestheticScore: number; refinerStart: number; }; + baseLayerImageCache: ImageWithDims | null; }; export type StageAttrs = { x: number; y: number; width: number; height: number; scale: number }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addLayers.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addLayers.ts new file mode 100644 index 0000000000..8d4b74e76c --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addLayers.ts @@ -0,0 +1,96 @@ +import { getStore } from 'app/store/nanostores/store'; +import openBase64ImageInTab from 'common/util/openBase64ImageInTab'; +import { $nodeManager } from 'features/controlLayers/konva/renderers/renderer'; +import { blobToDataURL } from 'features/controlLayers/konva/util'; +import { baseLayerImageCacheChanged } from 'features/controlLayers/store/canvasV2Slice'; +import type { LayerEntity } from 'features/controlLayers/store/types'; +import type Konva from 'konva'; +import type { IRect } from 'konva/lib/types'; +import { getImageDTO, imagesApi } from 'services/api/endpoints/images'; +import type { ImageDTO } from 'services/api/types'; +import { assert } from 'tsafe'; + +const isValidLayer = (entity: LayerEntity) => { + return ( + entity.isEnabled && + // Boolean(entity.bbox) && TODO(psyche): Re-enable this check when we have a way to calculate bbox for all layers + entity.objects.length > 0 + ); +}; + +/** + * Get the blobs of all regional prompt layers. Only visible layers are returned. + * @param layerIds The IDs of the layers to get blobs for. If not provided, all regional prompt layers are used. + * @param preview Whether to open a new tab displaying each layer. + * @returns A map of layer IDs to blobs. + */ + +const getBaseLayer = async (layers: LayerEntity[], bbox: IRect, preview: boolean = false): Promise => { + const manager = $nodeManager.get(); + assert(manager, 'Node manager is null'); + + const stage = manager.stage.clone(); + + stage.scaleX(1); + stage.scaleY(1); + stage.x(0); + stage.y(0); + + const validLayers = layers.filter(isValidLayer); + + // Konva bug (?) - when iterating over the array returned from `stage.getLayers()`, if you destroy a layer, the array + // is mutated in-place and the next iteration will skip the next layer. To avoid this, we first collect the layers + // to delete in a separate array and then destroy them. + // TODO(psyche): Maybe report this? + const toDelete: Konva.Layer[] = []; + + for (const konvaLayer of stage.getLayers()) { + const layer = validLayers.find((l) => l.id === konvaLayer.id()); + if (!layer) { + toDelete.push(konvaLayer); + } + } + + for (const konvaLayer of toDelete) { + konvaLayer.destroy(); + } + + const blob = await new Promise((resolve) => { + stage.toBlob({ + callback: (blob) => { + assert(blob, 'Blob is null'); + resolve(blob); + }, + ...bbox, + }); + }); + + if (preview) { + const base64 = await blobToDataURL(blob); + openBase64ImageInTab([{ base64, caption: 'base layer' }]); + } + + stage.destroy(); + + return blob; +}; + +export const getBaseLayerImage = async (): Promise => { + const { dispatch, getState } = getStore(); + const state = getState(); + if (state.canvasV2.baseLayerImageCache) { + const imageDTO = await getImageDTO(state.canvasV2.baseLayerImageCache.name); + if (imageDTO) { + return imageDTO; + } + } + const blob = await getBaseLayer(state.canvasV2.layers, state.canvasV2.bbox, true); + const file = new File([blob], 'image.png', { type: 'image/png' }); + const req = dispatch( + imagesApi.endpoints.uploadImage.initiate({ file, image_category: 'general', is_intermediate: true }) + ); + req.reset(); + const imageDTO = await req.unwrap(); + dispatch(baseLayerImageCacheChanged(imageDTO)); + return imageDTO; +}; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addRegions.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addRegions.ts index 7d47ac3331..ae99dd6e32 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addRegions.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addRegions.ts @@ -1,8 +1,8 @@ import { getStore } from 'app/store/nanostores/store'; import { deepClone } from 'common/util/deepClone'; import openBase64ImageInTab from 'common/util/openBase64ImageInTab'; -import { KonvaNodeManager } from 'features/controlLayers/konva/nodeManager'; -import { renderRegions } from 'features/controlLayers/konva/renderers/regions'; +import type { KonvaEntityAdapter } from 'features/controlLayers/konva/nodeManager'; +import { $nodeManager } from 'features/controlLayers/konva/renderers/renderer'; import { blobToDataURL } from 'features/controlLayers/konva/util'; import { rgMaskImageUploaded } from 'features/controlLayers/store/canvasV2Slice'; import type { Dimensions, IPAdapterEntity, RegionEntity } from 'features/controlLayers/store/types'; @@ -15,9 +15,7 @@ import { } from 'features/nodes/util/graph/constants'; import { addIPAdapterCollectorSafe, isValidIPAdapter } from 'features/nodes/util/graph/generation/addIPAdapters'; import type { Graph } from 'features/nodes/util/graph/generation/Graph'; -import Konva from 'konva'; import type { IRect } from 'konva/lib/types'; -import { size } from 'lodash-es'; import { getImageDTO, imagesApi } from 'services/api/endpoints/images'; import type { BaseModelType, ImageDTO, Invocation } from 'services/api/types'; import { assert } from 'tsafe'; @@ -50,38 +48,34 @@ export const addRegions = async ( const isSDXL = base === 'sdxl'; const validRegions = regions.filter((rg) => isValidRegion(rg, base)); - const blobs = await getRGMaskBlobs(validRegions, documentSize, bbox); - assert(size(blobs) === size(validRegions), 'Mismatch between layer IDs and blobs'); - for (const rg of validRegions) { - const blob = blobs[rg.id]; - assert(blob, `Blob for layer ${rg.id} not found`); + for (const region of validRegions) { // Upload the mask image, or get the cached image if it exists - const { image_name } = await getMaskImage(rg, blob); + const { image_name } = await getRegionMaskImage(region, bbox, true); // The main mask-to-tensor node const maskToTensor = g.addNode({ - id: `${PROMPT_REGION_MASK_TO_TENSOR_PREFIX}_${rg.id}`, + id: `${PROMPT_REGION_MASK_TO_TENSOR_PREFIX}_${region.id}`, type: 'alpha_mask_to_tensor', image: { image_name, }, }); - if (rg.positivePrompt) { + if (region.positivePrompt) { // The main positive conditioning node const regionalPosCond = g.addNode( isSDXL ? { type: 'sdxl_compel_prompt', - id: `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${rg.id}`, - prompt: rg.positivePrompt, - style: rg.positivePrompt, // TODO: Should we put the positive prompt in both fields? + id: `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${region.id}`, + prompt: region.positivePrompt, + style: region.positivePrompt, // TODO: Should we put the positive prompt in both fields? } : { type: 'compel', - id: `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${rg.id}`, - prompt: rg.positivePrompt, + id: `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${region.id}`, + prompt: region.positivePrompt, } ); // Connect the mask to the conditioning @@ -106,20 +100,20 @@ export const addRegions = async ( } } - if (rg.negativePrompt) { + if (region.negativePrompt) { // The main negative conditioning node const regionalNegCond = g.addNode( isSDXL ? { type: 'sdxl_compel_prompt', - id: `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${rg.id}`, - prompt: rg.negativePrompt, - style: rg.negativePrompt, + id: `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${region.id}`, + prompt: region.negativePrompt, + style: region.negativePrompt, } : { type: 'compel', - id: `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${rg.id}`, - prompt: rg.negativePrompt, + id: `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${region.id}`, + prompt: region.negativePrompt, } ); // Connect the mask to the conditioning @@ -143,10 +137,10 @@ export const addRegions = async ( } // If we are using the "invert" auto-negative setting, we need to add an additional negative conditioning node - if (rg.autoNegative === 'invert' && rg.positivePrompt) { + if (region.autoNegative === 'invert' && region.positivePrompt) { // We re-use the mask image, but invert it when converting to tensor const invertTensorMask = g.addNode({ - id: `${PROMPT_REGION_INVERT_TENSOR_MASK_PREFIX}_${rg.id}`, + id: `${PROMPT_REGION_INVERT_TENSOR_MASK_PREFIX}_${region.id}`, type: 'invert_tensor_mask', }); // Connect the OG mask image to the inverted mask-to-tensor node @@ -156,14 +150,14 @@ export const addRegions = async ( isSDXL ? { type: 'sdxl_compel_prompt', - id: `${PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX}_${rg.id}`, - prompt: rg.positivePrompt, - style: rg.positivePrompt, + id: `${PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX}_${region.id}`, + prompt: region.positivePrompt, + style: region.positivePrompt, } : { type: 'compel', - id: `${PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX}_${rg.id}`, - prompt: rg.positivePrompt, + id: `${PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX}_${region.id}`, + prompt: region.positivePrompt, } ); // Connect the inverted mask to the conditioning @@ -186,7 +180,7 @@ export const addRegions = async ( } } - const validRGIPAdapters: IPAdapterEntity[] = rg.ipAdapters.filter((ipa) => isValidIPAdapter(ipa, base)); + const validRGIPAdapters: IPAdapterEntity[] = region.ipAdapters.filter((ipa) => isValidIPAdapter(ipa, base)); for (const ipa of validRGIPAdapters) { const ipAdapterCollect = addIPAdapterCollectorSafe(g, denoise); @@ -245,6 +239,20 @@ export const getMaskImage = async (rg: RegionEntity, blob: Blob): Promise => { + const { dispatch } = getStore(); + // No cached mask, or the cached image no longer exists - we need to upload the mask image + const file = new File([blob], `${id}_mask.png`, { type: 'image/png' }); + const req = dispatch( + imagesApi.endpoints.uploadImage.initiate({ file, image_category: 'mask', is_intermediate: true }) + ); + req.reset(); + + const imageDTO = await req.unwrap(); + dispatch(rgMaskImageUploaded({ id, imageDTO })); + return imageDTO; +}; + /** * Get the blobs of all regional prompt layers. Only visible layers are returned. * @param layerIds The IDs of the layers to get blobs for. If not provided, all regional prompt layers are used. @@ -252,53 +260,47 @@ export const getMaskImage = async (rg: RegionEntity, blob: Blob): Promise> => { - const container = document.createElement('div'); - const stage = new Konva.Stage({ container, ...documentSize }); - const manager = new KonvaNodeManager(stage); - renderRegions(manager, regions, 1, 'brush', null); - const adapters = manager.getAll(); - const blobs: Record = {}; +): Promise => { + const manager = $nodeManager.get(); + assert(manager, 'Node manager is null'); - // First remove all layers - for (const adapter of adapters) { - adapter.konvaLayer.remove(); - } - - // Next render each layer to a blob - for (const adapter of adapters) { - const region = regions.find((l) => l.id === adapter.id); - if (!region) { - continue; + // TODO(psyche): Why do I need to annotate this? TS must have some kind of circular ref w/ this type but I can't figure it out... + const adapter: KonvaEntityAdapter | undefined = manager.get(region.id); + assert(adapter, `Adapter for region ${region.id} not found`); + if (region.imageCache) { + const imageDTO = await getImageDTO(region.imageCache.name); + if (imageDTO) { + return imageDTO; } - stage.add(adapter.konvaLayer); - const blob = await new Promise((resolve) => { - stage.toBlob({ - callback: (blob) => { - assert(blob, 'Blob is null'); - resolve(blob); - }, - ...bbox, - }); + } + const layer = adapter.konvaLayer.clone(); + const objectGroup = adapter.konvaObjectGroup.clone(); + layer.destroyChildren(); + layer.add(objectGroup); + objectGroup.opacity(1); + objectGroup.cache(); + + const blob = await new Promise((resolve) => { + layer.toBlob({ + callback: (blob) => { + assert(blob, 'Blob is null'); + resolve(blob); + }, + ...bbox, }); + }); - if (preview) { - const base64 = await blobToDataURL(blob); - openBase64ImageInTab([ - { - base64, - caption: `${region.id}: ${region.positivePrompt} / ${region.negativePrompt}`, - }, - ]); - } - adapter.konvaLayer.remove(); - blobs[adapter.id] = blob; + if (preview) { + const base64 = await blobToDataURL(blob); + const caption = `${region.id}: ${region.positivePrompt} / ${region.negativePrompt}`; + openBase64ImageInTab([{ base64, caption }]); } - return blobs; + layer.destroy(); + + return await uploadMaskImage(region, blob); };