feat(ui): generation mode calculation, fudged graphs

This commit is contained in:
psychedelicious 2024-06-21 22:30:57 +10:00
parent 32da98ab8f
commit b703884763
7 changed files with 179 additions and 132 deletions

View File

@ -23,6 +23,8 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
const manager = $nodeManager.get();
assert(manager, 'Konva node manager not initialized');
console.log('generation mode', manager.util.getGenerationMode());
if (model?.base === 'sdxl') {
graph = await buildGenerationTabSDXLGraph(state, manager);
} else {

View File

@ -1,10 +1,9 @@
export const getImageDataTransparency = (pixels: Uint8ClampedArray) => {
export const getImageDataTransparency = (imageData: ImageData) => {
let isFullyTransparent = true;
let isPartiallyTransparent = false;
const len = pixels.length;
let i = 3;
for (i; i < len; i += 4) {
if (pixels[i] === 255) {
const len = imageData.data.length;
for (let i = 3; i < len; i += 4) {
if (imageData.data[i] === 255) {
isFullyTransparent = false;
} else {
isPartiallyTransparent = true;
@ -18,8 +17,8 @@ export const getImageDataTransparency = (pixels: Uint8ClampedArray) => {
export const areAnyPixelsBlack = (pixels: Uint8ClampedArray) => {
const len = pixels.length;
let i = 0;
for (i; i < len; ) {
const i = 0;
for (let i = 0; i < len; i) {
if (pixels[i++] === 0 && pixels[i++] === 0 && pixels[i++] === 0 && pixels[i++] === 255) {
return true;
}

View File

@ -1,5 +1,5 @@
import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
import { blobToDataURL } from 'features/controlLayers/konva/util';
import { getImageDataTransparency } from 'common/util/arrayBuffer';
import { konvaNodeToBlob, konvaNodeToImageData, previewBlob } from 'features/controlLayers/konva/util';
import type {
BrushLine,
BrushLineAddedArg,
@ -7,6 +7,7 @@ import type {
CanvasV2State,
EraserLine,
EraserLineAddedArg,
GenerationMode,
ImageObject,
PointAddedToLineArg,
PosChangedArg,
@ -157,6 +158,9 @@ type Util = {
getRegionMaskImage: (arg: { id: string; bbox?: Rect; preview?: boolean }) => Promise<ImageDTO>;
getInpaintMaskImage: (arg: { bbox?: Rect; preview?: boolean }) => Promise<ImageDTO>;
getImageSourceImage: (arg: { bbox?: Rect; preview?: boolean }) => Promise<ImageDTO>;
getMaskLayerClone: (arg: { id: string }) => Konva.Layer;
getCompositeLayerStageClone: () => Konva.Stage;
getGenerationMode: () => GenerationMode;
};
export class KonvaNodeManager {
@ -183,6 +187,9 @@ export class KonvaNodeManager {
getRegionMaskImage: this._getRegionMaskImage.bind(this),
getInpaintMaskImage: this._getInpaintMaskImage.bind(this),
getImageSourceImage: this._getImageSourceImage.bind(this),
getMaskLayerClone: this._getMaskLayerClone.bind(this),
getCompositeLayerStageClone: this._getCompositeLayerStageClone.bind(this),
getGenerationMode: this._getGenerationMode.bind(this),
};
this._konvaApi = null;
this._preview = null;
@ -254,112 +261,34 @@ export class KonvaNodeManager {
return this._stateApi;
}
async _getRegionMaskImage(arg: { id: string; bbox?: Rect; preview?: boolean }): Promise<ImageDTO> {
const { id, bbox, preview = false } = arg;
const region = this.stateApi.getRegionsState().entities.find((entity) => entity.id === id);
assert(region, `Region entity state with id ${id} not found`);
const adapter = this.get(region.id);
assert(adapter, `Adapter for region ${region.id} not found`);
_getMaskLayerClone(arg: { id: string }): Konva.Layer {
const { id } = arg;
const adapter = this.get(id);
assert(adapter, `Adapter for entity ${id} not found`);
if (region.imageCache) {
const imageDTO = await this.util.getImageDTO(region.imageCache.name);
if (imageDTO) {
return imageDTO;
}
}
const layerClone = adapter.konvaLayer.clone();
const objectGroupClone = adapter.konvaObjectGroup.clone();
const layer = adapter.konvaLayer.clone();
const objectGroup = adapter.konvaObjectGroup.clone();
layer.destroyChildren();
layer.add(objectGroup);
objectGroup.opacity(1);
objectGroup.cache();
layerClone.destroyChildren();
layerClone.add(objectGroupClone);
const blob = await new Promise<Blob>((resolve) => {
layer.toBlob({
callback: (blob) => {
assert(blob, 'Blob is null');
resolve(blob);
},
...bbox,
});
});
objectGroupClone.opacity(1);
objectGroupClone.cache();
if (preview) {
const base64 = await blobToDataURL(blob);
const caption = `${region.id}: ${region.positivePrompt} / ${region.negativePrompt}`;
openBase64ImageInTab([{ base64, caption }]);
}
layer.destroy();
const imageDTO = await this.util.uploadImage(blob, `${region.id}_mask.png`, 'mask', true);
this.stateApi.onRegionMaskImageCached(region.id, imageDTO);
return imageDTO;
return layerClone;
}
async _getInpaintMaskImage(arg: { bbox?: Rect; preview?: boolean }): Promise<ImageDTO> {
const { bbox, preview = false } = arg;
const inpaintMask = this.stateApi.getInpaintMaskState();
const adapter = this.get(inpaintMask.id);
assert(adapter, `Adapter for ${inpaintMask.id} not found`);
if (inpaintMask.imageCache) {
const imageDTO = await this.util.getImageDTO(inpaintMask.imageCache.name);
if (imageDTO) {
return imageDTO;
}
}
const layer = adapter.konvaLayer.clone();
const objectGroup = adapter.konvaObjectGroup.clone();
layer.destroyChildren();
layer.add(objectGroup);
objectGroup.opacity(1);
objectGroup.cache();
const blob = await new Promise<Blob>((resolve) => {
layer.toBlob({
callback: (blob) => {
assert(blob, 'Blob is null');
resolve(blob);
},
...bbox,
});
});
if (preview) {
const base64 = await blobToDataURL(blob);
const caption = 'inpaint mask';
openBase64ImageInTab([{ base64, caption }]);
}
layer.destroy();
const imageDTO = await this.util.uploadImage(blob, 'inpaint_mask.png', 'mask', true);
this.stateApi.onInpaintMaskImageCached(imageDTO);
return imageDTO;
}
async _getImageSourceImage(arg: { bbox?: Rect; preview?: boolean }): Promise<ImageDTO> {
const { bbox, preview = false } = arg;
_getCompositeLayerStageClone(): Konva.Stage {
const layersState = this.stateApi.getLayersState();
const { entities, imageCache } = layersState;
if (imageCache) {
const imageDTO = await this.util.getImageDTO(imageCache.name);
if (imageDTO) {
return imageDTO;
}
}
const stage = this.stage.clone();
const stageClone = this.stage.clone();
stage.scaleX(1);
stage.scaleY(1);
stage.x(0);
stage.y(0);
stageClone.scaleX(1);
stageClone.scaleY(1);
stageClone.x(0);
stageClone.y(0);
const validLayers = entities.filter(isValidLayer);
const validLayers = layersState.entities.filter(isValidLayer);
// Konva bug (?) - when iterating over the array returned from `stage.getLayers()`, if you destroy a layer, the array
// is mutated in-place and the next iteration will skip the next layer. To avoid this, we first collect the layers
@ -367,7 +296,7 @@ export class KonvaNodeManager {
// TODO(psyche): Maybe report this?
const toDelete: Konva.Layer[] = [];
for (const konvaLayer of stage.getLayers()) {
for (const konvaLayer of stageClone.getLayers()) {
const layer = validLayers.find((l) => l.id === konvaLayer.id());
if (!layer) {
toDelete.push(konvaLayer);
@ -378,22 +307,100 @@ export class KonvaNodeManager {
konvaLayer.destroy();
}
const blob = await new Promise<Blob>((resolve) => {
stage.toBlob({
callback: (blob) => {
assert(blob, 'Blob is null');
resolve(blob);
},
...bbox,
});
});
return stageClone;
}
if (preview) {
const base64 = await blobToDataURL(blob);
openBase64ImageInTab([{ base64, caption: 'base layer' }]);
_getGenerationMode(): GenerationMode {
const { x, y, width, height } = this.stateApi.getBbox();
const inpaintMaskLayer = this.util.getMaskLayerClone({ id: 'inpaint_mask' });
const inpaintMaskImageData = konvaNodeToImageData(inpaintMaskLayer, { x, y, width, height });
const inpaintMaskTransparency = getImageDataTransparency(inpaintMaskImageData);
const compositeLayer = this.util.getCompositeLayerStageClone();
const compositeLayerImageData = konvaNodeToImageData(compositeLayer, { x, y, width, height });
const compositeLayerTransparency = getImageDataTransparency(compositeLayerImageData);
if (compositeLayerTransparency.isPartiallyTransparent) {
if (compositeLayerTransparency.isFullyTransparent) {
return 'txt2img';
}
return 'outpaint';
} else {
if (!inpaintMaskTransparency.isFullyTransparent) {
return 'inpaint';
}
return 'img2img';
}
}
async _getRegionMaskImage(arg: { id: string; bbox?: Rect; preview?: boolean }): Promise<ImageDTO> {
const { id, bbox, preview = false } = arg;
const region = this.stateApi.getRegionsState().entities.find((entity) => entity.id === id);
assert(region, `Region entity state with id ${id} not found`);
if (region.imageCache) {
const imageDTO = await this.util.getImageDTO(region.imageCache.name);
if (imageDTO) {
return imageDTO;
}
}
stage.destroy();
const layerClone = this.util.getMaskLayerClone({ id });
const blob = await konvaNodeToBlob(layerClone, bbox);
if (preview) {
previewBlob(blob, `region ${region.id} mask`);
}
layerClone.destroy();
const imageDTO = await this.util.uploadImage(blob, `${region.id}_mask.png`, 'mask', true);
this.stateApi.onRegionMaskImageCached(region.id, imageDTO);
return imageDTO;
}
async _getInpaintMaskImage(arg: { bbox?: Rect; preview?: boolean }): Promise<ImageDTO> {
const { bbox, preview = false } = arg;
const inpaintMask = this.stateApi.getInpaintMaskState();
if (inpaintMask.imageCache) {
const imageDTO = await this.util.getImageDTO(inpaintMask.imageCache.name);
if (imageDTO) {
return imageDTO;
}
}
const layerClone = this.util.getMaskLayerClone({ id: inpaintMask.id });
const blob = await konvaNodeToBlob(layerClone, bbox);
if (preview) {
previewBlob(blob, 'inpaint mask');
}
layerClone.destroy();
const imageDTO = await this.util.uploadImage(blob, 'inpaint_mask.png', 'mask', true);
this.stateApi.onInpaintMaskImageCached(imageDTO);
return imageDTO;
}
async _getImageSourceImage(arg: { bbox?: Rect; preview?: boolean }): Promise<ImageDTO> {
const { bbox, preview = false } = arg;
const { imageCache } = this.stateApi.getLayersState();
if (imageCache) {
const imageDTO = await this.util.getImageDTO(imageCache.name);
if (imageDTO) {
return imageDTO;
}
}
const stageClone = this.util.getCompositeLayerStageClone();
const blob = await konvaNodeToBlob(stageClone, bbox);
if (preview) {
previewBlob(blob, 'image source');
}
stageClone.destroy();
const imageDTO = await this.util.uploadImage(blob, 'base_layer.png', 'general', true);
this.stateApi.onLayerImageCached(imageDTO);

View File

@ -11,10 +11,11 @@ import {
RG_LAYER_NAME,
RG_LAYER_RECT_SHAPE_NAME,
} from 'features/controlLayers/konva/naming';
import type { RgbaColor } from 'features/controlLayers/store/types';
import type { Rect, RgbaColor } from 'features/controlLayers/store/types';
import Konva from 'konva';
import type { KonvaEventObject } from 'konva/lib/Node';
import type { IRect, Vector2d } from 'konva/lib/types';
import type { Vector2d } from 'konva/lib/types';
import { assert } from 'tsafe';
/**
* Gets the scaled and floored cursor position on the stage. If the cursor is not currently over the stage, returns null.
@ -203,24 +204,33 @@ export const dataURLToImageData = async (dataURL: string, width: number, height:
/**
* Converts a Konva node to a Blob
* @param node - The Konva node to convert to a Blob
* @param boundingBox - The bounding box to crop to
* @param bbox - The bounding box to crop to
* @returns A Promise that resolves with Blob of the node cropped to the bounding box
*/
export const konvaNodeToBlob = async (node: Konva.Node, boundingBox: IRect): Promise<Blob> => {
return await canvasToBlob(node.toCanvas(boundingBox));
export const konvaNodeToBlob = async (node: Konva.Node, bbox?: Rect): Promise<Blob> => {
return await new Promise<Blob>((resolve) => {
node.toBlob({
callback: (blob) => {
assert(blob, 'Blob is null');
resolve(blob);
},
...(bbox ?? {}),
});
});
};
/**
* Converts a Konva node to an ImageData object
* @param node - The Konva node to convert to an ImageData object
* @param boundingBox - The bounding box to crop to
* @param bbox - The bounding box to crop to
* @returns A Promise that resolves with ImageData object of the node cropped to the bounding box
*/
export const konvaNodeToImageData = async (node: Konva.Node, boundingBox: IRect): Promise<ImageData> => {
export const konvaNodeToImageData = (node: Konva.Node, bbox?: Rect): ImageData => {
// get a dataURL of the bbox'd region
const dataURL = node.toDataURL(boundingBox);
return await dataURLToImageData(dataURL, boundingBox.width, boundingBox.height);
const canvas = node.toCanvas({ ...(bbox ?? {}) });
const ctx = canvas.getContext('2d');
assert(ctx, 'ctx is null');
return ctx.getImageData(0, 0, canvas.width, canvas.height);
};
/**
@ -246,3 +256,16 @@ export const getPixelUnderCursor = (stage: Konva.Stage): RgbaColor | null => {
return { r, g, b, a };
};
export const previewBlob = async (blob: Blob, label?: string) => {
const url = URL.createObjectURL(blob);
const w = window.open('');
if (!w) {
return;
}
if (label) {
w.document.write(label);
w.document.write('</br>');
}
w.document.write(`<img src="${url}" style="border: 1px solid red;" />`);
};

View File

@ -906,3 +906,5 @@ export const isLine = (obj: RenderableObject): obj is BrushLine | EraserLine =>
export type RemoveIndexString<T> = {
[K in keyof T as string extends K ? never : K]: T[K];
};
export type GenerationMode = 'txt2img' | 'img2img' | 'inpaint' | 'outpaint';

View File

@ -1,4 +1,5 @@
import type { RootState } from 'app/store/store';
import type { KonvaNodeManager } from 'features/controlLayers/konva/nodeManager';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import {
CLIP_SKIP,
@ -29,7 +30,7 @@ import { assert } from 'tsafe';
import { addRegions } from './addRegions';
export const buildGenerationTabGraph = async (state: RootState): Promise<GraphType> => {
export const buildGenerationTabGraph = async (state: RootState, manager: KonvaNodeManager): Promise<GraphType> => {
const {
model,
cfgScale: cfg_scale,
@ -159,6 +160,7 @@ export const buildGenerationTabGraph = async (state: RootState): Promise<GraphTy
const _addedCAs = addControlAdapters(state.canvasV2.controlAdapters.entities, g, denoise, modelConfig.base);
const _addedIPAs = addIPAdapters(state.canvasV2.ipAdapters.entities, g, denoise, modelConfig.base);
const _addedRegions = await addRegions(
manager,
state.canvasV2.regions.entities,
g,
state.canvasV2.document,

View File

@ -1,4 +1,5 @@
import type { RootState } from 'app/store/store';
import type { KonvaNodeManager } from 'features/controlLayers/konva/nodeManager';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import {
LATENTS_TO_IMAGE,
@ -27,7 +28,10 @@ import { assert } from 'tsafe';
import { addRegions } from './addRegions';
export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<NonNullableGraph> => {
export const buildGenerationTabSDXLGraph = async (
state: RootState,
manager: KonvaNodeManager
): Promise<NonNullableGraph> => {
const {
model,
cfgScale: cfg_scale,
@ -42,6 +46,7 @@ export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<Non
negativePrompt,
refinerModel,
refinerStart,
img2imgStrength,
} = state.canvasV2.params;
const { width, height } = state.canvasV2.bbox;
@ -76,6 +81,7 @@ export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<Non
id: NEGATIVE_CONDITIONING_COLLECT,
});
const noise = g.addNode({ type: 'noise', id: NOISE, seed, width, height, use_cpu: shouldUseCpuNoise });
const i2l = g.addNode({ type: 'i2l', id: 'i2l' });
const denoise = g.addNode({
type: 'denoise_latents',
id: SDXL_DENOISE_LATENTS,
@ -83,7 +89,7 @@ export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<Non
cfg_rescale_multiplier,
scheduler,
steps,
denoising_start: 0,
denoising_start: refinerModel ? Math.min(refinerStart, 1 - img2imgStrength) : 1 - img2imgStrength,
denoising_end: refinerModel ? refinerStart : 1,
});
const l2i = g.addNode({
@ -116,6 +122,7 @@ export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<Non
g.addEdge(posCondCollect, 'collection', denoise, 'positive_conditioning');
g.addEdge(negCondCollect, 'collection', denoise, 'negative_conditioning');
g.addEdge(noise, 'noise', denoise, 'noise');
g.addEdge(i2l, 'latents', denoise, 'latents');
g.addEdge(denoise, 'latents', l2i, 'latents');
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
@ -146,6 +153,7 @@ export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<Non
// We might get the VAE from the main model, custom VAE, or seamless node.
const vaeSource = seamless ?? vaeLoader ?? modelLoader;
g.addEdge(vaeSource, 'vae', l2i, 'vae');
g.addEdge(vaeSource, 'vae', i2l, 'vae');
// Add Refiner if enabled
if (refinerModel) {
@ -155,6 +163,7 @@ export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<Non
const _addedCAs = addControlAdapters(state.canvasV2.controlAdapters.entities, g, denoise, modelConfig.base);
const _addedIPAs = addIPAdapters(state.canvasV2.ipAdapters.entities, g, denoise, modelConfig.base);
const _addedRegions = await addRegions(
manager,
state.canvasV2.regions.entities,
g,
state.canvasV2.document,
@ -166,6 +175,9 @@ export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<Non
posCondCollect,
negCondCollect
);
const { image_name } = await manager.util.getImageSourceImage({ bbox: state.canvasV2.bbox, preview: true });
await manager.util.getInpaintMaskImage({ bbox: state.canvasV2.bbox, preview: true });
i2l.image = { image_name };
if (state.system.shouldUseNSFWChecker) {
imageOutput = addNSFWChecker(g, imageOutput);