mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): add utils for getting images from canvas
This commit is contained in:
parent
275fc2ccf9
commit
7dd11bd60a
@ -1,10 +1,12 @@
|
||||
import { enqueueRequested } from 'app/store/actions';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { $nodeManager } from 'features/controlLayers/konva/renderers/renderer';
|
||||
import { isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
|
||||
import { buildGenerationTabGraph } from 'features/nodes/util/graph/generation/buildGenerationTabGraph';
|
||||
import { buildGenerationTabSDXLGraph } from 'features/nodes/util/graph/generation/buildGenerationTabSDXLGraph';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
export const addEnqueueRequestedLinear = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
@ -18,10 +20,13 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
|
||||
|
||||
let graph;
|
||||
|
||||
const manager = $nodeManager.get();
|
||||
assert(manager, 'Konva node manager not initialized');
|
||||
|
||||
if (model?.base === 'sdxl') {
|
||||
graph = await buildGenerationTabSDXLGraph(state);
|
||||
graph = await buildGenerationTabSDXLGraph(state, manager);
|
||||
} else {
|
||||
graph = await buildGenerationTabGraph(state);
|
||||
graph = await buildGenerationTabGraph(state, manager);
|
||||
}
|
||||
|
||||
const batchConfig = prepareLinearUIBatch(state, graph, prepend);
|
||||
|
@ -1,3 +1,5 @@
|
||||
import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
|
||||
import { blobToDataURL } from 'features/controlLayers/konva/util';
|
||||
import type {
|
||||
BrushLine,
|
||||
BrushLineAddedArg,
|
||||
@ -15,8 +17,11 @@ import type {
|
||||
StageAttrs,
|
||||
Tool,
|
||||
} from 'features/controlLayers/store/types';
|
||||
import { isValidLayer } from 'features/nodes/util/graph/generation/addLayers';
|
||||
import type Konva from 'konva';
|
||||
import type { Vector2d } from 'konva/lib/types';
|
||||
import { getImageDTO as defaultGetImageDTO, uploadImage as defaultUploadImage } from 'services/api/endpoints/images';
|
||||
import type { ImageCategory, ImageDTO } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
export type BrushLineObjectRecord = {
|
||||
@ -132,24 +137,53 @@ type StateApi = {
|
||||
getMetaKey: () => boolean;
|
||||
getAltKey: () => boolean;
|
||||
getDocument: () => CanvasV2State['document'];
|
||||
getLayerEntityStates: () => CanvasV2State['layers']['entities'];
|
||||
getControlAdapterEntityStates: () => CanvasV2State['controlAdapters']['entities'];
|
||||
getRegionEntityStates: () => CanvasV2State['regions']['entities'];
|
||||
getInpaintMaskEntityState: () => CanvasV2State['inpaintMask'];
|
||||
getLayersState: () => CanvasV2State['layers'];
|
||||
getControlAdaptersState: () => CanvasV2State['controlAdapters'];
|
||||
getRegionsState: () => CanvasV2State['regions'];
|
||||
getInpaintMaskState: () => CanvasV2State['inpaintMask'];
|
||||
onInpaintMaskImageCached: (imageDTO: ImageDTO) => void;
|
||||
onRegionMaskImageCached: (id: string, imageDTO: ImageDTO) => void;
|
||||
onLayerImageCached: (imageDTO: ImageDTO) => void;
|
||||
};
|
||||
|
||||
type Util = {
|
||||
getImageDTO: (imageName: string) => Promise<ImageDTO | null>;
|
||||
uploadImage: (
|
||||
blob: Blob,
|
||||
fileName: string,
|
||||
image_category: ImageCategory,
|
||||
is_intermediate: boolean
|
||||
) => Promise<ImageDTO>;
|
||||
getRegionMaskImage: (arg: { id: string; bbox?: Rect; preview?: boolean }) => Promise<ImageDTO>;
|
||||
getInpaintMaskImage: (arg: { bbox?: Rect; preview?: boolean }) => Promise<ImageDTO>;
|
||||
getImageSourceImage: (arg: { bbox?: Rect; preview?: boolean }) => Promise<ImageDTO>;
|
||||
};
|
||||
|
||||
export class KonvaNodeManager {
|
||||
stage: Konva.Stage;
|
||||
container: HTMLDivElement;
|
||||
adapters: Map<string, KonvaEntityAdapter>;
|
||||
util: Util;
|
||||
_background: BackgroundLayer | null;
|
||||
_preview: PreviewLayer | null;
|
||||
_konvaApi: KonvaApi | null;
|
||||
_stateApi: StateApi | null;
|
||||
|
||||
constructor(stage: Konva.Stage, container: HTMLDivElement) {
|
||||
constructor(
|
||||
stage: Konva.Stage,
|
||||
container: HTMLDivElement,
|
||||
getImageDTO: Util['getImageDTO'] = defaultGetImageDTO,
|
||||
uploadImage: Util['uploadImage'] = defaultUploadImage
|
||||
) {
|
||||
this.stage = stage;
|
||||
this.container = container;
|
||||
this.util = {
|
||||
getImageDTO,
|
||||
uploadImage,
|
||||
getRegionMaskImage: this._getRegionMaskImage.bind(this),
|
||||
getInpaintMaskImage: this._getInpaintMaskImage.bind(this),
|
||||
getImageSourceImage: this._getImageSourceImage.bind(this),
|
||||
};
|
||||
this._konvaApi = null;
|
||||
this._preview = null;
|
||||
this._background = null;
|
||||
@ -219,6 +253,152 @@ export class KonvaNodeManager {
|
||||
assert(this._stateApi !== null, 'State API has not been set');
|
||||
return this._stateApi;
|
||||
}
|
||||
|
||||
async _getRegionMaskImage(arg: { id: string; bbox?: Rect; preview?: boolean }): Promise<ImageDTO> {
|
||||
const { id, bbox, preview = false } = arg;
|
||||
const region = this.stateApi.getRegionsState().entities.find((entity) => entity.id === id);
|
||||
assert(region, `Region entity state with id ${id} not found`);
|
||||
const adapter = this.get(region.id);
|
||||
assert(adapter, `Adapter for region ${region.id} not found`);
|
||||
|
||||
if (region.imageCache) {
|
||||
const imageDTO = await this.util.getImageDTO(region.imageCache.name);
|
||||
if (imageDTO) {
|
||||
return imageDTO;
|
||||
}
|
||||
}
|
||||
|
||||
const layer = adapter.konvaLayer.clone();
|
||||
const objectGroup = adapter.konvaObjectGroup.clone();
|
||||
layer.destroyChildren();
|
||||
layer.add(objectGroup);
|
||||
objectGroup.opacity(1);
|
||||
objectGroup.cache();
|
||||
|
||||
const blob = await new Promise<Blob>((resolve) => {
|
||||
layer.toBlob({
|
||||
callback: (blob) => {
|
||||
assert(blob, 'Blob is null');
|
||||
resolve(blob);
|
||||
},
|
||||
...bbox,
|
||||
});
|
||||
});
|
||||
|
||||
if (preview) {
|
||||
const base64 = await blobToDataURL(blob);
|
||||
const caption = `${region.id}: ${region.positivePrompt} / ${region.negativePrompt}`;
|
||||
openBase64ImageInTab([{ base64, caption }]);
|
||||
}
|
||||
|
||||
layer.destroy();
|
||||
|
||||
const imageDTO = await this.util.uploadImage(blob, `${region.id}_mask.png`, 'mask', true);
|
||||
this.stateApi.onRegionMaskImageCached(region.id, imageDTO);
|
||||
return imageDTO;
|
||||
}
|
||||
|
||||
async _getInpaintMaskImage(arg: { bbox?: Rect; preview?: boolean }): Promise<ImageDTO> {
|
||||
const { bbox, preview = false } = arg;
|
||||
const inpaintMask = this.stateApi.getInpaintMaskState();
|
||||
const adapter = this.get(inpaintMask.id);
|
||||
assert(adapter, `Adapter for ${inpaintMask.id} not found`);
|
||||
|
||||
if (inpaintMask.imageCache) {
|
||||
const imageDTO = await this.util.getImageDTO(inpaintMask.imageCache.name);
|
||||
if (imageDTO) {
|
||||
return imageDTO;
|
||||
}
|
||||
}
|
||||
|
||||
const layer = adapter.konvaLayer.clone();
|
||||
const objectGroup = adapter.konvaObjectGroup.clone();
|
||||
layer.destroyChildren();
|
||||
layer.add(objectGroup);
|
||||
objectGroup.opacity(1);
|
||||
objectGroup.cache();
|
||||
|
||||
const blob = await new Promise<Blob>((resolve) => {
|
||||
layer.toBlob({
|
||||
callback: (blob) => {
|
||||
assert(blob, 'Blob is null');
|
||||
resolve(blob);
|
||||
},
|
||||
...bbox,
|
||||
});
|
||||
});
|
||||
|
||||
if (preview) {
|
||||
const base64 = await blobToDataURL(blob);
|
||||
const caption = 'inpaint mask';
|
||||
openBase64ImageInTab([{ base64, caption }]);
|
||||
}
|
||||
|
||||
layer.destroy();
|
||||
|
||||
const imageDTO = await this.util.uploadImage(blob, 'inpaint_mask.png', 'mask', true);
|
||||
this.stateApi.onInpaintMaskImageCached(imageDTO);
|
||||
return imageDTO;
|
||||
}
|
||||
|
||||
async _getImageSourceImage(arg: { bbox?: Rect; preview?: boolean }): Promise<ImageDTO> {
|
||||
const { bbox, preview = false } = arg;
|
||||
const layersState = this.stateApi.getLayersState();
|
||||
const { entities, imageCache } = layersState;
|
||||
if (imageCache) {
|
||||
const imageDTO = await this.util.getImageDTO(imageCache.name);
|
||||
if (imageDTO) {
|
||||
return imageDTO;
|
||||
}
|
||||
}
|
||||
|
||||
const stage = this.stage.clone();
|
||||
|
||||
stage.scaleX(1);
|
||||
stage.scaleY(1);
|
||||
stage.x(0);
|
||||
stage.y(0);
|
||||
|
||||
const validLayers = entities.filter(isValidLayer);
|
||||
|
||||
// Konva bug (?) - when iterating over the array returned from `stage.getLayers()`, if you destroy a layer, the array
|
||||
// is mutated in-place and the next iteration will skip the next layer. To avoid this, we first collect the layers
|
||||
// to delete in a separate array and then destroy them.
|
||||
// TODO(psyche): Maybe report this?
|
||||
const toDelete: Konva.Layer[] = [];
|
||||
|
||||
for (const konvaLayer of stage.getLayers()) {
|
||||
const layer = validLayers.find((l) => l.id === konvaLayer.id());
|
||||
if (!layer) {
|
||||
toDelete.push(konvaLayer);
|
||||
}
|
||||
}
|
||||
|
||||
for (const konvaLayer of toDelete) {
|
||||
konvaLayer.destroy();
|
||||
}
|
||||
|
||||
const blob = await new Promise<Blob>((resolve) => {
|
||||
stage.toBlob({
|
||||
callback: (blob) => {
|
||||
assert(blob, 'Blob is null');
|
||||
resolve(blob);
|
||||
},
|
||||
...bbox,
|
||||
});
|
||||
});
|
||||
|
||||
if (preview) {
|
||||
const base64 = await blobToDataURL(blob);
|
||||
openBase64ImageInTab([{ base64, caption: 'base layer' }]);
|
||||
}
|
||||
|
||||
stage.destroy();
|
||||
|
||||
const imageDTO = await this.util.uploadImage(blob, 'base_layer.png', 'general', true);
|
||||
this.stateApi.onLayerImageCached(imageDTO);
|
||||
return imageDTO;
|
||||
}
|
||||
}
|
||||
|
||||
export class KonvaEntityAdapter {
|
||||
|
@ -6,12 +6,12 @@ import type { KonvaNodeManager } from 'features/controlLayers/konva/nodeManager'
|
||||
* @returns An arrange entities function
|
||||
*/
|
||||
export const getArrangeEntities = (manager: KonvaNodeManager) => {
|
||||
const { getLayerEntityStates, getControlAdapterEntityStates, getRegionEntityStates } = manager.stateApi;
|
||||
const { getLayersState, getControlAdaptersState, getRegionsState } = manager.stateApi;
|
||||
|
||||
function arrangeEntities(): void {
|
||||
const layers = getLayerEntityStates();
|
||||
const controlAdapters = getControlAdapterEntityStates();
|
||||
const regions = getRegionEntityStates();
|
||||
const layers = getLayersState().entities;
|
||||
const controlAdapters = getControlAdaptersState().entities;
|
||||
const regions = getRegionsState().entities;
|
||||
let zIndex = 0;
|
||||
manager.background.layer.zIndex(++zIndex);
|
||||
for (const layer of layers) {
|
||||
|
@ -100,10 +100,10 @@ export const renderControlAdapter = async (manager: KonvaNodeManager, entity: Co
|
||||
* @returns A function to render all control adapters
|
||||
*/
|
||||
export const getRenderControlAdapters = (manager: KonvaNodeManager) => {
|
||||
const { getControlAdapterEntityStates } = manager.stateApi;
|
||||
const { getControlAdaptersState } = manager.stateApi;
|
||||
|
||||
function renderControlAdapters(): void {
|
||||
const entities = getControlAdapterEntityStates();
|
||||
const { entities } = getControlAdaptersState();
|
||||
// Destroy nonexistent layers
|
||||
for (const adapters of manager.getAll('control_adapter')) {
|
||||
if (!entities.find((ca) => ca.id === adapters.id)) {
|
||||
|
@ -71,10 +71,10 @@ const getInpaintMask = (
|
||||
* @returns A function to render the inpaint mask
|
||||
*/
|
||||
export const getRenderInpaintMask = (manager: KonvaNodeManager) => {
|
||||
const { getInpaintMaskEntityState, getMaskOpacity, getToolState, getSelectedEntity, onPosChanged } = manager.stateApi;
|
||||
const { getInpaintMaskState, getMaskOpacity, getToolState, getSelectedEntity, onPosChanged } = manager.stateApi;
|
||||
|
||||
function renderInpaintMask(): void {
|
||||
const entity = getInpaintMaskEntityState();
|
||||
const entity = getInpaintMaskState();
|
||||
const globalMaskLayerOpacity = getMaskOpacity();
|
||||
const toolState = getToolState();
|
||||
const selectedEntity = getSelectedEntity();
|
||||
|
@ -136,10 +136,10 @@ export const renderLayer = async (
|
||||
* @returns A function to render all layers
|
||||
*/
|
||||
export const getRenderLayers = (manager: KonvaNodeManager) => {
|
||||
const { getLayerEntityStates, getToolState, onPosChanged } = manager.stateApi;
|
||||
const { getLayersState, getToolState, onPosChanged } = manager.stateApi;
|
||||
|
||||
function renderLayers(): void {
|
||||
const entities = getLayerEntityStates();
|
||||
const { entities } = getLayersState();
|
||||
const tool = getToolState();
|
||||
// Destroy nonexistent layers
|
||||
for (const adapter of manager.getAll('layer')) {
|
||||
|
@ -233,10 +233,10 @@ export const renderRegion = (
|
||||
* @returns A function to render all regions
|
||||
*/
|
||||
export const getRenderRegions = (manager: KonvaNodeManager) => {
|
||||
const { getRegionEntityStates, getMaskOpacity, getToolState, getSelectedEntity, onPosChanged } = manager.stateApi;
|
||||
const { getRegionsState, getMaskOpacity, getToolState, getSelectedEntity, onPosChanged } = manager.stateApi;
|
||||
|
||||
function renderRegions(): void {
|
||||
const entities = getRegionEntityStates();
|
||||
const { entities } = getRegionsState();
|
||||
const maskOpacity = getMaskOpacity();
|
||||
const toolState = getToolState();
|
||||
const selectedEntity = getSelectedEntity();
|
||||
|
@ -32,17 +32,20 @@ import {
|
||||
imBboxChanged,
|
||||
imBrushLineAdded,
|
||||
imEraserLineAdded,
|
||||
imImageCacheChanged,
|
||||
imLinePointAdded,
|
||||
imTranslated,
|
||||
layerBboxChanged,
|
||||
layerBrushLineAdded,
|
||||
layerEraserLineAdded,
|
||||
layerImageCacheChanged,
|
||||
layerLinePointAdded,
|
||||
layerRectAdded,
|
||||
layerTranslated,
|
||||
rgBboxChanged,
|
||||
rgBrushLineAdded,
|
||||
rgEraserLineAdded,
|
||||
rgImageCacheChanged,
|
||||
rgLinePointAdded,
|
||||
rgRectAdded,
|
||||
rgTranslated,
|
||||
@ -65,6 +68,7 @@ import type { IRect, Vector2d } from 'konva/lib/types';
|
||||
import { debounce } from 'lodash-es';
|
||||
import { atom } from 'nanostores';
|
||||
import type { RgbaColor } from 'react-colorful';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
|
||||
export const $nodeManager = atom<KonvaNodeManager | null>(null);
|
||||
|
||||
@ -175,6 +179,19 @@ export const initializeRenderer = (
|
||||
logIfDebugging('Eraser width changed');
|
||||
dispatch(eraserWidthChanged(width));
|
||||
};
|
||||
const onRegionMaskImageCached = (id: string, imageDTO: ImageDTO) => {
|
||||
logIfDebugging('Region mask image cached');
|
||||
dispatch(rgImageCacheChanged({ id, imageDTO }));
|
||||
};
|
||||
const onInpaintMaskImageCached = (imageDTO: ImageDTO) => {
|
||||
logIfDebugging('Inpaint mask image cached');
|
||||
dispatch(imImageCacheChanged({ imageDTO }));
|
||||
};
|
||||
const onLayerImageCached = (imageDTO: ImageDTO) => {
|
||||
logIfDebugging('Layer image cached');
|
||||
dispatch(layerImageCacheChanged({ imageDTO }));
|
||||
};
|
||||
|
||||
const setTool = (tool: Tool) => {
|
||||
logIfDebugging('Tool selection changed');
|
||||
dispatch(toolChanged(tool));
|
||||
@ -240,11 +257,11 @@ export const initializeRenderer = (
|
||||
const getDocument = () => canvasV2.document;
|
||||
const getToolState = () => canvasV2.tool;
|
||||
const getSettings = () => canvasV2.settings;
|
||||
const getRegionEntityStates = () => canvasV2.regions.entities;
|
||||
const getLayerEntityStates = () => canvasV2.layers.entities;
|
||||
const getControlAdapterEntityStates = () => canvasV2.controlAdapters.entities;
|
||||
const getRegionsState = () => canvasV2.regions;
|
||||
const getLayersState = () => canvasV2.layers;
|
||||
const getControlAdaptersState = () => canvasV2.controlAdapters;
|
||||
const getInpaintMaskState = () => canvasV2.inpaintMask;
|
||||
const getMaskOpacity = () => canvasV2.settings.maskOpacity;
|
||||
const getInpaintMaskEntityState = () => canvasV2.inpaintMask;
|
||||
|
||||
// Read-write state, ephemeral interaction state
|
||||
let isDrawing = false;
|
||||
@ -309,12 +326,12 @@ export const initializeRenderer = (
|
||||
getCtrlKey: $ctrl.get,
|
||||
getMetaKey: $meta.get,
|
||||
getShiftKey: $shift.get,
|
||||
getControlAdapterEntityStates,
|
||||
getControlAdaptersState,
|
||||
getDocument,
|
||||
getLayerEntityStates,
|
||||
getRegionEntityStates,
|
||||
getLayersState,
|
||||
getRegionsState,
|
||||
getMaskOpacity,
|
||||
getInpaintMaskEntityState,
|
||||
getInpaintMaskState,
|
||||
|
||||
// Read-write state
|
||||
setTool,
|
||||
@ -342,6 +359,9 @@ export const initializeRenderer = (
|
||||
onEraserWidthChanged,
|
||||
onPosChanged,
|
||||
onBboxTransformed,
|
||||
onRegionMaskImageCached,
|
||||
onInpaintMaskImageCached,
|
||||
onLayerImageCached,
|
||||
};
|
||||
|
||||
const cleanupListeners = setStageEventHandlers(manager);
|
||||
|
@ -24,7 +24,7 @@ import { DEFAULT_RGBA_COLOR } from './types';
|
||||
const initialState: CanvasV2State = {
|
||||
_version: 3,
|
||||
selectedEntityIdentifier: { type: 'inpaint_mask', id: 'inpaint_mask' },
|
||||
layers: { entities: [], baseLayerImageCache: null },
|
||||
layers: { entities: [], imageCache: null },
|
||||
controlAdapters: { entities: [] },
|
||||
ipAdapters: { entities: [] },
|
||||
regions: { entities: [] },
|
||||
@ -161,7 +161,7 @@ export const canvasV2Slice = createSlice({
|
||||
allEntitiesDeleted: (state) => {
|
||||
state.regions.entities = [];
|
||||
state.layers.entities = [];
|
||||
state.layers.baseLayerImageCache = null;
|
||||
state.layers.imageCache = null;
|
||||
state.ipAdapters.entities = [];
|
||||
state.controlAdapters.entities = [];
|
||||
},
|
||||
@ -185,7 +185,6 @@ export const {
|
||||
scaledBboxChanged,
|
||||
bboxScaleMethodChanged,
|
||||
clipToBboxChanged,
|
||||
baseLayerImageCacheChanged,
|
||||
// layers
|
||||
layerAdded,
|
||||
layerRecalled,
|
||||
@ -205,6 +204,7 @@ export const {
|
||||
layerRectAdded,
|
||||
layerImageAdded,
|
||||
layerAllDeleted,
|
||||
layerImageCacheChanged,
|
||||
// IP Adapters
|
||||
ipaAdded,
|
||||
ipaRecalled,
|
||||
@ -255,7 +255,7 @@ export const {
|
||||
rgPositivePromptChanged,
|
||||
rgNegativePromptChanged,
|
||||
rgFillChanged,
|
||||
rgMaskImageUploaded,
|
||||
rgImageCacheChanged,
|
||||
rgAutoNegativeChanged,
|
||||
rgIPAdapterAdded,
|
||||
rgIPAdapterDeleted,
|
||||
|
@ -40,7 +40,7 @@ export const layersReducers = {
|
||||
y: 0,
|
||||
});
|
||||
state.selectedEntityIdentifier = { type: 'layer', id };
|
||||
state.layers.baseLayerImageCache = null;
|
||||
state.layers.imageCache = null;
|
||||
},
|
||||
prepare: () => ({ payload: { id: uuidv4() } }),
|
||||
},
|
||||
@ -48,7 +48,7 @@ export const layersReducers = {
|
||||
const { data } = action.payload;
|
||||
state.layers.entities.push(data);
|
||||
state.selectedEntityIdentifier = { type: 'layer', id: data.id };
|
||||
state.layers.baseLayerImageCache = null;
|
||||
state.layers.imageCache = null;
|
||||
},
|
||||
layerIsEnabledToggled: (state, action: PayloadAction<{ id: string }>) => {
|
||||
const { id } = action.payload;
|
||||
@ -57,7 +57,7 @@ export const layersReducers = {
|
||||
return;
|
||||
}
|
||||
layer.isEnabled = !layer.isEnabled;
|
||||
state.layers.baseLayerImageCache = null;
|
||||
state.layers.imageCache = null;
|
||||
},
|
||||
layerTranslated: (state, action: PayloadAction<{ id: string; x: number; y: number }>) => {
|
||||
const { id, x, y } = action.payload;
|
||||
@ -67,7 +67,7 @@ export const layersReducers = {
|
||||
}
|
||||
layer.x = x;
|
||||
layer.y = y;
|
||||
state.layers.baseLayerImageCache = null;
|
||||
state.layers.imageCache = null;
|
||||
},
|
||||
layerBboxChanged: (state, action: PayloadAction<{ id: string; bbox: IRect | null }>) => {
|
||||
const { id, bbox } = action.payload;
|
||||
@ -93,16 +93,16 @@ export const layersReducers = {
|
||||
layer.objects = [];
|
||||
layer.bbox = null;
|
||||
layer.bboxNeedsUpdate = false;
|
||||
state.layers.baseLayerImageCache = null;
|
||||
state.layers.imageCache = null;
|
||||
},
|
||||
layerDeleted: (state, action: PayloadAction<{ id: string }>) => {
|
||||
const { id } = action.payload;
|
||||
state.layers.entities = state.layers.entities.filter((l) => l.id !== id);
|
||||
state.layers.baseLayerImageCache = null;
|
||||
state.layers.imageCache = null;
|
||||
},
|
||||
layerAllDeleted: (state) => {
|
||||
state.layers.entities = [];
|
||||
state.layers.baseLayerImageCache = null;
|
||||
state.layers.imageCache = null;
|
||||
},
|
||||
layerOpacityChanged: (state, action: PayloadAction<{ id: string; opacity: number }>) => {
|
||||
const { id, opacity } = action.payload;
|
||||
@ -111,7 +111,7 @@ export const layersReducers = {
|
||||
return;
|
||||
}
|
||||
layer.opacity = opacity;
|
||||
state.layers.baseLayerImageCache = null;
|
||||
state.layers.imageCache = null;
|
||||
},
|
||||
layerMovedForwardOne: (state, action: PayloadAction<{ id: string }>) => {
|
||||
const { id } = action.payload;
|
||||
@ -120,7 +120,7 @@ export const layersReducers = {
|
||||
return;
|
||||
}
|
||||
moveOneToEnd(state.layers.entities, layer);
|
||||
state.layers.baseLayerImageCache = null;
|
||||
state.layers.imageCache = null;
|
||||
},
|
||||
layerMovedToFront: (state, action: PayloadAction<{ id: string }>) => {
|
||||
const { id } = action.payload;
|
||||
@ -129,7 +129,7 @@ export const layersReducers = {
|
||||
return;
|
||||
}
|
||||
moveToEnd(state.layers.entities, layer);
|
||||
state.layers.baseLayerImageCache = null;
|
||||
state.layers.imageCache = null;
|
||||
},
|
||||
layerMovedBackwardOne: (state, action: PayloadAction<{ id: string }>) => {
|
||||
const { id } = action.payload;
|
||||
@ -138,7 +138,7 @@ export const layersReducers = {
|
||||
return;
|
||||
}
|
||||
moveOneToStart(state.layers.entities, layer);
|
||||
state.layers.baseLayerImageCache = null;
|
||||
state.layers.imageCache = null;
|
||||
},
|
||||
layerMovedToBack: (state, action: PayloadAction<{ id: string }>) => {
|
||||
const { id } = action.payload;
|
||||
@ -147,7 +147,7 @@ export const layersReducers = {
|
||||
return;
|
||||
}
|
||||
moveToStart(state.layers.entities, layer);
|
||||
state.layers.baseLayerImageCache = null;
|
||||
state.layers.imageCache = null;
|
||||
},
|
||||
layerBrushLineAdded: {
|
||||
reducer: (state, action: PayloadAction<BrushLineAddedArg & { lineId: string }>) => {
|
||||
@ -166,7 +166,7 @@ export const layersReducers = {
|
||||
clip,
|
||||
});
|
||||
layer.bboxNeedsUpdate = true;
|
||||
state.layers.baseLayerImageCache = null;
|
||||
state.layers.imageCache = null;
|
||||
},
|
||||
prepare: (payload: BrushLineAddedArg) => ({
|
||||
payload: { ...payload, lineId: uuidv4() },
|
||||
@ -188,7 +188,7 @@ export const layersReducers = {
|
||||
clip,
|
||||
});
|
||||
layer.bboxNeedsUpdate = true;
|
||||
state.layers.baseLayerImageCache = null;
|
||||
state.layers.imageCache = null;
|
||||
},
|
||||
prepare: (payload: EraserLineAddedArg) => ({
|
||||
payload: { ...payload, lineId: uuidv4() },
|
||||
@ -206,7 +206,7 @@ export const layersReducers = {
|
||||
}
|
||||
lastObject.points.push(...point);
|
||||
layer.bboxNeedsUpdate = true;
|
||||
state.layers.baseLayerImageCache = null;
|
||||
state.layers.imageCache = null;
|
||||
},
|
||||
layerRectAdded: {
|
||||
reducer: (state, action: PayloadAction<RectShapeAddedArg & { rectId: string }>) => {
|
||||
@ -226,7 +226,7 @@ export const layersReducers = {
|
||||
color,
|
||||
});
|
||||
layer.bboxNeedsUpdate = true;
|
||||
state.layers.baseLayerImageCache = null;
|
||||
state.layers.imageCache = null;
|
||||
},
|
||||
prepare: (payload: RectShapeAddedArg) => ({ payload: { ...payload, rectId: uuidv4() } }),
|
||||
},
|
||||
@ -239,11 +239,12 @@ export const layersReducers = {
|
||||
}
|
||||
layer.objects.push(imageDTOToImageObject(id, objectId, imageDTO));
|
||||
layer.bboxNeedsUpdate = true;
|
||||
state.layers.baseLayerImageCache = null;
|
||||
state.layers.imageCache = null;
|
||||
},
|
||||
prepare: (payload: ImageObjectAddedArg) => ({ payload: { ...payload, objectId: uuidv4() } }),
|
||||
},
|
||||
baseLayerImageCacheChanged: (state, action: PayloadAction<ImageDTO | null>) => {
|
||||
state.layers.baseLayerImageCache = action.payload ? imageDTOToImageWithDims(action.payload) : null;
|
||||
layerImageCacheChanged: (state, action: PayloadAction<{ imageDTO: ImageDTO | null }>) => {
|
||||
const { imageDTO } = action.payload;
|
||||
state.layers.imageCache = imageDTO ? imageDTOToImageWithDims(imageDTO) : null;
|
||||
},
|
||||
} satisfies SliceCaseReducers<CanvasV2State>;
|
||||
|
@ -1,11 +1,7 @@
|
||||
import type { PayloadAction, SliceCaseReducers } from '@reduxjs/toolkit';
|
||||
import { moveOneToEnd, moveOneToStart, moveToEnd, moveToStart } from 'common/util/arrayUtils';
|
||||
import { getBrushLineId, getEraserLineId, getRectShapeId } from 'features/controlLayers/konva/naming';
|
||||
import type {
|
||||
CanvasV2State,
|
||||
CLIPVisionModelV2,
|
||||
IPMethodV2,
|
||||
} from 'features/controlLayers/store/types';
|
||||
import type { CanvasV2State, CLIPVisionModelV2, IPMethodV2 } from 'features/controlLayers/store/types';
|
||||
import { imageDTOToImageObject, imageDTOToImageWithDims } from 'features/controlLayers/store/types';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import type { ParameterAutoNegative } from 'features/parameters/types/parameterSchemas';
|
||||
@ -182,7 +178,7 @@ export const regionsReducers = {
|
||||
}
|
||||
rg.fill = fill;
|
||||
},
|
||||
rgMaskImageUploaded: (state, action: PayloadAction<{ id: string; imageDTO: ImageDTO }>) => {
|
||||
rgImageCacheChanged: (state, action: PayloadAction<{ id: string; imageDTO: ImageDTO }>) => {
|
||||
const { id, imageDTO } = action.payload;
|
||||
const rg = selectRG(state, id);
|
||||
if (!rg) {
|
||||
|
@ -797,7 +797,7 @@ export type CanvasV2State = {
|
||||
selectedEntityIdentifier: CanvasEntityIdentifier | null;
|
||||
inpaintMask: InpaintMaskEntity;
|
||||
layers: {
|
||||
baseLayerImageCache: ImageWithDims | null;
|
||||
imageCache: ImageWithDims | null;
|
||||
entities: LayerEntity[];
|
||||
};
|
||||
controlAdapters: { entities: ControlAdapterEntity[] };
|
||||
|
@ -1,96 +1,9 @@
|
||||
import { getStore } from 'app/store/nanostores/store';
|
||||
import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
|
||||
import { $nodeManager } from 'features/controlLayers/konva/renderers/renderer';
|
||||
import { blobToDataURL } from 'features/controlLayers/konva/util';
|
||||
import { baseLayerImageCacheChanged } from 'features/controlLayers/store/canvasV2Slice';
|
||||
import type { LayerEntity } from 'features/controlLayers/store/types';
|
||||
import type Konva from 'konva';
|
||||
import type { IRect } from 'konva/lib/types';
|
||||
import { getImageDTO, imagesApi } from 'services/api/endpoints/images';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
const isValidLayer = (entity: LayerEntity) => {
|
||||
export const isValidLayer = (entity: LayerEntity) => {
|
||||
return (
|
||||
entity.isEnabled &&
|
||||
// Boolean(entity.bbox) && TODO(psyche): Re-enable this check when we have a way to calculate bbox for all layers
|
||||
entity.objects.length > 0
|
||||
);
|
||||
};
|
||||
|
||||
/**
|
||||
* Get the blobs of all regional prompt layers. Only visible layers are returned.
|
||||
* @param layerIds The IDs of the layers to get blobs for. If not provided, all regional prompt layers are used.
|
||||
* @param preview Whether to open a new tab displaying each layer.
|
||||
* @returns A map of layer IDs to blobs.
|
||||
*/
|
||||
|
||||
const getBaseLayer = async (layers: LayerEntity[], bbox: IRect, preview: boolean = false): Promise<Blob> => {
|
||||
const manager = $nodeManager.get();
|
||||
assert(manager, 'Node manager is null');
|
||||
|
||||
const stage = manager.stage.clone();
|
||||
|
||||
stage.scaleX(1);
|
||||
stage.scaleY(1);
|
||||
stage.x(0);
|
||||
stage.y(0);
|
||||
|
||||
const validLayers = layers.filter(isValidLayer);
|
||||
|
||||
// Konva bug (?) - when iterating over the array returned from `stage.getLayers()`, if you destroy a layer, the array
|
||||
// is mutated in-place and the next iteration will skip the next layer. To avoid this, we first collect the layers
|
||||
// to delete in a separate array and then destroy them.
|
||||
// TODO(psyche): Maybe report this?
|
||||
const toDelete: Konva.Layer[] = [];
|
||||
|
||||
for (const konvaLayer of stage.getLayers()) {
|
||||
const layer = validLayers.find((l) => l.id === konvaLayer.id());
|
||||
if (!layer) {
|
||||
toDelete.push(konvaLayer);
|
||||
}
|
||||
}
|
||||
|
||||
for (const konvaLayer of toDelete) {
|
||||
konvaLayer.destroy();
|
||||
}
|
||||
|
||||
const blob = await new Promise<Blob>((resolve) => {
|
||||
stage.toBlob({
|
||||
callback: (blob) => {
|
||||
assert(blob, 'Blob is null');
|
||||
resolve(blob);
|
||||
},
|
||||
...bbox,
|
||||
});
|
||||
});
|
||||
|
||||
if (preview) {
|
||||
const base64 = await blobToDataURL(blob);
|
||||
openBase64ImageInTab([{ base64, caption: 'base layer' }]);
|
||||
}
|
||||
|
||||
stage.destroy();
|
||||
|
||||
return blob;
|
||||
};
|
||||
|
||||
export const getBaseLayerImage = async (): Promise<ImageDTO> => {
|
||||
const { dispatch, getState } = getStore();
|
||||
const state = getState();
|
||||
if (state.canvasV2.layers.baseLayerImageCache) {
|
||||
const imageDTO = await getImageDTO(state.canvasV2.layers.baseLayerImageCache.name);
|
||||
if (imageDTO) {
|
||||
return imageDTO;
|
||||
}
|
||||
}
|
||||
const blob = await getBaseLayer(state.canvasV2.layers.entities, state.canvasV2.bbox, true);
|
||||
const file = new File([blob], 'image.png', { type: 'image/png' });
|
||||
const req = dispatch(
|
||||
imagesApi.endpoints.uploadImage.initiate({ file, image_category: 'general', is_intermediate: true })
|
||||
);
|
||||
req.reset();
|
||||
const imageDTO = await req.unwrap();
|
||||
dispatch(baseLayerImageCacheChanged(imageDTO));
|
||||
return imageDTO;
|
||||
};
|
||||
|
@ -1,10 +1,5 @@
|
||||
import { getStore } from 'app/store/nanostores/store';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
|
||||
import type { KonvaEntityAdapter } from 'features/controlLayers/konva/nodeManager';
|
||||
import { $nodeManager } from 'features/controlLayers/konva/renderers/renderer';
|
||||
import { blobToDataURL } from 'features/controlLayers/konva/util';
|
||||
import { rgMaskImageUploaded } from 'features/controlLayers/store/canvasV2Slice';
|
||||
import type { KonvaNodeManager } from 'features/controlLayers/konva/nodeManager';
|
||||
import type { Dimensions, IPAdapterEntity, RegionEntity } from 'features/controlLayers/store/types';
|
||||
import {
|
||||
PROMPT_REGION_INVERT_TENSOR_MASK_PREFIX,
|
||||
@ -16,8 +11,7 @@ import {
|
||||
import { addIPAdapterCollectorSafe, isValidIPAdapter } from 'features/nodes/util/graph/generation/addIPAdapters';
|
||||
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import type { IRect } from 'konva/lib/types';
|
||||
import { getImageDTO, imagesApi } from 'services/api/endpoints/images';
|
||||
import type { BaseModelType, ImageDTO, Invocation } from 'services/api/types';
|
||||
import type { BaseModelType, Invocation } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
/**
|
||||
@ -34,6 +28,7 @@ import { assert } from 'tsafe';
|
||||
*/
|
||||
|
||||
export const addRegions = async (
|
||||
manager: KonvaNodeManager,
|
||||
regions: RegionEntity[],
|
||||
g: Graph,
|
||||
documentSize: Dimensions,
|
||||
@ -51,7 +46,7 @@ export const addRegions = async (
|
||||
|
||||
for (const region of validRegions) {
|
||||
// Upload the mask image, or get the cached image if it exists
|
||||
const { image_name } = await getRegionMaskImage(region, bbox, true);
|
||||
const { image_name } = await manager.util.getRegionMaskImage({ id: region.id, bbox, preview: true });
|
||||
|
||||
// The main mask-to-tensor node
|
||||
const maskToTensor = g.addNode({
|
||||
@ -217,90 +212,3 @@ export const isValidRegion = (rg: RegionEntity, base: BaseModelType) => {
|
||||
const hasIPAdapter = rg.ipAdapters.filter((ipa) => isValidIPAdapter(ipa, base)).length > 0;
|
||||
return hasTextPrompt || hasIPAdapter;
|
||||
};
|
||||
|
||||
export const getMaskImage = async (rg: RegionEntity, blob: Blob): Promise<ImageDTO> => {
|
||||
const { id, imageCache } = rg;
|
||||
if (imageCache) {
|
||||
const imageDTO = await getImageDTO(imageCache.name);
|
||||
if (imageDTO) {
|
||||
return imageDTO;
|
||||
}
|
||||
}
|
||||
const { dispatch } = getStore();
|
||||
// No cached mask, or the cached image no longer exists - we need to upload the mask image
|
||||
const file = new File([blob], `${rg.id}_mask.png`, { type: 'image/png' });
|
||||
const req = dispatch(
|
||||
imagesApi.endpoints.uploadImage.initiate({ file, image_category: 'mask', is_intermediate: true })
|
||||
);
|
||||
req.reset();
|
||||
|
||||
const imageDTO = await req.unwrap();
|
||||
dispatch(rgMaskImageUploaded({ id, imageDTO }));
|
||||
return imageDTO;
|
||||
};
|
||||
|
||||
export const uploadMaskImage = async ({ id }: RegionEntity, blob: Blob): Promise<ImageDTO> => {
|
||||
const { dispatch } = getStore();
|
||||
// No cached mask, or the cached image no longer exists - we need to upload the mask image
|
||||
const file = new File([blob], `${id}_mask.png`, { type: 'image/png' });
|
||||
const req = dispatch(
|
||||
imagesApi.endpoints.uploadImage.initiate({ file, image_category: 'mask', is_intermediate: true })
|
||||
);
|
||||
req.reset();
|
||||
|
||||
const imageDTO = await req.unwrap();
|
||||
dispatch(rgMaskImageUploaded({ id, imageDTO }));
|
||||
return imageDTO;
|
||||
};
|
||||
|
||||
/**
|
||||
* Get the blobs of all regional prompt layers. Only visible layers are returned.
|
||||
* @param layerIds The IDs of the layers to get blobs for. If not provided, all regional prompt layers are used.
|
||||
* @param preview Whether to open a new tab displaying each layer.
|
||||
* @returns A map of layer IDs to blobs.
|
||||
*/
|
||||
|
||||
export const getRegionMaskImage = async (
|
||||
region: RegionEntity,
|
||||
bbox: IRect,
|
||||
preview: boolean = false
|
||||
): Promise<ImageDTO> => {
|
||||
const manager = $nodeManager.get();
|
||||
assert(manager, 'Node manager is null');
|
||||
|
||||
// TODO(psyche): Why do I need to annotate this? TS must have some kind of circular ref w/ this type but I can't figure it out...
|
||||
const adapter: KonvaEntityAdapter | undefined = manager.get(region.id);
|
||||
assert(adapter, `Adapter for region ${region.id} not found`);
|
||||
if (region.imageCache) {
|
||||
const imageDTO = await getImageDTO(region.imageCache.name);
|
||||
if (imageDTO) {
|
||||
return imageDTO;
|
||||
}
|
||||
}
|
||||
const layer = adapter.konvaLayer.clone();
|
||||
const objectGroup = adapter.konvaObjectGroup.clone();
|
||||
layer.destroyChildren();
|
||||
layer.add(objectGroup);
|
||||
objectGroup.opacity(1);
|
||||
objectGroup.cache();
|
||||
|
||||
const blob = await new Promise<Blob>((resolve) => {
|
||||
layer.toBlob({
|
||||
callback: (blob) => {
|
||||
assert(blob, 'Blob is null');
|
||||
resolve(blob);
|
||||
},
|
||||
...bbox,
|
||||
});
|
||||
});
|
||||
|
||||
if (preview) {
|
||||
const base64 = await blobToDataURL(blob);
|
||||
const caption = `${region.id}: ${region.positivePrompt} / ${region.negativePrompt}`;
|
||||
openBase64ImageInTab([{ base64, caption }]);
|
||||
}
|
||||
|
||||
layer.destroy();
|
||||
|
||||
return await uploadMaskImage(region, blob);
|
||||
};
|
||||
|
@ -588,3 +588,16 @@ export const getImageDTO = async (image_name: string, forceRefetch?: boolean): P
|
||||
return null;
|
||||
}
|
||||
};
|
||||
|
||||
export const uploadImage = async (
|
||||
blob: Blob,
|
||||
fileName: string,
|
||||
image_category: ImageCategory,
|
||||
is_intermediate: boolean
|
||||
): Promise<ImageDTO> => {
|
||||
const { dispatch } = getStore();
|
||||
const file = new File([blob], fileName, { type: 'image/png' });
|
||||
const req = dispatch(imagesApi.endpoints.uploadImage.initiate({ file, image_category, is_intermediate }));
|
||||
req.reset();
|
||||
return await req.unwrap();
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user