feat(ui): staging area (rendering wip)

This commit is contained in:
psychedelicious 2024-06-26 10:48:16 +10:00
parent 62310e7929
commit d497da0e61
17 changed files with 428 additions and 54 deletions

View File

@ -1,6 +1,7 @@
import type { TypedStartListening } from '@reduxjs/toolkit'; import type { TypedStartListening } from '@reduxjs/toolkit';
import { createListenerMiddleware } from '@reduxjs/toolkit'; import { createListenerMiddleware } from '@reduxjs/toolkit';
import { addAdHocPostProcessingRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/addAdHocPostProcessingRequestedListener'; import { addAdHocPostProcessingRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/addAdHocPostProcessingRequestedListener';
import { addStagingListeners } from 'app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener';
import { addAnyEnqueuedListener } from 'app/store/middleware/listenerMiddleware/listeners/anyEnqueued'; import { addAnyEnqueuedListener } from 'app/store/middleware/listenerMiddleware/listeners/anyEnqueued';
import { addAppConfigReceivedListener } from 'app/store/middleware/listenerMiddleware/listeners/appConfigReceived'; import { addAppConfigReceivedListener } from 'app/store/middleware/listenerMiddleware/listeners/appConfigReceived';
import { addAppStartedListener } from 'app/store/middleware/listenerMiddleware/listeners/appStarted'; import { addAppStartedListener } from 'app/store/middleware/listenerMiddleware/listeners/appStarted';
@ -87,6 +88,7 @@ addBatchEnqueuedListener(startAppListening);
// addCanvasMergedListener(startAppListening); // addCanvasMergedListener(startAppListening);
// addStagingAreaImageSavedListener(startAppListening); // addStagingAreaImageSavedListener(startAppListening);
// addCommitStagingAreaImageListener(startAppListening); // addCommitStagingAreaImageListener(startAppListening);
addStagingListeners(startAppListening);
// Socket.IO // Socket.IO
addGeneratorProgressEventListener(startAppListening); addGeneratorProgressEventListener(startAppListening);

View File

@ -1,30 +1,38 @@
import { isAnyOf } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { import {
canvasBatchIdsReset, layerAdded,
commitStagingAreaImage, layerImageAdded,
discardStagedImages, stagingAreaImageAccepted,
resetCanvas, stagingAreaReset,
setInitialCanvasImage, } from 'features/controlLayers/store/canvasV2Slice';
} from 'features/canvas/store/canvasSlice';
import { toast } from 'features/toast/toast'; import { toast } from 'features/toast/toast';
import { t } from 'i18next'; import { t } from 'i18next';
import { queueApi } from 'services/api/endpoints/queue'; import { queueApi } from 'services/api/endpoints/queue';
import { assert } from 'tsafe';
const matcher = isAnyOf(commitStagingAreaImage, discardStagedImages, resetCanvas, setInitialCanvasImage); export const addStagingListeners = (startAppListening: AppStartListening) => {
export const addCommitStagingAreaImageListener = (startAppListening: AppStartListening) => {
startAppListening({ startAppListening({
matcher, actionCreator: stagingAreaReset,
effect: async (_, { dispatch, getState }) => { effect: async (_, { dispatch, getState }) => {
const log = logger('canvas'); const log = logger('canvas');
const state = getState(); const stagingArea = getState().canvasV2.stagingArea;
const { batchIds } = state.canvas;
if (!stagingArea) {
// Should not happen
return;
}
if (stagingArea.batchIds.length === 0) {
return;
}
try { try {
const req = dispatch( const req = dispatch(
queueApi.endpoints.cancelByBatchIds.initiate({ batch_ids: batchIds }, { fixedCacheKey: 'cancelByBatchIds' }) queueApi.endpoints.cancelByBatchIds.initiate(
{ batch_ids: stagingArea.batchIds },
{ fixedCacheKey: 'cancelByBatchIds' }
)
); );
const { canceled } = await req.unwrap(); const { canceled } = await req.unwrap();
req.reset(); req.reset();
@ -36,7 +44,6 @@ export const addCommitStagingAreaImageListener = (startAppListening: AppStartLis
status: 'success', status: 'success',
}); });
} }
dispatch(canvasBatchIdsReset());
} catch { } catch {
log.error('Failed to cancel canvas batches'); log.error('Failed to cancel canvas batches');
toast({ toast({
@ -47,4 +54,32 @@ export const addCommitStagingAreaImageListener = (startAppListening: AppStartLis
} }
}, },
}); });
startAppListening({
actionCreator: stagingAreaImageAccepted,
effect: async (action, api) => {
const { imageDTO } = action.payload;
const { layers, stagingArea, selectedEntityIdentifier } = api.getState().canvasV2;
let layer = layers.entities.find((layer) => layer.id === selectedEntityIdentifier?.id);
if (!layer) {
layer = layers.entities[0];
}
if (!layer) {
// We need to create a new layer to add the accepted image
api.dispatch(layerAdded());
layer = layers.entities[0];
}
assert(layer, 'No layer found to stage image');
assert(stagingArea, 'Staging should be defined');
const { x, y } = stagingArea.bbox;
const { id } = layer;
api.dispatch(layerImageAdded({ id, imageDTO, pos: { x, y } }));
api.dispatch(stagingAreaReset());
},
});
}; };

View File

@ -1,6 +1,7 @@
import { enqueueRequested } from 'app/store/actions'; import { enqueueRequested } from 'app/store/actions';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { getNodeManager } from 'features/controlLayers/konva/nodeManager'; import { getNodeManager } from 'features/controlLayers/konva/nodeManager';
import { stagingAreaInitialized } from 'features/controlLayers/store/canvasV2Slice';
import { isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice'; import { isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig'; import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph'; import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
@ -40,10 +41,19 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
}) })
); );
try { try {
await req.unwrap(); const enqueueResult = await req.unwrap();
if (shouldShowProgressInViewer) { if (shouldShowProgressInViewer) {
dispatch(isImageViewerOpenChanged(true)); dispatch(isImageViewerOpenChanged(true));
} }
// TODO(psyche): update the backend schema, this is always provided
const batchId = enqueueResult.batch.batch_id;
assert(batchId, 'No batch ID found in enqueue result');
dispatch(
stagingAreaInitialized({
batchIds: [batchId],
bbox: getState().canvasV2.bbox,
})
);
} finally { } finally {
req.reset(); req.reset();
} }

View File

@ -2,6 +2,7 @@ import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { deepClone } from 'common/util/deepClone'; import { deepClone } from 'common/util/deepClone';
import { parseify } from 'common/util/serialize'; import { parseify } from 'common/util/serialize';
import { stagingAreaImageAdded } from 'features/controlLayers/store/canvasV2Slice';
import { import {
boardIdSelected, boardIdSelected,
galleryViewChanged, galleryViewChanged,
@ -11,10 +12,12 @@ import {
} from 'features/gallery/store/gallerySlice'; } from 'features/gallery/store/gallerySlice';
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState'; import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
import { zNodeStatus } from 'features/nodes/types/invocation'; import { zNodeStatus } from 'features/nodes/types/invocation';
import { CANVAS_OUTPUT } from 'features/nodes/util/graph/constants';
import { boardsApi } from 'services/api/endpoints/boards'; import { boardsApi } from 'services/api/endpoints/boards';
import { imagesApi } from 'services/api/endpoints/images'; import { imagesApi } from 'services/api/endpoints/images';
import { getCategories, getListImagesUrl } from 'services/api/util'; import { getCategories, getListImagesUrl } from 'services/api/util';
import { socketInvocationComplete } from 'services/events/actions'; import { socketInvocationComplete } from 'services/events/actions';
import { assert } from 'tsafe';
// These nodes output an image, but do not actually *save* an image, so we don't want to handle the gallery logic on them // These nodes output an image, but do not actually *save* an image, so we don't want to handle the gallery logic on them
const nodeTypeDenylist = ['load_image', 'image']; const nodeTypeDenylist = ['load_image', 'image'];
@ -45,10 +48,11 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
imageDTORequest.unsubscribe(); imageDTORequest.unsubscribe();
// Add canvas images to the staging area // Add canvas images to the staging area
// TODO(psyche): canvas batchid processing if (canvasV2.stagingArea?.batchIds.includes(data.batch_id) && data.invocation_source_id === CANVAS_OUTPUT) {
// if (canvas.batchIds.includes(data.batch_id) && data.invocation_source_id === CANVAS_OUTPUT) { const stagingArea = getState().canvasV2.stagingArea;
// dispatch(addImageToStagingArea(imageDTO)); assert(stagingArea, 'Staging should be defined');
// } dispatch(stagingAreaImageAdded({ imageDTO }));
}
if (!imageDTO.is_intermediate) { if (!imageDTO.is_intermediate) {
// update the total images for the board // update the total images for the board

View File

@ -2,6 +2,7 @@
import { Flex } from '@invoke-ai/ui-library'; import { Flex } from '@invoke-ai/ui-library';
import { ControlLayersToolbar } from 'features/controlLayers/components/ControlLayersToolbar'; import { ControlLayersToolbar } from 'features/controlLayers/components/ControlLayersToolbar';
import { StageComponent } from 'features/controlLayers/components/StageComponent'; import { StageComponent } from 'features/controlLayers/components/StageComponent';
import { StagingAreaToolbar } from 'features/controlLayers/components/StagingArea/StagingAreaToolbar';
import { memo } from 'react'; import { memo } from 'react';
export const ControlLayersEditor = memo(() => { export const ControlLayersEditor = memo(() => {
@ -17,6 +18,9 @@ export const ControlLayersEditor = memo(() => {
> >
<ControlLayersToolbar /> <ControlLayersToolbar />
<StageComponent /> <StageComponent />
<Flex position="absolute" bottom={2} gap={2} align="center" justify="center">
<StagingAreaToolbar />
</Flex>
</Flex> </Flex>
); );
}); });

View File

@ -0,0 +1,151 @@
import { Button, ButtonGroup, IconButton } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import {
stagingAreaImageAccepted,
stagingAreaImageDiscarded,
stagingAreaNextImageSelected,
stagingAreaPreviousImageSelected,
stagingAreaReset,
} from 'features/controlLayers/store/canvasV2Slice';
import type { CanvasV2State } from 'features/controlLayers/store/types';
import { memo, useCallback, useMemo } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
import { PiArrowLeftBold, PiArrowRightBold, PiCheckBold, PiTrashSimpleBold, PiXBold } from 'react-icons/pi';
export const StagingAreaToolbar = memo(() => {
const stagingArea = useAppSelector((s) => s.canvasV2.stagingArea);
if (!stagingArea || stagingArea.images.length === 0) {
return null;
}
return <StagingAreaToolbarContent stagingArea={stagingArea} />;
});
StagingAreaToolbar.displayName = 'StagingAreaToolbar';
type Props = {
stagingArea: NonNullable<CanvasV2State['stagingArea']>;
};
export const StagingAreaToolbarContent = memo(({ stagingArea }: Props) => {
const dispatch = useAppDispatch();
const images = useMemo(() => stagingArea.images, [stagingArea]);
const imageDTO = useMemo(() => {
if (stagingArea.selectedImageIndex === null) {
return null;
}
return images[stagingArea.selectedImageIndex] ?? null;
}, [images, stagingArea.selectedImageIndex]);
const { t } = useTranslation();
const onPrev = useCallback(() => {
dispatch(stagingAreaPreviousImageSelected());
}, [dispatch]);
const onNext = useCallback(() => {
dispatch(stagingAreaNextImageSelected());
}, [dispatch]);
const onAccept = useCallback(() => {
if (!imageDTO || !stagingArea) {
return;
}
dispatch(stagingAreaImageAccepted({ imageDTO }));
}, [dispatch, imageDTO, stagingArea]);
const onDiscardOne = useCallback(() => {
if (!imageDTO || !stagingArea) {
return;
}
if (images.length === 1) {
dispatch(stagingAreaReset());
} else {
dispatch(stagingAreaImageDiscarded({ imageDTO }));
}
}, [dispatch, imageDTO, images.length, stagingArea]);
const onDiscardAll = useCallback(() => {
if (!stagingArea) {
return;
}
dispatch(stagingAreaReset());
}, [dispatch, stagingArea]);
useHotkeys(['left'], onPrev, {
preventDefault: true,
});
useHotkeys(['right'], onNext, {
preventDefault: true,
});
useHotkeys(['enter'], onAccept, {
preventDefault: true,
});
return (
<>
<ButtonGroup borderRadius="base" shadow="dark-lg">
<IconButton
tooltip={`${t('unifiedCanvas.previous')} (Left)`}
aria-label={`${t('unifiedCanvas.previous')} (Left)`}
icon={<PiArrowLeftBold />}
onClick={onPrev}
colorScheme="invokeBlue"
/>
<Button
colorScheme="base"
pointerEvents="none"
minW={20}
>{`${(stagingArea.selectedImageIndex ?? 0) + 1}/${images.length}`}</Button>
<IconButton
tooltip={`${t('unifiedCanvas.next')} (Right)`}
aria-label={`${t('unifiedCanvas.next')} (Right)`}
icon={<PiArrowRightBold />}
onClick={onNext}
colorScheme="invokeBlue"
/>
</ButtonGroup>
<ButtonGroup borderRadius="base" shadow="dark-lg">
<IconButton
tooltip={`${t('unifiedCanvas.accept')} (Enter)`}
aria-label={`${t('unifiedCanvas.accept')} (Enter)`}
icon={<PiCheckBold />}
onClick={onAccept}
colorScheme="invokeBlue"
/>
{/* <IconButton
tooltip={`${t('unifiedCanvas.saveToGallery')} (Shift+S)`}
aria-label={t('unifiedCanvas.saveToGallery')}
isDisabled={!imageDTO || !imageDTO.is_intermediate}
icon={<PiFloppyDiskBold />}
onClick={handleSaveToGallery}
colorScheme="invokeBlue"
/> */}
<IconButton
tooltip={`${t('unifiedCanvas.discardCurrent')}`}
aria-label={t('unifiedCanvas.discardCurrent')}
icon={<PiXBold />}
onClick={onDiscardOne}
colorScheme="invokeBlue"
fontSize={16}
isDisabled={images.length <= 1}
/>
<IconButton
tooltip={`${t('unifiedCanvas.discardAll')} (Esc)`}
aria-label={t('unifiedCanvas.discardAll')}
icon={<PiTrashSimpleBold />}
onClick={onDiscardAll}
colorScheme="error"
fontSize={16}
isDisabled={images.length === 0}
/>
</ButtonGroup>
</>
);
});
StagingAreaToolbarContent.displayName = 'StagingAreaToolbarContent';

View File

@ -53,6 +53,7 @@ export type ImageObjectRecord = {
konvaPlaceholderGroup: Konva.Group; konvaPlaceholderGroup: Konva.Group;
konvaPlaceholderRect: Konva.Rect; konvaPlaceholderRect: Konva.Rect;
konvaPlaceholderText: Konva.Text; konvaPlaceholderText: Konva.Text;
imageName: string | null;
konvaImage: Konva.Image | null; // The image is loaded asynchronously, so it may not be available immediately konvaImage: Konva.Image | null; // The image is loaded asynchronously, so it may not be available immediately
isLoading: boolean; isLoading: boolean;
isError: boolean; isError: boolean;
@ -69,6 +70,7 @@ type KonvaApi = {
renderDocumentOverlay: () => void; renderDocumentOverlay: () => void;
renderBackground: () => void; renderBackground: () => void;
renderToolPreview: () => void; renderToolPreview: () => void;
renderStagingArea: () => void;
arrangeEntities: () => void; arrangeEntities: () => void;
fitDocumentToStage: () => void; fitDocumentToStage: () => void;
fitStageToContainer: () => void; fitStageToContainer: () => void;
@ -102,6 +104,10 @@ type PreviewLayer = {
innerRect: Konva.Rect; innerRect: Konva.Rect;
outerRect: Konva.Rect; outerRect: Konva.Rect;
}; };
stagingArea: {
group: Konva.Group;
image: ImageObjectRecord | null;
};
}; };
type StateApi = { type StateApi = {
@ -143,6 +149,7 @@ type StateApi = {
getControlAdaptersState: () => CanvasV2State['controlAdapters']; getControlAdaptersState: () => CanvasV2State['controlAdapters'];
getRegionsState: () => CanvasV2State['regions']; getRegionsState: () => CanvasV2State['regions'];
getInpaintMaskState: () => CanvasV2State['inpaintMask']; getInpaintMaskState: () => CanvasV2State['inpaintMask'];
getStagingAreaState: () => CanvasV2State['stagingArea'];
onInpaintMaskImageCached: (imageDTO: ImageDTO) => void; onInpaintMaskImageCached: (imageDTO: ImageDTO) => void;
onRegionMaskImageCached: (id: string, imageDTO: ImageDTO) => void; onRegionMaskImageCached: (id: string, imageDTO: ImageDTO) => void;
onLayerImageCached: (imageDTO: ImageDTO) => void; onLayerImageCached: (imageDTO: ImageDTO) => void;

View File

@ -153,6 +153,7 @@ export const updateImageSource = async (arg: {
const imageDTO = await getImageDTO(image.name); const imageDTO = await getImageDTO(image.name);
if (!imageDTO) { if (!imageDTO) {
objectRecord.imageName = null;
objectRecord.isLoading = false; objectRecord.isLoading = false;
objectRecord.isError = true; objectRecord.isError = true;
objectRecord.konvaPlaceholderGroup.visible(true); objectRecord.konvaPlaceholderGroup.visible(true);
@ -173,6 +174,7 @@ export const updateImageSource = async (arg: {
image: imageEl, image: imageEl,
}); });
objectRecord.konvaImageGroup.add(objectRecord.konvaImage); objectRecord.konvaImageGroup.add(objectRecord.konvaImage);
objectRecord.imageName = image.name;
} }
objectRecord.isLoading = false; objectRecord.isLoading = false;
objectRecord.isError = false; objectRecord.isError = false;
@ -180,6 +182,7 @@ export const updateImageSource = async (arg: {
onLoad?.(objectRecord.konvaImage); onLoad?.(objectRecord.konvaImage);
}; };
imageEl.onerror = () => { imageEl.onerror = () => {
objectRecord.imageName = null;
objectRecord.isLoading = false; objectRecord.isLoading = false;
objectRecord.isError = true; objectRecord.isError = true;
objectRecord.konvaPlaceholderGroup.visible(true); objectRecord.konvaPlaceholderGroup.visible(true);
@ -189,6 +192,7 @@ export const updateImageSource = async (arg: {
imageEl.id = image.name; imageEl.id = image.name;
imageEl.src = imageDTO.image_url; imageEl.src = imageDTO.image_url;
} catch { } catch {
objectRecord.imageName = null;
objectRecord.isLoading = false; objectRecord.isLoading = false;
objectRecord.isError = true; objectRecord.isError = true;
objectRecord.konvaPlaceholderGroup.visible(true); objectRecord.konvaPlaceholderGroup.visible(true);
@ -218,7 +222,7 @@ export const createImageObjectGroup = (arg: {
} }
const { id, image } = obj; const { id, image } = obj;
const { width, height } = obj; const { width, height } = obj;
const konvaImageGroup = new Konva.Group({ id, name, listening: false }); const konvaImageGroup = new Konva.Group({ id, name, listening: false, x: obj.x, y: obj.y });
const konvaPlaceholderGroup = new Konva.Group({ name: IMAGE_PLACEHOLDER_NAME, listening: false }); const konvaPlaceholderGroup = new Konva.Group({ name: IMAGE_PLACEHOLDER_NAME, listening: false });
const konvaPlaceholderRect = new Konva.Rect({ const konvaPlaceholderRect = new Konva.Rect({
fill: 'hsl(220 12% 45% / 1)', // 'base.500' fill: 'hsl(220 12% 45% / 1)', // 'base.500'
@ -246,6 +250,7 @@ export const createImageObjectGroup = (arg: {
konvaPlaceholderRect, konvaPlaceholderRect,
konvaPlaceholderText, konvaPlaceholderText,
konvaImage: null, konvaImage: null,
imageName: null,
isLoading: false, isLoading: false,
isError: false, isError: false,
}); });

View File

@ -22,6 +22,7 @@ import {
} from 'features/controlLayers/konva/renderers/preview'; } from 'features/controlLayers/konva/renderers/preview';
import { getRenderRegions } from 'features/controlLayers/konva/renderers/regions'; import { getRenderRegions } from 'features/controlLayers/konva/renderers/regions';
import { getFitDocumentToStage, getFitStageToContainer } from 'features/controlLayers/konva/renderers/stage'; import { getFitDocumentToStage, getFitStageToContainer } from 'features/controlLayers/konva/renderers/stage';
import { createStagingArea, getRenderStagingArea } from 'features/controlLayers/konva/renderers/stagingArea';
import { import {
$stageAttrs, $stageAttrs,
bboxChanged, bboxChanged,
@ -259,6 +260,7 @@ export const initializeRenderer = (
const getControlAdaptersState = () => canvasV2.controlAdapters; const getControlAdaptersState = () => canvasV2.controlAdapters;
const getInpaintMaskState = () => canvasV2.inpaintMask; const getInpaintMaskState = () => canvasV2.inpaintMask;
const getMaskOpacity = () => canvasV2.settings.maskOpacity; const getMaskOpacity = () => canvasV2.settings.maskOpacity;
const getStagingAreaState = () => canvasV2.stagingArea;
// Read-write state, ephemeral interaction state // Read-write state, ephemeral interaction state
let isDrawing = false; let isDrawing = false;
@ -307,6 +309,7 @@ export const initializeRenderer = (
bbox: createBboxNodes(stage, getBbox, onBboxTransformed, $shift.get, $ctrl.get, $meta.get, $alt.get), bbox: createBboxNodes(stage, getBbox, onBboxTransformed, $shift.get, $ctrl.get, $meta.get, $alt.get),
tool: createToolPreviewNodes(), tool: createToolPreviewNodes(),
documentOverlay: createDocumentOverlay(), documentOverlay: createDocumentOverlay(),
stagingArea: createStagingArea(),
}; };
manager.preview.layer.add(manager.preview.bbox.group); manager.preview.layer.add(manager.preview.bbox.group);
manager.preview.layer.add(manager.preview.tool.group); manager.preview.layer.add(manager.preview.tool.group);
@ -329,6 +332,7 @@ export const initializeRenderer = (
getRegionsState, getRegionsState,
getMaskOpacity, getMaskOpacity,
getInpaintMaskState, getInpaintMaskState,
getStagingAreaState,
// Read-write state // Read-write state
setTool, setTool,
@ -376,6 +380,7 @@ export const initializeRenderer = (
renderBbox: getRenderBbox(manager), renderBbox: getRenderBbox(manager),
renderToolPreview: getRenderToolPreview(manager), renderToolPreview: getRenderToolPreview(manager),
renderDocumentOverlay: getRenderDocumentOverlay(manager), renderDocumentOverlay: getRenderDocumentOverlay(manager),
renderStagingArea: getRenderStagingArea(manager),
renderBackground: getRenderBackground(manager), renderBackground: getRenderBackground(manager),
arrangeEntities: getArrangeEntities(manager), arrangeEntities: getArrangeEntities(manager),
fitDocumentToStage: getFitDocumentToStage(manager), fitDocumentToStage: getFitDocumentToStage(manager),

View File

@ -0,0 +1,41 @@
import type { KonvaNodeManager } from 'features/controlLayers/konva/nodeManager';
import { createImageObjectGroup, updateImageSource } from 'features/controlLayers/konva/renderers/objects';
import { imageDTOToImageObject, imageDTOToImageWithDims } from 'features/controlLayers/store/types';
import Konva from 'konva';
import { assert } from 'tsafe';
export const createStagingArea = (): KonvaNodeManager['preview']['stagingArea'] => {
const group = new Konva.Group({ id: 'staging_area_group', listening: false });
return { group, image: null };
};
export const getRenderStagingArea = async (manager: KonvaNodeManager) => {
const { getStagingAreaState } = manager.stateApi;
const stagingArea = getStagingAreaState();
if (!stagingArea || stagingArea.selectedImageIndex === null) {
if (manager.preview.stagingArea.image) {
manager.preview.stagingArea.image.konvaImageGroup.visible(false);
manager.preview.stagingArea.image = null;
}
return;
}
if (stagingArea.selectedImageIndex) {
const imageDTO = stagingArea.images[stagingArea.selectedImageIndex];
assert(imageDTO, 'Image must exist');
if (manager.preview.stagingArea.image) {
if (manager.preview.stagingArea.image.imageName !== imageDTO.image_name) {
await updateImageSource({
objectRecord: manager.preview.stagingArea.image,
image: imageDTOToImageWithDims(imageDTO),
});
}
} else {
manager.preview.stagingArea.image = await createImageObjectGroup({
obj: imageDTOToImageObject(imageDTO),
name: imageDTO.image_name,
});
}
}
};

View File

@ -13,6 +13,7 @@ import { lorasReducers } from 'features/controlLayers/store/lorasReducers';
import { paramsReducers } from 'features/controlLayers/store/paramsReducers'; import { paramsReducers } from 'features/controlLayers/store/paramsReducers';
import { regionsReducers } from 'features/controlLayers/store/regionsReducers'; import { regionsReducers } from 'features/controlLayers/store/regionsReducers';
import { settingsReducers } from 'features/controlLayers/store/settingsReducers'; import { settingsReducers } from 'features/controlLayers/store/settingsReducers';
import { stagingAreaReducers } from 'features/controlLayers/store/stagingAreaReducers';
import { toolReducers } from 'features/controlLayers/store/toolReducers'; import { toolReducers } from 'features/controlLayers/store/toolReducers';
import { initialAspectRatioState } from 'features/parameters/components/ImageSize/constants'; import { initialAspectRatioState } from 'features/parameters/components/ImageSize/constants';
import type { AspectRatioState } from 'features/parameters/components/ImageSize/types'; import type { AspectRatioState } from 'features/parameters/components/ImageSize/types';
@ -119,6 +120,7 @@ const initialState: CanvasV2State = {
refinerNegativeAestheticScore: 2.5, refinerNegativeAestheticScore: 2.5,
refinerStart: 0.8, refinerStart: 0.8,
}, },
stagingArea: null,
}; };
export const canvasV2Slice = createSlice({ export const canvasV2Slice = createSlice({
@ -136,6 +138,7 @@ export const canvasV2Slice = createSlice({
...toolReducers, ...toolReducers,
...bboxReducers, ...bboxReducers,
...inpaintMaskReducers, ...inpaintMaskReducers,
...stagingAreaReducers,
widthChanged: (state, action: PayloadAction<{ width: number; updateAspectRatio?: boolean; clamp?: boolean }>) => { widthChanged: (state, action: PayloadAction<{ width: number; updateAspectRatio?: boolean; clamp?: boolean }>) => {
const { width, updateAspectRatio, clamp } = action.payload; const { width, updateAspectRatio, clamp } = action.payload;
state.document.width = clamp ? Math.max(roundDownToMultiple(width, 8), 64) : width; state.document.width = clamp ? Math.max(roundDownToMultiple(width, 8), 64) : width;
@ -327,6 +330,15 @@ export const {
imEraserLineAdded, imEraserLineAdded,
imLinePointAdded, imLinePointAdded,
imRectAdded, imRectAdded,
// Staging
stagingAreaInitialized,
stagingAreaImageAdded,
stagingAreaBatchIdAdded,
stagingAreaImageDiscarded,
stagingAreaImageAccepted,
stagingAreaReset,
stagingAreaNextImageSelected,
stagingAreaPreviousImageSelected,
} = canvasV2Slice.actions; } = canvasV2Slice.actions;
export const selectCanvasV2Slice = (state: RootState) => state.canvasV2; export const selectCanvasV2Slice = (state: RootState) => state.canvasV2;

View File

@ -231,17 +231,27 @@ export const layersReducers = {
prepare: (payload: RectShapeAddedArg) => ({ payload: { ...payload, rectId: uuidv4() } }), prepare: (payload: RectShapeAddedArg) => ({ payload: { ...payload, rectId: uuidv4() } }),
}, },
layerImageAdded: { layerImageAdded: {
reducer: (state, action: PayloadAction<ImageObjectAddedArg & { objectId: string }>) => { reducer: (
const { id, objectId, imageDTO } = action.payload; state,
action: PayloadAction<ImageObjectAddedArg & { objectId: string; pos?: { x: number; y: number } }>
) => {
const { id, objectId, imageDTO, pos } = action.payload;
const layer = selectLayer(state, id); const layer = selectLayer(state, id);
if (!layer) { if (!layer) {
return; return;
} }
layer.objects.push(imageDTOToImageObject(id, objectId, imageDTO)); const imageObject = imageDTOToImageObject(id, objectId, imageDTO);
if (pos) {
imageObject.x = pos.x;
imageObject.y = pos.y;
}
layer.objects.push(imageObject);
layer.bboxNeedsUpdate = true; layer.bboxNeedsUpdate = true;
state.layers.imageCache = null; state.layers.imageCache = null;
}, },
prepare: (payload: ImageObjectAddedArg) => ({ payload: { ...payload, objectId: uuidv4() } }), prepare: (payload: ImageObjectAddedArg & { pos?: { x: number; y: number } }) => ({
payload: { ...payload, objectId: uuidv4() },
}),
}, },
layerImageCacheChanged: (state, action: PayloadAction<{ imageDTO: ImageDTO | null }>) => { layerImageCacheChanged: (state, action: PayloadAction<{ imageDTO: ImageDTO | null }>) => {
const { imageDTO } = action.payload; const { imageDTO } = action.payload;

View File

@ -0,0 +1,73 @@
import type { PayloadAction, SliceCaseReducers } from '@reduxjs/toolkit';
import type { CanvasV2State, Rect } from 'features/controlLayers/store/types';
import type { ImageDTO } from 'services/api/types';
export const stagingAreaReducers = {
stagingAreaInitialized: (state, action: PayloadAction<{ bbox: Rect; batchIds: string[] }>) => {
const { bbox, batchIds } = action.payload;
state.stagingArea = {
bbox,
batchIds,
selectedImageIndex: null,
images: [],
};
},
stagingAreaImageAdded: (state, action: PayloadAction<{ imageDTO: ImageDTO }>) => {
const { imageDTO } = action.payload;
if (!state.stagingArea) {
// Should not happen
return;
}
state.stagingArea.images.push(imageDTO);
if (!state.stagingArea.selectedImageIndex) {
state.stagingArea.selectedImageIndex = state.stagingArea.images.length - 1;
}
},
stagingAreaNextImageSelected: (state) => {
if (!state.stagingArea) {
// Should not happen
return;
}
if (state.stagingArea.selectedImageIndex === null) {
if (state.stagingArea.images.length > 0) {
state.stagingArea.selectedImageIndex = 0;
}
return;
}
state.stagingArea.selectedImageIndex = (state.stagingArea.selectedImageIndex + 1) % state.stagingArea.images.length;
},
stagingAreaPreviousImageSelected: (state) => {
if (!state.stagingArea) {
// Should not happen
return;
}
if (state.stagingArea.selectedImageIndex === null) {
if (state.stagingArea.images.length > 0) {
state.stagingArea.selectedImageIndex = 0;
}
return;
}
state.stagingArea.selectedImageIndex =
(state.stagingArea.selectedImageIndex - 1 + state.stagingArea.images.length) % state.stagingArea.images.length;
},
stagingAreaBatchIdAdded: (state, action: PayloadAction<{ batchId: string }>) => {
const { batchId } = action.payload;
if (!state.stagingArea) {
// Should not happen
return;
}
state.stagingArea.batchIds.push(batchId);
},
stagingAreaImageDiscarded: (state, action: PayloadAction<{ imageDTO: ImageDTO }>) => {
const { imageDTO } = action.payload;
if (!state.stagingArea) {
// Should not happen
return;
}
state.stagingArea.images = state.stagingArea.images.filter((image) => image.image_name !== imageDTO.image_name);
},
stagingAreaImageAccepted: (state, _: PayloadAction<{ imageDTO: ImageDTO }>) => state,
stagingAreaReset: (state) => {
state.stagingArea = null;
},
} satisfies SliceCaseReducers<CanvasV2State>;

View File

@ -785,6 +785,11 @@ export type Size = {
height: number; height: number;
}; };
export type Position = {
x: number;
y: number;
};
export type LoRA = { export type LoRA = {
id: string; id: string;
isEnabled: boolean; isEnabled: boolean;
@ -877,6 +882,12 @@ export type CanvasV2State = {
refinerNegativeAestheticScore: number; refinerNegativeAestheticScore: number;
refinerStart: number; refinerStart: number;
}; };
stagingArea: {
bbox: Rect;
images: ImageDTO[];
selectedImageIndex: number | null;
batchIds: string[];
} | null;
}; };
export type StageAttrs = { x: number; y: number; width: number; height: number; scale: number }; export type StageAttrs = { x: number; y: number; width: number; height: number; scale: number };
@ -891,7 +902,7 @@ export type EraserLineAddedArg = {
export type BrushLineAddedArg = EraserLineAddedArg & { color: RgbaColor }; export type BrushLineAddedArg = EraserLineAddedArg & { color: RgbaColor };
export type PointAddedToLineArg = { id: string; point: [number, number] }; export type PointAddedToLineArg = { id: string; point: [number, number] };
export type RectShapeAddedArg = { id: string; rect: IRect; color: RgbaColor }; export type RectShapeAddedArg = { id: string; rect: IRect; color: RgbaColor };
export type ImageObjectAddedArg = { id: string; imageDTO: ImageDTO }; export type ImageObjectAddedArg = { id: string; imageDTO: ImageDTO; pos?: Position };
//#region Type guards //#region Type guards
export const isLine = (obj: RenderableObject): obj is BrushLine | EraserLine => { export const isLine = (obj: RenderableObject): obj is BrushLine | EraserLine => {

View File

@ -1,6 +1,6 @@
import { deepClone } from 'common/util/deepClone'; import { deepClone } from 'common/util/deepClone';
import type { KonvaNodeManager } from 'features/controlLayers/konva/nodeManager'; import type { KonvaNodeManager } from 'features/controlLayers/konva/nodeManager';
import type { Dimensions, IPAdapterEntity, RegionEntity } from 'features/controlLayers/store/types'; import type { IPAdapterEntity, Rect, RegionEntity } from 'features/controlLayers/store/types';
import { import {
PROMPT_REGION_INVERT_TENSOR_MASK_PREFIX, PROMPT_REGION_INVERT_TENSOR_MASK_PREFIX,
PROMPT_REGION_MASK_TO_TENSOR_PREFIX, PROMPT_REGION_MASK_TO_TENSOR_PREFIX,
@ -10,7 +10,6 @@ import {
} from 'features/nodes/util/graph/constants'; } from 'features/nodes/util/graph/constants';
import { addIPAdapterCollectorSafe, isValidIPAdapter } from 'features/nodes/util/graph/generation/addIPAdapters'; import { addIPAdapterCollectorSafe, isValidIPAdapter } from 'features/nodes/util/graph/generation/addIPAdapters';
import type { Graph } from 'features/nodes/util/graph/generation/Graph'; import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { IRect } from 'konva/lib/types';
import type { BaseModelType, Invocation } from 'services/api/types'; import type { BaseModelType, Invocation } from 'services/api/types';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
@ -31,8 +30,7 @@ export const addRegions = async (
manager: KonvaNodeManager, manager: KonvaNodeManager,
regions: RegionEntity[], regions: RegionEntity[],
g: Graph, g: Graph,
documentSize: Dimensions, bbox: Rect,
bbox: IRect,
base: BaseModelType, base: BaseModelType,
denoise: Invocation<'denoise_latents'>, denoise: Invocation<'denoise_latents'>,
posCond: Invocation<'compel'> | Invocation<'sdxl_compel_prompt'>, posCond: Invocation<'compel'> | Invocation<'sdxl_compel_prompt'>,

View File

@ -2,6 +2,7 @@ import type { RootState } from 'app/store/store';
import type { KonvaNodeManager } from 'features/controlLayers/konva/nodeManager'; import type { KonvaNodeManager } from 'features/controlLayers/konva/nodeManager';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { import {
CANVAS_OUTPUT,
CLIP_SKIP, CLIP_SKIP,
CONTROL_LAYERS_GRAPH, CONTROL_LAYERS_GRAPH,
DENOISE_LATENTS, DENOISE_LATENTS,
@ -119,7 +120,7 @@ export const buildSD1Graph = async (state: RootState, manager: KonvaNodeManager)
}) })
: null; : null;
let imageOutput: Invocation<'l2i' | 'img_nsfw' | 'img_watermark' | 'img_resize' | 'canvas_paste_back'> = l2i; let canvasOutput: Invocation<'l2i' | 'img_nsfw' | 'img_watermark' | 'img_resize' | 'canvas_paste_back'> = l2i;
g.addEdge(modelLoader, 'unet', denoise, 'unet'); g.addEdge(modelLoader, 'unet', denoise, 'unet');
g.addEdge(modelLoader, 'clip', clipSkip, 'clip'); g.addEdge(modelLoader, 'clip', clipSkip, 'clip');
@ -163,9 +164,9 @@ export const buildSD1Graph = async (state: RootState, manager: KonvaNodeManager)
g.addEdge(vaeSource, 'vae', l2i, 'vae'); g.addEdge(vaeSource, 'vae', l2i, 'vae');
if (generationMode === 'txt2img') { if (generationMode === 'txt2img') {
imageOutput = addTextToImage(g, l2i, originalSize, scaledSize); canvasOutput = addTextToImage(g, l2i, originalSize, scaledSize);
} else if (generationMode === 'img2img') { } else if (generationMode === 'img2img') {
imageOutput = await addImageToImage( canvasOutput = await addImageToImage(
g, g,
manager, manager,
l2i, l2i,
@ -178,7 +179,7 @@ export const buildSD1Graph = async (state: RootState, manager: KonvaNodeManager)
); );
} else if (generationMode === 'inpaint') { } else if (generationMode === 'inpaint') {
const { compositing } = state.canvasV2; const { compositing } = state.canvasV2;
imageOutput = await addInpaint( canvasOutput = await addInpaint(
g, g,
manager, manager,
l2i, l2i,
@ -194,7 +195,7 @@ export const buildSD1Graph = async (state: RootState, manager: KonvaNodeManager)
); );
} else if (generationMode === 'outpaint') { } else if (generationMode === 'outpaint') {
const { compositing } = state.canvasV2; const { compositing } = state.canvasV2;
imageOutput = await addOutpaint( canvasOutput = await addOutpaint(
g, g,
manager, manager,
l2i, l2i,
@ -216,7 +217,6 @@ export const buildSD1Graph = async (state: RootState, manager: KonvaNodeManager)
manager, manager,
state.canvasV2.regions.entities, state.canvasV2.regions.entities,
g, g,
state.canvasV2.document,
state.canvasV2.bbox, state.canvasV2.bbox,
modelConfig.base, modelConfig.base,
denoise, denoise,
@ -232,18 +232,21 @@ export const buildSD1Graph = async (state: RootState, manager: KonvaNodeManager)
// } // }
if (state.system.shouldUseNSFWChecker) { if (state.system.shouldUseNSFWChecker) {
imageOutput = addNSFWChecker(g, imageOutput); canvasOutput = addNSFWChecker(g, canvasOutput);
} }
if (state.system.shouldUseWatermarker) { if (state.system.shouldUseWatermarker) {
imageOutput = addWatermarker(g, imageOutput); canvasOutput = addWatermarker(g, canvasOutput);
} }
// This is the terminal node and must always save to gallery. // This is the terminal node and must always save to gallery.
imageOutput.is_intermediate = false; g.updateNode(canvasOutput, {
imageOutput.use_cache = false; id: CANVAS_OUTPUT,
imageOutput.board = getBoardField(state); is_intermediate: false,
use_cache: false,
board: getBoardField(state),
});
g.setMetadataReceivingNode(imageOutput); g.setMetadataReceivingNode(canvasOutput);
return g.getGraph(); return g.getGraph();
}; };

View File

@ -2,6 +2,7 @@ import type { RootState } from 'app/store/store';
import type { KonvaNodeManager } from 'features/controlLayers/konva/nodeManager'; import type { KonvaNodeManager } from 'features/controlLayers/konva/nodeManager';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { import {
CANVAS_OUTPUT,
LATENTS_TO_IMAGE, LATENTS_TO_IMAGE,
NEGATIVE_CONDITIONING, NEGATIVE_CONDITIONING,
NEGATIVE_CONDITIONING_COLLECT, NEGATIVE_CONDITIONING_COLLECT,
@ -117,7 +118,7 @@ export const buildSDXLGraph = async (state: RootState, manager: KonvaNodeManager
}) })
: null; : null;
let imageOutput: Invocation<'l2i' | 'img_nsfw' | 'img_watermark' | 'img_resize' | 'canvas_paste_back'> = l2i; let canvasOutput: Invocation<'l2i' | 'img_nsfw' | 'img_watermark' | 'img_resize' | 'canvas_paste_back'> = l2i;
g.addEdge(modelLoader, 'unet', denoise, 'unet'); g.addEdge(modelLoader, 'unet', denoise, 'unet');
g.addEdge(modelLoader, 'clip', posCond, 'clip'); g.addEdge(modelLoader, 'clip', posCond, 'clip');
@ -166,9 +167,9 @@ export const buildSDXLGraph = async (state: RootState, manager: KonvaNodeManager
} }
if (generationMode === 'txt2img') { if (generationMode === 'txt2img') {
imageOutput = addTextToImage(g, l2i, originalSize, scaledSize); canvasOutput = addTextToImage(g, l2i, originalSize, scaledSize);
} else if (generationMode === 'img2img') { } else if (generationMode === 'img2img') {
imageOutput = await addImageToImage( canvasOutput = await addImageToImage(
g, g,
manager, manager,
l2i, l2i,
@ -181,7 +182,7 @@ export const buildSDXLGraph = async (state: RootState, manager: KonvaNodeManager
); );
} else if (generationMode === 'inpaint') { } else if (generationMode === 'inpaint') {
const { compositing } = state.canvasV2; const { compositing } = state.canvasV2;
imageOutput = await addInpaint( canvasOutput = await addInpaint(
g, g,
manager, manager,
l2i, l2i,
@ -197,7 +198,7 @@ export const buildSDXLGraph = async (state: RootState, manager: KonvaNodeManager
); );
} else if (generationMode === 'outpaint') { } else if (generationMode === 'outpaint') {
const { compositing } = state.canvasV2; const { compositing } = state.canvasV2;
imageOutput = await addOutpaint( canvasOutput = await addOutpaint(
g, g,
manager, manager,
l2i, l2i,
@ -219,7 +220,6 @@ export const buildSDXLGraph = async (state: RootState, manager: KonvaNodeManager
manager, manager,
state.canvasV2.regions.entities, state.canvasV2.regions.entities,
g, g,
state.canvasV2.document,
state.canvasV2.bbox, state.canvasV2.bbox,
modelConfig.base, modelConfig.base,
denoise, denoise,
@ -230,18 +230,21 @@ export const buildSDXLGraph = async (state: RootState, manager: KonvaNodeManager
); );
if (state.system.shouldUseNSFWChecker) { if (state.system.shouldUseNSFWChecker) {
imageOutput = addNSFWChecker(g, imageOutput); canvasOutput = addNSFWChecker(g, canvasOutput);
} }
if (state.system.shouldUseWatermarker) { if (state.system.shouldUseWatermarker) {
imageOutput = addWatermarker(g, imageOutput); canvasOutput = addWatermarker(g, canvasOutput);
} }
// This is the terminal node and must always save to gallery. // This is the terminal node and must always save to gallery.
imageOutput.is_intermediate = false; g.updateNode(canvasOutput, {
imageOutput.use_cache = false; id: CANVAS_OUTPUT,
imageOutput.board = getBoardField(state); is_intermediate: false,
use_cache: false,
board: getBoardField(state),
});
g.setMetadataReceivingNode(imageOutput); g.setMetadataReceivingNode(canvasOutput);
return g.getGraph(); return g.getGraph();
}; };