feat(ui): add utils for getting images from canvas

This commit is contained in:
psychedelicious 2024-06-21 16:00:42 +10:00
parent 275fc2ccf9
commit 7dd11bd60a
15 changed files with 277 additions and 241 deletions

View File

@ -1,10 +1,12 @@
import { enqueueRequested } from 'app/store/actions'; import { enqueueRequested } from 'app/store/actions';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { $nodeManager } from 'features/controlLayers/konva/renderers/renderer';
import { isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice'; import { isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig'; import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
import { buildGenerationTabGraph } from 'features/nodes/util/graph/generation/buildGenerationTabGraph'; import { buildGenerationTabGraph } from 'features/nodes/util/graph/generation/buildGenerationTabGraph';
import { buildGenerationTabSDXLGraph } from 'features/nodes/util/graph/generation/buildGenerationTabSDXLGraph'; import { buildGenerationTabSDXLGraph } from 'features/nodes/util/graph/generation/buildGenerationTabSDXLGraph';
import { queueApi } from 'services/api/endpoints/queue'; import { queueApi } from 'services/api/endpoints/queue';
import { assert } from 'tsafe';
export const addEnqueueRequestedLinear = (startAppListening: AppStartListening) => { export const addEnqueueRequestedLinear = (startAppListening: AppStartListening) => {
startAppListening({ startAppListening({
@ -18,10 +20,13 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
let graph; let graph;
const manager = $nodeManager.get();
assert(manager, 'Konva node manager not initialized');
if (model?.base === 'sdxl') { if (model?.base === 'sdxl') {
graph = await buildGenerationTabSDXLGraph(state); graph = await buildGenerationTabSDXLGraph(state, manager);
} else { } else {
graph = await buildGenerationTabGraph(state); graph = await buildGenerationTabGraph(state, manager);
} }
const batchConfig = prepareLinearUIBatch(state, graph, prepend); const batchConfig = prepareLinearUIBatch(state, graph, prepend);

View File

@ -1,3 +1,5 @@
import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
import { blobToDataURL } from 'features/controlLayers/konva/util';
import type { import type {
BrushLine, BrushLine,
BrushLineAddedArg, BrushLineAddedArg,
@ -15,8 +17,11 @@ import type {
StageAttrs, StageAttrs,
Tool, Tool,
} from 'features/controlLayers/store/types'; } from 'features/controlLayers/store/types';
import { isValidLayer } from 'features/nodes/util/graph/generation/addLayers';
import type Konva from 'konva'; import type Konva from 'konva';
import type { Vector2d } from 'konva/lib/types'; import type { Vector2d } from 'konva/lib/types';
import { getImageDTO as defaultGetImageDTO, uploadImage as defaultUploadImage } from 'services/api/endpoints/images';
import type { ImageCategory, ImageDTO } from 'services/api/types';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
export type BrushLineObjectRecord = { export type BrushLineObjectRecord = {
@ -132,24 +137,53 @@ type StateApi = {
getMetaKey: () => boolean; getMetaKey: () => boolean;
getAltKey: () => boolean; getAltKey: () => boolean;
getDocument: () => CanvasV2State['document']; getDocument: () => CanvasV2State['document'];
getLayerEntityStates: () => CanvasV2State['layers']['entities']; getLayersState: () => CanvasV2State['layers'];
getControlAdapterEntityStates: () => CanvasV2State['controlAdapters']['entities']; getControlAdaptersState: () => CanvasV2State['controlAdapters'];
getRegionEntityStates: () => CanvasV2State['regions']['entities']; getRegionsState: () => CanvasV2State['regions'];
getInpaintMaskEntityState: () => CanvasV2State['inpaintMask']; getInpaintMaskState: () => CanvasV2State['inpaintMask'];
onInpaintMaskImageCached: (imageDTO: ImageDTO) => void;
onRegionMaskImageCached: (id: string, imageDTO: ImageDTO) => void;
onLayerImageCached: (imageDTO: ImageDTO) => void;
};
type Util = {
getImageDTO: (imageName: string) => Promise<ImageDTO | null>;
uploadImage: (
blob: Blob,
fileName: string,
image_category: ImageCategory,
is_intermediate: boolean
) => Promise<ImageDTO>;
getRegionMaskImage: (arg: { id: string; bbox?: Rect; preview?: boolean }) => Promise<ImageDTO>;
getInpaintMaskImage: (arg: { bbox?: Rect; preview?: boolean }) => Promise<ImageDTO>;
getImageSourceImage: (arg: { bbox?: Rect; preview?: boolean }) => Promise<ImageDTO>;
}; };
export class KonvaNodeManager { export class KonvaNodeManager {
stage: Konva.Stage; stage: Konva.Stage;
container: HTMLDivElement; container: HTMLDivElement;
adapters: Map<string, KonvaEntityAdapter>; adapters: Map<string, KonvaEntityAdapter>;
util: Util;
_background: BackgroundLayer | null; _background: BackgroundLayer | null;
_preview: PreviewLayer | null; _preview: PreviewLayer | null;
_konvaApi: KonvaApi | null; _konvaApi: KonvaApi | null;
_stateApi: StateApi | null; _stateApi: StateApi | null;
constructor(stage: Konva.Stage, container: HTMLDivElement) { constructor(
stage: Konva.Stage,
container: HTMLDivElement,
getImageDTO: Util['getImageDTO'] = defaultGetImageDTO,
uploadImage: Util['uploadImage'] = defaultUploadImage
) {
this.stage = stage; this.stage = stage;
this.container = container; this.container = container;
this.util = {
getImageDTO,
uploadImage,
getRegionMaskImage: this._getRegionMaskImage.bind(this),
getInpaintMaskImage: this._getInpaintMaskImage.bind(this),
getImageSourceImage: this._getImageSourceImage.bind(this),
};
this._konvaApi = null; this._konvaApi = null;
this._preview = null; this._preview = null;
this._background = null; this._background = null;
@ -219,6 +253,152 @@ export class KonvaNodeManager {
assert(this._stateApi !== null, 'State API has not been set'); assert(this._stateApi !== null, 'State API has not been set');
return this._stateApi; return this._stateApi;
} }
async _getRegionMaskImage(arg: { id: string; bbox?: Rect; preview?: boolean }): Promise<ImageDTO> {
const { id, bbox, preview = false } = arg;
const region = this.stateApi.getRegionsState().entities.find((entity) => entity.id === id);
assert(region, `Region entity state with id ${id} not found`);
const adapter = this.get(region.id);
assert(adapter, `Adapter for region ${region.id} not found`);
if (region.imageCache) {
const imageDTO = await this.util.getImageDTO(region.imageCache.name);
if (imageDTO) {
return imageDTO;
}
}
const layer = adapter.konvaLayer.clone();
const objectGroup = adapter.konvaObjectGroup.clone();
layer.destroyChildren();
layer.add(objectGroup);
objectGroup.opacity(1);
objectGroup.cache();
const blob = await new Promise<Blob>((resolve) => {
layer.toBlob({
callback: (blob) => {
assert(blob, 'Blob is null');
resolve(blob);
},
...bbox,
});
});
if (preview) {
const base64 = await blobToDataURL(blob);
const caption = `${region.id}: ${region.positivePrompt} / ${region.negativePrompt}`;
openBase64ImageInTab([{ base64, caption }]);
}
layer.destroy();
const imageDTO = await this.util.uploadImage(blob, `${region.id}_mask.png`, 'mask', true);
this.stateApi.onRegionMaskImageCached(region.id, imageDTO);
return imageDTO;
}
async _getInpaintMaskImage(arg: { bbox?: Rect; preview?: boolean }): Promise<ImageDTO> {
const { bbox, preview = false } = arg;
const inpaintMask = this.stateApi.getInpaintMaskState();
const adapter = this.get(inpaintMask.id);
assert(adapter, `Adapter for ${inpaintMask.id} not found`);
if (inpaintMask.imageCache) {
const imageDTO = await this.util.getImageDTO(inpaintMask.imageCache.name);
if (imageDTO) {
return imageDTO;
}
}
const layer = adapter.konvaLayer.clone();
const objectGroup = adapter.konvaObjectGroup.clone();
layer.destroyChildren();
layer.add(objectGroup);
objectGroup.opacity(1);
objectGroup.cache();
const blob = await new Promise<Blob>((resolve) => {
layer.toBlob({
callback: (blob) => {
assert(blob, 'Blob is null');
resolve(blob);
},
...bbox,
});
});
if (preview) {
const base64 = await blobToDataURL(blob);
const caption = 'inpaint mask';
openBase64ImageInTab([{ base64, caption }]);
}
layer.destroy();
const imageDTO = await this.util.uploadImage(blob, 'inpaint_mask.png', 'mask', true);
this.stateApi.onInpaintMaskImageCached(imageDTO);
return imageDTO;
}
async _getImageSourceImage(arg: { bbox?: Rect; preview?: boolean }): Promise<ImageDTO> {
const { bbox, preview = false } = arg;
const layersState = this.stateApi.getLayersState();
const { entities, imageCache } = layersState;
if (imageCache) {
const imageDTO = await this.util.getImageDTO(imageCache.name);
if (imageDTO) {
return imageDTO;
}
}
const stage = this.stage.clone();
stage.scaleX(1);
stage.scaleY(1);
stage.x(0);
stage.y(0);
const validLayers = entities.filter(isValidLayer);
// Konva bug (?) - when iterating over the array returned from `stage.getLayers()`, if you destroy a layer, the array
// is mutated in-place and the next iteration will skip the next layer. To avoid this, we first collect the layers
// to delete in a separate array and then destroy them.
// TODO(psyche): Maybe report this?
const toDelete: Konva.Layer[] = [];
for (const konvaLayer of stage.getLayers()) {
const layer = validLayers.find((l) => l.id === konvaLayer.id());
if (!layer) {
toDelete.push(konvaLayer);
}
}
for (const konvaLayer of toDelete) {
konvaLayer.destroy();
}
const blob = await new Promise<Blob>((resolve) => {
stage.toBlob({
callback: (blob) => {
assert(blob, 'Blob is null');
resolve(blob);
},
...bbox,
});
});
if (preview) {
const base64 = await blobToDataURL(blob);
openBase64ImageInTab([{ base64, caption: 'base layer' }]);
}
stage.destroy();
const imageDTO = await this.util.uploadImage(blob, 'base_layer.png', 'general', true);
this.stateApi.onLayerImageCached(imageDTO);
return imageDTO;
}
} }
export class KonvaEntityAdapter { export class KonvaEntityAdapter {

View File

@ -6,12 +6,12 @@ import type { KonvaNodeManager } from 'features/controlLayers/konva/nodeManager'
* @returns An arrange entities function * @returns An arrange entities function
*/ */
export const getArrangeEntities = (manager: KonvaNodeManager) => { export const getArrangeEntities = (manager: KonvaNodeManager) => {
const { getLayerEntityStates, getControlAdapterEntityStates, getRegionEntityStates } = manager.stateApi; const { getLayersState, getControlAdaptersState, getRegionsState } = manager.stateApi;
function arrangeEntities(): void { function arrangeEntities(): void {
const layers = getLayerEntityStates(); const layers = getLayersState().entities;
const controlAdapters = getControlAdapterEntityStates(); const controlAdapters = getControlAdaptersState().entities;
const regions = getRegionEntityStates(); const regions = getRegionsState().entities;
let zIndex = 0; let zIndex = 0;
manager.background.layer.zIndex(++zIndex); manager.background.layer.zIndex(++zIndex);
for (const layer of layers) { for (const layer of layers) {

View File

@ -100,10 +100,10 @@ export const renderControlAdapter = async (manager: KonvaNodeManager, entity: Co
* @returns A function to render all control adapters * @returns A function to render all control adapters
*/ */
export const getRenderControlAdapters = (manager: KonvaNodeManager) => { export const getRenderControlAdapters = (manager: KonvaNodeManager) => {
const { getControlAdapterEntityStates } = manager.stateApi; const { getControlAdaptersState } = manager.stateApi;
function renderControlAdapters(): void { function renderControlAdapters(): void {
const entities = getControlAdapterEntityStates(); const { entities } = getControlAdaptersState();
// Destroy nonexistent layers // Destroy nonexistent layers
for (const adapters of manager.getAll('control_adapter')) { for (const adapters of manager.getAll('control_adapter')) {
if (!entities.find((ca) => ca.id === adapters.id)) { if (!entities.find((ca) => ca.id === adapters.id)) {

View File

@ -71,10 +71,10 @@ const getInpaintMask = (
* @returns A function to render the inpaint mask * @returns A function to render the inpaint mask
*/ */
export const getRenderInpaintMask = (manager: KonvaNodeManager) => { export const getRenderInpaintMask = (manager: KonvaNodeManager) => {
const { getInpaintMaskEntityState, getMaskOpacity, getToolState, getSelectedEntity, onPosChanged } = manager.stateApi; const { getInpaintMaskState, getMaskOpacity, getToolState, getSelectedEntity, onPosChanged } = manager.stateApi;
function renderInpaintMask(): void { function renderInpaintMask(): void {
const entity = getInpaintMaskEntityState(); const entity = getInpaintMaskState();
const globalMaskLayerOpacity = getMaskOpacity(); const globalMaskLayerOpacity = getMaskOpacity();
const toolState = getToolState(); const toolState = getToolState();
const selectedEntity = getSelectedEntity(); const selectedEntity = getSelectedEntity();

View File

@ -136,10 +136,10 @@ export const renderLayer = async (
* @returns A function to render all layers * @returns A function to render all layers
*/ */
export const getRenderLayers = (manager: KonvaNodeManager) => { export const getRenderLayers = (manager: KonvaNodeManager) => {
const { getLayerEntityStates, getToolState, onPosChanged } = manager.stateApi; const { getLayersState, getToolState, onPosChanged } = manager.stateApi;
function renderLayers(): void { function renderLayers(): void {
const entities = getLayerEntityStates(); const { entities } = getLayersState();
const tool = getToolState(); const tool = getToolState();
// Destroy nonexistent layers // Destroy nonexistent layers
for (const adapter of manager.getAll('layer')) { for (const adapter of manager.getAll('layer')) {

View File

@ -233,10 +233,10 @@ export const renderRegion = (
* @returns A function to render all regions * @returns A function to render all regions
*/ */
export const getRenderRegions = (manager: KonvaNodeManager) => { export const getRenderRegions = (manager: KonvaNodeManager) => {
const { getRegionEntityStates, getMaskOpacity, getToolState, getSelectedEntity, onPosChanged } = manager.stateApi; const { getRegionsState, getMaskOpacity, getToolState, getSelectedEntity, onPosChanged } = manager.stateApi;
function renderRegions(): void { function renderRegions(): void {
const entities = getRegionEntityStates(); const { entities } = getRegionsState();
const maskOpacity = getMaskOpacity(); const maskOpacity = getMaskOpacity();
const toolState = getToolState(); const toolState = getToolState();
const selectedEntity = getSelectedEntity(); const selectedEntity = getSelectedEntity();

View File

@ -32,17 +32,20 @@ import {
imBboxChanged, imBboxChanged,
imBrushLineAdded, imBrushLineAdded,
imEraserLineAdded, imEraserLineAdded,
imImageCacheChanged,
imLinePointAdded, imLinePointAdded,
imTranslated, imTranslated,
layerBboxChanged, layerBboxChanged,
layerBrushLineAdded, layerBrushLineAdded,
layerEraserLineAdded, layerEraserLineAdded,
layerImageCacheChanged,
layerLinePointAdded, layerLinePointAdded,
layerRectAdded, layerRectAdded,
layerTranslated, layerTranslated,
rgBboxChanged, rgBboxChanged,
rgBrushLineAdded, rgBrushLineAdded,
rgEraserLineAdded, rgEraserLineAdded,
rgImageCacheChanged,
rgLinePointAdded, rgLinePointAdded,
rgRectAdded, rgRectAdded,
rgTranslated, rgTranslated,
@ -65,6 +68,7 @@ import type { IRect, Vector2d } from 'konva/lib/types';
import { debounce } from 'lodash-es'; import { debounce } from 'lodash-es';
import { atom } from 'nanostores'; import { atom } from 'nanostores';
import type { RgbaColor } from 'react-colorful'; import type { RgbaColor } from 'react-colorful';
import type { ImageDTO } from 'services/api/types';
export const $nodeManager = atom<KonvaNodeManager | null>(null); export const $nodeManager = atom<KonvaNodeManager | null>(null);
@ -175,6 +179,19 @@ export const initializeRenderer = (
logIfDebugging('Eraser width changed'); logIfDebugging('Eraser width changed');
dispatch(eraserWidthChanged(width)); dispatch(eraserWidthChanged(width));
}; };
const onRegionMaskImageCached = (id: string, imageDTO: ImageDTO) => {
logIfDebugging('Region mask image cached');
dispatch(rgImageCacheChanged({ id, imageDTO }));
};
const onInpaintMaskImageCached = (imageDTO: ImageDTO) => {
logIfDebugging('Inpaint mask image cached');
dispatch(imImageCacheChanged({ imageDTO }));
};
const onLayerImageCached = (imageDTO: ImageDTO) => {
logIfDebugging('Layer image cached');
dispatch(layerImageCacheChanged({ imageDTO }));
};
const setTool = (tool: Tool) => { const setTool = (tool: Tool) => {
logIfDebugging('Tool selection changed'); logIfDebugging('Tool selection changed');
dispatch(toolChanged(tool)); dispatch(toolChanged(tool));
@ -240,11 +257,11 @@ export const initializeRenderer = (
const getDocument = () => canvasV2.document; const getDocument = () => canvasV2.document;
const getToolState = () => canvasV2.tool; const getToolState = () => canvasV2.tool;
const getSettings = () => canvasV2.settings; const getSettings = () => canvasV2.settings;
const getRegionEntityStates = () => canvasV2.regions.entities; const getRegionsState = () => canvasV2.regions;
const getLayerEntityStates = () => canvasV2.layers.entities; const getLayersState = () => canvasV2.layers;
const getControlAdapterEntityStates = () => canvasV2.controlAdapters.entities; const getControlAdaptersState = () => canvasV2.controlAdapters;
const getInpaintMaskState = () => canvasV2.inpaintMask;
const getMaskOpacity = () => canvasV2.settings.maskOpacity; const getMaskOpacity = () => canvasV2.settings.maskOpacity;
const getInpaintMaskEntityState = () => canvasV2.inpaintMask;
// Read-write state, ephemeral interaction state // Read-write state, ephemeral interaction state
let isDrawing = false; let isDrawing = false;
@ -309,12 +326,12 @@ export const initializeRenderer = (
getCtrlKey: $ctrl.get, getCtrlKey: $ctrl.get,
getMetaKey: $meta.get, getMetaKey: $meta.get,
getShiftKey: $shift.get, getShiftKey: $shift.get,
getControlAdapterEntityStates, getControlAdaptersState,
getDocument, getDocument,
getLayerEntityStates, getLayersState,
getRegionEntityStates, getRegionsState,
getMaskOpacity, getMaskOpacity,
getInpaintMaskEntityState, getInpaintMaskState,
// Read-write state // Read-write state
setTool, setTool,
@ -342,6 +359,9 @@ export const initializeRenderer = (
onEraserWidthChanged, onEraserWidthChanged,
onPosChanged, onPosChanged,
onBboxTransformed, onBboxTransformed,
onRegionMaskImageCached,
onInpaintMaskImageCached,
onLayerImageCached,
}; };
const cleanupListeners = setStageEventHandlers(manager); const cleanupListeners = setStageEventHandlers(manager);

View File

@ -24,7 +24,7 @@ import { DEFAULT_RGBA_COLOR } from './types';
const initialState: CanvasV2State = { const initialState: CanvasV2State = {
_version: 3, _version: 3,
selectedEntityIdentifier: { type: 'inpaint_mask', id: 'inpaint_mask' }, selectedEntityIdentifier: { type: 'inpaint_mask', id: 'inpaint_mask' },
layers: { entities: [], baseLayerImageCache: null }, layers: { entities: [], imageCache: null },
controlAdapters: { entities: [] }, controlAdapters: { entities: [] },
ipAdapters: { entities: [] }, ipAdapters: { entities: [] },
regions: { entities: [] }, regions: { entities: [] },
@ -161,7 +161,7 @@ export const canvasV2Slice = createSlice({
allEntitiesDeleted: (state) => { allEntitiesDeleted: (state) => {
state.regions.entities = []; state.regions.entities = [];
state.layers.entities = []; state.layers.entities = [];
state.layers.baseLayerImageCache = null; state.layers.imageCache = null;
state.ipAdapters.entities = []; state.ipAdapters.entities = [];
state.controlAdapters.entities = []; state.controlAdapters.entities = [];
}, },
@ -185,7 +185,6 @@ export const {
scaledBboxChanged, scaledBboxChanged,
bboxScaleMethodChanged, bboxScaleMethodChanged,
clipToBboxChanged, clipToBboxChanged,
baseLayerImageCacheChanged,
// layers // layers
layerAdded, layerAdded,
layerRecalled, layerRecalled,
@ -205,6 +204,7 @@ export const {
layerRectAdded, layerRectAdded,
layerImageAdded, layerImageAdded,
layerAllDeleted, layerAllDeleted,
layerImageCacheChanged,
// IP Adapters // IP Adapters
ipaAdded, ipaAdded,
ipaRecalled, ipaRecalled,
@ -255,7 +255,7 @@ export const {
rgPositivePromptChanged, rgPositivePromptChanged,
rgNegativePromptChanged, rgNegativePromptChanged,
rgFillChanged, rgFillChanged,
rgMaskImageUploaded, rgImageCacheChanged,
rgAutoNegativeChanged, rgAutoNegativeChanged,
rgIPAdapterAdded, rgIPAdapterAdded,
rgIPAdapterDeleted, rgIPAdapterDeleted,

View File

@ -40,7 +40,7 @@ export const layersReducers = {
y: 0, y: 0,
}); });
state.selectedEntityIdentifier = { type: 'layer', id }; state.selectedEntityIdentifier = { type: 'layer', id };
state.layers.baseLayerImageCache = null; state.layers.imageCache = null;
}, },
prepare: () => ({ payload: { id: uuidv4() } }), prepare: () => ({ payload: { id: uuidv4() } }),
}, },
@ -48,7 +48,7 @@ export const layersReducers = {
const { data } = action.payload; const { data } = action.payload;
state.layers.entities.push(data); state.layers.entities.push(data);
state.selectedEntityIdentifier = { type: 'layer', id: data.id }; state.selectedEntityIdentifier = { type: 'layer', id: data.id };
state.layers.baseLayerImageCache = null; state.layers.imageCache = null;
}, },
layerIsEnabledToggled: (state, action: PayloadAction<{ id: string }>) => { layerIsEnabledToggled: (state, action: PayloadAction<{ id: string }>) => {
const { id } = action.payload; const { id } = action.payload;
@ -57,7 +57,7 @@ export const layersReducers = {
return; return;
} }
layer.isEnabled = !layer.isEnabled; layer.isEnabled = !layer.isEnabled;
state.layers.baseLayerImageCache = null; state.layers.imageCache = null;
}, },
layerTranslated: (state, action: PayloadAction<{ id: string; x: number; y: number }>) => { layerTranslated: (state, action: PayloadAction<{ id: string; x: number; y: number }>) => {
const { id, x, y } = action.payload; const { id, x, y } = action.payload;
@ -67,7 +67,7 @@ export const layersReducers = {
} }
layer.x = x; layer.x = x;
layer.y = y; layer.y = y;
state.layers.baseLayerImageCache = null; state.layers.imageCache = null;
}, },
layerBboxChanged: (state, action: PayloadAction<{ id: string; bbox: IRect | null }>) => { layerBboxChanged: (state, action: PayloadAction<{ id: string; bbox: IRect | null }>) => {
const { id, bbox } = action.payload; const { id, bbox } = action.payload;
@ -93,16 +93,16 @@ export const layersReducers = {
layer.objects = []; layer.objects = [];
layer.bbox = null; layer.bbox = null;
layer.bboxNeedsUpdate = false; layer.bboxNeedsUpdate = false;
state.layers.baseLayerImageCache = null; state.layers.imageCache = null;
}, },
layerDeleted: (state, action: PayloadAction<{ id: string }>) => { layerDeleted: (state, action: PayloadAction<{ id: string }>) => {
const { id } = action.payload; const { id } = action.payload;
state.layers.entities = state.layers.entities.filter((l) => l.id !== id); state.layers.entities = state.layers.entities.filter((l) => l.id !== id);
state.layers.baseLayerImageCache = null; state.layers.imageCache = null;
}, },
layerAllDeleted: (state) => { layerAllDeleted: (state) => {
state.layers.entities = []; state.layers.entities = [];
state.layers.baseLayerImageCache = null; state.layers.imageCache = null;
}, },
layerOpacityChanged: (state, action: PayloadAction<{ id: string; opacity: number }>) => { layerOpacityChanged: (state, action: PayloadAction<{ id: string; opacity: number }>) => {
const { id, opacity } = action.payload; const { id, opacity } = action.payload;
@ -111,7 +111,7 @@ export const layersReducers = {
return; return;
} }
layer.opacity = opacity; layer.opacity = opacity;
state.layers.baseLayerImageCache = null; state.layers.imageCache = null;
}, },
layerMovedForwardOne: (state, action: PayloadAction<{ id: string }>) => { layerMovedForwardOne: (state, action: PayloadAction<{ id: string }>) => {
const { id } = action.payload; const { id } = action.payload;
@ -120,7 +120,7 @@ export const layersReducers = {
return; return;
} }
moveOneToEnd(state.layers.entities, layer); moveOneToEnd(state.layers.entities, layer);
state.layers.baseLayerImageCache = null; state.layers.imageCache = null;
}, },
layerMovedToFront: (state, action: PayloadAction<{ id: string }>) => { layerMovedToFront: (state, action: PayloadAction<{ id: string }>) => {
const { id } = action.payload; const { id } = action.payload;
@ -129,7 +129,7 @@ export const layersReducers = {
return; return;
} }
moveToEnd(state.layers.entities, layer); moveToEnd(state.layers.entities, layer);
state.layers.baseLayerImageCache = null; state.layers.imageCache = null;
}, },
layerMovedBackwardOne: (state, action: PayloadAction<{ id: string }>) => { layerMovedBackwardOne: (state, action: PayloadAction<{ id: string }>) => {
const { id } = action.payload; const { id } = action.payload;
@ -138,7 +138,7 @@ export const layersReducers = {
return; return;
} }
moveOneToStart(state.layers.entities, layer); moveOneToStart(state.layers.entities, layer);
state.layers.baseLayerImageCache = null; state.layers.imageCache = null;
}, },
layerMovedToBack: (state, action: PayloadAction<{ id: string }>) => { layerMovedToBack: (state, action: PayloadAction<{ id: string }>) => {
const { id } = action.payload; const { id } = action.payload;
@ -147,7 +147,7 @@ export const layersReducers = {
return; return;
} }
moveToStart(state.layers.entities, layer); moveToStart(state.layers.entities, layer);
state.layers.baseLayerImageCache = null; state.layers.imageCache = null;
}, },
layerBrushLineAdded: { layerBrushLineAdded: {
reducer: (state, action: PayloadAction<BrushLineAddedArg & { lineId: string }>) => { reducer: (state, action: PayloadAction<BrushLineAddedArg & { lineId: string }>) => {
@ -166,7 +166,7 @@ export const layersReducers = {
clip, clip,
}); });
layer.bboxNeedsUpdate = true; layer.bboxNeedsUpdate = true;
state.layers.baseLayerImageCache = null; state.layers.imageCache = null;
}, },
prepare: (payload: BrushLineAddedArg) => ({ prepare: (payload: BrushLineAddedArg) => ({
payload: { ...payload, lineId: uuidv4() }, payload: { ...payload, lineId: uuidv4() },
@ -188,7 +188,7 @@ export const layersReducers = {
clip, clip,
}); });
layer.bboxNeedsUpdate = true; layer.bboxNeedsUpdate = true;
state.layers.baseLayerImageCache = null; state.layers.imageCache = null;
}, },
prepare: (payload: EraserLineAddedArg) => ({ prepare: (payload: EraserLineAddedArg) => ({
payload: { ...payload, lineId: uuidv4() }, payload: { ...payload, lineId: uuidv4() },
@ -206,7 +206,7 @@ export const layersReducers = {
} }
lastObject.points.push(...point); lastObject.points.push(...point);
layer.bboxNeedsUpdate = true; layer.bboxNeedsUpdate = true;
state.layers.baseLayerImageCache = null; state.layers.imageCache = null;
}, },
layerRectAdded: { layerRectAdded: {
reducer: (state, action: PayloadAction<RectShapeAddedArg & { rectId: string }>) => { reducer: (state, action: PayloadAction<RectShapeAddedArg & { rectId: string }>) => {
@ -226,7 +226,7 @@ export const layersReducers = {
color, color,
}); });
layer.bboxNeedsUpdate = true; layer.bboxNeedsUpdate = true;
state.layers.baseLayerImageCache = null; state.layers.imageCache = null;
}, },
prepare: (payload: RectShapeAddedArg) => ({ payload: { ...payload, rectId: uuidv4() } }), prepare: (payload: RectShapeAddedArg) => ({ payload: { ...payload, rectId: uuidv4() } }),
}, },
@ -239,11 +239,12 @@ export const layersReducers = {
} }
layer.objects.push(imageDTOToImageObject(id, objectId, imageDTO)); layer.objects.push(imageDTOToImageObject(id, objectId, imageDTO));
layer.bboxNeedsUpdate = true; layer.bboxNeedsUpdate = true;
state.layers.baseLayerImageCache = null; state.layers.imageCache = null;
}, },
prepare: (payload: ImageObjectAddedArg) => ({ payload: { ...payload, objectId: uuidv4() } }), prepare: (payload: ImageObjectAddedArg) => ({ payload: { ...payload, objectId: uuidv4() } }),
}, },
baseLayerImageCacheChanged: (state, action: PayloadAction<ImageDTO | null>) => { layerImageCacheChanged: (state, action: PayloadAction<{ imageDTO: ImageDTO | null }>) => {
state.layers.baseLayerImageCache = action.payload ? imageDTOToImageWithDims(action.payload) : null; const { imageDTO } = action.payload;
state.layers.imageCache = imageDTO ? imageDTOToImageWithDims(imageDTO) : null;
}, },
} satisfies SliceCaseReducers<CanvasV2State>; } satisfies SliceCaseReducers<CanvasV2State>;

View File

@ -1,11 +1,7 @@
import type { PayloadAction, SliceCaseReducers } from '@reduxjs/toolkit'; import type { PayloadAction, SliceCaseReducers } from '@reduxjs/toolkit';
import { moveOneToEnd, moveOneToStart, moveToEnd, moveToStart } from 'common/util/arrayUtils'; import { moveOneToEnd, moveOneToStart, moveToEnd, moveToStart } from 'common/util/arrayUtils';
import { getBrushLineId, getEraserLineId, getRectShapeId } from 'features/controlLayers/konva/naming'; import { getBrushLineId, getEraserLineId, getRectShapeId } from 'features/controlLayers/konva/naming';
import type { import type { CanvasV2State, CLIPVisionModelV2, IPMethodV2 } from 'features/controlLayers/store/types';
CanvasV2State,
CLIPVisionModelV2,
IPMethodV2,
} from 'features/controlLayers/store/types';
import { imageDTOToImageObject, imageDTOToImageWithDims } from 'features/controlLayers/store/types'; import { imageDTOToImageObject, imageDTOToImageWithDims } from 'features/controlLayers/store/types';
import { zModelIdentifierField } from 'features/nodes/types/common'; import { zModelIdentifierField } from 'features/nodes/types/common';
import type { ParameterAutoNegative } from 'features/parameters/types/parameterSchemas'; import type { ParameterAutoNegative } from 'features/parameters/types/parameterSchemas';
@ -182,7 +178,7 @@ export const regionsReducers = {
} }
rg.fill = fill; rg.fill = fill;
}, },
rgMaskImageUploaded: (state, action: PayloadAction<{ id: string; imageDTO: ImageDTO }>) => { rgImageCacheChanged: (state, action: PayloadAction<{ id: string; imageDTO: ImageDTO }>) => {
const { id, imageDTO } = action.payload; const { id, imageDTO } = action.payload;
const rg = selectRG(state, id); const rg = selectRG(state, id);
if (!rg) { if (!rg) {

View File

@ -797,7 +797,7 @@ export type CanvasV2State = {
selectedEntityIdentifier: CanvasEntityIdentifier | null; selectedEntityIdentifier: CanvasEntityIdentifier | null;
inpaintMask: InpaintMaskEntity; inpaintMask: InpaintMaskEntity;
layers: { layers: {
baseLayerImageCache: ImageWithDims | null; imageCache: ImageWithDims | null;
entities: LayerEntity[]; entities: LayerEntity[];
}; };
controlAdapters: { entities: ControlAdapterEntity[] }; controlAdapters: { entities: ControlAdapterEntity[] };

View File

@ -1,96 +1,9 @@
import { getStore } from 'app/store/nanostores/store';
import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
import { $nodeManager } from 'features/controlLayers/konva/renderers/renderer';
import { blobToDataURL } from 'features/controlLayers/konva/util';
import { baseLayerImageCacheChanged } from 'features/controlLayers/store/canvasV2Slice';
import type { LayerEntity } from 'features/controlLayers/store/types'; import type { LayerEntity } from 'features/controlLayers/store/types';
import type Konva from 'konva';
import type { IRect } from 'konva/lib/types';
import { getImageDTO, imagesApi } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
import { assert } from 'tsafe';
const isValidLayer = (entity: LayerEntity) => { export const isValidLayer = (entity: LayerEntity) => {
return ( return (
entity.isEnabled && entity.isEnabled &&
// Boolean(entity.bbox) && TODO(psyche): Re-enable this check when we have a way to calculate bbox for all layers // Boolean(entity.bbox) && TODO(psyche): Re-enable this check when we have a way to calculate bbox for all layers
entity.objects.length > 0 entity.objects.length > 0
); );
}; };
/**
* Get the blobs of all regional prompt layers. Only visible layers are returned.
* @param layerIds The IDs of the layers to get blobs for. If not provided, all regional prompt layers are used.
* @param preview Whether to open a new tab displaying each layer.
* @returns A map of layer IDs to blobs.
*/
const getBaseLayer = async (layers: LayerEntity[], bbox: IRect, preview: boolean = false): Promise<Blob> => {
const manager = $nodeManager.get();
assert(manager, 'Node manager is null');
const stage = manager.stage.clone();
stage.scaleX(1);
stage.scaleY(1);
stage.x(0);
stage.y(0);
const validLayers = layers.filter(isValidLayer);
// Konva bug (?) - when iterating over the array returned from `stage.getLayers()`, if you destroy a layer, the array
// is mutated in-place and the next iteration will skip the next layer. To avoid this, we first collect the layers
// to delete in a separate array and then destroy them.
// TODO(psyche): Maybe report this?
const toDelete: Konva.Layer[] = [];
for (const konvaLayer of stage.getLayers()) {
const layer = validLayers.find((l) => l.id === konvaLayer.id());
if (!layer) {
toDelete.push(konvaLayer);
}
}
for (const konvaLayer of toDelete) {
konvaLayer.destroy();
}
const blob = await new Promise<Blob>((resolve) => {
stage.toBlob({
callback: (blob) => {
assert(blob, 'Blob is null');
resolve(blob);
},
...bbox,
});
});
if (preview) {
const base64 = await blobToDataURL(blob);
openBase64ImageInTab([{ base64, caption: 'base layer' }]);
}
stage.destroy();
return blob;
};
export const getBaseLayerImage = async (): Promise<ImageDTO> => {
const { dispatch, getState } = getStore();
const state = getState();
if (state.canvasV2.layers.baseLayerImageCache) {
const imageDTO = await getImageDTO(state.canvasV2.layers.baseLayerImageCache.name);
if (imageDTO) {
return imageDTO;
}
}
const blob = await getBaseLayer(state.canvasV2.layers.entities, state.canvasV2.bbox, true);
const file = new File([blob], 'image.png', { type: 'image/png' });
const req = dispatch(
imagesApi.endpoints.uploadImage.initiate({ file, image_category: 'general', is_intermediate: true })
);
req.reset();
const imageDTO = await req.unwrap();
dispatch(baseLayerImageCacheChanged(imageDTO));
return imageDTO;
};

View File

@ -1,10 +1,5 @@
import { getStore } from 'app/store/nanostores/store';
import { deepClone } from 'common/util/deepClone'; import { deepClone } from 'common/util/deepClone';
import openBase64ImageInTab from 'common/util/openBase64ImageInTab'; import type { KonvaNodeManager } from 'features/controlLayers/konva/nodeManager';
import type { KonvaEntityAdapter } from 'features/controlLayers/konva/nodeManager';
import { $nodeManager } from 'features/controlLayers/konva/renderers/renderer';
import { blobToDataURL } from 'features/controlLayers/konva/util';
import { rgMaskImageUploaded } from 'features/controlLayers/store/canvasV2Slice';
import type { Dimensions, IPAdapterEntity, RegionEntity } from 'features/controlLayers/store/types'; import type { Dimensions, IPAdapterEntity, RegionEntity } from 'features/controlLayers/store/types';
import { import {
PROMPT_REGION_INVERT_TENSOR_MASK_PREFIX, PROMPT_REGION_INVERT_TENSOR_MASK_PREFIX,
@ -16,8 +11,7 @@ import {
import { addIPAdapterCollectorSafe, isValidIPAdapter } from 'features/nodes/util/graph/generation/addIPAdapters'; import { addIPAdapterCollectorSafe, isValidIPAdapter } from 'features/nodes/util/graph/generation/addIPAdapters';
import type { Graph } from 'features/nodes/util/graph/generation/Graph'; import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { IRect } from 'konva/lib/types'; import type { IRect } from 'konva/lib/types';
import { getImageDTO, imagesApi } from 'services/api/endpoints/images'; import type { BaseModelType, Invocation } from 'services/api/types';
import type { BaseModelType, ImageDTO, Invocation } from 'services/api/types';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
/** /**
@ -34,6 +28,7 @@ import { assert } from 'tsafe';
*/ */
export const addRegions = async ( export const addRegions = async (
manager: KonvaNodeManager,
regions: RegionEntity[], regions: RegionEntity[],
g: Graph, g: Graph,
documentSize: Dimensions, documentSize: Dimensions,
@ -51,7 +46,7 @@ export const addRegions = async (
for (const region of validRegions) { for (const region of validRegions) {
// Upload the mask image, or get the cached image if it exists // Upload the mask image, or get the cached image if it exists
const { image_name } = await getRegionMaskImage(region, bbox, true); const { image_name } = await manager.util.getRegionMaskImage({ id: region.id, bbox, preview: true });
// The main mask-to-tensor node // The main mask-to-tensor node
const maskToTensor = g.addNode({ const maskToTensor = g.addNode({
@ -217,90 +212,3 @@ export const isValidRegion = (rg: RegionEntity, base: BaseModelType) => {
const hasIPAdapter = rg.ipAdapters.filter((ipa) => isValidIPAdapter(ipa, base)).length > 0; const hasIPAdapter = rg.ipAdapters.filter((ipa) => isValidIPAdapter(ipa, base)).length > 0;
return hasTextPrompt || hasIPAdapter; return hasTextPrompt || hasIPAdapter;
}; };
export const getMaskImage = async (rg: RegionEntity, blob: Blob): Promise<ImageDTO> => {
const { id, imageCache } = rg;
if (imageCache) {
const imageDTO = await getImageDTO(imageCache.name);
if (imageDTO) {
return imageDTO;
}
}
const { dispatch } = getStore();
// No cached mask, or the cached image no longer exists - we need to upload the mask image
const file = new File([blob], `${rg.id}_mask.png`, { type: 'image/png' });
const req = dispatch(
imagesApi.endpoints.uploadImage.initiate({ file, image_category: 'mask', is_intermediate: true })
);
req.reset();
const imageDTO = await req.unwrap();
dispatch(rgMaskImageUploaded({ id, imageDTO }));
return imageDTO;
};
export const uploadMaskImage = async ({ id }: RegionEntity, blob: Blob): Promise<ImageDTO> => {
const { dispatch } = getStore();
// No cached mask, or the cached image no longer exists - we need to upload the mask image
const file = new File([blob], `${id}_mask.png`, { type: 'image/png' });
const req = dispatch(
imagesApi.endpoints.uploadImage.initiate({ file, image_category: 'mask', is_intermediate: true })
);
req.reset();
const imageDTO = await req.unwrap();
dispatch(rgMaskImageUploaded({ id, imageDTO }));
return imageDTO;
};
/**
* Get the blobs of all regional prompt layers. Only visible layers are returned.
* @param layerIds The IDs of the layers to get blobs for. If not provided, all regional prompt layers are used.
* @param preview Whether to open a new tab displaying each layer.
* @returns A map of layer IDs to blobs.
*/
export const getRegionMaskImage = async (
region: RegionEntity,
bbox: IRect,
preview: boolean = false
): Promise<ImageDTO> => {
const manager = $nodeManager.get();
assert(manager, 'Node manager is null');
// TODO(psyche): Why do I need to annotate this? TS must have some kind of circular ref w/ this type but I can't figure it out...
const adapter: KonvaEntityAdapter | undefined = manager.get(region.id);
assert(adapter, `Adapter for region ${region.id} not found`);
if (region.imageCache) {
const imageDTO = await getImageDTO(region.imageCache.name);
if (imageDTO) {
return imageDTO;
}
}
const layer = adapter.konvaLayer.clone();
const objectGroup = adapter.konvaObjectGroup.clone();
layer.destroyChildren();
layer.add(objectGroup);
objectGroup.opacity(1);
objectGroup.cache();
const blob = await new Promise<Blob>((resolve) => {
layer.toBlob({
callback: (blob) => {
assert(blob, 'Blob is null');
resolve(blob);
},
...bbox,
});
});
if (preview) {
const base64 = await blobToDataURL(blob);
const caption = `${region.id}: ${region.positivePrompt} / ${region.negativePrompt}`;
openBase64ImageInTab([{ base64, caption }]);
}
layer.destroy();
return await uploadMaskImage(region, blob);
};

View File

@ -588,3 +588,16 @@ export const getImageDTO = async (image_name: string, forceRefetch?: boolean): P
return null; return null;
} }
}; };
export const uploadImage = async (
blob: Blob,
fileName: string,
image_category: ImageCategory,
is_intermediate: boolean
): Promise<ImageDTO> => {
const { dispatch } = getStore();
const file = new File([blob], fileName, { type: 'image/png' });
const req = dispatch(imagesApi.endpoints.uploadImage.initiate({ file, image_category, is_intermediate }));
req.reset();
return await req.unwrap();
};