diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts index 7709770d81..6bb27c0eaf 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts @@ -1,6 +1,7 @@ import type { TypedStartListening } from '@reduxjs/toolkit'; import { createListenerMiddleware } from '@reduxjs/toolkit'; import { addAdHocPostProcessingRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/addAdHocPostProcessingRequestedListener'; +import { addStagingListeners } from 'app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener'; import { addAnyEnqueuedListener } from 'app/store/middleware/listenerMiddleware/listeners/anyEnqueued'; import { addAppConfigReceivedListener } from 'app/store/middleware/listenerMiddleware/listeners/appConfigReceived'; import { addAppStartedListener } from 'app/store/middleware/listenerMiddleware/listeners/appStarted'; @@ -87,6 +88,7 @@ addBatchEnqueuedListener(startAppListening); // addCanvasMergedListener(startAppListening); // addStagingAreaImageSavedListener(startAppListening); // addCommitStagingAreaImageListener(startAppListening); +addStagingListeners(startAppListening); // Socket.IO addGeneratorProgressEventListener(startAppListening); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener.ts index 9095a08431..48e52a46aa 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener.ts @@ -1,30 +1,38 @@ -import { isAnyOf } from '@reduxjs/toolkit'; import { logger } from 'app/logging/logger'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import { - canvasBatchIdsReset, - commitStagingAreaImage, - discardStagedImages, - resetCanvas, - setInitialCanvasImage, -} from 'features/canvas/store/canvasSlice'; + layerAdded, + layerImageAdded, + stagingAreaImageAccepted, + stagingAreaReset, +} from 'features/controlLayers/store/canvasV2Slice'; import { toast } from 'features/toast/toast'; import { t } from 'i18next'; import { queueApi } from 'services/api/endpoints/queue'; +import { assert } from 'tsafe'; -const matcher = isAnyOf(commitStagingAreaImage, discardStagedImages, resetCanvas, setInitialCanvasImage); - -export const addCommitStagingAreaImageListener = (startAppListening: AppStartListening) => { +export const addStagingListeners = (startAppListening: AppStartListening) => { startAppListening({ - matcher, + actionCreator: stagingAreaReset, effect: async (_, { dispatch, getState }) => { const log = logger('canvas'); - const state = getState(); - const { batchIds } = state.canvas; + const stagingArea = getState().canvasV2.stagingArea; + + if (!stagingArea) { + // Should not happen + return; + } + + if (stagingArea.batchIds.length === 0) { + return; + } try { const req = dispatch( - queueApi.endpoints.cancelByBatchIds.initiate({ batch_ids: batchIds }, { fixedCacheKey: 'cancelByBatchIds' }) + queueApi.endpoints.cancelByBatchIds.initiate( + { batch_ids: stagingArea.batchIds }, + { fixedCacheKey: 'cancelByBatchIds' } + ) ); const { canceled } = await req.unwrap(); req.reset(); @@ -36,7 +44,6 @@ export const addCommitStagingAreaImageListener = (startAppListening: AppStartLis status: 'success', }); } - dispatch(canvasBatchIdsReset()); } catch { log.error('Failed to cancel canvas batches'); toast({ @@ -47,4 +54,32 @@ export const addCommitStagingAreaImageListener = (startAppListening: AppStartLis } }, }); + + startAppListening({ + actionCreator: stagingAreaImageAccepted, + effect: async (action, api) => { + const { imageDTO } = action.payload; + const { layers, stagingArea, selectedEntityIdentifier } = api.getState().canvasV2; + let layer = layers.entities.find((layer) => layer.id === selectedEntityIdentifier?.id); + + if (!layer) { + layer = layers.entities[0]; + } + + if (!layer) { + // We need to create a new layer to add the accepted image + api.dispatch(layerAdded()); + layer = layers.entities[0]; + } + + assert(layer, 'No layer found to stage image'); + assert(stagingArea, 'Staging should be defined'); + + const { x, y } = stagingArea.bbox; + const { id } = layer; + + api.dispatch(layerImageAdded({ id, imageDTO, pos: { x, y } })); + api.dispatch(stagingAreaReset()); + }, + }); }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts index 0e1544b17b..bbac10e2a1 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts @@ -1,6 +1,7 @@ import { enqueueRequested } from 'app/store/actions'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import { getNodeManager } from 'features/controlLayers/konva/nodeManager'; +import { stagingAreaInitialized } from 'features/controlLayers/store/canvasV2Slice'; import { isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice'; import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig'; import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph'; @@ -40,10 +41,19 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening) }) ); try { - await req.unwrap(); + const enqueueResult = await req.unwrap(); if (shouldShowProgressInViewer) { dispatch(isImageViewerOpenChanged(true)); } + // TODO(psyche): update the backend schema, this is always provided + const batchId = enqueueResult.batch.batch_id; + assert(batchId, 'No batch ID found in enqueue result'); + dispatch( + stagingAreaInitialized({ + batchIds: [batchId], + bbox: getState().canvasV2.bbox, + }) + ); } finally { req.reset(); } diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts index a82f3f265c..53aa9acf0e 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts @@ -2,6 +2,7 @@ import { logger } from 'app/logging/logger'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import { deepClone } from 'common/util/deepClone'; import { parseify } from 'common/util/serialize'; +import { stagingAreaImageAdded } from 'features/controlLayers/store/canvasV2Slice'; import { boardIdSelected, galleryViewChanged, @@ -11,10 +12,12 @@ import { } from 'features/gallery/store/gallerySlice'; import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState'; import { zNodeStatus } from 'features/nodes/types/invocation'; +import { CANVAS_OUTPUT } from 'features/nodes/util/graph/constants'; import { boardsApi } from 'services/api/endpoints/boards'; import { imagesApi } from 'services/api/endpoints/images'; import { getCategories, getListImagesUrl } from 'services/api/util'; import { socketInvocationComplete } from 'services/events/actions'; +import { assert } from 'tsafe'; // These nodes output an image, but do not actually *save* an image, so we don't want to handle the gallery logic on them const nodeTypeDenylist = ['load_image', 'image']; @@ -45,10 +48,11 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi imageDTORequest.unsubscribe(); // Add canvas images to the staging area - // TODO(psyche): canvas batchid processing - // if (canvas.batchIds.includes(data.batch_id) && data.invocation_source_id === CANVAS_OUTPUT) { - // dispatch(addImageToStagingArea(imageDTO)); - // } + if (canvasV2.stagingArea?.batchIds.includes(data.batch_id) && data.invocation_source_id === CANVAS_OUTPUT) { + const stagingArea = getState().canvasV2.stagingArea; + assert(stagingArea, 'Staging should be defined'); + dispatch(stagingAreaImageAdded({ imageDTO })); + } if (!imageDTO.is_intermediate) { // update the total images for the board diff --git a/invokeai/frontend/web/src/features/controlLayers/components/ControlLayersEditor.tsx b/invokeai/frontend/web/src/features/controlLayers/components/ControlLayersEditor.tsx index 41ee961b59..1a3b0c20f9 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/ControlLayersEditor.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/ControlLayersEditor.tsx @@ -2,6 +2,7 @@ import { Flex } from '@invoke-ai/ui-library'; import { ControlLayersToolbar } from 'features/controlLayers/components/ControlLayersToolbar'; import { StageComponent } from 'features/controlLayers/components/StageComponent'; +import { StagingAreaToolbar } from 'features/controlLayers/components/StagingArea/StagingAreaToolbar'; import { memo } from 'react'; export const ControlLayersEditor = memo(() => { @@ -17,6 +18,9 @@ export const ControlLayersEditor = memo(() => { > + + + ); }); diff --git a/invokeai/frontend/web/src/features/controlLayers/components/StagingArea/StagingAreaToolbar.tsx b/invokeai/frontend/web/src/features/controlLayers/components/StagingArea/StagingAreaToolbar.tsx new file mode 100644 index 0000000000..e7faac86e8 --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/components/StagingArea/StagingAreaToolbar.tsx @@ -0,0 +1,151 @@ +import { Button, ButtonGroup, IconButton } from '@invoke-ai/ui-library'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { + stagingAreaImageAccepted, + stagingAreaImageDiscarded, + stagingAreaNextImageSelected, + stagingAreaPreviousImageSelected, + stagingAreaReset, +} from 'features/controlLayers/store/canvasV2Slice'; +import type { CanvasV2State } from 'features/controlLayers/store/types'; +import { memo, useCallback, useMemo } from 'react'; +import { useHotkeys } from 'react-hotkeys-hook'; +import { useTranslation } from 'react-i18next'; +import { PiArrowLeftBold, PiArrowRightBold, PiCheckBold, PiTrashSimpleBold, PiXBold } from 'react-icons/pi'; + +export const StagingAreaToolbar = memo(() => { + const stagingArea = useAppSelector((s) => s.canvasV2.stagingArea); + + if (!stagingArea || stagingArea.images.length === 0) { + return null; + } + + return ; +}); + +StagingAreaToolbar.displayName = 'StagingAreaToolbar'; + +type Props = { + stagingArea: NonNullable; +}; + +export const StagingAreaToolbarContent = memo(({ stagingArea }: Props) => { + const dispatch = useAppDispatch(); + const images = useMemo(() => stagingArea.images, [stagingArea]); + const imageDTO = useMemo(() => { + if (stagingArea.selectedImageIndex === null) { + return null; + } + return images[stagingArea.selectedImageIndex] ?? null; + }, [images, stagingArea.selectedImageIndex]); + + const { t } = useTranslation(); + + const onPrev = useCallback(() => { + dispatch(stagingAreaPreviousImageSelected()); + }, [dispatch]); + + const onNext = useCallback(() => { + dispatch(stagingAreaNextImageSelected()); + }, [dispatch]); + + const onAccept = useCallback(() => { + if (!imageDTO || !stagingArea) { + return; + } + dispatch(stagingAreaImageAccepted({ imageDTO })); + }, [dispatch, imageDTO, stagingArea]); + + const onDiscardOne = useCallback(() => { + if (!imageDTO || !stagingArea) { + return; + } + if (images.length === 1) { + dispatch(stagingAreaReset()); + } else { + dispatch(stagingAreaImageDiscarded({ imageDTO })); + } + }, [dispatch, imageDTO, images.length, stagingArea]); + + const onDiscardAll = useCallback(() => { + if (!stagingArea) { + return; + } + dispatch(stagingAreaReset()); + }, [dispatch, stagingArea]); + + useHotkeys(['left'], onPrev, { + preventDefault: true, + }); + + useHotkeys(['right'], onNext, { + preventDefault: true, + }); + + useHotkeys(['enter'], onAccept, { + preventDefault: true, + }); + + return ( + <> + + } + onClick={onPrev} + colorScheme="invokeBlue" + /> + + } + onClick={onNext} + colorScheme="invokeBlue" + /> + + + } + onClick={onAccept} + colorScheme="invokeBlue" + /> + {/* } + onClick={handleSaveToGallery} + colorScheme="invokeBlue" + /> */} + } + onClick={onDiscardOne} + colorScheme="invokeBlue" + fontSize={16} + isDisabled={images.length <= 1} + /> + } + onClick={onDiscardAll} + colorScheme="error" + fontSize={16} + isDisabled={images.length === 0} + /> + + + ); +}); + +StagingAreaToolbarContent.displayName = 'StagingAreaToolbarContent'; diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/nodeManager.ts b/invokeai/frontend/web/src/features/controlLayers/konva/nodeManager.ts index 7e64696ad4..c927dd5edb 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/nodeManager.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/nodeManager.ts @@ -53,6 +53,7 @@ export type ImageObjectRecord = { konvaPlaceholderGroup: Konva.Group; konvaPlaceholderRect: Konva.Rect; konvaPlaceholderText: Konva.Text; + imageName: string | null; konvaImage: Konva.Image | null; // The image is loaded asynchronously, so it may not be available immediately isLoading: boolean; isError: boolean; @@ -69,6 +70,7 @@ type KonvaApi = { renderDocumentOverlay: () => void; renderBackground: () => void; renderToolPreview: () => void; + renderStagingArea: () => void; arrangeEntities: () => void; fitDocumentToStage: () => void; fitStageToContainer: () => void; @@ -102,6 +104,10 @@ type PreviewLayer = { innerRect: Konva.Rect; outerRect: Konva.Rect; }; + stagingArea: { + group: Konva.Group; + image: ImageObjectRecord | null; + }; }; type StateApi = { @@ -143,6 +149,7 @@ type StateApi = { getControlAdaptersState: () => CanvasV2State['controlAdapters']; getRegionsState: () => CanvasV2State['regions']; getInpaintMaskState: () => CanvasV2State['inpaintMask']; + getStagingAreaState: () => CanvasV2State['stagingArea']; onInpaintMaskImageCached: (imageDTO: ImageDTO) => void; onRegionMaskImageCached: (id: string, imageDTO: ImageDTO) => void; onLayerImageCached: (imageDTO: ImageDTO) => void; diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/renderers/objects.ts b/invokeai/frontend/web/src/features/controlLayers/konva/renderers/objects.ts index 5e5df4e96a..60d5038dbf 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/renderers/objects.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/renderers/objects.ts @@ -153,6 +153,7 @@ export const updateImageSource = async (arg: { const imageDTO = await getImageDTO(image.name); if (!imageDTO) { + objectRecord.imageName = null; objectRecord.isLoading = false; objectRecord.isError = true; objectRecord.konvaPlaceholderGroup.visible(true); @@ -173,6 +174,7 @@ export const updateImageSource = async (arg: { image: imageEl, }); objectRecord.konvaImageGroup.add(objectRecord.konvaImage); + objectRecord.imageName = image.name; } objectRecord.isLoading = false; objectRecord.isError = false; @@ -180,6 +182,7 @@ export const updateImageSource = async (arg: { onLoad?.(objectRecord.konvaImage); }; imageEl.onerror = () => { + objectRecord.imageName = null; objectRecord.isLoading = false; objectRecord.isError = true; objectRecord.konvaPlaceholderGroup.visible(true); @@ -189,6 +192,7 @@ export const updateImageSource = async (arg: { imageEl.id = image.name; imageEl.src = imageDTO.image_url; } catch { + objectRecord.imageName = null; objectRecord.isLoading = false; objectRecord.isError = true; objectRecord.konvaPlaceholderGroup.visible(true); @@ -218,7 +222,7 @@ export const createImageObjectGroup = (arg: { } const { id, image } = obj; const { width, height } = obj; - const konvaImageGroup = new Konva.Group({ id, name, listening: false }); + const konvaImageGroup = new Konva.Group({ id, name, listening: false, x: obj.x, y: obj.y }); const konvaPlaceholderGroup = new Konva.Group({ name: IMAGE_PLACEHOLDER_NAME, listening: false }); const konvaPlaceholderRect = new Konva.Rect({ fill: 'hsl(220 12% 45% / 1)', // 'base.500' @@ -246,6 +250,7 @@ export const createImageObjectGroup = (arg: { konvaPlaceholderRect, konvaPlaceholderText, konvaImage: null, + imageName: null, isLoading: false, isError: false, }); diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/renderers/renderer.ts b/invokeai/frontend/web/src/features/controlLayers/konva/renderers/renderer.ts index 11fb09e57f..890c93a580 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/renderers/renderer.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/renderers/renderer.ts @@ -22,6 +22,7 @@ import { } from 'features/controlLayers/konva/renderers/preview'; import { getRenderRegions } from 'features/controlLayers/konva/renderers/regions'; import { getFitDocumentToStage, getFitStageToContainer } from 'features/controlLayers/konva/renderers/stage'; +import { createStagingArea, getRenderStagingArea } from 'features/controlLayers/konva/renderers/stagingArea'; import { $stageAttrs, bboxChanged, @@ -259,6 +260,7 @@ export const initializeRenderer = ( const getControlAdaptersState = () => canvasV2.controlAdapters; const getInpaintMaskState = () => canvasV2.inpaintMask; const getMaskOpacity = () => canvasV2.settings.maskOpacity; + const getStagingAreaState = () => canvasV2.stagingArea; // Read-write state, ephemeral interaction state let isDrawing = false; @@ -307,6 +309,7 @@ export const initializeRenderer = ( bbox: createBboxNodes(stage, getBbox, onBboxTransformed, $shift.get, $ctrl.get, $meta.get, $alt.get), tool: createToolPreviewNodes(), documentOverlay: createDocumentOverlay(), + stagingArea: createStagingArea(), }; manager.preview.layer.add(manager.preview.bbox.group); manager.preview.layer.add(manager.preview.tool.group); @@ -329,6 +332,7 @@ export const initializeRenderer = ( getRegionsState, getMaskOpacity, getInpaintMaskState, + getStagingAreaState, // Read-write state setTool, @@ -376,6 +380,7 @@ export const initializeRenderer = ( renderBbox: getRenderBbox(manager), renderToolPreview: getRenderToolPreview(manager), renderDocumentOverlay: getRenderDocumentOverlay(manager), + renderStagingArea: getRenderStagingArea(manager), renderBackground: getRenderBackground(manager), arrangeEntities: getArrangeEntities(manager), fitDocumentToStage: getFitDocumentToStage(manager), diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/renderers/stagingArea.ts b/invokeai/frontend/web/src/features/controlLayers/konva/renderers/stagingArea.ts new file mode 100644 index 0000000000..d4178e6739 --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/konva/renderers/stagingArea.ts @@ -0,0 +1,41 @@ +import type { KonvaNodeManager } from 'features/controlLayers/konva/nodeManager'; +import { createImageObjectGroup, updateImageSource } from 'features/controlLayers/konva/renderers/objects'; +import { imageDTOToImageObject, imageDTOToImageWithDims } from 'features/controlLayers/store/types'; +import Konva from 'konva'; +import { assert } from 'tsafe'; + +export const createStagingArea = (): KonvaNodeManager['preview']['stagingArea'] => { + const group = new Konva.Group({ id: 'staging_area_group', listening: false }); + return { group, image: null }; +}; + +export const getRenderStagingArea = async (manager: KonvaNodeManager) => { + const { getStagingAreaState } = manager.stateApi; + const stagingArea = getStagingAreaState(); + + if (!stagingArea || stagingArea.selectedImageIndex === null) { + if (manager.preview.stagingArea.image) { + manager.preview.stagingArea.image.konvaImageGroup.visible(false); + manager.preview.stagingArea.image = null; + } + return; + } + + if (stagingArea.selectedImageIndex) { + const imageDTO = stagingArea.images[stagingArea.selectedImageIndex]; + assert(imageDTO, 'Image must exist'); + if (manager.preview.stagingArea.image) { + if (manager.preview.stagingArea.image.imageName !== imageDTO.image_name) { + await updateImageSource({ + objectRecord: manager.preview.stagingArea.image, + image: imageDTOToImageWithDims(imageDTO), + }); + } + } else { + manager.preview.stagingArea.image = await createImageObjectGroup({ + obj: imageDTOToImageObject(imageDTO), + name: imageDTO.image_name, + }); + } + } +}; diff --git a/invokeai/frontend/web/src/features/controlLayers/store/canvasV2Slice.ts b/invokeai/frontend/web/src/features/controlLayers/store/canvasV2Slice.ts index 2e28fe6370..8c579c1cb4 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/canvasV2Slice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/canvasV2Slice.ts @@ -13,6 +13,7 @@ import { lorasReducers } from 'features/controlLayers/store/lorasReducers'; import { paramsReducers } from 'features/controlLayers/store/paramsReducers'; import { regionsReducers } from 'features/controlLayers/store/regionsReducers'; import { settingsReducers } from 'features/controlLayers/store/settingsReducers'; +import { stagingAreaReducers } from 'features/controlLayers/store/stagingAreaReducers'; import { toolReducers } from 'features/controlLayers/store/toolReducers'; import { initialAspectRatioState } from 'features/parameters/components/ImageSize/constants'; import type { AspectRatioState } from 'features/parameters/components/ImageSize/types'; @@ -119,6 +120,7 @@ const initialState: CanvasV2State = { refinerNegativeAestheticScore: 2.5, refinerStart: 0.8, }, + stagingArea: null, }; export const canvasV2Slice = createSlice({ @@ -136,6 +138,7 @@ export const canvasV2Slice = createSlice({ ...toolReducers, ...bboxReducers, ...inpaintMaskReducers, + ...stagingAreaReducers, widthChanged: (state, action: PayloadAction<{ width: number; updateAspectRatio?: boolean; clamp?: boolean }>) => { const { width, updateAspectRatio, clamp } = action.payload; state.document.width = clamp ? Math.max(roundDownToMultiple(width, 8), 64) : width; @@ -327,6 +330,15 @@ export const { imEraserLineAdded, imLinePointAdded, imRectAdded, + // Staging + stagingAreaInitialized, + stagingAreaImageAdded, + stagingAreaBatchIdAdded, + stagingAreaImageDiscarded, + stagingAreaImageAccepted, + stagingAreaReset, + stagingAreaNextImageSelected, + stagingAreaPreviousImageSelected, } = canvasV2Slice.actions; export const selectCanvasV2Slice = (state: RootState) => state.canvasV2; diff --git a/invokeai/frontend/web/src/features/controlLayers/store/layersReducers.ts b/invokeai/frontend/web/src/features/controlLayers/store/layersReducers.ts index 21cbbd4d85..9edabb1acf 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/layersReducers.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/layersReducers.ts @@ -231,17 +231,27 @@ export const layersReducers = { prepare: (payload: RectShapeAddedArg) => ({ payload: { ...payload, rectId: uuidv4() } }), }, layerImageAdded: { - reducer: (state, action: PayloadAction) => { - const { id, objectId, imageDTO } = action.payload; + reducer: ( + state, + action: PayloadAction + ) => { + const { id, objectId, imageDTO, pos } = action.payload; const layer = selectLayer(state, id); if (!layer) { return; } - layer.objects.push(imageDTOToImageObject(id, objectId, imageDTO)); + const imageObject = imageDTOToImageObject(id, objectId, imageDTO); + if (pos) { + imageObject.x = pos.x; + imageObject.y = pos.y; + } + layer.objects.push(imageObject); layer.bboxNeedsUpdate = true; state.layers.imageCache = null; }, - prepare: (payload: ImageObjectAddedArg) => ({ payload: { ...payload, objectId: uuidv4() } }), + prepare: (payload: ImageObjectAddedArg & { pos?: { x: number; y: number } }) => ({ + payload: { ...payload, objectId: uuidv4() }, + }), }, layerImageCacheChanged: (state, action: PayloadAction<{ imageDTO: ImageDTO | null }>) => { const { imageDTO } = action.payload; diff --git a/invokeai/frontend/web/src/features/controlLayers/store/stagingAreaReducers.ts b/invokeai/frontend/web/src/features/controlLayers/store/stagingAreaReducers.ts new file mode 100644 index 0000000000..8b22d56b65 --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/store/stagingAreaReducers.ts @@ -0,0 +1,73 @@ +import type { PayloadAction, SliceCaseReducers } from '@reduxjs/toolkit'; +import type { CanvasV2State, Rect } from 'features/controlLayers/store/types'; +import type { ImageDTO } from 'services/api/types'; + +export const stagingAreaReducers = { + stagingAreaInitialized: (state, action: PayloadAction<{ bbox: Rect; batchIds: string[] }>) => { + const { bbox, batchIds } = action.payload; + state.stagingArea = { + bbox, + batchIds, + selectedImageIndex: null, + images: [], + }; + }, + stagingAreaImageAdded: (state, action: PayloadAction<{ imageDTO: ImageDTO }>) => { + const { imageDTO } = action.payload; + if (!state.stagingArea) { + // Should not happen + return; + } + state.stagingArea.images.push(imageDTO); + if (!state.stagingArea.selectedImageIndex) { + state.stagingArea.selectedImageIndex = state.stagingArea.images.length - 1; + } + }, + stagingAreaNextImageSelected: (state) => { + if (!state.stagingArea) { + // Should not happen + return; + } + if (state.stagingArea.selectedImageIndex === null) { + if (state.stagingArea.images.length > 0) { + state.stagingArea.selectedImageIndex = 0; + } + return; + } + state.stagingArea.selectedImageIndex = (state.stagingArea.selectedImageIndex + 1) % state.stagingArea.images.length; + }, + stagingAreaPreviousImageSelected: (state) => { + if (!state.stagingArea) { + // Should not happen + return; + } + if (state.stagingArea.selectedImageIndex === null) { + if (state.stagingArea.images.length > 0) { + state.stagingArea.selectedImageIndex = 0; + } + return; + } + state.stagingArea.selectedImageIndex = + (state.stagingArea.selectedImageIndex - 1 + state.stagingArea.images.length) % state.stagingArea.images.length; + }, + stagingAreaBatchIdAdded: (state, action: PayloadAction<{ batchId: string }>) => { + const { batchId } = action.payload; + if (!state.stagingArea) { + // Should not happen + return; + } + state.stagingArea.batchIds.push(batchId); + }, + stagingAreaImageDiscarded: (state, action: PayloadAction<{ imageDTO: ImageDTO }>) => { + const { imageDTO } = action.payload; + if (!state.stagingArea) { + // Should not happen + return; + } + state.stagingArea.images = state.stagingArea.images.filter((image) => image.image_name !== imageDTO.image_name); + }, + stagingAreaImageAccepted: (state, _: PayloadAction<{ imageDTO: ImageDTO }>) => state, + stagingAreaReset: (state) => { + state.stagingArea = null; + }, +} satisfies SliceCaseReducers; diff --git a/invokeai/frontend/web/src/features/controlLayers/store/types.ts b/invokeai/frontend/web/src/features/controlLayers/store/types.ts index 0f8b048a6c..4f27d72186 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/types.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/types.ts @@ -785,6 +785,11 @@ export type Size = { height: number; }; +export type Position = { + x: number; + y: number; +}; + export type LoRA = { id: string; isEnabled: boolean; @@ -877,6 +882,12 @@ export type CanvasV2State = { refinerNegativeAestheticScore: number; refinerStart: number; }; + stagingArea: { + bbox: Rect; + images: ImageDTO[]; + selectedImageIndex: number | null; + batchIds: string[]; + } | null; }; export type StageAttrs = { x: number; y: number; width: number; height: number; scale: number }; @@ -891,7 +902,7 @@ export type EraserLineAddedArg = { export type BrushLineAddedArg = EraserLineAddedArg & { color: RgbaColor }; export type PointAddedToLineArg = { id: string; point: [number, number] }; export type RectShapeAddedArg = { id: string; rect: IRect; color: RgbaColor }; -export type ImageObjectAddedArg = { id: string; imageDTO: ImageDTO }; +export type ImageObjectAddedArg = { id: string; imageDTO: ImageDTO; pos?: Position }; //#region Type guards export const isLine = (obj: RenderableObject): obj is BrushLine | EraserLine => { diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addRegions.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addRegions.ts index e5290422da..c7a8a3a002 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addRegions.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addRegions.ts @@ -1,6 +1,6 @@ import { deepClone } from 'common/util/deepClone'; import type { KonvaNodeManager } from 'features/controlLayers/konva/nodeManager'; -import type { Dimensions, IPAdapterEntity, RegionEntity } from 'features/controlLayers/store/types'; +import type { IPAdapterEntity, Rect, RegionEntity } from 'features/controlLayers/store/types'; import { PROMPT_REGION_INVERT_TENSOR_MASK_PREFIX, PROMPT_REGION_MASK_TO_TENSOR_PREFIX, @@ -10,7 +10,6 @@ import { } from 'features/nodes/util/graph/constants'; import { addIPAdapterCollectorSafe, isValidIPAdapter } from 'features/nodes/util/graph/generation/addIPAdapters'; import type { Graph } from 'features/nodes/util/graph/generation/Graph'; -import type { IRect } from 'konva/lib/types'; import type { BaseModelType, Invocation } from 'services/api/types'; import { assert } from 'tsafe'; @@ -31,8 +30,7 @@ export const addRegions = async ( manager: KonvaNodeManager, regions: RegionEntity[], g: Graph, - documentSize: Dimensions, - bbox: IRect, + bbox: Rect, base: BaseModelType, denoise: Invocation<'denoise_latents'>, posCond: Invocation<'compel'> | Invocation<'sdxl_compel_prompt'>, diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts index 894ea19d84..a954938c15 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts @@ -2,6 +2,7 @@ import type { RootState } from 'app/store/store'; import type { KonvaNodeManager } from 'features/controlLayers/konva/nodeManager'; import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; import { + CANVAS_OUTPUT, CLIP_SKIP, CONTROL_LAYERS_GRAPH, DENOISE_LATENTS, @@ -119,7 +120,7 @@ export const buildSD1Graph = async (state: RootState, manager: KonvaNodeManager) }) : null; - let imageOutput: Invocation<'l2i' | 'img_nsfw' | 'img_watermark' | 'img_resize' | 'canvas_paste_back'> = l2i; + let canvasOutput: Invocation<'l2i' | 'img_nsfw' | 'img_watermark' | 'img_resize' | 'canvas_paste_back'> = l2i; g.addEdge(modelLoader, 'unet', denoise, 'unet'); g.addEdge(modelLoader, 'clip', clipSkip, 'clip'); @@ -163,9 +164,9 @@ export const buildSD1Graph = async (state: RootState, manager: KonvaNodeManager) g.addEdge(vaeSource, 'vae', l2i, 'vae'); if (generationMode === 'txt2img') { - imageOutput = addTextToImage(g, l2i, originalSize, scaledSize); + canvasOutput = addTextToImage(g, l2i, originalSize, scaledSize); } else if (generationMode === 'img2img') { - imageOutput = await addImageToImage( + canvasOutput = await addImageToImage( g, manager, l2i, @@ -178,7 +179,7 @@ export const buildSD1Graph = async (state: RootState, manager: KonvaNodeManager) ); } else if (generationMode === 'inpaint') { const { compositing } = state.canvasV2; - imageOutput = await addInpaint( + canvasOutput = await addInpaint( g, manager, l2i, @@ -194,7 +195,7 @@ export const buildSD1Graph = async (state: RootState, manager: KonvaNodeManager) ); } else if (generationMode === 'outpaint') { const { compositing } = state.canvasV2; - imageOutput = await addOutpaint( + canvasOutput = await addOutpaint( g, manager, l2i, @@ -216,7 +217,6 @@ export const buildSD1Graph = async (state: RootState, manager: KonvaNodeManager) manager, state.canvasV2.regions.entities, g, - state.canvasV2.document, state.canvasV2.bbox, modelConfig.base, denoise, @@ -232,18 +232,21 @@ export const buildSD1Graph = async (state: RootState, manager: KonvaNodeManager) // } if (state.system.shouldUseNSFWChecker) { - imageOutput = addNSFWChecker(g, imageOutput); + canvasOutput = addNSFWChecker(g, canvasOutput); } if (state.system.shouldUseWatermarker) { - imageOutput = addWatermarker(g, imageOutput); + canvasOutput = addWatermarker(g, canvasOutput); } // This is the terminal node and must always save to gallery. - imageOutput.is_intermediate = false; - imageOutput.use_cache = false; - imageOutput.board = getBoardField(state); + g.updateNode(canvasOutput, { + id: CANVAS_OUTPUT, + is_intermediate: false, + use_cache: false, + board: getBoardField(state), + }); - g.setMetadataReceivingNode(imageOutput); + g.setMetadataReceivingNode(canvasOutput); return g.getGraph(); }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts index e8591a9b95..3d780a2f3d 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts @@ -2,6 +2,7 @@ import type { RootState } from 'app/store/store'; import type { KonvaNodeManager } from 'features/controlLayers/konva/nodeManager'; import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; import { + CANVAS_OUTPUT, LATENTS_TO_IMAGE, NEGATIVE_CONDITIONING, NEGATIVE_CONDITIONING_COLLECT, @@ -117,7 +118,7 @@ export const buildSDXLGraph = async (state: RootState, manager: KonvaNodeManager }) : null; - let imageOutput: Invocation<'l2i' | 'img_nsfw' | 'img_watermark' | 'img_resize' | 'canvas_paste_back'> = l2i; + let canvasOutput: Invocation<'l2i' | 'img_nsfw' | 'img_watermark' | 'img_resize' | 'canvas_paste_back'> = l2i; g.addEdge(modelLoader, 'unet', denoise, 'unet'); g.addEdge(modelLoader, 'clip', posCond, 'clip'); @@ -166,9 +167,9 @@ export const buildSDXLGraph = async (state: RootState, manager: KonvaNodeManager } if (generationMode === 'txt2img') { - imageOutput = addTextToImage(g, l2i, originalSize, scaledSize); + canvasOutput = addTextToImage(g, l2i, originalSize, scaledSize); } else if (generationMode === 'img2img') { - imageOutput = await addImageToImage( + canvasOutput = await addImageToImage( g, manager, l2i, @@ -181,7 +182,7 @@ export const buildSDXLGraph = async (state: RootState, manager: KonvaNodeManager ); } else if (generationMode === 'inpaint') { const { compositing } = state.canvasV2; - imageOutput = await addInpaint( + canvasOutput = await addInpaint( g, manager, l2i, @@ -197,7 +198,7 @@ export const buildSDXLGraph = async (state: RootState, manager: KonvaNodeManager ); } else if (generationMode === 'outpaint') { const { compositing } = state.canvasV2; - imageOutput = await addOutpaint( + canvasOutput = await addOutpaint( g, manager, l2i, @@ -219,7 +220,6 @@ export const buildSDXLGraph = async (state: RootState, manager: KonvaNodeManager manager, state.canvasV2.regions.entities, g, - state.canvasV2.document, state.canvasV2.bbox, modelConfig.base, denoise, @@ -230,18 +230,21 @@ export const buildSDXLGraph = async (state: RootState, manager: KonvaNodeManager ); if (state.system.shouldUseNSFWChecker) { - imageOutput = addNSFWChecker(g, imageOutput); + canvasOutput = addNSFWChecker(g, canvasOutput); } if (state.system.shouldUseWatermarker) { - imageOutput = addWatermarker(g, imageOutput); + canvasOutput = addWatermarker(g, canvasOutput); } // This is the terminal node and must always save to gallery. - imageOutput.is_intermediate = false; - imageOutput.use_cache = false; - imageOutput.board = getBoardField(state); + g.updateNode(canvasOutput, { + id: CANVAS_OUTPUT, + is_intermediate: false, + use_cache: false, + board: getBoardField(state), + }); - g.setMetadataReceivingNode(imageOutput); + g.setMetadataReceivingNode(canvasOutput); return g.getGraph(); };