diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts index 6da0b82dc3..b76d8acf85 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts @@ -1,10 +1,12 @@ import { enqueueRequested } from 'app/store/actions'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; +import { $nodeManager } from 'features/controlLayers/konva/renderers/renderer'; import { isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice'; import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig'; import { buildGenerationTabGraph } from 'features/nodes/util/graph/generation/buildGenerationTabGraph'; import { buildGenerationTabSDXLGraph } from 'features/nodes/util/graph/generation/buildGenerationTabSDXLGraph'; import { queueApi } from 'services/api/endpoints/queue'; +import { assert } from 'tsafe'; export const addEnqueueRequestedLinear = (startAppListening: AppStartListening) => { startAppListening({ @@ -18,10 +20,13 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening) let graph; + const manager = $nodeManager.get(); + assert(manager, 'Konva node manager not initialized'); + if (model?.base === 'sdxl') { - graph = await buildGenerationTabSDXLGraph(state); + graph = await buildGenerationTabSDXLGraph(state, manager); } else { - graph = await buildGenerationTabGraph(state); + graph = await buildGenerationTabGraph(state, manager); } const batchConfig = prepareLinearUIBatch(state, graph, prepend); diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/nodeManager.ts b/invokeai/frontend/web/src/features/controlLayers/konva/nodeManager.ts index c093f87ee0..19e095703c 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/nodeManager.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/nodeManager.ts @@ -1,3 +1,5 @@ +import openBase64ImageInTab from 'common/util/openBase64ImageInTab'; +import { blobToDataURL } from 'features/controlLayers/konva/util'; import type { BrushLine, BrushLineAddedArg, @@ -15,8 +17,11 @@ import type { StageAttrs, Tool, } from 'features/controlLayers/store/types'; +import { isValidLayer } from 'features/nodes/util/graph/generation/addLayers'; import type Konva from 'konva'; import type { Vector2d } from 'konva/lib/types'; +import { getImageDTO as defaultGetImageDTO, uploadImage as defaultUploadImage } from 'services/api/endpoints/images'; +import type { ImageCategory, ImageDTO } from 'services/api/types'; import { assert } from 'tsafe'; export type BrushLineObjectRecord = { @@ -132,24 +137,53 @@ type StateApi = { getMetaKey: () => boolean; getAltKey: () => boolean; getDocument: () => CanvasV2State['document']; - getLayerEntityStates: () => CanvasV2State['layers']['entities']; - getControlAdapterEntityStates: () => CanvasV2State['controlAdapters']['entities']; - getRegionEntityStates: () => CanvasV2State['regions']['entities']; - getInpaintMaskEntityState: () => CanvasV2State['inpaintMask']; + getLayersState: () => CanvasV2State['layers']; + getControlAdaptersState: () => CanvasV2State['controlAdapters']; + getRegionsState: () => CanvasV2State['regions']; + getInpaintMaskState: () => CanvasV2State['inpaintMask']; + onInpaintMaskImageCached: (imageDTO: ImageDTO) => void; + onRegionMaskImageCached: (id: string, imageDTO: ImageDTO) => void; + onLayerImageCached: (imageDTO: ImageDTO) => void; +}; + +type Util = { + getImageDTO: (imageName: string) => Promise; + uploadImage: ( + blob: Blob, + fileName: string, + image_category: ImageCategory, + is_intermediate: boolean + ) => Promise; + getRegionMaskImage: (arg: { id: string; bbox?: Rect; preview?: boolean }) => Promise; + getInpaintMaskImage: (arg: { bbox?: Rect; preview?: boolean }) => Promise; + getImageSourceImage: (arg: { bbox?: Rect; preview?: boolean }) => Promise; }; export class KonvaNodeManager { stage: Konva.Stage; container: HTMLDivElement; adapters: Map; + util: Util; _background: BackgroundLayer | null; _preview: PreviewLayer | null; _konvaApi: KonvaApi | null; _stateApi: StateApi | null; - constructor(stage: Konva.Stage, container: HTMLDivElement) { + constructor( + stage: Konva.Stage, + container: HTMLDivElement, + getImageDTO: Util['getImageDTO'] = defaultGetImageDTO, + uploadImage: Util['uploadImage'] = defaultUploadImage + ) { this.stage = stage; this.container = container; + this.util = { + getImageDTO, + uploadImage, + getRegionMaskImage: this._getRegionMaskImage.bind(this), + getInpaintMaskImage: this._getInpaintMaskImage.bind(this), + getImageSourceImage: this._getImageSourceImage.bind(this), + }; this._konvaApi = null; this._preview = null; this._background = null; @@ -219,6 +253,152 @@ export class KonvaNodeManager { assert(this._stateApi !== null, 'State API has not been set'); return this._stateApi; } + + async _getRegionMaskImage(arg: { id: string; bbox?: Rect; preview?: boolean }): Promise { + const { id, bbox, preview = false } = arg; + const region = this.stateApi.getRegionsState().entities.find((entity) => entity.id === id); + assert(region, `Region entity state with id ${id} not found`); + const adapter = this.get(region.id); + assert(adapter, `Adapter for region ${region.id} not found`); + + if (region.imageCache) { + const imageDTO = await this.util.getImageDTO(region.imageCache.name); + if (imageDTO) { + return imageDTO; + } + } + + const layer = adapter.konvaLayer.clone(); + const objectGroup = adapter.konvaObjectGroup.clone(); + layer.destroyChildren(); + layer.add(objectGroup); + objectGroup.opacity(1); + objectGroup.cache(); + + const blob = await new Promise((resolve) => { + layer.toBlob({ + callback: (blob) => { + assert(blob, 'Blob is null'); + resolve(blob); + }, + ...bbox, + }); + }); + + if (preview) { + const base64 = await blobToDataURL(blob); + const caption = `${region.id}: ${region.positivePrompt} / ${region.negativePrompt}`; + openBase64ImageInTab([{ base64, caption }]); + } + + layer.destroy(); + + const imageDTO = await this.util.uploadImage(blob, `${region.id}_mask.png`, 'mask', true); + this.stateApi.onRegionMaskImageCached(region.id, imageDTO); + return imageDTO; + } + + async _getInpaintMaskImage(arg: { bbox?: Rect; preview?: boolean }): Promise { + const { bbox, preview = false } = arg; + const inpaintMask = this.stateApi.getInpaintMaskState(); + const adapter = this.get(inpaintMask.id); + assert(adapter, `Adapter for ${inpaintMask.id} not found`); + + if (inpaintMask.imageCache) { + const imageDTO = await this.util.getImageDTO(inpaintMask.imageCache.name); + if (imageDTO) { + return imageDTO; + } + } + + const layer = adapter.konvaLayer.clone(); + const objectGroup = adapter.konvaObjectGroup.clone(); + layer.destroyChildren(); + layer.add(objectGroup); + objectGroup.opacity(1); + objectGroup.cache(); + + const blob = await new Promise((resolve) => { + layer.toBlob({ + callback: (blob) => { + assert(blob, 'Blob is null'); + resolve(blob); + }, + ...bbox, + }); + }); + + if (preview) { + const base64 = await blobToDataURL(blob); + const caption = 'inpaint mask'; + openBase64ImageInTab([{ base64, caption }]); + } + + layer.destroy(); + + const imageDTO = await this.util.uploadImage(blob, 'inpaint_mask.png', 'mask', true); + this.stateApi.onInpaintMaskImageCached(imageDTO); + return imageDTO; + } + + async _getImageSourceImage(arg: { bbox?: Rect; preview?: boolean }): Promise { + const { bbox, preview = false } = arg; + const layersState = this.stateApi.getLayersState(); + const { entities, imageCache } = layersState; + if (imageCache) { + const imageDTO = await this.util.getImageDTO(imageCache.name); + if (imageDTO) { + return imageDTO; + } + } + + const stage = this.stage.clone(); + + stage.scaleX(1); + stage.scaleY(1); + stage.x(0); + stage.y(0); + + const validLayers = entities.filter(isValidLayer); + + // Konva bug (?) - when iterating over the array returned from `stage.getLayers()`, if you destroy a layer, the array + // is mutated in-place and the next iteration will skip the next layer. To avoid this, we first collect the layers + // to delete in a separate array and then destroy them. + // TODO(psyche): Maybe report this? + const toDelete: Konva.Layer[] = []; + + for (const konvaLayer of stage.getLayers()) { + const layer = validLayers.find((l) => l.id === konvaLayer.id()); + if (!layer) { + toDelete.push(konvaLayer); + } + } + + for (const konvaLayer of toDelete) { + konvaLayer.destroy(); + } + + const blob = await new Promise((resolve) => { + stage.toBlob({ + callback: (blob) => { + assert(blob, 'Blob is null'); + resolve(blob); + }, + ...bbox, + }); + }); + + if (preview) { + const base64 = await blobToDataURL(blob); + openBase64ImageInTab([{ base64, caption: 'base layer' }]); + } + + stage.destroy(); + + const imageDTO = await this.util.uploadImage(blob, 'base_layer.png', 'general', true); + this.stateApi.onLayerImageCached(imageDTO); + return imageDTO; + } } export class KonvaEntityAdapter { diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/renderers/arrange.ts b/invokeai/frontend/web/src/features/controlLayers/konva/renderers/arrange.ts index dc640b10bb..2fa3039932 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/renderers/arrange.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/renderers/arrange.ts @@ -6,12 +6,12 @@ import type { KonvaNodeManager } from 'features/controlLayers/konva/nodeManager' * @returns An arrange entities function */ export const getArrangeEntities = (manager: KonvaNodeManager) => { - const { getLayerEntityStates, getControlAdapterEntityStates, getRegionEntityStates } = manager.stateApi; + const { getLayersState, getControlAdaptersState, getRegionsState } = manager.stateApi; function arrangeEntities(): void { - const layers = getLayerEntityStates(); - const controlAdapters = getControlAdapterEntityStates(); - const regions = getRegionEntityStates(); + const layers = getLayersState().entities; + const controlAdapters = getControlAdaptersState().entities; + const regions = getRegionsState().entities; let zIndex = 0; manager.background.layer.zIndex(++zIndex); for (const layer of layers) { diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/renderers/controlAdapters.ts b/invokeai/frontend/web/src/features/controlLayers/konva/renderers/controlAdapters.ts index d9db711147..db61ced5db 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/renderers/controlAdapters.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/renderers/controlAdapters.ts @@ -100,10 +100,10 @@ export const renderControlAdapter = async (manager: KonvaNodeManager, entity: Co * @returns A function to render all control adapters */ export const getRenderControlAdapters = (manager: KonvaNodeManager) => { - const { getControlAdapterEntityStates } = manager.stateApi; + const { getControlAdaptersState } = manager.stateApi; function renderControlAdapters(): void { - const entities = getControlAdapterEntityStates(); + const { entities } = getControlAdaptersState(); // Destroy nonexistent layers for (const adapters of manager.getAll('control_adapter')) { if (!entities.find((ca) => ca.id === adapters.id)) { diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/renderers/inpaintMask.ts b/invokeai/frontend/web/src/features/controlLayers/konva/renderers/inpaintMask.ts index 3c1c103ff4..fc7bed9a64 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/renderers/inpaintMask.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/renderers/inpaintMask.ts @@ -71,10 +71,10 @@ const getInpaintMask = ( * @returns A function to render the inpaint mask */ export const getRenderInpaintMask = (manager: KonvaNodeManager) => { - const { getInpaintMaskEntityState, getMaskOpacity, getToolState, getSelectedEntity, onPosChanged } = manager.stateApi; + const { getInpaintMaskState, getMaskOpacity, getToolState, getSelectedEntity, onPosChanged } = manager.stateApi; function renderInpaintMask(): void { - const entity = getInpaintMaskEntityState(); + const entity = getInpaintMaskState(); const globalMaskLayerOpacity = getMaskOpacity(); const toolState = getToolState(); const selectedEntity = getSelectedEntity(); diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/renderers/layers.ts b/invokeai/frontend/web/src/features/controlLayers/konva/renderers/layers.ts index eff2f7de85..8dfb803c5b 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/renderers/layers.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/renderers/layers.ts @@ -136,10 +136,10 @@ export const renderLayer = async ( * @returns A function to render all layers */ export const getRenderLayers = (manager: KonvaNodeManager) => { - const { getLayerEntityStates, getToolState, onPosChanged } = manager.stateApi; + const { getLayersState, getToolState, onPosChanged } = manager.stateApi; function renderLayers(): void { - const entities = getLayerEntityStates(); + const { entities } = getLayersState(); const tool = getToolState(); // Destroy nonexistent layers for (const adapter of manager.getAll('layer')) { diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/renderers/regions.ts b/invokeai/frontend/web/src/features/controlLayers/konva/renderers/regions.ts index a6087b320b..70b8cfb51b 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/renderers/regions.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/renderers/regions.ts @@ -233,10 +233,10 @@ export const renderRegion = ( * @returns A function to render all regions */ export const getRenderRegions = (manager: KonvaNodeManager) => { - const { getRegionEntityStates, getMaskOpacity, getToolState, getSelectedEntity, onPosChanged } = manager.stateApi; + const { getRegionsState, getMaskOpacity, getToolState, getSelectedEntity, onPosChanged } = manager.stateApi; function renderRegions(): void { - const entities = getRegionEntityStates(); + const { entities } = getRegionsState(); const maskOpacity = getMaskOpacity(); const toolState = getToolState(); const selectedEntity = getSelectedEntity(); diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/renderers/renderer.ts b/invokeai/frontend/web/src/features/controlLayers/konva/renderers/renderer.ts index dcd876de01..939d9d776c 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/renderers/renderer.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/renderers/renderer.ts @@ -32,17 +32,20 @@ import { imBboxChanged, imBrushLineAdded, imEraserLineAdded, + imImageCacheChanged, imLinePointAdded, imTranslated, layerBboxChanged, layerBrushLineAdded, layerEraserLineAdded, + layerImageCacheChanged, layerLinePointAdded, layerRectAdded, layerTranslated, rgBboxChanged, rgBrushLineAdded, rgEraserLineAdded, + rgImageCacheChanged, rgLinePointAdded, rgRectAdded, rgTranslated, @@ -65,6 +68,7 @@ import type { IRect, Vector2d } from 'konva/lib/types'; import { debounce } from 'lodash-es'; import { atom } from 'nanostores'; import type { RgbaColor } from 'react-colorful'; +import type { ImageDTO } from 'services/api/types'; export const $nodeManager = atom(null); @@ -175,6 +179,19 @@ export const initializeRenderer = ( logIfDebugging('Eraser width changed'); dispatch(eraserWidthChanged(width)); }; + const onRegionMaskImageCached = (id: string, imageDTO: ImageDTO) => { + logIfDebugging('Region mask image cached'); + dispatch(rgImageCacheChanged({ id, imageDTO })); + }; + const onInpaintMaskImageCached = (imageDTO: ImageDTO) => { + logIfDebugging('Inpaint mask image cached'); + dispatch(imImageCacheChanged({ imageDTO })); + }; + const onLayerImageCached = (imageDTO: ImageDTO) => { + logIfDebugging('Layer image cached'); + dispatch(layerImageCacheChanged({ imageDTO })); + }; + const setTool = (tool: Tool) => { logIfDebugging('Tool selection changed'); dispatch(toolChanged(tool)); @@ -240,11 +257,11 @@ export const initializeRenderer = ( const getDocument = () => canvasV2.document; const getToolState = () => canvasV2.tool; const getSettings = () => canvasV2.settings; - const getRegionEntityStates = () => canvasV2.regions.entities; - const getLayerEntityStates = () => canvasV2.layers.entities; - const getControlAdapterEntityStates = () => canvasV2.controlAdapters.entities; + const getRegionsState = () => canvasV2.regions; + const getLayersState = () => canvasV2.layers; + const getControlAdaptersState = () => canvasV2.controlAdapters; + const getInpaintMaskState = () => canvasV2.inpaintMask; const getMaskOpacity = () => canvasV2.settings.maskOpacity; - const getInpaintMaskEntityState = () => canvasV2.inpaintMask; // Read-write state, ephemeral interaction state let isDrawing = false; @@ -309,12 +326,12 @@ export const initializeRenderer = ( getCtrlKey: $ctrl.get, getMetaKey: $meta.get, getShiftKey: $shift.get, - getControlAdapterEntityStates, + getControlAdaptersState, getDocument, - getLayerEntityStates, - getRegionEntityStates, + getLayersState, + getRegionsState, getMaskOpacity, - getInpaintMaskEntityState, + getInpaintMaskState, // Read-write state setTool, @@ -342,6 +359,9 @@ export const initializeRenderer = ( onEraserWidthChanged, onPosChanged, onBboxTransformed, + onRegionMaskImageCached, + onInpaintMaskImageCached, + onLayerImageCached, }; const cleanupListeners = setStageEventHandlers(manager); diff --git a/invokeai/frontend/web/src/features/controlLayers/store/canvasV2Slice.ts b/invokeai/frontend/web/src/features/controlLayers/store/canvasV2Slice.ts index 20bbeade0a..60d7151815 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/canvasV2Slice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/canvasV2Slice.ts @@ -24,7 +24,7 @@ import { DEFAULT_RGBA_COLOR } from './types'; const initialState: CanvasV2State = { _version: 3, selectedEntityIdentifier: { type: 'inpaint_mask', id: 'inpaint_mask' }, - layers: { entities: [], baseLayerImageCache: null }, + layers: { entities: [], imageCache: null }, controlAdapters: { entities: [] }, ipAdapters: { entities: [] }, regions: { entities: [] }, @@ -161,7 +161,7 @@ export const canvasV2Slice = createSlice({ allEntitiesDeleted: (state) => { state.regions.entities = []; state.layers.entities = []; - state.layers.baseLayerImageCache = null; + state.layers.imageCache = null; state.ipAdapters.entities = []; state.controlAdapters.entities = []; }, @@ -185,7 +185,6 @@ export const { scaledBboxChanged, bboxScaleMethodChanged, clipToBboxChanged, - baseLayerImageCacheChanged, // layers layerAdded, layerRecalled, @@ -205,6 +204,7 @@ export const { layerRectAdded, layerImageAdded, layerAllDeleted, + layerImageCacheChanged, // IP Adapters ipaAdded, ipaRecalled, @@ -255,7 +255,7 @@ export const { rgPositivePromptChanged, rgNegativePromptChanged, rgFillChanged, - rgMaskImageUploaded, + rgImageCacheChanged, rgAutoNegativeChanged, rgIPAdapterAdded, rgIPAdapterDeleted, diff --git a/invokeai/frontend/web/src/features/controlLayers/store/layersReducers.ts b/invokeai/frontend/web/src/features/controlLayers/store/layersReducers.ts index c7303f00cc..21cbbd4d85 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/layersReducers.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/layersReducers.ts @@ -40,7 +40,7 @@ export const layersReducers = { y: 0, }); state.selectedEntityIdentifier = { type: 'layer', id }; - state.layers.baseLayerImageCache = null; + state.layers.imageCache = null; }, prepare: () => ({ payload: { id: uuidv4() } }), }, @@ -48,7 +48,7 @@ export const layersReducers = { const { data } = action.payload; state.layers.entities.push(data); state.selectedEntityIdentifier = { type: 'layer', id: data.id }; - state.layers.baseLayerImageCache = null; + state.layers.imageCache = null; }, layerIsEnabledToggled: (state, action: PayloadAction<{ id: string }>) => { const { id } = action.payload; @@ -57,7 +57,7 @@ export const layersReducers = { return; } layer.isEnabled = !layer.isEnabled; - state.layers.baseLayerImageCache = null; + state.layers.imageCache = null; }, layerTranslated: (state, action: PayloadAction<{ id: string; x: number; y: number }>) => { const { id, x, y } = action.payload; @@ -67,7 +67,7 @@ export const layersReducers = { } layer.x = x; layer.y = y; - state.layers.baseLayerImageCache = null; + state.layers.imageCache = null; }, layerBboxChanged: (state, action: PayloadAction<{ id: string; bbox: IRect | null }>) => { const { id, bbox } = action.payload; @@ -93,16 +93,16 @@ export const layersReducers = { layer.objects = []; layer.bbox = null; layer.bboxNeedsUpdate = false; - state.layers.baseLayerImageCache = null; + state.layers.imageCache = null; }, layerDeleted: (state, action: PayloadAction<{ id: string }>) => { const { id } = action.payload; state.layers.entities = state.layers.entities.filter((l) => l.id !== id); - state.layers.baseLayerImageCache = null; + state.layers.imageCache = null; }, layerAllDeleted: (state) => { state.layers.entities = []; - state.layers.baseLayerImageCache = null; + state.layers.imageCache = null; }, layerOpacityChanged: (state, action: PayloadAction<{ id: string; opacity: number }>) => { const { id, opacity } = action.payload; @@ -111,7 +111,7 @@ export const layersReducers = { return; } layer.opacity = opacity; - state.layers.baseLayerImageCache = null; + state.layers.imageCache = null; }, layerMovedForwardOne: (state, action: PayloadAction<{ id: string }>) => { const { id } = action.payload; @@ -120,7 +120,7 @@ export const layersReducers = { return; } moveOneToEnd(state.layers.entities, layer); - state.layers.baseLayerImageCache = null; + state.layers.imageCache = null; }, layerMovedToFront: (state, action: PayloadAction<{ id: string }>) => { const { id } = action.payload; @@ -129,7 +129,7 @@ export const layersReducers = { return; } moveToEnd(state.layers.entities, layer); - state.layers.baseLayerImageCache = null; + state.layers.imageCache = null; }, layerMovedBackwardOne: (state, action: PayloadAction<{ id: string }>) => { const { id } = action.payload; @@ -138,7 +138,7 @@ export const layersReducers = { return; } moveOneToStart(state.layers.entities, layer); - state.layers.baseLayerImageCache = null; + state.layers.imageCache = null; }, layerMovedToBack: (state, action: PayloadAction<{ id: string }>) => { const { id } = action.payload; @@ -147,7 +147,7 @@ export const layersReducers = { return; } moveToStart(state.layers.entities, layer); - state.layers.baseLayerImageCache = null; + state.layers.imageCache = null; }, layerBrushLineAdded: { reducer: (state, action: PayloadAction) => { @@ -166,7 +166,7 @@ export const layersReducers = { clip, }); layer.bboxNeedsUpdate = true; - state.layers.baseLayerImageCache = null; + state.layers.imageCache = null; }, prepare: (payload: BrushLineAddedArg) => ({ payload: { ...payload, lineId: uuidv4() }, @@ -188,7 +188,7 @@ export const layersReducers = { clip, }); layer.bboxNeedsUpdate = true; - state.layers.baseLayerImageCache = null; + state.layers.imageCache = null; }, prepare: (payload: EraserLineAddedArg) => ({ payload: { ...payload, lineId: uuidv4() }, @@ -206,7 +206,7 @@ export const layersReducers = { } lastObject.points.push(...point); layer.bboxNeedsUpdate = true; - state.layers.baseLayerImageCache = null; + state.layers.imageCache = null; }, layerRectAdded: { reducer: (state, action: PayloadAction) => { @@ -226,7 +226,7 @@ export const layersReducers = { color, }); layer.bboxNeedsUpdate = true; - state.layers.baseLayerImageCache = null; + state.layers.imageCache = null; }, prepare: (payload: RectShapeAddedArg) => ({ payload: { ...payload, rectId: uuidv4() } }), }, @@ -239,11 +239,12 @@ export const layersReducers = { } layer.objects.push(imageDTOToImageObject(id, objectId, imageDTO)); layer.bboxNeedsUpdate = true; - state.layers.baseLayerImageCache = null; + state.layers.imageCache = null; }, prepare: (payload: ImageObjectAddedArg) => ({ payload: { ...payload, objectId: uuidv4() } }), }, - baseLayerImageCacheChanged: (state, action: PayloadAction) => { - state.layers.baseLayerImageCache = action.payload ? imageDTOToImageWithDims(action.payload) : null; + layerImageCacheChanged: (state, action: PayloadAction<{ imageDTO: ImageDTO | null }>) => { + const { imageDTO } = action.payload; + state.layers.imageCache = imageDTO ? imageDTOToImageWithDims(imageDTO) : null; }, } satisfies SliceCaseReducers; diff --git a/invokeai/frontend/web/src/features/controlLayers/store/regionsReducers.ts b/invokeai/frontend/web/src/features/controlLayers/store/regionsReducers.ts index 752e6a0af1..49aa761246 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/regionsReducers.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/regionsReducers.ts @@ -1,11 +1,7 @@ import type { PayloadAction, SliceCaseReducers } from '@reduxjs/toolkit'; import { moveOneToEnd, moveOneToStart, moveToEnd, moveToStart } from 'common/util/arrayUtils'; import { getBrushLineId, getEraserLineId, getRectShapeId } from 'features/controlLayers/konva/naming'; -import type { - CanvasV2State, - CLIPVisionModelV2, - IPMethodV2, -} from 'features/controlLayers/store/types'; +import type { CanvasV2State, CLIPVisionModelV2, IPMethodV2 } from 'features/controlLayers/store/types'; import { imageDTOToImageObject, imageDTOToImageWithDims } from 'features/controlLayers/store/types'; import { zModelIdentifierField } from 'features/nodes/types/common'; import type { ParameterAutoNegative } from 'features/parameters/types/parameterSchemas'; @@ -182,7 +178,7 @@ export const regionsReducers = { } rg.fill = fill; }, - rgMaskImageUploaded: (state, action: PayloadAction<{ id: string; imageDTO: ImageDTO }>) => { + rgImageCacheChanged: (state, action: PayloadAction<{ id: string; imageDTO: ImageDTO }>) => { const { id, imageDTO } = action.payload; const rg = selectRG(state, id); if (!rg) { diff --git a/invokeai/frontend/web/src/features/controlLayers/store/types.ts b/invokeai/frontend/web/src/features/controlLayers/store/types.ts index f8410bb4b9..1964216489 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/types.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/types.ts @@ -797,7 +797,7 @@ export type CanvasV2State = { selectedEntityIdentifier: CanvasEntityIdentifier | null; inpaintMask: InpaintMaskEntity; layers: { - baseLayerImageCache: ImageWithDims | null; + imageCache: ImageWithDims | null; entities: LayerEntity[]; }; controlAdapters: { entities: ControlAdapterEntity[] }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addLayers.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addLayers.ts index d665d53d25..64cf859383 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addLayers.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addLayers.ts @@ -1,96 +1,9 @@ -import { getStore } from 'app/store/nanostores/store'; -import openBase64ImageInTab from 'common/util/openBase64ImageInTab'; -import { $nodeManager } from 'features/controlLayers/konva/renderers/renderer'; -import { blobToDataURL } from 'features/controlLayers/konva/util'; -import { baseLayerImageCacheChanged } from 'features/controlLayers/store/canvasV2Slice'; import type { LayerEntity } from 'features/controlLayers/store/types'; -import type Konva from 'konva'; -import type { IRect } from 'konva/lib/types'; -import { getImageDTO, imagesApi } from 'services/api/endpoints/images'; -import type { ImageDTO } from 'services/api/types'; -import { assert } from 'tsafe'; -const isValidLayer = (entity: LayerEntity) => { +export const isValidLayer = (entity: LayerEntity) => { return ( entity.isEnabled && // Boolean(entity.bbox) && TODO(psyche): Re-enable this check when we have a way to calculate bbox for all layers entity.objects.length > 0 ); }; - -/** - * Get the blobs of all regional prompt layers. Only visible layers are returned. - * @param layerIds The IDs of the layers to get blobs for. If not provided, all regional prompt layers are used. - * @param preview Whether to open a new tab displaying each layer. - * @returns A map of layer IDs to blobs. - */ - -const getBaseLayer = async (layers: LayerEntity[], bbox: IRect, preview: boolean = false): Promise => { - const manager = $nodeManager.get(); - assert(manager, 'Node manager is null'); - - const stage = manager.stage.clone(); - - stage.scaleX(1); - stage.scaleY(1); - stage.x(0); - stage.y(0); - - const validLayers = layers.filter(isValidLayer); - - // Konva bug (?) - when iterating over the array returned from `stage.getLayers()`, if you destroy a layer, the array - // is mutated in-place and the next iteration will skip the next layer. To avoid this, we first collect the layers - // to delete in a separate array and then destroy them. - // TODO(psyche): Maybe report this? - const toDelete: Konva.Layer[] = []; - - for (const konvaLayer of stage.getLayers()) { - const layer = validLayers.find((l) => l.id === konvaLayer.id()); - if (!layer) { - toDelete.push(konvaLayer); - } - } - - for (const konvaLayer of toDelete) { - konvaLayer.destroy(); - } - - const blob = await new Promise((resolve) => { - stage.toBlob({ - callback: (blob) => { - assert(blob, 'Blob is null'); - resolve(blob); - }, - ...bbox, - }); - }); - - if (preview) { - const base64 = await blobToDataURL(blob); - openBase64ImageInTab([{ base64, caption: 'base layer' }]); - } - - stage.destroy(); - - return blob; -}; - -export const getBaseLayerImage = async (): Promise => { - const { dispatch, getState } = getStore(); - const state = getState(); - if (state.canvasV2.layers.baseLayerImageCache) { - const imageDTO = await getImageDTO(state.canvasV2.layers.baseLayerImageCache.name); - if (imageDTO) { - return imageDTO; - } - } - const blob = await getBaseLayer(state.canvasV2.layers.entities, state.canvasV2.bbox, true); - const file = new File([blob], 'image.png', { type: 'image/png' }); - const req = dispatch( - imagesApi.endpoints.uploadImage.initiate({ file, image_category: 'general', is_intermediate: true }) - ); - req.reset(); - const imageDTO = await req.unwrap(); - dispatch(baseLayerImageCacheChanged(imageDTO)); - return imageDTO; -}; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addRegions.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addRegions.ts index ae99dd6e32..e5290422da 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addRegions.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addRegions.ts @@ -1,10 +1,5 @@ -import { getStore } from 'app/store/nanostores/store'; import { deepClone } from 'common/util/deepClone'; -import openBase64ImageInTab from 'common/util/openBase64ImageInTab'; -import type { KonvaEntityAdapter } from 'features/controlLayers/konva/nodeManager'; -import { $nodeManager } from 'features/controlLayers/konva/renderers/renderer'; -import { blobToDataURL } from 'features/controlLayers/konva/util'; -import { rgMaskImageUploaded } from 'features/controlLayers/store/canvasV2Slice'; +import type { KonvaNodeManager } from 'features/controlLayers/konva/nodeManager'; import type { Dimensions, IPAdapterEntity, RegionEntity } from 'features/controlLayers/store/types'; import { PROMPT_REGION_INVERT_TENSOR_MASK_PREFIX, @@ -16,8 +11,7 @@ import { import { addIPAdapterCollectorSafe, isValidIPAdapter } from 'features/nodes/util/graph/generation/addIPAdapters'; import type { Graph } from 'features/nodes/util/graph/generation/Graph'; import type { IRect } from 'konva/lib/types'; -import { getImageDTO, imagesApi } from 'services/api/endpoints/images'; -import type { BaseModelType, ImageDTO, Invocation } from 'services/api/types'; +import type { BaseModelType, Invocation } from 'services/api/types'; import { assert } from 'tsafe'; /** @@ -34,6 +28,7 @@ import { assert } from 'tsafe'; */ export const addRegions = async ( + manager: KonvaNodeManager, regions: RegionEntity[], g: Graph, documentSize: Dimensions, @@ -51,7 +46,7 @@ export const addRegions = async ( for (const region of validRegions) { // Upload the mask image, or get the cached image if it exists - const { image_name } = await getRegionMaskImage(region, bbox, true); + const { image_name } = await manager.util.getRegionMaskImage({ id: region.id, bbox, preview: true }); // The main mask-to-tensor node const maskToTensor = g.addNode({ @@ -217,90 +212,3 @@ export const isValidRegion = (rg: RegionEntity, base: BaseModelType) => { const hasIPAdapter = rg.ipAdapters.filter((ipa) => isValidIPAdapter(ipa, base)).length > 0; return hasTextPrompt || hasIPAdapter; }; - -export const getMaskImage = async (rg: RegionEntity, blob: Blob): Promise => { - const { id, imageCache } = rg; - if (imageCache) { - const imageDTO = await getImageDTO(imageCache.name); - if (imageDTO) { - return imageDTO; - } - } - const { dispatch } = getStore(); - // No cached mask, or the cached image no longer exists - we need to upload the mask image - const file = new File([blob], `${rg.id}_mask.png`, { type: 'image/png' }); - const req = dispatch( - imagesApi.endpoints.uploadImage.initiate({ file, image_category: 'mask', is_intermediate: true }) - ); - req.reset(); - - const imageDTO = await req.unwrap(); - dispatch(rgMaskImageUploaded({ id, imageDTO })); - return imageDTO; -}; - -export const uploadMaskImage = async ({ id }: RegionEntity, blob: Blob): Promise => { - const { dispatch } = getStore(); - // No cached mask, or the cached image no longer exists - we need to upload the mask image - const file = new File([blob], `${id}_mask.png`, { type: 'image/png' }); - const req = dispatch( - imagesApi.endpoints.uploadImage.initiate({ file, image_category: 'mask', is_intermediate: true }) - ); - req.reset(); - - const imageDTO = await req.unwrap(); - dispatch(rgMaskImageUploaded({ id, imageDTO })); - return imageDTO; -}; - -/** - * Get the blobs of all regional prompt layers. Only visible layers are returned. - * @param layerIds The IDs of the layers to get blobs for. If not provided, all regional prompt layers are used. - * @param preview Whether to open a new tab displaying each layer. - * @returns A map of layer IDs to blobs. - */ - -export const getRegionMaskImage = async ( - region: RegionEntity, - bbox: IRect, - preview: boolean = false -): Promise => { - const manager = $nodeManager.get(); - assert(manager, 'Node manager is null'); - - // TODO(psyche): Why do I need to annotate this? TS must have some kind of circular ref w/ this type but I can't figure it out... - const adapter: KonvaEntityAdapter | undefined = manager.get(region.id); - assert(adapter, `Adapter for region ${region.id} not found`); - if (region.imageCache) { - const imageDTO = await getImageDTO(region.imageCache.name); - if (imageDTO) { - return imageDTO; - } - } - const layer = adapter.konvaLayer.clone(); - const objectGroup = adapter.konvaObjectGroup.clone(); - layer.destroyChildren(); - layer.add(objectGroup); - objectGroup.opacity(1); - objectGroup.cache(); - - const blob = await new Promise((resolve) => { - layer.toBlob({ - callback: (blob) => { - assert(blob, 'Blob is null'); - resolve(blob); - }, - ...bbox, - }); - }); - - if (preview) { - const base64 = await blobToDataURL(blob); - const caption = `${region.id}: ${region.positivePrompt} / ${region.negativePrompt}`; - openBase64ImageInTab([{ base64, caption }]); - } - - layer.destroy(); - - return await uploadMaskImage(region, blob); -}; diff --git a/invokeai/frontend/web/src/services/api/endpoints/images.ts b/invokeai/frontend/web/src/services/api/endpoints/images.ts index 9672acc149..9ba144ecd8 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/images.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/images.ts @@ -588,3 +588,16 @@ export const getImageDTO = async (image_name: string, forceRefetch?: boolean): P return null; } }; + +export const uploadImage = async ( + blob: Blob, + fileName: string, + image_category: ImageCategory, + is_intermediate: boolean +): Promise => { + const { dispatch } = getStore(); + const file = new File([blob], fileName, { type: 'image/png' }); + const req = dispatch(imagesApi.endpoints.uploadImage.initiate({ file, image_category, is_intermediate })); + req.reset(); + return await req.unwrap(); +};