feat(ui): txt2img, img2img, inpaint & outpaint working

This commit is contained in:
psychedelicious 2024-08-07 15:32:53 +10:00
parent a42d0ce1d2
commit 3ae7250ef7
11 changed files with 206 additions and 69 deletions

View File

@ -21,7 +21,7 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
assert(manager, 'No model found in state'); assert(manager, 'No model found in state');
let didStartStaging = false; let didStartStaging = false;
if (!state.canvasV2.session.isStaging && state.canvasV2.session.isActive) { if (!state.canvasV2.session.isStaging) {
dispatch(sessionStartedStaging()); dispatch(sessionStartedStaging());
didStartStaging = true; didStartStaging = true;
} }
@ -49,7 +49,8 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
); );
req.reset(); req.reset();
await req.unwrap(); await req.unwrap();
} catch { } catch (error) {
console.log('Error in enqueueRequestedLinear', error);
if (didStartStaging && getState().canvasV2.session.isStaging) { if (didStartStaging && getState().canvasV2.session.isStaging) {
dispatch(sessionStagingAreaReset()); dispatch(sessionStagingAreaReset());
} }

View File

@ -1,16 +0,0 @@
export const getImageDataTransparency = (imageData: ImageData) => {
let isFullyTransparent = true;
let isPartiallyTransparent = false;
const len = imageData.data.length;
for (let i = 3; i < len; i += 4) {
if (imageData.data[i] !== 0) {
isFullyTransparent = false;
} else {
isPartiallyTransparent = true;
}
if (!isFullyTransparent && isPartiallyTransparent) {
return { isFullyTransparent, isPartiallyTransparent };
}
}
return { isFullyTransparent, isPartiallyTransparent };
};

View File

@ -28,7 +28,6 @@ export class CanvasLayerAdapter {
renderer: CanvasObjectRenderer; renderer: CanvasObjectRenderer;
isFirstRender: boolean = true; isFirstRender: boolean = true;
bboxNeedsUpdate: boolean = true;
constructor(state: CanvasLayerAdapter['state'], manager: CanvasLayerAdapter['manager']) { constructor(state: CanvasLayerAdapter['state'], manager: CanvasLayerAdapter['manager']) {
this.id = state.id; this.id = state.id;
@ -40,6 +39,8 @@ export class CanvasLayerAdapter {
this.konva = { this.konva = {
layer: new Konva.Layer({ layer: new Konva.Layer({
// We need the ID on the layer to help with building the composite initial image
// See `getCompositeLayerStageClone()`
id: this.id, id: this.id,
name: `${this.type}:layer`, name: `${this.type}:layer`,
listening: false, listening: false,
@ -134,7 +135,6 @@ export class CanvasLayerAdapter {
id: this.id, id: this.id,
type: this.type, type: this.type,
state: deepClone(this.state), state: deepClone(this.state),
bboxNeedsUpdate: this.bboxNeedsUpdate,
transformer: this.transformer.repr(), transformer: this.transformer.repr(),
renderer: this.renderer.repr(), renderer: this.renderer.repr(),
}; };

View File

@ -13,32 +13,41 @@ import type { CanvasTransformer } from 'features/controlLayers/konva/CanvasTrans
import { import {
getCompositeLayerImage, getCompositeLayerImage,
getControlAdapterImage, getControlAdapterImage,
getGenerationMode, getImageDataTransparency,
getInpaintMaskImage, getInpaintMaskImage,
getPrefixedId, getPrefixedId,
getRegionMaskImage, getRegionMaskImage,
konvaNodeToBlob,
konvaNodeToImageData,
nanoid, nanoid,
} from 'features/controlLayers/konva/util'; } from 'features/controlLayers/konva/util';
import type { Extents, ExtentsResult, GetBboxTask, WorkerLogMessage } from 'features/controlLayers/konva/worker'; import type { Extents, ExtentsResult, GetBboxTask, WorkerLogMessage } from 'features/controlLayers/konva/worker';
import { $lastProgressEvent, $shouldShowStagedImage } from 'features/controlLayers/store/canvasV2Slice'; import { $lastProgressEvent, $shouldShowStagedImage } from 'features/controlLayers/store/canvasV2Slice';
import { import type {
type CanvasControlAdapterState, CanvasControlAdapterState,
type CanvasEntityIdentifier, CanvasEntityIdentifier,
type CanvasInpaintMaskState, CanvasInpaintMaskState,
type CanvasLayerState, CanvasLayerState,
type CanvasRegionalGuidanceState, CanvasRegionalGuidanceState,
type CanvasV2State, CanvasV2State,
type Coordinate, Coordinate,
type GenerationMode, GenerationMode,
type GetLoggingContext, GetLoggingContext,
RGBA_WHITE, Rect,
type RgbaColor, RgbaColor,
} from 'features/controlLayers/store/types'; } from 'features/controlLayers/store/types';
import { RGBA_RED } from 'features/controlLayers/store/types';
import { isValidLayer } from 'features/nodes/util/graph/generation/addLayers';
import type Konva from 'konva'; import type Konva from 'konva';
import { atom } from 'nanostores'; import { atom } from 'nanostores';
import type { Logger } from 'roarr'; import type { Logger } from 'roarr';
import { getImageDTO as defaultGetImageDTO, uploadImage as defaultUploadImage } from 'services/api/endpoints/images'; import {
getImageDTO as defaultGetImageDTO,
getImageDTO,
uploadImage as defaultUploadImage,
} from 'services/api/endpoints/images';
import type { ImageCategory, ImageDTO } from 'services/api/types'; import type { ImageCategory, ImageDTO } from 'services/api/types';
import { assert } from 'tsafe';
import { CanvasBackground } from './CanvasBackground'; import { CanvasBackground } from './CanvasBackground';
import { CanvasBbox } from './CanvasBbox'; import { CanvasBbox } from './CanvasBbox';
@ -350,7 +359,8 @@ export class CanvasManager {
if (selectedEntity) { if (selectedEntity) {
// These two entity types use a compositing rect for opacity. Their fill is always white. // These two entity types use a compositing rect for opacity. Their fill is always white.
if (selectedEntity.state.type === 'regional_guidance' || selectedEntity.state.type === 'inpaint_mask') { if (selectedEntity.state.type === 'regional_guidance' || selectedEntity.state.type === 'inpaint_mask') {
currentFill = RGBA_WHITE; currentFill = RGBA_RED;
// currentFill = RGBA_WHITE;
} }
} }
return currentFill; return currentFill;
@ -620,8 +630,96 @@ export class CanvasManager {
return pixels / this.getStageScale(); return pixels / this.getStageScale();
} }
getCompositeLayerStageClone = (): Konva.Stage => {
const layersState = this.stateApi.getLayersState();
const stageClone = this.stage.clone();
stageClone.scaleX(1);
stageClone.scaleY(1);
stageClone.x(0);
stageClone.y(0);
const validLayers = layersState.entities.filter(isValidLayer);
// getLayers() returns the internal `children` array of the stage directly - calling destroy on a layer will
// mutate that array. We need to clone the array to avoid mutating the original.
for (const konvaLayer of stageClone.getLayers().slice()) {
if (!validLayers.find((l) => l.id === konvaLayer.id())) {
konvaLayer.destroy();
}
}
return stageClone;
};
getCompositeLayerBlob = (rect?: Rect): Promise<Blob> => {
return konvaNodeToBlob(this.getCompositeLayerStageClone(), rect);
};
getCompositeLayerImageData = (rect?: Rect): ImageData => {
return konvaNodeToImageData(this.getCompositeLayerStageClone(), rect);
};
getCompositeLayerImageDTO = async (rect?: Rect): Promise<ImageDTO> => {
const blob = await this.getCompositeLayerBlob(rect);
const imageDTO = await this.util.uploadImage(blob, 'composite-layer.png', 'general', true);
this.stateApi.setLayerImageCache(imageDTO);
return imageDTO;
};
getInpaintMaskBlob = (rect?: Rect): Promise<Blob> => {
return this.inpaintMask.renderer.getBlob({ rect });
};
getInpaintMaskImageData = (rect?: Rect): ImageData => {
return this.inpaintMask.renderer.getImageData({ rect });
};
getInpaintMaskImageDTO = async (rect?: Rect): Promise<ImageDTO> => {
const blob = await this.inpaintMask.renderer.getBlob({ rect });
const imageDTO = await this.util.uploadImage(blob, 'inpaint-mask.png', 'mask', true);
this.stateApi.setInpaintMaskImageCache(imageDTO);
return imageDTO;
};
getRegionMaskImageDTO = async (id: string, rect?: Rect): Promise<ImageDTO> => {
const region = this.getEntity({ id, type: 'regional_guidance' });
assert(region?.type === 'regional_guidance');
if (region.state.imageCache) {
const imageDTO = await getImageDTO(region.state.imageCache.name);
if (imageDTO) {
return imageDTO;
}
}
return region.adapter.renderer.getImageDTO({
rect,
category: 'other',
is_intermediate: true,
onUploaded: (imageDTO) => {
this.stateApi.setRegionMaskImageCache(region.state.id, imageDTO);
},
});
};
getGenerationMode(): GenerationMode { getGenerationMode(): GenerationMode {
return getGenerationMode({ manager: this }); const { rect } = this.stateApi.getBbox();
const inpaintMaskImageData = this.getInpaintMaskImageData(rect);
const inpaintMaskTransparency = getImageDataTransparency(inpaintMaskImageData);
const compositeLayerImageData = this.getCompositeLayerImageData(rect);
const compositeLayerTransparency = getImageDataTransparency(compositeLayerImageData);
if (compositeLayerTransparency === 'FULLY_TRANSPARENT') {
// When the initial image is fully transparent, we are always doing txt2img
return 'txt2img';
} else if (compositeLayerTransparency === 'PARTIALLY_TRANSPARENT') {
// When the initial image is partially transparent, we are always outpainting
return 'outpaint';
} else if (inpaintMaskTransparency === 'FULLY_TRANSPARENT') {
// compositeLayerTransparency === 'OPAQUE'
// When the inpaint mask is fully transparent, we are doing img2img
return 'img2img';
} else {
// Else at least some of the inpaint mask is opaque, so we are inpainting
return 'inpaint';
}
} }
getControlAdapterImage(arg: Omit<Parameters<typeof getControlAdapterImage>[0], 'manager'>) { getControlAdapterImage(arg: Omit<Parameters<typeof getControlAdapterImage>[0], 'manager'>) {

View File

@ -1,3 +1,4 @@
import { deepClone } from 'common/util/deepClone';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { CanvasObjectRenderer } from 'features/controlLayers/konva/CanvasObjectRenderer'; import { CanvasObjectRenderer } from 'features/controlLayers/konva/CanvasObjectRenderer';
import { CanvasTransformer } from 'features/controlLayers/konva/CanvasTransformer'; import { CanvasTransformer } from 'features/controlLayers/konva/CanvasTransformer';
@ -41,6 +42,8 @@ export class CanvasMaskAdapter {
this.konva = { this.konva = {
layer: new Konva.Layer({ layer: new Konva.Layer({
// We need the ID on the layer to help with building the composite initial image
// See `getCompositeLayerStageClone()`
id: this.id, id: this.id,
name: `${this.type}:layer`, name: `${this.type}:layer`,
listening: false, listening: false,
@ -135,4 +138,12 @@ export class CanvasMaskAdapter {
const isEnabled = get(arg, 'isEnabled', this.state.isEnabled); const isEnabled = get(arg, 'isEnabled', this.state.isEnabled);
this.konva.layer.visible(isEnabled); this.konva.layer.visible(isEnabled);
}; };
repr = () => {
return {
id: this.id,
type: this.type,
state: deepClone(this.state),
};
};
} }

View File

@ -8,18 +8,20 @@ import type { CanvasLayerAdapter } from 'features/controlLayers/konva/CanvasLaye
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import type { CanvasMaskAdapter } from 'features/controlLayers/konva/CanvasMaskAdapter'; import type { CanvasMaskAdapter } from 'features/controlLayers/konva/CanvasMaskAdapter';
import { CanvasRectRenderer } from 'features/controlLayers/konva/CanvasRect'; import { CanvasRectRenderer } from 'features/controlLayers/konva/CanvasRect';
import { getPrefixedId, konvaNodeToBlob, previewBlob } from 'features/controlLayers/konva/util'; import { getPrefixedId, konvaNodeToBlob, konvaNodeToImageData, previewBlob } from 'features/controlLayers/konva/util';
import { import {
type CanvasBrushLineState, type CanvasBrushLineState,
type CanvasEraserLineState, type CanvasEraserLineState,
type CanvasImageState, type CanvasImageState,
type CanvasRectState, type CanvasRectState,
imageDTOToImageObject, imageDTOToImageObject,
type Rect,
type RgbColor, type RgbColor,
} from 'features/controlLayers/store/types'; } from 'features/controlLayers/store/types';
import Konva from 'konva'; import Konva from 'konva';
import type { Logger } from 'roarr'; import type { Logger } from 'roarr';
import { uploadImage } from 'services/api/endpoints/images'; import { uploadImage } from 'services/api/endpoints/images';
import type { ImageCategory, ImageDTO } from 'services/api/types';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
/** /**
@ -348,10 +350,8 @@ export class CanvasObjectRenderer {
rasterize = async () => { rasterize = async () => {
this.log.debug('Rasterizing entity'); this.log.debug('Rasterizing entity');
const objectGroupClone = this.konva.objectGroup.clone(); const rect = this.parent.transformer.getRelativeRect();
const interactionRectClone = this.parent.transformer.konva.proxyRect.clone(); const blob = await this.getBlob({ rect });
const rect = interactionRectClone.getClientRect();
const blob = await konvaNodeToBlob(objectGroupClone, rect);
if (this.manager._isDebugging) { if (this.manager._isDebugging) {
previewBlob(blob, 'Rasterized entity'); previewBlob(blob, 'Rasterized entity');
} }
@ -365,6 +365,33 @@ export class CanvasObjectRenderer {
}); });
}; };
getBlob = ({ rect }: { rect?: Rect }): Promise<Blob> => {
return konvaNodeToBlob(this.konva.objectGroup.clone(), rect);
};
getImageData = ({ rect }: { rect?: Rect }): ImageData => {
return konvaNodeToImageData(this.konva.objectGroup.clone(), rect);
};
getImageDTO = async ({
rect,
category,
is_intermediate,
onUploaded,
}: {
rect?: Rect;
category: ImageCategory;
is_intermediate: boolean;
onUploaded?: (imageDTO: ImageDTO) => void;
}): Promise<ImageDTO> => {
const blob = await this.getBlob({ rect });
const imageDTO = await uploadImage(blob, `${this.id}.png`, category, is_intermediate);
if (onUploaded) {
onUploaded(imageDTO);
}
return imageDTO;
};
/** /**
* Destroys this renderer and all of its object renderers. * Destroys this renderer and all of its object renderers.
*/ */

View File

@ -685,6 +685,10 @@ export class CanvasTransformer {
this.calculateRect(); this.calculateRect();
}; };
getRelativeRect = (): Rect => {
return this.konva.proxyRect.getClientRect({ relativeTo: this.parent.konva.layer });
};
_enableTransform = () => { _enableTransform = () => {
this.isTransformEnabled = true; this.isTransformEnabled = true;
this.konva.transformer.visible(true); this.konva.transformer.visible(true);

View File

@ -1,4 +1,3 @@
import { getImageDataTransparency } from 'common/util/arrayBuffer';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import type { import type {
CanvasObjectState, CanvasObjectState,
@ -329,7 +328,7 @@ export const previewBlob = async (blob: Blob, label?: string) => {
export function getInpaintMaskLayerClone(arg: { manager: CanvasManager }): Konva.Layer { export function getInpaintMaskLayerClone(arg: { manager: CanvasManager }): Konva.Layer {
const { manager } = arg; const { manager } = arg;
const layerClone = manager.inpaintMask.konva.layer.clone(); const layerClone = manager.inpaintMask.konva.layer.clone();
const objectGroupClone = manager.inpaintMask.konva.group.clone(); const objectGroupClone = manager.inpaintMask.renderer.konva.objectGroup.clone();
layerClone.destroyChildren(); layerClone.destroyChildren();
layerClone.add(objectGroupClone); layerClone.add(objectGroupClone);
@ -347,7 +346,7 @@ export function getRegionMaskLayerClone(arg: { manager: CanvasManager; id: strin
assert(canvasRegion, `Canvas region with id ${id} not found`); assert(canvasRegion, `Canvas region with id ${id} not found`);
const layerClone = canvasRegion.konva.layer.clone(); const layerClone = canvasRegion.konva.layer.clone();
const objectGroupClone = canvasRegion.konva.group.clone(); const objectGroupClone = canvasRegion.renderer.konva.objectGroup.clone();
layerClone.destroyChildren(); layerClone.destroyChildren();
layerClone.add(objectGroupClone); layerClone.add(objectGroupClone);
@ -407,27 +406,42 @@ export function getCompositeLayerStageClone(arg: { manager: CanvasManager }): Ko
const validLayers = layersState.entities.filter(isValidLayer); const validLayers = layersState.entities.filter(isValidLayer);
console.log(validLayers); console.log(validLayers);
// Konva bug (?) - when iterating over the array returned from `stage.getLayers()`, if you destroy a layer, the array // getLayers() returns the internal `children` array of the stage directly - calling destroy on a layer will
// is mutated in-place and the next iteration will skip the next layer. To avoid this, we first collect the layers // mutate that array. We need to clone the array to avoid mutating the original.
// to delete in a separate array and then destroy them. for (const konvaLayer of stageClone.getLayers().slice()) {
// TODO(psyche): Maybe report this? if (!validLayers.find((l) => l.id === konvaLayer.id())) {
const toDelete: Konva.Layer[] = []; console.log('destroying', konvaLayer.id());
for (const konvaLayer of stageClone.getLayers()) {
const layer = validLayers.find((l) => l.id === konvaLayer.id());
if (!layer) {
console.log('deleting', konvaLayer);
toDelete.push(konvaLayer);
}
}
for (const konvaLayer of toDelete) {
konvaLayer.destroy(); konvaLayer.destroy();
} }
}
return stageClone; return stageClone;
} }
export type Transparency = 'FULLY_TRANSPARENT' | 'PARTIALLY_TRANSPARENT' | 'OPAQUE';
export function getImageDataTransparency(imageData: ImageData): Transparency {
let isFullyTransparent = true;
let isPartiallyTransparent = false;
const len = imageData.data.length;
for (let i = 3; i < len; i += 4) {
if (imageData.data[i] !== 0) {
isFullyTransparent = false;
} else {
isPartiallyTransparent = true;
}
if (!isFullyTransparent && isPartiallyTransparent) {
return 'PARTIALLY_TRANSPARENT';
}
}
if (isFullyTransparent) {
return 'FULLY_TRANSPARENT';
}
if (isPartiallyTransparent) {
return 'PARTIALLY_TRANSPARENT';
}
return 'OPAQUE';
}
export function getGenerationMode(arg: { manager: CanvasManager }): GenerationMode { export function getGenerationMode(arg: { manager: CanvasManager }): GenerationMode {
const { manager } = arg; const { manager } = arg;
const { x, y, width, height } = manager.stateApi.getBbox().rect; const { x, y, width, height } = manager.stateApi.getBbox().rect;

View File

@ -2,7 +2,7 @@ import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import type { CanvasV2State, Dimensions } from 'features/controlLayers/store/types'; import type { CanvasV2State, Dimensions } from 'features/controlLayers/store/types';
import type { Graph } from 'features/nodes/util/graph/generation/Graph'; import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { ParameterPrecision } from 'features/parameters/types/parameterSchemas'; import type { ParameterPrecision } from 'features/parameters/types/parameterSchemas';
import { isEqual, pick } from 'lodash-es'; import { isEqual } from 'lodash-es';
import type { Invocation } from 'services/api/types'; import type { Invocation } from 'services/api/types';
export const addInpaint = async ( export const addInpaint = async (
@ -21,9 +21,8 @@ export const addInpaint = async (
): Promise<Invocation<'canvas_v2_mask_and_crop'>> => { ): Promise<Invocation<'canvas_v2_mask_and_crop'>> => {
denoise.denoising_start = denoising_start; denoise.denoising_start = denoising_start;
const cropBbox = pick(bbox.rect, ['x', 'y', 'width', 'height']); const initialImage = await manager.getCompositeLayerImageDTO(bbox.rect);
const initialImage = await manager.getInitialImage({ bbox: cropBbox }); const maskImage = await manager.getInpaintMaskImageDTO(bbox.rect);
const maskImage = await manager.getInpaintMaskImage({ bbox: cropBbox });
if (!isEqual(scaledSize, originalSize)) { if (!isEqual(scaledSize, originalSize)) {
// Scale before processing requires some resizing // Scale before processing requires some resizing

View File

@ -3,7 +3,7 @@ import type { CanvasV2State, Dimensions } from 'features/controlLayers/store/typ
import type { Graph } from 'features/nodes/util/graph/generation/Graph'; import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import { getInfill } from 'features/nodes/util/graph/graphBuilderUtils'; import { getInfill } from 'features/nodes/util/graph/graphBuilderUtils';
import type { ParameterPrecision } from 'features/parameters/types/parameterSchemas'; import type { ParameterPrecision } from 'features/parameters/types/parameterSchemas';
import { isEqual, pick } from 'lodash-es'; import { isEqual } from 'lodash-es';
import type { Invocation } from 'services/api/types'; import type { Invocation } from 'services/api/types';
export const addOutpaint = async ( export const addOutpaint = async (
@ -22,9 +22,8 @@ export const addOutpaint = async (
): Promise<Invocation<'canvas_v2_mask_and_crop'>> => { ): Promise<Invocation<'canvas_v2_mask_and_crop'>> => {
denoise.denoising_start = denoising_start; denoise.denoising_start = denoising_start;
const cropBbox = pick(bbox.rect, ['x', 'y', 'width', 'height']); const initialImage = await manager.getCompositeLayerImageDTO(bbox.rect);
const initialImage = await manager.getInitialImage({ bbox: cropBbox }); const maskImage = await manager.getInpaintMaskImageDTO(bbox.rect);
const maskImage = await manager.getInpaintMaskImage({ bbox: cropBbox });
const infill = getInfill(g, compositing); const infill = getInfill(g, compositing);
if (!isEqual(scaledSize, originalSize)) { if (!isEqual(scaledSize, originalSize)) {

View File

@ -1,6 +1,6 @@
import { deepClone } from 'common/util/deepClone'; import { deepClone } from 'common/util/deepClone';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import type { CanvasIPAdapterState, Rect, CanvasRegionalGuidanceState } from 'features/controlLayers/store/types'; import type { CanvasIPAdapterState, CanvasRegionalGuidanceState, Rect } from 'features/controlLayers/store/types';
import { import {
PROMPT_REGION_INVERT_TENSOR_MASK_PREFIX, PROMPT_REGION_INVERT_TENSOR_MASK_PREFIX,
PROMPT_REGION_MASK_TO_TENSOR_PREFIX, PROMPT_REGION_MASK_TO_TENSOR_PREFIX,
@ -44,7 +44,7 @@ export const addRegions = async (
for (const region of validRegions) { for (const region of validRegions) {
// Upload the mask image, or get the cached image if it exists // Upload the mask image, or get the cached image if it exists
const { image_name } = await manager.getRegionMaskImage({ id: region.id, bbox }); const { image_name } = await manager.getRegionMaskImageDTO(region.id, bbox);
// The main mask-to-tensor node // The main mask-to-tensor node
const maskToTensor = g.addNode({ const maskToTensor = g.addNode({