feat(ui): add utils for getting images from canvas

This commit is contained in:
psychedelicious 2024-06-21 16:00:42 +10:00
parent 275fc2ccf9
commit 7dd11bd60a
15 changed files with 277 additions and 241 deletions

View File

@ -1,10 +1,12 @@
import { enqueueRequested } from 'app/store/actions';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { $nodeManager } from 'features/controlLayers/konva/renderers/renderer';
import { isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
import { buildGenerationTabGraph } from 'features/nodes/util/graph/generation/buildGenerationTabGraph';
import { buildGenerationTabSDXLGraph } from 'features/nodes/util/graph/generation/buildGenerationTabSDXLGraph';
import { queueApi } from 'services/api/endpoints/queue';
import { assert } from 'tsafe';
export const addEnqueueRequestedLinear = (startAppListening: AppStartListening) => {
startAppListening({
@ -18,10 +20,13 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
let graph;
const manager = $nodeManager.get();
assert(manager, 'Konva node manager not initialized');
if (model?.base === 'sdxl') {
graph = await buildGenerationTabSDXLGraph(state);
graph = await buildGenerationTabSDXLGraph(state, manager);
} else {
graph = await buildGenerationTabGraph(state);
graph = await buildGenerationTabGraph(state, manager);
}
const batchConfig = prepareLinearUIBatch(state, graph, prepend);

View File

@ -1,3 +1,5 @@
import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
import { blobToDataURL } from 'features/controlLayers/konva/util';
import type {
BrushLine,
BrushLineAddedArg,
@ -15,8 +17,11 @@ import type {
StageAttrs,
Tool,
} from 'features/controlLayers/store/types';
import { isValidLayer } from 'features/nodes/util/graph/generation/addLayers';
import type Konva from 'konva';
import type { Vector2d } from 'konva/lib/types';
import { getImageDTO as defaultGetImageDTO, uploadImage as defaultUploadImage } from 'services/api/endpoints/images';
import type { ImageCategory, ImageDTO } from 'services/api/types';
import { assert } from 'tsafe';
export type BrushLineObjectRecord = {
@ -132,24 +137,53 @@ type StateApi = {
getMetaKey: () => boolean;
getAltKey: () => boolean;
getDocument: () => CanvasV2State['document'];
getLayerEntityStates: () => CanvasV2State['layers']['entities'];
getControlAdapterEntityStates: () => CanvasV2State['controlAdapters']['entities'];
getRegionEntityStates: () => CanvasV2State['regions']['entities'];
getInpaintMaskEntityState: () => CanvasV2State['inpaintMask'];
getLayersState: () => CanvasV2State['layers'];
getControlAdaptersState: () => CanvasV2State['controlAdapters'];
getRegionsState: () => CanvasV2State['regions'];
getInpaintMaskState: () => CanvasV2State['inpaintMask'];
onInpaintMaskImageCached: (imageDTO: ImageDTO) => void;
onRegionMaskImageCached: (id: string, imageDTO: ImageDTO) => void;
onLayerImageCached: (imageDTO: ImageDTO) => void;
};
type Util = {
getImageDTO: (imageName: string) => Promise<ImageDTO | null>;
uploadImage: (
blob: Blob,
fileName: string,
image_category: ImageCategory,
is_intermediate: boolean
) => Promise<ImageDTO>;
getRegionMaskImage: (arg: { id: string; bbox?: Rect; preview?: boolean }) => Promise<ImageDTO>;
getInpaintMaskImage: (arg: { bbox?: Rect; preview?: boolean }) => Promise<ImageDTO>;
getImageSourceImage: (arg: { bbox?: Rect; preview?: boolean }) => Promise<ImageDTO>;
};
export class KonvaNodeManager {
stage: Konva.Stage;
container: HTMLDivElement;
adapters: Map<string, KonvaEntityAdapter>;
util: Util;
_background: BackgroundLayer | null;
_preview: PreviewLayer | null;
_konvaApi: KonvaApi | null;
_stateApi: StateApi | null;
constructor(stage: Konva.Stage, container: HTMLDivElement) {
constructor(
stage: Konva.Stage,
container: HTMLDivElement,
getImageDTO: Util['getImageDTO'] = defaultGetImageDTO,
uploadImage: Util['uploadImage'] = defaultUploadImage
) {
this.stage = stage;
this.container = container;
this.util = {
getImageDTO,
uploadImage,
getRegionMaskImage: this._getRegionMaskImage.bind(this),
getInpaintMaskImage: this._getInpaintMaskImage.bind(this),
getImageSourceImage: this._getImageSourceImage.bind(this),
};
this._konvaApi = null;
this._preview = null;
this._background = null;
@ -219,6 +253,152 @@ export class KonvaNodeManager {
assert(this._stateApi !== null, 'State API has not been set');
return this._stateApi;
}
async _getRegionMaskImage(arg: { id: string; bbox?: Rect; preview?: boolean }): Promise<ImageDTO> {
const { id, bbox, preview = false } = arg;
const region = this.stateApi.getRegionsState().entities.find((entity) => entity.id === id);
assert(region, `Region entity state with id ${id} not found`);
const adapter = this.get(region.id);
assert(adapter, `Adapter for region ${region.id} not found`);
if (region.imageCache) {
const imageDTO = await this.util.getImageDTO(region.imageCache.name);
if (imageDTO) {
return imageDTO;
}
}
const layer = adapter.konvaLayer.clone();
const objectGroup = adapter.konvaObjectGroup.clone();
layer.destroyChildren();
layer.add(objectGroup);
objectGroup.opacity(1);
objectGroup.cache();
const blob = await new Promise<Blob>((resolve) => {
layer.toBlob({
callback: (blob) => {
assert(blob, 'Blob is null');
resolve(blob);
},
...bbox,
});
});
if (preview) {
const base64 = await blobToDataURL(blob);
const caption = `${region.id}: ${region.positivePrompt} / ${region.negativePrompt}`;
openBase64ImageInTab([{ base64, caption }]);
}
layer.destroy();
const imageDTO = await this.util.uploadImage(blob, `${region.id}_mask.png`, 'mask', true);
this.stateApi.onRegionMaskImageCached(region.id, imageDTO);
return imageDTO;
}
async _getInpaintMaskImage(arg: { bbox?: Rect; preview?: boolean }): Promise<ImageDTO> {
const { bbox, preview = false } = arg;
const inpaintMask = this.stateApi.getInpaintMaskState();
const adapter = this.get(inpaintMask.id);
assert(adapter, `Adapter for ${inpaintMask.id} not found`);
if (inpaintMask.imageCache) {
const imageDTO = await this.util.getImageDTO(inpaintMask.imageCache.name);
if (imageDTO) {
return imageDTO;
}
}
const layer = adapter.konvaLayer.clone();
const objectGroup = adapter.konvaObjectGroup.clone();
layer.destroyChildren();
layer.add(objectGroup);
objectGroup.opacity(1);
objectGroup.cache();
const blob = await new Promise<Blob>((resolve) => {
layer.toBlob({
callback: (blob) => {
assert(blob, 'Blob is null');
resolve(blob);
},
...bbox,
});
});
if (preview) {
const base64 = await blobToDataURL(blob);
const caption = 'inpaint mask';
openBase64ImageInTab([{ base64, caption }]);
}
layer.destroy();
const imageDTO = await this.util.uploadImage(blob, 'inpaint_mask.png', 'mask', true);
this.stateApi.onInpaintMaskImageCached(imageDTO);
return imageDTO;
}
async _getImageSourceImage(arg: { bbox?: Rect; preview?: boolean }): Promise<ImageDTO> {
const { bbox, preview = false } = arg;
const layersState = this.stateApi.getLayersState();
const { entities, imageCache } = layersState;
if (imageCache) {
const imageDTO = await this.util.getImageDTO(imageCache.name);
if (imageDTO) {
return imageDTO;
}
}
const stage = this.stage.clone();
stage.scaleX(1);
stage.scaleY(1);
stage.x(0);
stage.y(0);
const validLayers = entities.filter(isValidLayer);
// Konva bug (?) - when iterating over the array returned from `stage.getLayers()`, if you destroy a layer, the array
// is mutated in-place and the next iteration will skip the next layer. To avoid this, we first collect the layers
// to delete in a separate array and then destroy them.
// TODO(psyche): Maybe report this?
const toDelete: Konva.Layer[] = [];
for (const konvaLayer of stage.getLayers()) {
const layer = validLayers.find((l) => l.id === konvaLayer.id());
if (!layer) {
toDelete.push(konvaLayer);
}
}
for (const konvaLayer of toDelete) {
konvaLayer.destroy();
}
const blob = await new Promise<Blob>((resolve) => {
stage.toBlob({
callback: (blob) => {
assert(blob, 'Blob is null');
resolve(blob);
},
...bbox,
});
});
if (preview) {
const base64 = await blobToDataURL(blob);
openBase64ImageInTab([{ base64, caption: 'base layer' }]);
}
stage.destroy();
const imageDTO = await this.util.uploadImage(blob, 'base_layer.png', 'general', true);
this.stateApi.onLayerImageCached(imageDTO);
return imageDTO;
}
}
export class KonvaEntityAdapter {

View File

@ -6,12 +6,12 @@ import type { KonvaNodeManager } from 'features/controlLayers/konva/nodeManager'
* @returns An arrange entities function
*/
export const getArrangeEntities = (manager: KonvaNodeManager) => {
const { getLayerEntityStates, getControlAdapterEntityStates, getRegionEntityStates } = manager.stateApi;
const { getLayersState, getControlAdaptersState, getRegionsState } = manager.stateApi;
function arrangeEntities(): void {
const layers = getLayerEntityStates();
const controlAdapters = getControlAdapterEntityStates();
const regions = getRegionEntityStates();
const layers = getLayersState().entities;
const controlAdapters = getControlAdaptersState().entities;
const regions = getRegionsState().entities;
let zIndex = 0;
manager.background.layer.zIndex(++zIndex);
for (const layer of layers) {

View File

@ -100,10 +100,10 @@ export const renderControlAdapter = async (manager: KonvaNodeManager, entity: Co
* @returns A function to render all control adapters
*/
export const getRenderControlAdapters = (manager: KonvaNodeManager) => {
const { getControlAdapterEntityStates } = manager.stateApi;
const { getControlAdaptersState } = manager.stateApi;
function renderControlAdapters(): void {
const entities = getControlAdapterEntityStates();
const { entities } = getControlAdaptersState();
// Destroy nonexistent layers
for (const adapters of manager.getAll('control_adapter')) {
if (!entities.find((ca) => ca.id === adapters.id)) {

View File

@ -71,10 +71,10 @@ const getInpaintMask = (
* @returns A function to render the inpaint mask
*/
export const getRenderInpaintMask = (manager: KonvaNodeManager) => {
const { getInpaintMaskEntityState, getMaskOpacity, getToolState, getSelectedEntity, onPosChanged } = manager.stateApi;
const { getInpaintMaskState, getMaskOpacity, getToolState, getSelectedEntity, onPosChanged } = manager.stateApi;
function renderInpaintMask(): void {
const entity = getInpaintMaskEntityState();
const entity = getInpaintMaskState();
const globalMaskLayerOpacity = getMaskOpacity();
const toolState = getToolState();
const selectedEntity = getSelectedEntity();

View File

@ -136,10 +136,10 @@ export const renderLayer = async (
* @returns A function to render all layers
*/
export const getRenderLayers = (manager: KonvaNodeManager) => {
const { getLayerEntityStates, getToolState, onPosChanged } = manager.stateApi;
const { getLayersState, getToolState, onPosChanged } = manager.stateApi;
function renderLayers(): void {
const entities = getLayerEntityStates();
const { entities } = getLayersState();
const tool = getToolState();
// Destroy nonexistent layers
for (const adapter of manager.getAll('layer')) {

View File

@ -233,10 +233,10 @@ export const renderRegion = (
* @returns A function to render all regions
*/
export const getRenderRegions = (manager: KonvaNodeManager) => {
const { getRegionEntityStates, getMaskOpacity, getToolState, getSelectedEntity, onPosChanged } = manager.stateApi;
const { getRegionsState, getMaskOpacity, getToolState, getSelectedEntity, onPosChanged } = manager.stateApi;
function renderRegions(): void {
const entities = getRegionEntityStates();
const { entities } = getRegionsState();
const maskOpacity = getMaskOpacity();
const toolState = getToolState();
const selectedEntity = getSelectedEntity();

View File

@ -32,17 +32,20 @@ import {
imBboxChanged,
imBrushLineAdded,
imEraserLineAdded,
imImageCacheChanged,
imLinePointAdded,
imTranslated,
layerBboxChanged,
layerBrushLineAdded,
layerEraserLineAdded,
layerImageCacheChanged,
layerLinePointAdded,
layerRectAdded,
layerTranslated,
rgBboxChanged,
rgBrushLineAdded,
rgEraserLineAdded,
rgImageCacheChanged,
rgLinePointAdded,
rgRectAdded,
rgTranslated,
@ -65,6 +68,7 @@ import type { IRect, Vector2d } from 'konva/lib/types';
import { debounce } from 'lodash-es';
import { atom } from 'nanostores';
import type { RgbaColor } from 'react-colorful';
import type { ImageDTO } from 'services/api/types';
export const $nodeManager = atom<KonvaNodeManager | null>(null);
@ -175,6 +179,19 @@ export const initializeRenderer = (
logIfDebugging('Eraser width changed');
dispatch(eraserWidthChanged(width));
};
const onRegionMaskImageCached = (id: string, imageDTO: ImageDTO) => {
logIfDebugging('Region mask image cached');
dispatch(rgImageCacheChanged({ id, imageDTO }));
};
const onInpaintMaskImageCached = (imageDTO: ImageDTO) => {
logIfDebugging('Inpaint mask image cached');
dispatch(imImageCacheChanged({ imageDTO }));
};
const onLayerImageCached = (imageDTO: ImageDTO) => {
logIfDebugging('Layer image cached');
dispatch(layerImageCacheChanged({ imageDTO }));
};
const setTool = (tool: Tool) => {
logIfDebugging('Tool selection changed');
dispatch(toolChanged(tool));
@ -240,11 +257,11 @@ export const initializeRenderer = (
const getDocument = () => canvasV2.document;
const getToolState = () => canvasV2.tool;
const getSettings = () => canvasV2.settings;
const getRegionEntityStates = () => canvasV2.regions.entities;
const getLayerEntityStates = () => canvasV2.layers.entities;
const getControlAdapterEntityStates = () => canvasV2.controlAdapters.entities;
const getRegionsState = () => canvasV2.regions;
const getLayersState = () => canvasV2.layers;
const getControlAdaptersState = () => canvasV2.controlAdapters;
const getInpaintMaskState = () => canvasV2.inpaintMask;
const getMaskOpacity = () => canvasV2.settings.maskOpacity;
const getInpaintMaskEntityState = () => canvasV2.inpaintMask;
// Read-write state, ephemeral interaction state
let isDrawing = false;
@ -309,12 +326,12 @@ export const initializeRenderer = (
getCtrlKey: $ctrl.get,
getMetaKey: $meta.get,
getShiftKey: $shift.get,
getControlAdapterEntityStates,
getControlAdaptersState,
getDocument,
getLayerEntityStates,
getRegionEntityStates,
getLayersState,
getRegionsState,
getMaskOpacity,
getInpaintMaskEntityState,
getInpaintMaskState,
// Read-write state
setTool,
@ -342,6 +359,9 @@ export const initializeRenderer = (
onEraserWidthChanged,
onPosChanged,
onBboxTransformed,
onRegionMaskImageCached,
onInpaintMaskImageCached,
onLayerImageCached,
};
const cleanupListeners = setStageEventHandlers(manager);

View File

@ -24,7 +24,7 @@ import { DEFAULT_RGBA_COLOR } from './types';
const initialState: CanvasV2State = {
_version: 3,
selectedEntityIdentifier: { type: 'inpaint_mask', id: 'inpaint_mask' },
layers: { entities: [], baseLayerImageCache: null },
layers: { entities: [], imageCache: null },
controlAdapters: { entities: [] },
ipAdapters: { entities: [] },
regions: { entities: [] },
@ -161,7 +161,7 @@ export const canvasV2Slice = createSlice({
allEntitiesDeleted: (state) => {
state.regions.entities = [];
state.layers.entities = [];
state.layers.baseLayerImageCache = null;
state.layers.imageCache = null;
state.ipAdapters.entities = [];
state.controlAdapters.entities = [];
},
@ -185,7 +185,6 @@ export const {
scaledBboxChanged,
bboxScaleMethodChanged,
clipToBboxChanged,
baseLayerImageCacheChanged,
// layers
layerAdded,
layerRecalled,
@ -205,6 +204,7 @@ export const {
layerRectAdded,
layerImageAdded,
layerAllDeleted,
layerImageCacheChanged,
// IP Adapters
ipaAdded,
ipaRecalled,
@ -255,7 +255,7 @@ export const {
rgPositivePromptChanged,
rgNegativePromptChanged,
rgFillChanged,
rgMaskImageUploaded,
rgImageCacheChanged,
rgAutoNegativeChanged,
rgIPAdapterAdded,
rgIPAdapterDeleted,

View File

@ -40,7 +40,7 @@ export const layersReducers = {
y: 0,
});
state.selectedEntityIdentifier = { type: 'layer', id };
state.layers.baseLayerImageCache = null;
state.layers.imageCache = null;
},
prepare: () => ({ payload: { id: uuidv4() } }),
},
@ -48,7 +48,7 @@ export const layersReducers = {
const { data } = action.payload;
state.layers.entities.push(data);
state.selectedEntityIdentifier = { type: 'layer', id: data.id };
state.layers.baseLayerImageCache = null;
state.layers.imageCache = null;
},
layerIsEnabledToggled: (state, action: PayloadAction<{ id: string }>) => {
const { id } = action.payload;
@ -57,7 +57,7 @@ export const layersReducers = {
return;
}
layer.isEnabled = !layer.isEnabled;
state.layers.baseLayerImageCache = null;
state.layers.imageCache = null;
},
layerTranslated: (state, action: PayloadAction<{ id: string; x: number; y: number }>) => {
const { id, x, y } = action.payload;
@ -67,7 +67,7 @@ export const layersReducers = {
}
layer.x = x;
layer.y = y;
state.layers.baseLayerImageCache = null;
state.layers.imageCache = null;
},
layerBboxChanged: (state, action: PayloadAction<{ id: string; bbox: IRect | null }>) => {
const { id, bbox } = action.payload;
@ -93,16 +93,16 @@ export const layersReducers = {
layer.objects = [];
layer.bbox = null;
layer.bboxNeedsUpdate = false;
state.layers.baseLayerImageCache = null;
state.layers.imageCache = null;
},
layerDeleted: (state, action: PayloadAction<{ id: string }>) => {
const { id } = action.payload;
state.layers.entities = state.layers.entities.filter((l) => l.id !== id);
state.layers.baseLayerImageCache = null;
state.layers.imageCache = null;
},
layerAllDeleted: (state) => {
state.layers.entities = [];
state.layers.baseLayerImageCache = null;
state.layers.imageCache = null;
},
layerOpacityChanged: (state, action: PayloadAction<{ id: string; opacity: number }>) => {
const { id, opacity } = action.payload;
@ -111,7 +111,7 @@ export const layersReducers = {
return;
}
layer.opacity = opacity;
state.layers.baseLayerImageCache = null;
state.layers.imageCache = null;
},
layerMovedForwardOne: (state, action: PayloadAction<{ id: string }>) => {
const { id } = action.payload;
@ -120,7 +120,7 @@ export const layersReducers = {
return;
}
moveOneToEnd(state.layers.entities, layer);
state.layers.baseLayerImageCache = null;
state.layers.imageCache = null;
},
layerMovedToFront: (state, action: PayloadAction<{ id: string }>) => {
const { id } = action.payload;
@ -129,7 +129,7 @@ export const layersReducers = {
return;
}
moveToEnd(state.layers.entities, layer);
state.layers.baseLayerImageCache = null;
state.layers.imageCache = null;
},
layerMovedBackwardOne: (state, action: PayloadAction<{ id: string }>) => {
const { id } = action.payload;
@ -138,7 +138,7 @@ export const layersReducers = {
return;
}
moveOneToStart(state.layers.entities, layer);
state.layers.baseLayerImageCache = null;
state.layers.imageCache = null;
},
layerMovedToBack: (state, action: PayloadAction<{ id: string }>) => {
const { id } = action.payload;
@ -147,7 +147,7 @@ export const layersReducers = {
return;
}
moveToStart(state.layers.entities, layer);
state.layers.baseLayerImageCache = null;
state.layers.imageCache = null;
},
layerBrushLineAdded: {
reducer: (state, action: PayloadAction<BrushLineAddedArg & { lineId: string }>) => {
@ -166,7 +166,7 @@ export const layersReducers = {
clip,
});
layer.bboxNeedsUpdate = true;
state.layers.baseLayerImageCache = null;
state.layers.imageCache = null;
},
prepare: (payload: BrushLineAddedArg) => ({
payload: { ...payload, lineId: uuidv4() },
@ -188,7 +188,7 @@ export const layersReducers = {
clip,
});
layer.bboxNeedsUpdate = true;
state.layers.baseLayerImageCache = null;
state.layers.imageCache = null;
},
prepare: (payload: EraserLineAddedArg) => ({
payload: { ...payload, lineId: uuidv4() },
@ -206,7 +206,7 @@ export const layersReducers = {
}
lastObject.points.push(...point);
layer.bboxNeedsUpdate = true;
state.layers.baseLayerImageCache = null;
state.layers.imageCache = null;
},
layerRectAdded: {
reducer: (state, action: PayloadAction<RectShapeAddedArg & { rectId: string }>) => {
@ -226,7 +226,7 @@ export const layersReducers = {
color,
});
layer.bboxNeedsUpdate = true;
state.layers.baseLayerImageCache = null;
state.layers.imageCache = null;
},
prepare: (payload: RectShapeAddedArg) => ({ payload: { ...payload, rectId: uuidv4() } }),
},
@ -239,11 +239,12 @@ export const layersReducers = {
}
layer.objects.push(imageDTOToImageObject(id, objectId, imageDTO));
layer.bboxNeedsUpdate = true;
state.layers.baseLayerImageCache = null;
state.layers.imageCache = null;
},
prepare: (payload: ImageObjectAddedArg) => ({ payload: { ...payload, objectId: uuidv4() } }),
},
baseLayerImageCacheChanged: (state, action: PayloadAction<ImageDTO | null>) => {
state.layers.baseLayerImageCache = action.payload ? imageDTOToImageWithDims(action.payload) : null;
layerImageCacheChanged: (state, action: PayloadAction<{ imageDTO: ImageDTO | null }>) => {
const { imageDTO } = action.payload;
state.layers.imageCache = imageDTO ? imageDTOToImageWithDims(imageDTO) : null;
},
} satisfies SliceCaseReducers<CanvasV2State>;

View File

@ -1,11 +1,7 @@
import type { PayloadAction, SliceCaseReducers } from '@reduxjs/toolkit';
import { moveOneToEnd, moveOneToStart, moveToEnd, moveToStart } from 'common/util/arrayUtils';
import { getBrushLineId, getEraserLineId, getRectShapeId } from 'features/controlLayers/konva/naming';
import type {
CanvasV2State,
CLIPVisionModelV2,
IPMethodV2,
} from 'features/controlLayers/store/types';
import type { CanvasV2State, CLIPVisionModelV2, IPMethodV2 } from 'features/controlLayers/store/types';
import { imageDTOToImageObject, imageDTOToImageWithDims } from 'features/controlLayers/store/types';
import { zModelIdentifierField } from 'features/nodes/types/common';
import type { ParameterAutoNegative } from 'features/parameters/types/parameterSchemas';
@ -182,7 +178,7 @@ export const regionsReducers = {
}
rg.fill = fill;
},
rgMaskImageUploaded: (state, action: PayloadAction<{ id: string; imageDTO: ImageDTO }>) => {
rgImageCacheChanged: (state, action: PayloadAction<{ id: string; imageDTO: ImageDTO }>) => {
const { id, imageDTO } = action.payload;
const rg = selectRG(state, id);
if (!rg) {

View File

@ -797,7 +797,7 @@ export type CanvasV2State = {
selectedEntityIdentifier: CanvasEntityIdentifier | null;
inpaintMask: InpaintMaskEntity;
layers: {
baseLayerImageCache: ImageWithDims | null;
imageCache: ImageWithDims | null;
entities: LayerEntity[];
};
controlAdapters: { entities: ControlAdapterEntity[] };

View File

@ -1,96 +1,9 @@
import { getStore } from 'app/store/nanostores/store';
import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
import { $nodeManager } from 'features/controlLayers/konva/renderers/renderer';
import { blobToDataURL } from 'features/controlLayers/konva/util';
import { baseLayerImageCacheChanged } from 'features/controlLayers/store/canvasV2Slice';
import type { LayerEntity } from 'features/controlLayers/store/types';
import type Konva from 'konva';
import type { IRect } from 'konva/lib/types';
import { getImageDTO, imagesApi } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
import { assert } from 'tsafe';
const isValidLayer = (entity: LayerEntity) => {
export const isValidLayer = (entity: LayerEntity) => {
return (
entity.isEnabled &&
// Boolean(entity.bbox) && TODO(psyche): Re-enable this check when we have a way to calculate bbox for all layers
entity.objects.length > 0
);
};
/**
* Get the blobs of all regional prompt layers. Only visible layers are returned.
* @param layerIds The IDs of the layers to get blobs for. If not provided, all regional prompt layers are used.
* @param preview Whether to open a new tab displaying each layer.
* @returns A map of layer IDs to blobs.
*/
const getBaseLayer = async (layers: LayerEntity[], bbox: IRect, preview: boolean = false): Promise<Blob> => {
const manager = $nodeManager.get();
assert(manager, 'Node manager is null');
const stage = manager.stage.clone();
stage.scaleX(1);
stage.scaleY(1);
stage.x(0);
stage.y(0);
const validLayers = layers.filter(isValidLayer);
// Konva bug (?) - when iterating over the array returned from `stage.getLayers()`, if you destroy a layer, the array
// is mutated in-place and the next iteration will skip the next layer. To avoid this, we first collect the layers
// to delete in a separate array and then destroy them.
// TODO(psyche): Maybe report this?
const toDelete: Konva.Layer[] = [];
for (const konvaLayer of stage.getLayers()) {
const layer = validLayers.find((l) => l.id === konvaLayer.id());
if (!layer) {
toDelete.push(konvaLayer);
}
}
for (const konvaLayer of toDelete) {
konvaLayer.destroy();
}
const blob = await new Promise<Blob>((resolve) => {
stage.toBlob({
callback: (blob) => {
assert(blob, 'Blob is null');
resolve(blob);
},
...bbox,
});
});
if (preview) {
const base64 = await blobToDataURL(blob);
openBase64ImageInTab([{ base64, caption: 'base layer' }]);
}
stage.destroy();
return blob;
};
export const getBaseLayerImage = async (): Promise<ImageDTO> => {
const { dispatch, getState } = getStore();
const state = getState();
if (state.canvasV2.layers.baseLayerImageCache) {
const imageDTO = await getImageDTO(state.canvasV2.layers.baseLayerImageCache.name);
if (imageDTO) {
return imageDTO;
}
}
const blob = await getBaseLayer(state.canvasV2.layers.entities, state.canvasV2.bbox, true);
const file = new File([blob], 'image.png', { type: 'image/png' });
const req = dispatch(
imagesApi.endpoints.uploadImage.initiate({ file, image_category: 'general', is_intermediate: true })
);
req.reset();
const imageDTO = await req.unwrap();
dispatch(baseLayerImageCacheChanged(imageDTO));
return imageDTO;
};

View File

@ -1,10 +1,5 @@
import { getStore } from 'app/store/nanostores/store';
import { deepClone } from 'common/util/deepClone';
import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
import type { KonvaEntityAdapter } from 'features/controlLayers/konva/nodeManager';
import { $nodeManager } from 'features/controlLayers/konva/renderers/renderer';
import { blobToDataURL } from 'features/controlLayers/konva/util';
import { rgMaskImageUploaded } from 'features/controlLayers/store/canvasV2Slice';
import type { KonvaNodeManager } from 'features/controlLayers/konva/nodeManager';
import type { Dimensions, IPAdapterEntity, RegionEntity } from 'features/controlLayers/store/types';
import {
PROMPT_REGION_INVERT_TENSOR_MASK_PREFIX,
@ -16,8 +11,7 @@ import {
import { addIPAdapterCollectorSafe, isValidIPAdapter } from 'features/nodes/util/graph/generation/addIPAdapters';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { IRect } from 'konva/lib/types';
import { getImageDTO, imagesApi } from 'services/api/endpoints/images';
import type { BaseModelType, ImageDTO, Invocation } from 'services/api/types';
import type { BaseModelType, Invocation } from 'services/api/types';
import { assert } from 'tsafe';
/**
@ -34,6 +28,7 @@ import { assert } from 'tsafe';
*/
export const addRegions = async (
manager: KonvaNodeManager,
regions: RegionEntity[],
g: Graph,
documentSize: Dimensions,
@ -51,7 +46,7 @@ export const addRegions = async (
for (const region of validRegions) {
// Upload the mask image, or get the cached image if it exists
const { image_name } = await getRegionMaskImage(region, bbox, true);
const { image_name } = await manager.util.getRegionMaskImage({ id: region.id, bbox, preview: true });
// The main mask-to-tensor node
const maskToTensor = g.addNode({
@ -217,90 +212,3 @@ export const isValidRegion = (rg: RegionEntity, base: BaseModelType) => {
const hasIPAdapter = rg.ipAdapters.filter((ipa) => isValidIPAdapter(ipa, base)).length > 0;
return hasTextPrompt || hasIPAdapter;
};
export const getMaskImage = async (rg: RegionEntity, blob: Blob): Promise<ImageDTO> => {
const { id, imageCache } = rg;
if (imageCache) {
const imageDTO = await getImageDTO(imageCache.name);
if (imageDTO) {
return imageDTO;
}
}
const { dispatch } = getStore();
// No cached mask, or the cached image no longer exists - we need to upload the mask image
const file = new File([blob], `${rg.id}_mask.png`, { type: 'image/png' });
const req = dispatch(
imagesApi.endpoints.uploadImage.initiate({ file, image_category: 'mask', is_intermediate: true })
);
req.reset();
const imageDTO = await req.unwrap();
dispatch(rgMaskImageUploaded({ id, imageDTO }));
return imageDTO;
};
export const uploadMaskImage = async ({ id }: RegionEntity, blob: Blob): Promise<ImageDTO> => {
const { dispatch } = getStore();
// No cached mask, or the cached image no longer exists - we need to upload the mask image
const file = new File([blob], `${id}_mask.png`, { type: 'image/png' });
const req = dispatch(
imagesApi.endpoints.uploadImage.initiate({ file, image_category: 'mask', is_intermediate: true })
);
req.reset();
const imageDTO = await req.unwrap();
dispatch(rgMaskImageUploaded({ id, imageDTO }));
return imageDTO;
};
/**
* Get the blobs of all regional prompt layers. Only visible layers are returned.
* @param layerIds The IDs of the layers to get blobs for. If not provided, all regional prompt layers are used.
* @param preview Whether to open a new tab displaying each layer.
* @returns A map of layer IDs to blobs.
*/
export const getRegionMaskImage = async (
region: RegionEntity,
bbox: IRect,
preview: boolean = false
): Promise<ImageDTO> => {
const manager = $nodeManager.get();
assert(manager, 'Node manager is null');
// TODO(psyche): Why do I need to annotate this? TS must have some kind of circular ref w/ this type but I can't figure it out...
const adapter: KonvaEntityAdapter | undefined = manager.get(region.id);
assert(adapter, `Adapter for region ${region.id} not found`);
if (region.imageCache) {
const imageDTO = await getImageDTO(region.imageCache.name);
if (imageDTO) {
return imageDTO;
}
}
const layer = adapter.konvaLayer.clone();
const objectGroup = adapter.konvaObjectGroup.clone();
layer.destroyChildren();
layer.add(objectGroup);
objectGroup.opacity(1);
objectGroup.cache();
const blob = await new Promise<Blob>((resolve) => {
layer.toBlob({
callback: (blob) => {
assert(blob, 'Blob is null');
resolve(blob);
},
...bbox,
});
});
if (preview) {
const base64 = await blobToDataURL(blob);
const caption = `${region.id}: ${region.positivePrompt} / ${region.negativePrompt}`;
openBase64ImageInTab([{ base64, caption }]);
}
layer.destroy();
return await uploadMaskImage(region, blob);
};

View File

@ -588,3 +588,16 @@ export const getImageDTO = async (image_name: string, forceRefetch?: boolean): P
return null;
}
};
export const uploadImage = async (
blob: Blob,
fileName: string,
image_category: ImageCategory,
is_intermediate: boolean
): Promise<ImageDTO> => {
const { dispatch } = getStore();
const file = new File([blob], fileName, { type: 'image/png' });
const req = dispatch(imagesApi.endpoints.uploadImage.initiate({ file, image_category, is_intermediate }));
req.reset();
return await req.unwrap();
};