From b703884763726d8a83424c8dcec9ff26ce0d01a9 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 21 Jun 2024 22:30:57 +1000 Subject: [PATCH] feat(ui): generation mode calculation, fudged graphs --- .../listeners/enqueueRequestedLinear.ts | 2 + .../web/src/common/util/arrayBuffer.ts | 13 +- .../controlLayers/konva/nodeManager.ts | 231 +++++++++--------- .../src/features/controlLayers/konva/util.ts | 43 +++- .../src/features/controlLayers/store/types.ts | 2 + .../generation/buildGenerationTabGraph.ts | 4 +- .../generation/buildGenerationTabSDXLGraph.ts | 16 +- 7 files changed, 179 insertions(+), 132 deletions(-) diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts index b76d8acf85..29ff4b2224 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts @@ -23,6 +23,8 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening) const manager = $nodeManager.get(); assert(manager, 'Konva node manager not initialized'); + console.log('generation mode', manager.util.getGenerationMode()); + if (model?.base === 'sdxl') { graph = await buildGenerationTabSDXLGraph(state, manager); } else { diff --git a/invokeai/frontend/web/src/common/util/arrayBuffer.ts b/invokeai/frontend/web/src/common/util/arrayBuffer.ts index a21c9d8a47..c3b13dac26 100644 --- a/invokeai/frontend/web/src/common/util/arrayBuffer.ts +++ b/invokeai/frontend/web/src/common/util/arrayBuffer.ts @@ -1,10 +1,9 @@ -export const getImageDataTransparency = (pixels: Uint8ClampedArray) => { +export const getImageDataTransparency = (imageData: ImageData) => { let isFullyTransparent = true; let isPartiallyTransparent = false; - const len = pixels.length; - let i = 3; - for (i; i < len; i += 4) { - if (pixels[i] === 255) { + const len = imageData.data.length; + for (let i = 3; i < len; i += 4) { + if (imageData.data[i] === 255) { isFullyTransparent = false; } else { isPartiallyTransparent = true; @@ -18,8 +17,8 @@ export const getImageDataTransparency = (pixels: Uint8ClampedArray) => { export const areAnyPixelsBlack = (pixels: Uint8ClampedArray) => { const len = pixels.length; - let i = 0; - for (i; i < len; ) { + const i = 0; + for (let i = 0; i < len; i) { if (pixels[i++] === 0 && pixels[i++] === 0 && pixels[i++] === 0 && pixels[i++] === 255) { return true; } diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/nodeManager.ts b/invokeai/frontend/web/src/features/controlLayers/konva/nodeManager.ts index 19e095703c..f31a74b1f3 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/nodeManager.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/nodeManager.ts @@ -1,5 +1,5 @@ -import openBase64ImageInTab from 'common/util/openBase64ImageInTab'; -import { blobToDataURL } from 'features/controlLayers/konva/util'; +import { getImageDataTransparency } from 'common/util/arrayBuffer'; +import { konvaNodeToBlob, konvaNodeToImageData, previewBlob } from 'features/controlLayers/konva/util'; import type { BrushLine, BrushLineAddedArg, @@ -7,6 +7,7 @@ import type { CanvasV2State, EraserLine, EraserLineAddedArg, + GenerationMode, ImageObject, PointAddedToLineArg, PosChangedArg, @@ -157,6 +158,9 @@ type Util = { getRegionMaskImage: (arg: { id: string; bbox?: Rect; preview?: boolean }) => Promise; getInpaintMaskImage: (arg: { bbox?: Rect; preview?: boolean }) => Promise; getImageSourceImage: (arg: { bbox?: Rect; preview?: boolean }) => Promise; + getMaskLayerClone: (arg: { id: string }) => Konva.Layer; + getCompositeLayerStageClone: () => Konva.Stage; + getGenerationMode: () => GenerationMode; }; export class KonvaNodeManager { @@ -183,6 +187,9 @@ export class KonvaNodeManager { getRegionMaskImage: this._getRegionMaskImage.bind(this), getInpaintMaskImage: this._getInpaintMaskImage.bind(this), getImageSourceImage: this._getImageSourceImage.bind(this), + getMaskLayerClone: this._getMaskLayerClone.bind(this), + getCompositeLayerStageClone: this._getCompositeLayerStageClone.bind(this), + getGenerationMode: this._getGenerationMode.bind(this), }; this._konvaApi = null; this._preview = null; @@ -254,112 +261,34 @@ export class KonvaNodeManager { return this._stateApi; } - async _getRegionMaskImage(arg: { id: string; bbox?: Rect; preview?: boolean }): Promise { - const { id, bbox, preview = false } = arg; - const region = this.stateApi.getRegionsState().entities.find((entity) => entity.id === id); - assert(region, `Region entity state with id ${id} not found`); - const adapter = this.get(region.id); - assert(adapter, `Adapter for region ${region.id} not found`); + _getMaskLayerClone(arg: { id: string }): Konva.Layer { + const { id } = arg; + const adapter = this.get(id); + assert(adapter, `Adapter for entity ${id} not found`); - if (region.imageCache) { - const imageDTO = await this.util.getImageDTO(region.imageCache.name); - if (imageDTO) { - return imageDTO; - } - } + const layerClone = adapter.konvaLayer.clone(); + const objectGroupClone = adapter.konvaObjectGroup.clone(); - const layer = adapter.konvaLayer.clone(); - const objectGroup = adapter.konvaObjectGroup.clone(); - layer.destroyChildren(); - layer.add(objectGroup); - objectGroup.opacity(1); - objectGroup.cache(); + layerClone.destroyChildren(); + layerClone.add(objectGroupClone); - const blob = await new Promise((resolve) => { - layer.toBlob({ - callback: (blob) => { - assert(blob, 'Blob is null'); - resolve(blob); - }, - ...bbox, - }); - }); + objectGroupClone.opacity(1); + objectGroupClone.cache(); - if (preview) { - const base64 = await blobToDataURL(blob); - const caption = `${region.id}: ${region.positivePrompt} / ${region.negativePrompt}`; - openBase64ImageInTab([{ base64, caption }]); - } - - layer.destroy(); - - const imageDTO = await this.util.uploadImage(blob, `${region.id}_mask.png`, 'mask', true); - this.stateApi.onRegionMaskImageCached(region.id, imageDTO); - return imageDTO; + return layerClone; } - async _getInpaintMaskImage(arg: { bbox?: Rect; preview?: boolean }): Promise { - const { bbox, preview = false } = arg; - const inpaintMask = this.stateApi.getInpaintMaskState(); - const adapter = this.get(inpaintMask.id); - assert(adapter, `Adapter for ${inpaintMask.id} not found`); - - if (inpaintMask.imageCache) { - const imageDTO = await this.util.getImageDTO(inpaintMask.imageCache.name); - if (imageDTO) { - return imageDTO; - } - } - - const layer = adapter.konvaLayer.clone(); - const objectGroup = adapter.konvaObjectGroup.clone(); - layer.destroyChildren(); - layer.add(objectGroup); - objectGroup.opacity(1); - objectGroup.cache(); - - const blob = await new Promise((resolve) => { - layer.toBlob({ - callback: (blob) => { - assert(blob, 'Blob is null'); - resolve(blob); - }, - ...bbox, - }); - }); - - if (preview) { - const base64 = await blobToDataURL(blob); - const caption = 'inpaint mask'; - openBase64ImageInTab([{ base64, caption }]); - } - - layer.destroy(); - - const imageDTO = await this.util.uploadImage(blob, 'inpaint_mask.png', 'mask', true); - this.stateApi.onInpaintMaskImageCached(imageDTO); - return imageDTO; - } - - async _getImageSourceImage(arg: { bbox?: Rect; preview?: boolean }): Promise { - const { bbox, preview = false } = arg; + _getCompositeLayerStageClone(): Konva.Stage { const layersState = this.stateApi.getLayersState(); - const { entities, imageCache } = layersState; - if (imageCache) { - const imageDTO = await this.util.getImageDTO(imageCache.name); - if (imageDTO) { - return imageDTO; - } - } - const stage = this.stage.clone(); + const stageClone = this.stage.clone(); - stage.scaleX(1); - stage.scaleY(1); - stage.x(0); - stage.y(0); + stageClone.scaleX(1); + stageClone.scaleY(1); + stageClone.x(0); + stageClone.y(0); - const validLayers = entities.filter(isValidLayer); + const validLayers = layersState.entities.filter(isValidLayer); // Konva bug (?) - when iterating over the array returned from `stage.getLayers()`, if you destroy a layer, the array // is mutated in-place and the next iteration will skip the next layer. To avoid this, we first collect the layers @@ -367,7 +296,7 @@ export class KonvaNodeManager { // TODO(psyche): Maybe report this? const toDelete: Konva.Layer[] = []; - for (const konvaLayer of stage.getLayers()) { + for (const konvaLayer of stageClone.getLayers()) { const layer = validLayers.find((l) => l.id === konvaLayer.id()); if (!layer) { toDelete.push(konvaLayer); @@ -378,22 +307,100 @@ export class KonvaNodeManager { konvaLayer.destroy(); } - const blob = await new Promise((resolve) => { - stage.toBlob({ - callback: (blob) => { - assert(blob, 'Blob is null'); - resolve(blob); - }, - ...bbox, - }); - }); + return stageClone; + } - if (preview) { - const base64 = await blobToDataURL(blob); - openBase64ImageInTab([{ base64, caption: 'base layer' }]); + _getGenerationMode(): GenerationMode { + const { x, y, width, height } = this.stateApi.getBbox(); + const inpaintMaskLayer = this.util.getMaskLayerClone({ id: 'inpaint_mask' }); + const inpaintMaskImageData = konvaNodeToImageData(inpaintMaskLayer, { x, y, width, height }); + const inpaintMaskTransparency = getImageDataTransparency(inpaintMaskImageData); + const compositeLayer = this.util.getCompositeLayerStageClone(); + const compositeLayerImageData = konvaNodeToImageData(compositeLayer, { x, y, width, height }); + const compositeLayerTransparency = getImageDataTransparency(compositeLayerImageData); + if (compositeLayerTransparency.isPartiallyTransparent) { + if (compositeLayerTransparency.isFullyTransparent) { + return 'txt2img'; + } + return 'outpaint'; + } else { + if (!inpaintMaskTransparency.isFullyTransparent) { + return 'inpaint'; + } + return 'img2img'; + } + } + + async _getRegionMaskImage(arg: { id: string; bbox?: Rect; preview?: boolean }): Promise { + const { id, bbox, preview = false } = arg; + const region = this.stateApi.getRegionsState().entities.find((entity) => entity.id === id); + assert(region, `Region entity state with id ${id} not found`); + + if (region.imageCache) { + const imageDTO = await this.util.getImageDTO(region.imageCache.name); + if (imageDTO) { + return imageDTO; + } } - stage.destroy(); + const layerClone = this.util.getMaskLayerClone({ id }); + const blob = await konvaNodeToBlob(layerClone, bbox); + + if (preview) { + previewBlob(blob, `region ${region.id} mask`); + } + + layerClone.destroy(); + + const imageDTO = await this.util.uploadImage(blob, `${region.id}_mask.png`, 'mask', true); + this.stateApi.onRegionMaskImageCached(region.id, imageDTO); + return imageDTO; + } + + async _getInpaintMaskImage(arg: { bbox?: Rect; preview?: boolean }): Promise { + const { bbox, preview = false } = arg; + const inpaintMask = this.stateApi.getInpaintMaskState(); + + if (inpaintMask.imageCache) { + const imageDTO = await this.util.getImageDTO(inpaintMask.imageCache.name); + if (imageDTO) { + return imageDTO; + } + } + + const layerClone = this.util.getMaskLayerClone({ id: inpaintMask.id }); + const blob = await konvaNodeToBlob(layerClone, bbox); + + if (preview) { + previewBlob(blob, 'inpaint mask'); + } + + layerClone.destroy(); + + const imageDTO = await this.util.uploadImage(blob, 'inpaint_mask.png', 'mask', true); + this.stateApi.onInpaintMaskImageCached(imageDTO); + return imageDTO; + } + + async _getImageSourceImage(arg: { bbox?: Rect; preview?: boolean }): Promise { + const { bbox, preview = false } = arg; + const { imageCache } = this.stateApi.getLayersState(); + if (imageCache) { + const imageDTO = await this.util.getImageDTO(imageCache.name); + if (imageDTO) { + return imageDTO; + } + } + + const stageClone = this.util.getCompositeLayerStageClone(); + + const blob = await konvaNodeToBlob(stageClone, bbox); + + if (preview) { + previewBlob(blob, 'image source'); + } + + stageClone.destroy(); const imageDTO = await this.util.uploadImage(blob, 'base_layer.png', 'general', true); this.stateApi.onLayerImageCached(imageDTO); diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/util.ts b/invokeai/frontend/web/src/features/controlLayers/konva/util.ts index a4dd38627e..d6ece23dd2 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/util.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/util.ts @@ -11,10 +11,11 @@ import { RG_LAYER_NAME, RG_LAYER_RECT_SHAPE_NAME, } from 'features/controlLayers/konva/naming'; -import type { RgbaColor } from 'features/controlLayers/store/types'; +import type { Rect, RgbaColor } from 'features/controlLayers/store/types'; import Konva from 'konva'; import type { KonvaEventObject } from 'konva/lib/Node'; -import type { IRect, Vector2d } from 'konva/lib/types'; +import type { Vector2d } from 'konva/lib/types'; +import { assert } from 'tsafe'; /** * Gets the scaled and floored cursor position on the stage. If the cursor is not currently over the stage, returns null. @@ -203,24 +204,33 @@ export const dataURLToImageData = async (dataURL: string, width: number, height: /** * Converts a Konva node to a Blob * @param node - The Konva node to convert to a Blob - * @param boundingBox - The bounding box to crop to + * @param bbox - The bounding box to crop to * @returns A Promise that resolves with Blob of the node cropped to the bounding box */ -export const konvaNodeToBlob = async (node: Konva.Node, boundingBox: IRect): Promise => { - return await canvasToBlob(node.toCanvas(boundingBox)); +export const konvaNodeToBlob = async (node: Konva.Node, bbox?: Rect): Promise => { + return await new Promise((resolve) => { + node.toBlob({ + callback: (blob) => { + assert(blob, 'Blob is null'); + resolve(blob); + }, + ...(bbox ?? {}), + }); + }); }; /** * Converts a Konva node to an ImageData object * @param node - The Konva node to convert to an ImageData object - * @param boundingBox - The bounding box to crop to + * @param bbox - The bounding box to crop to * @returns A Promise that resolves with ImageData object of the node cropped to the bounding box */ -export const konvaNodeToImageData = async (node: Konva.Node, boundingBox: IRect): Promise => { +export const konvaNodeToImageData = (node: Konva.Node, bbox?: Rect): ImageData => { // get a dataURL of the bbox'd region - const dataURL = node.toDataURL(boundingBox); - - return await dataURLToImageData(dataURL, boundingBox.width, boundingBox.height); + const canvas = node.toCanvas({ ...(bbox ?? {}) }); + const ctx = canvas.getContext('2d'); + assert(ctx, 'ctx is null'); + return ctx.getImageData(0, 0, canvas.width, canvas.height); }; /** @@ -246,3 +256,16 @@ export const getPixelUnderCursor = (stage: Konva.Stage): RgbaColor | null => { return { r, g, b, a }; }; + +export const previewBlob = async (blob: Blob, label?: string) => { + const url = URL.createObjectURL(blob); + const w = window.open(''); + if (!w) { + return; + } + if (label) { + w.document.write(label); + w.document.write('
'); + } + w.document.write(``); +}; diff --git a/invokeai/frontend/web/src/features/controlLayers/store/types.ts b/invokeai/frontend/web/src/features/controlLayers/store/types.ts index 1964216489..482ffe1a5f 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/types.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/types.ts @@ -906,3 +906,5 @@ export const isLine = (obj: RenderableObject): obj is BrushLine | EraserLine => export type RemoveIndexString = { [K in keyof T as string extends K ? never : K]: T[K]; }; + +export type GenerationMode = 'txt2img' | 'img2img' | 'inpaint' | 'outpaint'; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildGenerationTabGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildGenerationTabGraph.ts index d9adb21cd5..cc859780f2 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildGenerationTabGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildGenerationTabGraph.ts @@ -1,4 +1,5 @@ import type { RootState } from 'app/store/store'; +import type { KonvaNodeManager } from 'features/controlLayers/konva/nodeManager'; import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; import { CLIP_SKIP, @@ -29,7 +30,7 @@ import { assert } from 'tsafe'; import { addRegions } from './addRegions'; -export const buildGenerationTabGraph = async (state: RootState): Promise => { +export const buildGenerationTabGraph = async (state: RootState, manager: KonvaNodeManager): Promise => { const { model, cfgScale: cfg_scale, @@ -159,6 +160,7 @@ export const buildGenerationTabGraph = async (state: RootState): Promise => { +export const buildGenerationTabSDXLGraph = async ( + state: RootState, + manager: KonvaNodeManager +): Promise => { const { model, cfgScale: cfg_scale, @@ -42,6 +46,7 @@ export const buildGenerationTabSDXLGraph = async (state: RootState): Promise