tidy(ui): abstract compositing logic to module

This commit is contained in:
psychedelicious 2024-08-22 14:57:11 +10:00
parent 21ed6bccd8
commit f442d206be
7 changed files with 254 additions and 225 deletions

View File

@ -0,0 +1,243 @@
import type { SerializableObject } from 'common/types';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import {
canvasToBlob,
canvasToImageData,
getImageDataTransparency,
getPrefixedId,
previewBlob,
} from 'features/controlLayers/konva/util';
import type { GenerationMode, Rect } from 'features/controlLayers/store/types';
import type { Logger } from 'roarr';
import { getImageDTO, uploadImage } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
import stableHash from 'stable-hash';
import { assert } from 'tsafe';
export class CanvasCompositorModule {
id: string;
path: string[];
log: Logger;
manager: CanvasManager;
constructor(manager: CanvasManager) {
this.id = getPrefixedId('canvas_compositor');
this.manager = manager;
this.path = this.manager.path.concat(this.id);
this.log = this.manager.buildLogger(this.getLoggingContext);
this.log.debug('Creating canvas compositor');
}
getCompositeRasterLayerEntityIds = (): string[] => {
const ids = [];
for (const adapter of this.manager.rasterLayerAdapters.values()) {
if (adapter.state.isEnabled && adapter.renderer.hasObjects()) {
ids.push(adapter.id);
}
}
return ids;
};
getCompositeInpaintMaskEntityIds = (): string[] => {
const ids = [];
for (const adapter of this.manager.inpaintMaskAdapters.values()) {
if (adapter.state.isEnabled && adapter.renderer.hasObjects()) {
ids.push(adapter.id);
}
}
return ids;
};
getCompositeRasterLayerCanvas = (rect: Rect): HTMLCanvasElement => {
const hash = this.getCompositeRasterLayerHash({ rect });
const cachedCanvas = this.manager.cache.canvasElementCache.get(hash);
if (cachedCanvas) {
this.log.trace({ rect }, 'Using cached composite inpaint mask canvas');
return cachedCanvas;
}
this.log.trace({ rect }, 'Building composite raster layer canvas');
const canvas = document.createElement('canvas');
canvas.width = rect.width;
canvas.height = rect.height;
const ctx = canvas.getContext('2d');
assert(ctx !== null, 'Canvas 2D context is null');
for (const id of this.getCompositeRasterLayerEntityIds()) {
const adapter = this.manager.rasterLayerAdapters.get(id);
if (!adapter) {
this.log.warn({ id }, 'Raster layer adapter not found');
continue;
}
this.log.trace({ id }, 'Drawing raster layer to composite canvas');
const adapterCanvas = adapter.getCanvas(rect);
ctx.drawImage(adapterCanvas, 0, 0);
}
this.manager.cache.canvasElementCache.set(hash, canvas);
return canvas;
};
getCompositeInpaintMaskCanvas = (rect: Rect): HTMLCanvasElement => {
const hash = this.getCompositeInpaintMaskHash({ rect });
const cachedCanvas = this.manager.cache.canvasElementCache.get(hash);
if (cachedCanvas) {
this.log.trace({ rect }, 'Using cached composite inpaint mask canvas');
return cachedCanvas;
}
this.log.trace({ rect }, 'Building composite inpaint mask canvas');
const canvas = document.createElement('canvas');
canvas.width = rect.width;
canvas.height = rect.height;
const ctx = canvas.getContext('2d');
assert(ctx !== null);
for (const id of this.getCompositeInpaintMaskEntityIds()) {
const adapter = this.manager.inpaintMaskAdapters.get(id);
if (!adapter) {
this.log.warn({ id }, 'Inpaint mask adapter not found');
continue;
}
this.log.trace({ id }, 'Drawing inpaint mask to composite canvas');
const adapterCanvas = adapter.getCanvas(rect);
ctx.drawImage(adapterCanvas, 0, 0);
}
this.manager.cache.canvasElementCache.set(hash, canvas);
return canvas;
};
getCompositeRasterLayerHash = (extra: SerializableObject): string => {
const data: Record<string, SerializableObject> = {
extra,
};
for (const id of this.getCompositeRasterLayerEntityIds()) {
const adapter = this.manager.rasterLayerAdapters.get(id);
if (!adapter) {
this.log.warn({ id }, 'Raster layer adapter not found');
continue;
}
data[id] = adapter.getHashableState();
}
return stableHash(data);
};
getCompositeInpaintMaskHash = (extra: SerializableObject): string => {
const data: Record<string, SerializableObject> = {
extra,
};
for (const id of this.getCompositeInpaintMaskEntityIds()) {
const adapter = this.manager.inpaintMaskAdapters.get(id);
if (!adapter) {
this.log.warn({ id }, 'Inpaint mask adapter not found');
continue;
}
data[id] = adapter.getHashableState();
}
return stableHash(data);
};
getCompositeRasterLayerImageDTO = async (rect: Rect): Promise<ImageDTO> => {
let imageDTO: ImageDTO | null = null;
const hash = this.getCompositeRasterLayerHash({ rect });
const cachedImageName = this.manager.cache.imageNameCache.get(hash);
if (cachedImageName) {
imageDTO = await getImageDTO(cachedImageName);
if (imageDTO) {
this.log.trace({ rect, imageName: cachedImageName, imageDTO }, 'Using cached composite raster layer image');
return imageDTO;
}
}
this.log.trace({ rect }, 'Rasterizing composite raster layer');
const canvas = this.getCompositeRasterLayerCanvas(rect);
const blob = await canvasToBlob(canvas);
if (this.manager._isDebugging) {
previewBlob(blob, 'Composite raster layer canvas');
}
imageDTO = await uploadImage(blob, 'composite-raster-layer.png', 'general', true);
this.manager.cache.imageNameCache.set(hash, imageDTO.image_name);
return imageDTO;
};
getCompositeInpaintMaskImageDTO = async (rect: Rect): Promise<ImageDTO> => {
let imageDTO: ImageDTO | null = null;
const hash = this.getCompositeInpaintMaskHash({ rect });
const cachedImageName = this.manager.cache.imageNameCache.get(hash);
if (cachedImageName) {
imageDTO = await getImageDTO(cachedImageName);
if (imageDTO) {
this.log.trace({ rect, cachedImageName, imageDTO }, 'Using cached composite inpaint mask image');
return imageDTO;
}
}
this.log.trace({ rect }, 'Rasterizing composite inpaint mask');
const canvas = this.getCompositeInpaintMaskCanvas(rect);
const blob = await canvasToBlob(canvas);
if (this.manager._isDebugging) {
previewBlob(blob, 'Composite inpaint mask canvas');
}
imageDTO = await uploadImage(blob, 'composite-inpaint-mask.png', 'general', true);
this.manager.cache.imageNameCache.set(hash, imageDTO.image_name);
return imageDTO;
};
getGenerationMode(): GenerationMode {
const { rect } = this.manager.stateApi.getBbox();
const compositeInpaintMaskHash = this.getCompositeInpaintMaskHash({ rect });
const compositeRasterLayerHash = this.getCompositeRasterLayerHash({ rect });
const hash = stableHash({ rect, compositeInpaintMaskHash, compositeRasterLayerHash });
const cachedGenerationMode = this.manager.cache.generationModeCache.get(hash);
if (cachedGenerationMode) {
this.log.trace({ rect, cachedGenerationMode }, 'Using cached generation mode');
return cachedGenerationMode;
}
const compositeInpaintMaskCanvas = this.getCompositeInpaintMaskCanvas(rect);
const compositeInpaintMaskImageData = canvasToImageData(compositeInpaintMaskCanvas);
const compositeInpaintMaskTransparency = getImageDataTransparency(compositeInpaintMaskImageData);
const compositeRasterLayerCanvas = this.getCompositeRasterLayerCanvas(rect);
const compositeRasterLayerImageData = canvasToImageData(compositeRasterLayerCanvas);
const compositeRasterLayerTransparency = getImageDataTransparency(compositeRasterLayerImageData);
let generationMode: GenerationMode;
if (compositeRasterLayerTransparency === 'FULLY_TRANSPARENT') {
// When the initial image is fully transparent, we are always doing txt2img
generationMode = 'txt2img';
} else if (compositeRasterLayerTransparency === 'PARTIALLY_TRANSPARENT') {
// When the initial image is partially transparent, we are always outpainting
generationMode = 'outpaint';
} else if (compositeInpaintMaskTransparency === 'FULLY_TRANSPARENT') {
// compositeLayerTransparency === 'OPAQUE'
// When the inpaint mask is fully transparent, we are doing img2img
generationMode = 'img2img';
} else {
// Else at least some of the inpaint mask is opaque, so we are inpainting
generationMode = 'inpaint';
}
this.manager.cache.generationModeCache.set(hash, generationMode);
return generationMode;
}
getLoggingContext = (): SerializableObject => {
return { ...this.manager.getLoggingContext(), path: this.path.join('.') };
};
}

View File

@ -3,25 +3,15 @@ import { logger } from 'app/logging/logger';
import type { AppStore } from 'app/store/store'; import type { AppStore } from 'app/store/store';
import type { SerializableObject } from 'common/types'; import type { SerializableObject } from 'common/types';
import { CanvasCacheModule } from 'features/controlLayers/konva/CanvasCacheModule'; import { CanvasCacheModule } from 'features/controlLayers/konva/CanvasCacheModule';
import { CanvasCompositorModule } from 'features/controlLayers/konva/CanvasCompositorModule';
import { CanvasFilter } from 'features/controlLayers/konva/CanvasFilter'; import { CanvasFilter } from 'features/controlLayers/konva/CanvasFilter';
import { CanvasRenderingModule } from 'features/controlLayers/konva/CanvasRenderingModule'; import { CanvasRenderingModule } from 'features/controlLayers/konva/CanvasRenderingModule';
import { CanvasStageModule } from 'features/controlLayers/konva/CanvasStageModule'; import { CanvasStageModule } from 'features/controlLayers/konva/CanvasStageModule';
import { CanvasWorkerModule } from 'features/controlLayers/konva/CanvasWorkerModule.js'; import { CanvasWorkerModule } from 'features/controlLayers/konva/CanvasWorkerModule.js';
import { import { getPrefixedId } from 'features/controlLayers/konva/util';
canvasToBlob,
canvasToImageData,
getImageDataTransparency,
getPrefixedId,
previewBlob,
} from 'features/controlLayers/konva/util';
import type { GenerationMode, Rect } from 'features/controlLayers/store/types';
import type Konva from 'konva'; import type Konva from 'konva';
import { atom } from 'nanostores'; import { atom } from 'nanostores';
import type { Logger } from 'roarr'; import type { Logger } from 'roarr';
import { getImageDTO, uploadImage } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
import stableHash from 'stable-hash';
import { assert } from 'tsafe';
import { CanvasBackground } from './CanvasBackground'; import { CanvasBackground } from './CanvasBackground';
import type { CanvasLayerAdapter } from './CanvasLayerAdapter'; import type { CanvasLayerAdapter } from './CanvasLayerAdapter';
@ -55,6 +45,7 @@ export class CanvasManager {
worker: CanvasWorkerModule; worker: CanvasWorkerModule;
cache: CanvasCacheModule; cache: CanvasCacheModule;
renderer: CanvasRenderingModule; renderer: CanvasRenderingModule;
compositor: CanvasCompositorModule;
_isDebugging: boolean = false; _isDebugging: boolean = false;
@ -70,6 +61,7 @@ export class CanvasManager {
this.cache = new CanvasCacheModule(this); this.cache = new CanvasCacheModule(this);
this.renderer = new CanvasRenderingModule(this); this.renderer = new CanvasRenderingModule(this);
this.preview = new CanvasPreview(this); this.preview = new CanvasPreview(this);
this.compositor = new CanvasCompositorModule(this);
this.stage.addLayer(this.preview.getLayer()); this.stage.addLayer(this.preview.getLayer());
this.background = new CanvasBackground(this); this.background = new CanvasBackground(this);
@ -185,212 +177,6 @@ export class CanvasManager {
}; };
}; };
getCompositeRasterLayerEntityIds = (): string[] => {
const ids = [];
for (const adapter of this.rasterLayerAdapters.values()) {
if (adapter.state.isEnabled && adapter.renderer.hasObjects()) {
ids.push(adapter.id);
}
}
return ids;
};
getCompositeInpaintMaskEntityIds = (): string[] => {
const ids = [];
for (const adapter of this.inpaintMaskAdapters.values()) {
if (adapter.state.isEnabled && adapter.renderer.hasObjects()) {
ids.push(adapter.id);
}
}
return ids;
};
getCompositeRasterLayerCanvas = (rect: Rect): HTMLCanvasElement => {
const hash = this.getCompositeRasterLayerHash({ rect });
const cachedCanvas = this.cache.canvasElementCache.get(hash);
if (cachedCanvas) {
this.log.trace({ rect }, 'Using cached composite inpaint mask canvas');
return cachedCanvas;
}
this.log.trace({ rect }, 'Building composite raster layer canvas');
const canvas = document.createElement('canvas');
canvas.width = rect.width;
canvas.height = rect.height;
const ctx = canvas.getContext('2d');
assert(ctx !== null);
for (const id of this.getCompositeRasterLayerEntityIds()) {
const adapter = this.rasterLayerAdapters.get(id);
if (!adapter) {
this.log.warn({ id }, 'Raster layer adapter not found');
continue;
}
this.log.trace({ id }, 'Drawing raster layer to composite canvas');
const adapterCanvas = adapter.getCanvas(rect);
ctx.drawImage(adapterCanvas, 0, 0);
}
this.cache.canvasElementCache.set(hash, canvas);
return canvas;
};
getCompositeInpaintMaskCanvas = (rect: Rect): HTMLCanvasElement => {
const hash = this.getCompositeInpaintMaskHash({ rect });
const cachedCanvas = this.cache.canvasElementCache.get(hash);
if (cachedCanvas) {
this.log.trace({ rect }, 'Using cached composite inpaint mask canvas');
return cachedCanvas;
}
this.log.trace({ rect }, 'Building composite inpaint mask canvas');
const canvas = document.createElement('canvas');
canvas.width = rect.width;
canvas.height = rect.height;
const ctx = canvas.getContext('2d');
assert(ctx !== null);
for (const id of this.getCompositeInpaintMaskEntityIds()) {
const adapter = this.inpaintMaskAdapters.get(id);
if (!adapter) {
this.log.warn({ id }, 'Inpaint mask adapter not found');
continue;
}
this.log.trace({ id }, 'Drawing inpaint mask to composite canvas');
const adapterCanvas = adapter.getCanvas(rect);
ctx.drawImage(adapterCanvas, 0, 0);
}
this.cache.canvasElementCache.set(hash, canvas);
return canvas;
};
getCompositeRasterLayerHash = (extra: SerializableObject): string => {
const data: Record<string, SerializableObject> = {
extra,
};
for (const id of this.getCompositeRasterLayerEntityIds()) {
const adapter = this.rasterLayerAdapters.get(id);
if (!adapter) {
this.log.warn({ id }, 'Raster layer adapter not found');
continue;
}
data[id] = adapter.getHashableState();
}
return stableHash(data);
};
getCompositeInpaintMaskHash = (extra: SerializableObject): string => {
const data: Record<string, SerializableObject> = {
extra,
};
for (const id of this.getCompositeInpaintMaskEntityIds()) {
const adapter = this.inpaintMaskAdapters.get(id);
if (!adapter) {
this.log.warn({ id }, 'Inpaint mask adapter not found');
continue;
}
data[id] = adapter.getHashableState();
}
return stableHash(data);
};
getCompositeRasterLayerImageDTO = async (rect: Rect): Promise<ImageDTO> => {
let imageDTO: ImageDTO | null = null;
const hash = this.getCompositeRasterLayerHash({ rect });
const cachedImageName = this.cache.imageNameCache.get(hash);
if (cachedImageName) {
imageDTO = await getImageDTO(cachedImageName);
if (imageDTO) {
this.log.trace({ rect, imageName: cachedImageName, imageDTO }, 'Using cached composite raster layer image');
return imageDTO;
}
}
this.log.trace({ rect }, 'Rasterizing composite raster layer');
const canvas = this.getCompositeRasterLayerCanvas(rect);
const blob = await canvasToBlob(canvas);
if (this._isDebugging) {
previewBlob(blob, 'Composite raster layer canvas');
}
imageDTO = await uploadImage(blob, 'composite-raster-layer.png', 'general', true);
this.cache.imageNameCache.set(hash, imageDTO.image_name);
return imageDTO;
};
getCompositeInpaintMaskImageDTO = async (rect: Rect): Promise<ImageDTO> => {
let imageDTO: ImageDTO | null = null;
const hash = this.getCompositeInpaintMaskHash({ rect });
const cachedImageName = this.cache.imageNameCache.get(hash);
if (cachedImageName) {
imageDTO = await getImageDTO(cachedImageName);
if (imageDTO) {
this.log.trace({ rect, cachedImageName, imageDTO }, 'Using cached composite inpaint mask image');
return imageDTO;
}
}
this.log.trace({ rect }, 'Rasterizing composite inpaint mask');
const canvas = this.getCompositeInpaintMaskCanvas(rect);
const blob = await canvasToBlob(canvas);
if (this._isDebugging) {
previewBlob(blob, 'Composite inpaint mask canvas');
}
imageDTO = await uploadImage(blob, 'composite-inpaint-mask.png', 'general', true);
this.cache.imageNameCache.set(hash, imageDTO.image_name);
return imageDTO;
};
getGenerationMode(): GenerationMode {
const { rect } = this.stateApi.getBbox();
const compositeInpaintMaskHash = this.getCompositeInpaintMaskHash({ rect });
const compositeRasterLayerHash = this.getCompositeRasterLayerHash({ rect });
const hash = stableHash({ rect, compositeInpaintMaskHash, compositeRasterLayerHash });
const cachedGenerationMode = this.cache.generationModeCache.get(hash);
if (cachedGenerationMode) {
this.log.trace({ rect, cachedGenerationMode }, 'Using cached generation mode');
return cachedGenerationMode;
}
const inpaintMaskImageData = canvasToImageData(this.getCompositeInpaintMaskCanvas(rect));
const inpaintMaskTransparency = getImageDataTransparency(inpaintMaskImageData);
const compositeLayerImageData = canvasToImageData(this.getCompositeRasterLayerCanvas(rect));
const compositeLayerTransparency = getImageDataTransparency(compositeLayerImageData);
let generationMode: GenerationMode;
if (compositeLayerTransparency === 'FULLY_TRANSPARENT') {
// When the initial image is fully transparent, we are always doing txt2img
generationMode = 'txt2img';
} else if (compositeLayerTransparency === 'PARTIALLY_TRANSPARENT') {
// When the initial image is partially transparent, we are always outpainting
generationMode = 'outpaint';
} else if (inpaintMaskTransparency === 'FULLY_TRANSPARENT') {
// compositeLayerTransparency === 'OPAQUE'
// When the inpaint mask is fully transparent, we are doing img2img
generationMode = 'img2img';
} else {
// Else at least some of the inpaint mask is opaque, so we are inpainting
generationMode = 'inpaint';
}
this.cache.generationModeCache.set(hash, generationMode);
return generationMode;
}
setCanvasManager = () => { setCanvasManager = () => {
this.log.debug('Setting canvas manager'); this.log.debug('Setting canvas manager');
$canvasManager.set(this); $canvasManager.set(this);

View File

@ -17,7 +17,7 @@ export const addImageToImage = async (
): Promise<Invocation<'img_resize' | 'l2i'>> => { ): Promise<Invocation<'img_resize' | 'l2i'>> => {
denoise.denoising_start = denoising_start; denoise.denoising_start = denoising_start;
const { image_name } = await manager.getCompositeRasterLayerImageDTO(bbox.rect); const { image_name } = await manager.compositor.getCompositeRasterLayerImageDTO(bbox.rect);
if (!isEqual(scaledSize, originalSize)) { if (!isEqual(scaledSize, originalSize)) {
// Resize the initial image to the scaled size, denoise, then resize back to the original size // Resize the initial image to the scaled size, denoise, then resize back to the original size

View File

@ -21,8 +21,8 @@ export const addInpaint = async (
): Promise<Invocation<'canvas_v2_mask_and_crop'>> => { ): Promise<Invocation<'canvas_v2_mask_and_crop'>> => {
denoise.denoising_start = denoising_start; denoise.denoising_start = denoising_start;
const initialImage = await manager.getCompositeRasterLayerImageDTO(bbox.rect); const initialImage = await manager.compositor.getCompositeRasterLayerImageDTO(bbox.rect);
const maskImage = await manager.getCompositeInpaintMaskImageDTO(bbox.rect); const maskImage = await manager.compositor.getCompositeInpaintMaskImageDTO(bbox.rect);
if (!isEqual(scaledSize, originalSize)) { if (!isEqual(scaledSize, originalSize)) {
// Scale before processing requires some resizing // Scale before processing requires some resizing

View File

@ -22,8 +22,8 @@ export const addOutpaint = async (
): Promise<Invocation<'canvas_v2_mask_and_crop'>> => { ): Promise<Invocation<'canvas_v2_mask_and_crop'>> => {
denoise.denoising_start = denoising_start; denoise.denoising_start = denoising_start;
const initialImage = await manager.getCompositeRasterLayerImageDTO(bbox.rect); const initialImage = await manager.compositor.getCompositeRasterLayerImageDTO(bbox.rect);
const maskImage = await manager.getCompositeInpaintMaskImageDTO(bbox.rect); const maskImage = await manager.compositor.getCompositeInpaintMaskImageDTO(bbox.rect);
const infill = getInfill(g, compositing); const infill = getInfill(g, compositing);
if (!isEqual(scaledSize, originalSize)) { if (!isEqual(scaledSize, originalSize)) {

View File

@ -38,7 +38,7 @@ import { addRegions } from './addRegions';
const log = logger('system'); const log = logger('system');
export const buildSD1Graph = async (state: RootState, manager: CanvasManager): Promise<Graph> => { export const buildSD1Graph = async (state: RootState, manager: CanvasManager): Promise<Graph> => {
const generationMode = manager.getGenerationMode(); const generationMode = manager.compositor.getGenerationMode();
log.debug({ generationMode }, 'Building SD1/SD2 graph'); log.debug({ generationMode }, 'Building SD1/SD2 graph');
const { bbox, params } = state.canvasV2; const { bbox, params } = state.canvasV2;

View File

@ -37,7 +37,7 @@ import { addRegions } from './addRegions';
const log = logger('system'); const log = logger('system');
export const buildSDXLGraph = async (state: RootState, manager: CanvasManager): Promise<Graph> => { export const buildSDXLGraph = async (state: RootState, manager: CanvasManager): Promise<Graph> => {
const generationMode = manager.getGenerationMode(); const generationMode = manager.compositor.getGenerationMode();
log.debug({ generationMode }, 'Building SDXL graph'); log.debug({ generationMode }, 'Building SDXL graph');
const { bbox, params } = state.canvasV2; const { bbox, params } = state.canvasV2;