From 6a8ceef4042088dd9e6f9a6e62eb05e204dbaade Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 22 Aug 2024 14:57:11 +1000 Subject: [PATCH] tidy(ui): abstract compositing logic to module --- .../konva/CanvasCompositorModule.ts | 243 ++++++++++++++++++ .../controlLayers/konva/CanvasManager.ts | 222 +--------------- .../util/graph/generation/addImageToImage.ts | 2 +- .../nodes/util/graph/generation/addInpaint.ts | 4 +- .../util/graph/generation/addOutpaint.ts | 4 +- .../util/graph/generation/buildSD1Graph.ts | 2 +- .../util/graph/generation/buildSDXLGraph.ts | 2 +- 7 files changed, 254 insertions(+), 225 deletions(-) create mode 100644 invokeai/frontend/web/src/features/controlLayers/konva/CanvasCompositorModule.ts diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasCompositorModule.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasCompositorModule.ts new file mode 100644 index 0000000000..c7df665a46 --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasCompositorModule.ts @@ -0,0 +1,243 @@ +import type { SerializableObject } from 'common/types'; +import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; +import { + canvasToBlob, + canvasToImageData, + getImageDataTransparency, + getPrefixedId, + previewBlob, +} from 'features/controlLayers/konva/util'; +import type { GenerationMode, Rect } from 'features/controlLayers/store/types'; +import type { Logger } from 'roarr'; +import { getImageDTO, uploadImage } from 'services/api/endpoints/images'; +import type { ImageDTO } from 'services/api/types'; +import stableHash from 'stable-hash'; +import { assert } from 'tsafe'; + +export class CanvasCompositorModule { + id: string; + path: string[]; + log: Logger; + manager: CanvasManager; + + constructor(manager: CanvasManager) { + this.id = getPrefixedId('canvas_compositor'); + this.manager = manager; + this.path = this.manager.path.concat(this.id); + this.log = this.manager.buildLogger(this.getLoggingContext); + this.log.debug('Creating canvas compositor'); + } + + getCompositeRasterLayerEntityIds = (): string[] => { + const ids = []; + for (const adapter of this.manager.rasterLayerAdapters.values()) { + if (adapter.state.isEnabled && adapter.renderer.hasObjects()) { + ids.push(adapter.id); + } + } + return ids; + }; + + getCompositeInpaintMaskEntityIds = (): string[] => { + const ids = []; + for (const adapter of this.manager.inpaintMaskAdapters.values()) { + if (adapter.state.isEnabled && adapter.renderer.hasObjects()) { + ids.push(adapter.id); + } + } + return ids; + }; + + getCompositeRasterLayerCanvas = (rect: Rect): HTMLCanvasElement => { + const hash = this.getCompositeRasterLayerHash({ rect }); + const cachedCanvas = this.manager.cache.canvasElementCache.get(hash); + + if (cachedCanvas) { + this.log.trace({ rect }, 'Using cached composite inpaint mask canvas'); + return cachedCanvas; + } + + this.log.trace({ rect }, 'Building composite raster layer canvas'); + + const canvas = document.createElement('canvas'); + canvas.width = rect.width; + canvas.height = rect.height; + + const ctx = canvas.getContext('2d'); + assert(ctx !== null, 'Canvas 2D context is null'); + + for (const id of this.getCompositeRasterLayerEntityIds()) { + const adapter = this.manager.rasterLayerAdapters.get(id); + if (!adapter) { + this.log.warn({ id }, 'Raster layer adapter not found'); + continue; + } + this.log.trace({ id }, 'Drawing raster layer to composite canvas'); + const adapterCanvas = adapter.getCanvas(rect); + ctx.drawImage(adapterCanvas, 0, 0); + } + this.manager.cache.canvasElementCache.set(hash, canvas); + return canvas; + }; + + getCompositeInpaintMaskCanvas = (rect: Rect): HTMLCanvasElement => { + const hash = this.getCompositeInpaintMaskHash({ rect }); + const cachedCanvas = this.manager.cache.canvasElementCache.get(hash); + + if (cachedCanvas) { + this.log.trace({ rect }, 'Using cached composite inpaint mask canvas'); + return cachedCanvas; + } + + this.log.trace({ rect }, 'Building composite inpaint mask canvas'); + + const canvas = document.createElement('canvas'); + canvas.width = rect.width; + canvas.height = rect.height; + + const ctx = canvas.getContext('2d'); + assert(ctx !== null); + + for (const id of this.getCompositeInpaintMaskEntityIds()) { + const adapter = this.manager.inpaintMaskAdapters.get(id); + if (!adapter) { + this.log.warn({ id }, 'Inpaint mask adapter not found'); + continue; + } + this.log.trace({ id }, 'Drawing inpaint mask to composite canvas'); + const adapterCanvas = adapter.getCanvas(rect); + ctx.drawImage(adapterCanvas, 0, 0); + } + this.manager.cache.canvasElementCache.set(hash, canvas); + return canvas; + }; + + getCompositeRasterLayerHash = (extra: SerializableObject): string => { + const data: Record = { + extra, + }; + for (const id of this.getCompositeRasterLayerEntityIds()) { + const adapter = this.manager.rasterLayerAdapters.get(id); + if (!adapter) { + this.log.warn({ id }, 'Raster layer adapter not found'); + continue; + } + data[id] = adapter.getHashableState(); + } + return stableHash(data); + }; + + getCompositeInpaintMaskHash = (extra: SerializableObject): string => { + const data: Record = { + extra, + }; + for (const id of this.getCompositeInpaintMaskEntityIds()) { + const adapter = this.manager.inpaintMaskAdapters.get(id); + if (!adapter) { + this.log.warn({ id }, 'Inpaint mask adapter not found'); + continue; + } + data[id] = adapter.getHashableState(); + } + return stableHash(data); + }; + + getCompositeRasterLayerImageDTO = async (rect: Rect): Promise => { + let imageDTO: ImageDTO | null = null; + + const hash = this.getCompositeRasterLayerHash({ rect }); + const cachedImageName = this.manager.cache.imageNameCache.get(hash); + + if (cachedImageName) { + imageDTO = await getImageDTO(cachedImageName); + if (imageDTO) { + this.log.trace({ rect, imageName: cachedImageName, imageDTO }, 'Using cached composite raster layer image'); + return imageDTO; + } + } + + this.log.trace({ rect }, 'Rasterizing composite raster layer'); + + const canvas = this.getCompositeRasterLayerCanvas(rect); + const blob = await canvasToBlob(canvas); + if (this.manager._isDebugging) { + previewBlob(blob, 'Composite raster layer canvas'); + } + + imageDTO = await uploadImage(blob, 'composite-raster-layer.png', 'general', true); + this.manager.cache.imageNameCache.set(hash, imageDTO.image_name); + return imageDTO; + }; + + getCompositeInpaintMaskImageDTO = async (rect: Rect): Promise => { + let imageDTO: ImageDTO | null = null; + + const hash = this.getCompositeInpaintMaskHash({ rect }); + const cachedImageName = this.manager.cache.imageNameCache.get(hash); + + if (cachedImageName) { + imageDTO = await getImageDTO(cachedImageName); + if (imageDTO) { + this.log.trace({ rect, cachedImageName, imageDTO }, 'Using cached composite inpaint mask image'); + return imageDTO; + } + } + + this.log.trace({ rect }, 'Rasterizing composite inpaint mask'); + + const canvas = this.getCompositeInpaintMaskCanvas(rect); + const blob = await canvasToBlob(canvas); + if (this.manager._isDebugging) { + previewBlob(blob, 'Composite inpaint mask canvas'); + } + + imageDTO = await uploadImage(blob, 'composite-inpaint-mask.png', 'general', true); + this.manager.cache.imageNameCache.set(hash, imageDTO.image_name); + return imageDTO; + }; + + getGenerationMode(): GenerationMode { + const { rect } = this.manager.stateApi.getBbox(); + + const compositeInpaintMaskHash = this.getCompositeInpaintMaskHash({ rect }); + const compositeRasterLayerHash = this.getCompositeRasterLayerHash({ rect }); + const hash = stableHash({ rect, compositeInpaintMaskHash, compositeRasterLayerHash }); + const cachedGenerationMode = this.manager.cache.generationModeCache.get(hash); + + if (cachedGenerationMode) { + this.log.trace({ rect, cachedGenerationMode }, 'Using cached generation mode'); + return cachedGenerationMode; + } + + const compositeInpaintMaskCanvas = this.getCompositeInpaintMaskCanvas(rect); + const compositeInpaintMaskImageData = canvasToImageData(compositeInpaintMaskCanvas); + const compositeInpaintMaskTransparency = getImageDataTransparency(compositeInpaintMaskImageData); + + const compositeRasterLayerCanvas = this.getCompositeRasterLayerCanvas(rect); + const compositeRasterLayerImageData = canvasToImageData(compositeRasterLayerCanvas); + const compositeRasterLayerTransparency = getImageDataTransparency(compositeRasterLayerImageData); + + let generationMode: GenerationMode; + if (compositeRasterLayerTransparency === 'FULLY_TRANSPARENT') { + // When the initial image is fully transparent, we are always doing txt2img + generationMode = 'txt2img'; + } else if (compositeRasterLayerTransparency === 'PARTIALLY_TRANSPARENT') { + // When the initial image is partially transparent, we are always outpainting + generationMode = 'outpaint'; + } else if (compositeInpaintMaskTransparency === 'FULLY_TRANSPARENT') { + // compositeLayerTransparency === 'OPAQUE' + // When the inpaint mask is fully transparent, we are doing img2img + generationMode = 'img2img'; + } else { + // Else at least some of the inpaint mask is opaque, so we are inpainting + generationMode = 'inpaint'; + } + + this.manager.cache.generationModeCache.set(hash, generationMode); + return generationMode; + } + + getLoggingContext = (): SerializableObject => { + return { ...this.manager.getLoggingContext(), path: this.path.join('.') }; + }; +} diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasManager.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasManager.ts index 49167fba9d..88f3792373 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasManager.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasManager.ts @@ -3,25 +3,15 @@ import { logger } from 'app/logging/logger'; import type { AppStore } from 'app/store/store'; import type { SerializableObject } from 'common/types'; import { CanvasCacheModule } from 'features/controlLayers/konva/CanvasCacheModule'; +import { CanvasCompositorModule } from 'features/controlLayers/konva/CanvasCompositorModule'; import { CanvasFilter } from 'features/controlLayers/konva/CanvasFilter'; import { CanvasRenderingModule } from 'features/controlLayers/konva/CanvasRenderingModule'; import { CanvasStageModule } from 'features/controlLayers/konva/CanvasStageModule'; import { CanvasWorkerModule } from 'features/controlLayers/konva/CanvasWorkerModule.js'; -import { - canvasToBlob, - canvasToImageData, - getImageDataTransparency, - getPrefixedId, - previewBlob, -} from 'features/controlLayers/konva/util'; -import type { GenerationMode, Rect } from 'features/controlLayers/store/types'; +import { getPrefixedId } from 'features/controlLayers/konva/util'; import type Konva from 'konva'; import { atom } from 'nanostores'; import type { Logger } from 'roarr'; -import { getImageDTO, uploadImage } from 'services/api/endpoints/images'; -import type { ImageDTO } from 'services/api/types'; -import stableHash from 'stable-hash'; -import { assert } from 'tsafe'; import { CanvasBackground } from './CanvasBackground'; import type { CanvasLayerAdapter } from './CanvasLayerAdapter'; @@ -55,6 +45,7 @@ export class CanvasManager { worker: CanvasWorkerModule; cache: CanvasCacheModule; renderer: CanvasRenderingModule; + compositor: CanvasCompositorModule; _isDebugging: boolean = false; @@ -70,6 +61,7 @@ export class CanvasManager { this.cache = new CanvasCacheModule(this); this.renderer = new CanvasRenderingModule(this); this.preview = new CanvasPreview(this); + this.compositor = new CanvasCompositorModule(this); this.stage.addLayer(this.preview.getLayer()); this.background = new CanvasBackground(this); @@ -185,212 +177,6 @@ export class CanvasManager { }; }; - getCompositeRasterLayerEntityIds = (): string[] => { - const ids = []; - for (const adapter of this.rasterLayerAdapters.values()) { - if (adapter.state.isEnabled && adapter.renderer.hasObjects()) { - ids.push(adapter.id); - } - } - return ids; - }; - - getCompositeInpaintMaskEntityIds = (): string[] => { - const ids = []; - for (const adapter of this.inpaintMaskAdapters.values()) { - if (adapter.state.isEnabled && adapter.renderer.hasObjects()) { - ids.push(adapter.id); - } - } - return ids; - }; - - getCompositeRasterLayerCanvas = (rect: Rect): HTMLCanvasElement => { - const hash = this.getCompositeRasterLayerHash({ rect }); - const cachedCanvas = this.cache.canvasElementCache.get(hash); - - if (cachedCanvas) { - this.log.trace({ rect }, 'Using cached composite inpaint mask canvas'); - return cachedCanvas; - } - - this.log.trace({ rect }, 'Building composite raster layer canvas'); - - const canvas = document.createElement('canvas'); - canvas.width = rect.width; - canvas.height = rect.height; - - const ctx = canvas.getContext('2d'); - assert(ctx !== null); - - for (const id of this.getCompositeRasterLayerEntityIds()) { - const adapter = this.rasterLayerAdapters.get(id); - if (!adapter) { - this.log.warn({ id }, 'Raster layer adapter not found'); - continue; - } - this.log.trace({ id }, 'Drawing raster layer to composite canvas'); - const adapterCanvas = adapter.getCanvas(rect); - ctx.drawImage(adapterCanvas, 0, 0); - } - this.cache.canvasElementCache.set(hash, canvas); - return canvas; - }; - - getCompositeInpaintMaskCanvas = (rect: Rect): HTMLCanvasElement => { - const hash = this.getCompositeInpaintMaskHash({ rect }); - const cachedCanvas = this.cache.canvasElementCache.get(hash); - - if (cachedCanvas) { - this.log.trace({ rect }, 'Using cached composite inpaint mask canvas'); - return cachedCanvas; - } - - this.log.trace({ rect }, 'Building composite inpaint mask canvas'); - - const canvas = document.createElement('canvas'); - canvas.width = rect.width; - canvas.height = rect.height; - - const ctx = canvas.getContext('2d'); - assert(ctx !== null); - - for (const id of this.getCompositeInpaintMaskEntityIds()) { - const adapter = this.inpaintMaskAdapters.get(id); - if (!adapter) { - this.log.warn({ id }, 'Inpaint mask adapter not found'); - continue; - } - this.log.trace({ id }, 'Drawing inpaint mask to composite canvas'); - const adapterCanvas = adapter.getCanvas(rect); - ctx.drawImage(adapterCanvas, 0, 0); - } - this.cache.canvasElementCache.set(hash, canvas); - return canvas; - }; - - getCompositeRasterLayerHash = (extra: SerializableObject): string => { - const data: Record = { - extra, - }; - for (const id of this.getCompositeRasterLayerEntityIds()) { - const adapter = this.rasterLayerAdapters.get(id); - if (!adapter) { - this.log.warn({ id }, 'Raster layer adapter not found'); - continue; - } - data[id] = adapter.getHashableState(); - } - return stableHash(data); - }; - - getCompositeInpaintMaskHash = (extra: SerializableObject): string => { - const data: Record = { - extra, - }; - for (const id of this.getCompositeInpaintMaskEntityIds()) { - const adapter = this.inpaintMaskAdapters.get(id); - if (!adapter) { - this.log.warn({ id }, 'Inpaint mask adapter not found'); - continue; - } - data[id] = adapter.getHashableState(); - } - return stableHash(data); - }; - - getCompositeRasterLayerImageDTO = async (rect: Rect): Promise => { - let imageDTO: ImageDTO | null = null; - - const hash = this.getCompositeRasterLayerHash({ rect }); - const cachedImageName = this.cache.imageNameCache.get(hash); - - if (cachedImageName) { - imageDTO = await getImageDTO(cachedImageName); - if (imageDTO) { - this.log.trace({ rect, imageName: cachedImageName, imageDTO }, 'Using cached composite raster layer image'); - return imageDTO; - } - } - - this.log.trace({ rect }, 'Rasterizing composite raster layer'); - - const canvas = this.getCompositeRasterLayerCanvas(rect); - const blob = await canvasToBlob(canvas); - if (this._isDebugging) { - previewBlob(blob, 'Composite raster layer canvas'); - } - - imageDTO = await uploadImage(blob, 'composite-raster-layer.png', 'general', true); - this.cache.imageNameCache.set(hash, imageDTO.image_name); - return imageDTO; - }; - - getCompositeInpaintMaskImageDTO = async (rect: Rect): Promise => { - let imageDTO: ImageDTO | null = null; - - const hash = this.getCompositeInpaintMaskHash({ rect }); - const cachedImageName = this.cache.imageNameCache.get(hash); - - if (cachedImageName) { - imageDTO = await getImageDTO(cachedImageName); - if (imageDTO) { - this.log.trace({ rect, cachedImageName, imageDTO }, 'Using cached composite inpaint mask image'); - return imageDTO; - } - } - - this.log.trace({ rect }, 'Rasterizing composite inpaint mask'); - - const canvas = this.getCompositeInpaintMaskCanvas(rect); - const blob = await canvasToBlob(canvas); - if (this._isDebugging) { - previewBlob(blob, 'Composite inpaint mask canvas'); - } - - imageDTO = await uploadImage(blob, 'composite-inpaint-mask.png', 'general', true); - this.cache.imageNameCache.set(hash, imageDTO.image_name); - return imageDTO; - }; - - getGenerationMode(): GenerationMode { - const { rect } = this.stateApi.getBbox(); - - const compositeInpaintMaskHash = this.getCompositeInpaintMaskHash({ rect }); - const compositeRasterLayerHash = this.getCompositeRasterLayerHash({ rect }); - const hash = stableHash({ rect, compositeInpaintMaskHash, compositeRasterLayerHash }); - const cachedGenerationMode = this.cache.generationModeCache.get(hash); - - if (cachedGenerationMode) { - this.log.trace({ rect, cachedGenerationMode }, 'Using cached generation mode'); - return cachedGenerationMode; - } - - const inpaintMaskImageData = canvasToImageData(this.getCompositeInpaintMaskCanvas(rect)); - const inpaintMaskTransparency = getImageDataTransparency(inpaintMaskImageData); - const compositeLayerImageData = canvasToImageData(this.getCompositeRasterLayerCanvas(rect)); - const compositeLayerTransparency = getImageDataTransparency(compositeLayerImageData); - - let generationMode: GenerationMode; - if (compositeLayerTransparency === 'FULLY_TRANSPARENT') { - // When the initial image is fully transparent, we are always doing txt2img - generationMode = 'txt2img'; - } else if (compositeLayerTransparency === 'PARTIALLY_TRANSPARENT') { - // When the initial image is partially transparent, we are always outpainting - generationMode = 'outpaint'; - } else if (inpaintMaskTransparency === 'FULLY_TRANSPARENT') { - // compositeLayerTransparency === 'OPAQUE' - // When the inpaint mask is fully transparent, we are doing img2img - generationMode = 'img2img'; - } else { - // Else at least some of the inpaint mask is opaque, so we are inpainting - generationMode = 'inpaint'; - } - - this.cache.generationModeCache.set(hash, generationMode); - return generationMode; - } - setCanvasManager = () => { this.log.debug('Setting canvas manager'); $canvasManager.set(this); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addImageToImage.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addImageToImage.ts index 3aca05107d..6314ef9df4 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addImageToImage.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addImageToImage.ts @@ -17,7 +17,7 @@ export const addImageToImage = async ( ): Promise> => { denoise.denoising_start = denoising_start; - const { image_name } = await manager.getCompositeRasterLayerImageDTO(bbox.rect); + const { image_name } = await manager.compositor.getCompositeRasterLayerImageDTO(bbox.rect); if (!isEqual(scaledSize, originalSize)) { // Resize the initial image to the scaled size, denoise, then resize back to the original size diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addInpaint.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addInpaint.ts index c18cbf672f..b15e55ce25 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addInpaint.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addInpaint.ts @@ -21,8 +21,8 @@ export const addInpaint = async ( ): Promise> => { denoise.denoising_start = denoising_start; - const initialImage = await manager.getCompositeRasterLayerImageDTO(bbox.rect); - const maskImage = await manager.getCompositeInpaintMaskImageDTO(bbox.rect); + const initialImage = await manager.compositor.getCompositeRasterLayerImageDTO(bbox.rect); + const maskImage = await manager.compositor.getCompositeInpaintMaskImageDTO(bbox.rect); if (!isEqual(scaledSize, originalSize)) { // Scale before processing requires some resizing diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addOutpaint.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addOutpaint.ts index 0bc2b40cd9..0798f82916 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addOutpaint.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addOutpaint.ts @@ -22,8 +22,8 @@ export const addOutpaint = async ( ): Promise> => { denoise.denoising_start = denoising_start; - const initialImage = await manager.getCompositeRasterLayerImageDTO(bbox.rect); - const maskImage = await manager.getCompositeInpaintMaskImageDTO(bbox.rect); + const initialImage = await manager.compositor.getCompositeRasterLayerImageDTO(bbox.rect); + const maskImage = await manager.compositor.getCompositeInpaintMaskImageDTO(bbox.rect); const infill = getInfill(g, compositing); if (!isEqual(scaledSize, originalSize)) { diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts index 580fee42b0..47d5fe44a0 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts @@ -38,7 +38,7 @@ import { addRegions } from './addRegions'; const log = logger('system'); export const buildSD1Graph = async (state: RootState, manager: CanvasManager): Promise => { - const generationMode = manager.getGenerationMode(); + const generationMode = manager.compositor.getGenerationMode(); log.debug({ generationMode }, 'Building SD1/SD2 graph'); const { bbox, params } = state.canvasV2; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts index 7a30befb42..ee34f09172 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts @@ -37,7 +37,7 @@ import { addRegions } from './addRegions'; const log = logger('system'); export const buildSDXLGraph = async (state: RootState, manager: CanvasManager): Promise => { - const generationMode = manager.getGenerationMode(); + const generationMode = manager.compositor.getGenerationMode(); log.debug({ generationMode }, 'Building SDXL graph'); const { bbox, params } = state.canvasV2;