feat(ui): revised canvas progress & staging image handling

This commit is contained in:
psychedelicious 2024-07-01 19:28:42 +10:00
parent febea88b58
commit 02c4b28de5
18 changed files with 204 additions and 237 deletions

View File

@ -35,7 +35,7 @@ import { addInvocationErrorEventListener } from 'app/store/middleware/listenerMi
import { addInvocationStartedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationStarted'; import { addInvocationStartedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationStarted';
import { addModelInstallEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall'; import { addModelInstallEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall';
import { addModelLoadEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelLoad'; import { addModelLoadEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelLoad';
import { addSocketQueueItemStatusChangedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueItemStatusChanged'; import { addSocketQueueEventsListeners } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueEvents';
import { addUpdateAllNodesRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested'; import { addUpdateAllNodesRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested';
import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested'; import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested';
import type { AppDispatch, RootState } from 'app/store/store'; import type { AppDispatch, RootState } from 'app/store/store';
@ -99,7 +99,7 @@ addSocketConnectedEventListener(startAppListening);
addSocketDisconnectedEventListener(startAppListening); addSocketDisconnectedEventListener(startAppListening);
addModelLoadEventListener(startAppListening); addModelLoadEventListener(startAppListening);
addModelInstallEventListener(startAppListening); addModelInstallEventListener(startAppListening);
addSocketQueueItemStatusChangedEventListener(startAppListening); addSocketQueueEventsListeners(startAppListening);
addBulkDownloadListeners(startAppListening); addBulkDownloadListeners(startAppListening);
// Boards // Boards

View File

@ -1,10 +1,11 @@
import { isAnyOf } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { import {
layerAdded, layerAdded,
layerImageAdded, layerImageAdded,
stagingAreaCanceledStaging,
stagingAreaImageAccepted, stagingAreaImageAccepted,
stagingAreaReset,
} from 'features/controlLayers/store/canvasV2Slice'; } from 'features/controlLayers/store/canvasV2Slice';
import { toast } from 'features/toast/toast'; import { toast } from 'features/toast/toast';
import { t } from 'i18next'; import { t } from 'i18next';
@ -13,25 +14,15 @@ import { assert } from 'tsafe';
export const addStagingListeners = (startAppListening: AppStartListening) => { export const addStagingListeners = (startAppListening: AppStartListening) => {
startAppListening({ startAppListening({
actionCreator: stagingAreaReset, matcher: isAnyOf(stagingAreaCanceledStaging, stagingAreaImageAccepted),
effect: async (_, { dispatch, getState }) => { effect: async (_, { dispatch }) => {
const log = logger('canvas'); const log = logger('canvas');
const stagingArea = getState().canvasV2.stagingArea;
if (!stagingArea) {
// Should not happen
return;
}
if (stagingArea.batchIds.length === 0) {
return;
}
try { try {
const req = dispatch( const req = dispatch(
queueApi.endpoints.cancelByBatchIds.initiate( queueApi.endpoints.cancelByBatchOrigin.initiate(
{ batch_ids: stagingArea.batchIds }, { origin: 'canvas' },
{ fixedCacheKey: 'cancelByBatchIds' } { fixedCacheKey: 'cancelByBatchOrigin' }
) )
); );
const { canceled } = await req.unwrap(); const { canceled } = await req.unwrap();
@ -59,7 +50,7 @@ export const addStagingListeners = (startAppListening: AppStartListening) => {
actionCreator: stagingAreaImageAccepted, actionCreator: stagingAreaImageAccepted,
effect: async (action, api) => { effect: async (action, api) => {
const { imageDTO } = action.payload; const { imageDTO } = action.payload;
const { layers, stagingArea, selectedEntityIdentifier } = api.getState().canvasV2; const { layers, selectedEntityIdentifier, bbox } = api.getState().canvasV2;
let layer = layers.entities.find((layer) => layer.id === selectedEntityIdentifier?.id); let layer = layers.entities.find((layer) => layer.id === selectedEntityIdentifier?.id);
if (!layer) { if (!layer) {
@ -73,13 +64,11 @@ export const addStagingListeners = (startAppListening: AppStartListening) => {
} }
assert(layer, 'No layer found to stage image'); assert(layer, 'No layer found to stage image');
assert(stagingArea, 'Staging should be defined');
const { x, y } = stagingArea.bbox; const { x, y } = bbox;
const { id } = layer; const { id } = layer;
api.dispatch(layerImageAdded({ id, imageDTO, pos: { x, y } })); api.dispatch(layerImageAdded({ id, imageDTO, pos: { x, y } }));
api.dispatch(stagingAreaReset());
}, },
}); });
}; };

View File

@ -1,12 +1,7 @@
import { enqueueRequested } from 'app/store/actions'; import { enqueueRequested } from 'app/store/actions';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { getNodeManager } from 'features/controlLayers/konva/nodeManager'; import { getNodeManager } from 'features/controlLayers/konva/nodeManager';
import { import { stagingAreaCanceledStaging, stagingAreaStartedStaging } from 'features/controlLayers/store/canvasV2Slice';
stagingAreaBatchIdAdded,
stagingAreaInitialized,
stagingAreaReset,
} from 'features/controlLayers/store/canvasV2Slice';
import { isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig'; import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph'; import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph'; import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph';
@ -19,20 +14,13 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
enqueueRequested.match(action) && action.payload.tabName === 'generation', enqueueRequested.match(action) && action.payload.tabName === 'generation',
effect: async (action, { getState, dispatch }) => { effect: async (action, { getState, dispatch }) => {
const state = getState(); const state = getState();
const { shouldShowProgressInViewer } = state.ui;
const model = state.canvasV2.params.model; const model = state.canvasV2.params.model;
const { prepend } = action.payload; const { prepend } = action.payload;
let didInitializeStagingArea = false; let didStartStaging = false;
if (!state.canvasV2.stagingArea.isStaging) {
if (state.canvasV2.stagingArea === null) { dispatch(stagingAreaStartedStaging());
dispatch( didStartStaging = true;
stagingAreaInitialized({
batchIds: [],
bbox: state.canvasV2.bbox,
})
);
didInitializeStagingArea = true;
} }
try { try {
@ -57,23 +45,11 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
fixedCacheKey: 'enqueueBatch', fixedCacheKey: 'enqueueBatch',
}) })
); );
const enqueueResult = await req.unwrap();
req.reset(); req.reset();
await req.unwrap();
if (shouldShowProgressInViewer) {
dispatch(isImageViewerOpenChanged(true));
}
// TODO(psyche): update the backend schema, this is always provided
const batchId = enqueueResult.batch.batch_id;
assert(batchId, 'No batch ID found in enqueue result');
dispatch(stagingAreaBatchIdAdded({ batchId }));
} catch { } catch {
if (didInitializeStagingArea) { if (didStartStaging && getState().canvasV2.stagingArea.isStaging) {
// We initialized the staging area in this listener, and there was a problem at some point. This means dispatch(stagingAreaCanceledStaging());
// there only possible canvas batch id is the one we just added, so we can reset the staging area without
// losing any data.
dispatch(stagingAreaReset());
} }
} }
}, },

View File

@ -30,6 +30,7 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
graph, graph,
workflow: builtWorkflow, workflow: builtWorkflow,
runs: state.canvasV2.params.iterations, runs: state.canvasV2.params.iterations,
origin: 'workflows',
}, },
prepend: action.payload.prepend, prepend: action.payload.prepend,
}; };

View File

@ -12,9 +12,11 @@ const log = logger('socketio');
export const addGeneratorProgressEventListener = (startAppListening: AppStartListening) => { export const addGeneratorProgressEventListener = (startAppListening: AppStartListening) => {
startAppListening({ startAppListening({
actionCreator: socketGeneratorProgress, actionCreator: socketGeneratorProgress,
effect: (action, { getState }) => { effect: (action) => {
log.trace(parseify(action.payload), `Generator progress`); log.trace(parseify(action.payload), `Generator progress`);
const { invocation_source_id, step, total_steps, progress_image, batch_id } = action.payload.data; const { invocation_source_id, step, total_steps, progress_image, origin } = action.payload.data;
if (origin === 'workflows') {
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]); const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
if (nes) { if (nes) {
nes.status = zNodeStatus.enum.IN_PROGRESS; nes.status = zNodeStatus.enum.IN_PROGRESS;
@ -22,9 +24,9 @@ export const addGeneratorProgressEventListener = (startAppListening: AppStartLis
nes.progressImage = progress_image ?? null; nes.progressImage = progress_image ?? null;
upsertExecutionState(nes.nodeId, nes); upsertExecutionState(nes.nodeId, nes);
} }
}
const isCanvasQueueItem = getState().canvasV2.stagingArea?.batchIds.includes(batch_id); if (origin === 'canvas') {
if (isCanvasQueueItem) {
$lastProgressEvent.set(action.payload.data); $lastProgressEvent.set(action.payload.data);
} }
}, },

View File

@ -3,13 +3,7 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'
import { deepClone } from 'common/util/deepClone'; import { deepClone } from 'common/util/deepClone';
import { parseify } from 'common/util/serialize'; import { parseify } from 'common/util/serialize';
import { stagingAreaImageAdded } from 'features/controlLayers/store/canvasV2Slice'; import { stagingAreaImageAdded } from 'features/controlLayers/store/canvasV2Slice';
import { import { boardIdSelected, galleryViewChanged, imageSelected, offsetChanged } from 'features/gallery/store/gallerySlice';
boardIdSelected,
galleryViewChanged,
imageSelected,
isImageViewerOpenChanged,
offsetChanged,
} from 'features/gallery/store/gallerySlice';
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState'; import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
import { zNodeStatus } from 'features/nodes/types/invocation'; import { zNodeStatus } from 'features/nodes/types/invocation';
import { CANVAS_OUTPUT } from 'features/nodes/util/graph/constants'; import { CANVAS_OUTPUT } from 'features/nodes/util/graph/constants';
@ -17,7 +11,6 @@ import { boardsApi } from 'services/api/endpoints/boards';
import { imagesApi } from 'services/api/endpoints/images'; import { imagesApi } from 'services/api/endpoints/images';
import { getCategories, getListImagesUrl } from 'services/api/util'; import { getCategories, getListImagesUrl } from 'services/api/util';
import { socketInvocationComplete } from 'services/events/actions'; import { socketInvocationComplete } from 'services/events/actions';
import { assert } from 'tsafe';
// These nodes output an image, but do not actually *save* an image, so we don't want to handle the gallery logic on them // These nodes output an image, but do not actually *save* an image, so we don't want to handle the gallery logic on them
const nodeTypeDenylist = ['load_image', 'image']; const nodeTypeDenylist = ['load_image', 'image'];
@ -35,7 +28,7 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
// This complete event has an associated image output // This complete event has an associated image output
if (data.result.type === 'image_output' && !nodeTypeDenylist.includes(data.invocation.type)) { if (data.result.type === 'image_output' && !nodeTypeDenylist.includes(data.invocation.type)) {
const { image_name } = data.result.image; const { image_name } = data.result.image;
const { canvasV2, gallery } = getState(); const { gallery, canvasV2 } = getState();
// This populates the `getImageDTO` cache // This populates the `getImageDTO` cache
const imageDTORequest = dispatch( const imageDTORequest = dispatch(
@ -47,12 +40,22 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
const imageDTO = await imageDTORequest.unwrap(); const imageDTO = await imageDTORequest.unwrap();
imageDTORequest.unsubscribe(); imageDTORequest.unsubscribe();
// Add canvas images to the staging area // handle tab-specific logic
if (canvasV2.stagingArea?.batchIds.includes(data.batch_id) && data.invocation_source_id === CANVAS_OUTPUT) { if (data.origin === 'canvas') {
const stagingArea = getState().canvasV2.stagingArea; if (data.invocation_source_id === CANVAS_OUTPUT && canvasV2.stagingArea.isStaging) {
assert(stagingArea, 'Staging should be defined');
dispatch(stagingAreaImageAdded({ imageDTO })); dispatch(stagingAreaImageAdded({ imageDTO }));
} }
} else if (data.origin === 'workflows') {
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
if (nes) {
nes.status = zNodeStatus.enum.COMPLETED;
if (nes.progress !== null) {
nes.progress = 1;
}
nes.outputs.push(result);
upsertExecutionState(nes.nodeId, nes);
}
}
if (!imageDTO.is_intermediate) { if (!imageDTO.is_intermediate) {
// update the total images for the board // update the total images for the board
@ -106,20 +109,9 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
} }
dispatch(imageSelected(imageDTO)); dispatch(imageSelected(imageDTO));
dispatch(isImageViewerOpenChanged(true));
} }
} }
} }
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
if (nes) {
nes.status = zNodeStatus.enum.COMPLETED;
if (nes.progress !== null) {
nes.progress = 1;
}
nes.outputs.push(result);
upsertExecutionState(nes.nodeId, nes);
}
}, },
}); });
}; };

View File

@ -12,7 +12,15 @@ import { socketQueueItemStatusChanged } from 'services/events/actions';
const log = logger('socketio'); const log = logger('socketio');
export const addSocketQueueItemStatusChangedEventListener = (startAppListening: AppStartListening) => { export const addSocketQueueEventsListeners = (startAppListening: AppStartListening) => {
// When the queue is cleared or canvas batch is canceled, we should clear the last canvas progress event
startAppListening({
matcher: queueApi.endpoints.clearQueue.matchFulfilled,
effect: () => {
$lastProgressEvent.set(null);
},
});
startAppListening({ startAppListening({
actionCreator: socketQueueItemStatusChanged, actionCreator: socketQueueItemStatusChanged,
effect: async (action, { dispatch, getState }) => { effect: async (action, { dispatch, getState }) => {
@ -29,13 +37,11 @@ export const addSocketQueueItemStatusChangedEventListener = (startAppListening:
error_type, error_type,
error_message, error_message,
error_traceback, error_traceback,
batch_id, origin,
} = action.payload.data; } = action.payload.data;
log.debug(action.payload, `Queue item ${item_id} status updated: ${status}`); log.debug(action.payload, `Queue item ${item_id} status updated: ${status}`);
const isCanvasQueueItem = getState().canvasV2.stagingArea?.batchIds.includes(batch_id);
// Update this specific queue item in the list of queue items (this is the queue item DTO, without the session) // Update this specific queue item in the list of queue items (this is the queue item DTO, without the session)
dispatch( dispatch(
queueApi.util.updateQueryData('listQueueItems', undefined, (draft) => { queueApi.util.updateQueryData('listQueueItems', undefined, (draft) => {
@ -96,7 +102,7 @@ export const addSocketQueueItemStatusChangedEventListener = (startAppListening:
} else if (status === 'failed' && error_type) { } else if (status === 'failed' && error_type) {
const isLocal = getState().config.isLocal ?? true; const isLocal = getState().config.isLocal ?? true;
const sessionId = session_id; const sessionId = session_id;
if (isCanvasQueueItem) { if (origin === 'canvas') {
$lastProgressEvent.set(null); $lastProgressEvent.set(null);
} }
@ -115,9 +121,7 @@ export const addSocketQueueItemStatusChangedEventListener = (startAppListening:
/> />
), ),
}); });
} else if (status === 'completed' && isCanvasQueueItem) { } else if (status === 'canceled' && origin === 'canvas') {
$lastProgressEvent.set(null);
} else if (status === 'canceled' && isCanvasQueueItem) {
$lastProgressEvent.set(null); $lastProgressEvent.set(null);
} }
}, },

View File

@ -3,13 +3,12 @@ import { useStore } from '@nanostores/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { import {
$shouldShowStagedImage, $shouldShowStagedImage,
stagingAreaCanceledStaging,
stagingAreaImageAccepted, stagingAreaImageAccepted,
stagingAreaImageDiscarded, stagingAreaImageDiscarded,
stagingAreaNextImageSelected, stagingAreaNextImageSelected,
stagingAreaPreviousImageSelected, stagingAreaPreviousImageSelected,
stagingAreaReset,
} from 'features/controlLayers/store/canvasV2Slice'; } from 'features/controlLayers/store/canvasV2Slice';
import type { CanvasV2State } from 'features/controlLayers/store/types';
import { memo, useCallback, useMemo } from 'react'; import { memo, useCallback, useMemo } from 'react';
import { useHotkeys } from 'react-hotkeys-hook'; import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
@ -25,29 +24,23 @@ import {
} from 'react-icons/pi'; } from 'react-icons/pi';
export const StagingAreaToolbar = memo(() => { export const StagingAreaToolbar = memo(() => {
const stagingArea = useAppSelector((s) => s.canvasV2.stagingArea); const isStaging = useAppSelector((s) => s.canvasV2.stagingArea.isStaging);
if (!stagingArea || stagingArea.images.length === 0) { if (!isStaging) {
return null; return null;
} }
return <StagingAreaToolbarContent stagingArea={stagingArea} />; return <StagingAreaToolbarContent />;
}); });
StagingAreaToolbar.displayName = 'StagingAreaToolbar'; StagingAreaToolbar.displayName = 'StagingAreaToolbar';
type Props = { export const StagingAreaToolbarContent = memo(() => {
stagingArea: NonNullable<CanvasV2State['stagingArea']>;
};
export const StagingAreaToolbarContent = memo(({ stagingArea }: Props) => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const stagingArea = useAppSelector((s) => s.canvasV2.stagingArea);
const shouldShowStagedImage = useStore($shouldShowStagedImage); const shouldShowStagedImage = useStore($shouldShowStagedImage);
const images = useMemo(() => stagingArea.images, [stagingArea]); const images = useMemo(() => stagingArea.images, [stagingArea]);
const imageDTO = useMemo(() => { const selectedImageDTO = useMemo(() => {
if (stagingArea.selectedImageIndex === null) {
return null;
}
return images[stagingArea.selectedImageIndex] ?? null; return images[stagingArea.selectedImageIndex] ?? null;
}, [images, stagingArea.selectedImageIndex]); }, [images, stagingArea.selectedImageIndex]);
@ -64,29 +57,26 @@ export const StagingAreaToolbarContent = memo(({ stagingArea }: Props) => {
}, [dispatch]); }, [dispatch]);
const onAccept = useCallback(() => { const onAccept = useCallback(() => {
if (!imageDTO || !stagingArea) { if (!selectedImageDTO) {
return; return;
} }
dispatch(stagingAreaImageAccepted({ imageDTO })); dispatch(stagingAreaImageAccepted({ imageDTO: selectedImageDTO }));
}, [dispatch, imageDTO, stagingArea]); }, [dispatch, selectedImageDTO]);
const onDiscardOne = useCallback(() => { const onDiscardOne = useCallback(() => {
if (!imageDTO || !stagingArea) { if (!selectedImageDTO) {
return; return;
} }
if (images.length === 1) { if (images.length === 1) {
dispatch(stagingAreaReset()); dispatch(stagingAreaCanceledStaging());
} else { } else {
dispatch(stagingAreaImageDiscarded({ imageDTO })); dispatch(stagingAreaImageDiscarded({ imageDTO: selectedImageDTO }));
} }
}, [dispatch, imageDTO, images.length, stagingArea]); }, [dispatch, selectedImageDTO, images.length]);
const onDiscardAll = useCallback(() => { const onDiscardAll = useCallback(() => {
if (!stagingArea) { dispatch(stagingAreaCanceledStaging());
return; }, [dispatch]);
}
dispatch(stagingAreaReset());
}, [dispatch, stagingArea]);
const onToggleShouldShowStagedImage = useCallback(() => { const onToggleShouldShowStagedImage = useCallback(() => {
$shouldShowStagedImage.set(!shouldShowStagedImage); $shouldShowStagedImage.set(!shouldShowStagedImage);
@ -117,6 +107,14 @@ export const StagingAreaToolbarContent = memo(({ stagingArea }: Props) => {
preventDefault: true, preventDefault: true,
}); });
const counterText = useMemo(() => {
if (images.length > 0) {
return `${(stagingArea.selectedImageIndex ?? 0) + 1} of ${images.length}`;
} else {
return `0 of 0`;
}
}, [images.length, stagingArea.selectedImageIndex]);
return ( return (
<> <>
<ButtonGroup borderRadius="base" shadow="dark-lg"> <ButtonGroup borderRadius="base" shadow="dark-lg">
@ -128,11 +126,9 @@ export const StagingAreaToolbarContent = memo(({ stagingArea }: Props) => {
colorScheme="invokeBlue" colorScheme="invokeBlue"
isDisabled={images.length <= 1 || !shouldShowStagedImage} isDisabled={images.length <= 1 || !shouldShowStagedImage}
/> />
<Button <Button colorScheme="base" pointerEvents="none" minW={28}>
colorScheme="base" {counterText}
pointerEvents="none" </Button>
minW={20}
>{`${(stagingArea.selectedImageIndex ?? 0) + 1}/${images.length}`}</Button>
<IconButton <IconButton
tooltip={`${t('unifiedCanvas.next')} (Right)`} tooltip={`${t('unifiedCanvas.next')} (Right)`}
aria-label={`${t('unifiedCanvas.next')} (Right)`} aria-label={`${t('unifiedCanvas.next')} (Right)`}
@ -149,6 +145,7 @@ export const StagingAreaToolbarContent = memo(({ stagingArea }: Props) => {
icon={<PiCheckBold />} icon={<PiCheckBold />}
onClick={onAccept} onClick={onAccept}
colorScheme="invokeBlue" colorScheme="invokeBlue"
isDisabled={!selectedImageDTO}
/> />
<IconButton <IconButton
tooltip={shouldShowStagedImage ? t('unifiedCanvas.showResultsOn') : t('unifiedCanvas.showResultsOff')} tooltip={shouldShowStagedImage ? t('unifiedCanvas.showResultsOn') : t('unifiedCanvas.showResultsOff')}
@ -161,10 +158,10 @@ export const StagingAreaToolbarContent = memo(({ stagingArea }: Props) => {
<IconButton <IconButton
tooltip={`${t('unifiedCanvas.saveToGallery')} (Shift+S)`} tooltip={`${t('unifiedCanvas.saveToGallery')} (Shift+S)`}
aria-label={t('unifiedCanvas.saveToGallery')} aria-label={t('unifiedCanvas.saveToGallery')}
isDisabled={!imageDTO || !imageDTO.is_intermediate}
icon={<PiFloppyDiskBold />} icon={<PiFloppyDiskBold />}
onClick={onSaveStagingImage} onClick={onSaveStagingImage}
colorScheme="invokeBlue" colorScheme="invokeBlue"
isDisabled={!selectedImageDTO || !selectedImageDTO.is_intermediate}
/> />
<IconButton <IconButton
tooltip={`${t('unifiedCanvas.discardCurrent')}`} tooltip={`${t('unifiedCanvas.discardCurrent')}`}
@ -173,7 +170,7 @@ export const StagingAreaToolbarContent = memo(({ stagingArea }: Props) => {
onClick={onDiscardOne} onClick={onDiscardOne}
colorScheme="invokeBlue" colorScheme="invokeBlue"
fontSize={16} fontSize={16}
isDisabled={images.length <= 1} isDisabled={!selectedImageDTO}
/> />
<IconButton <IconButton
tooltip={`${t('unifiedCanvas.discardAll')} (Esc)`} tooltip={`${t('unifiedCanvas.discardAll')} (Esc)`}
@ -182,7 +179,6 @@ export const StagingAreaToolbarContent = memo(({ stagingArea }: Props) => {
onClick={onDiscardAll} onClick={onDiscardAll}
colorScheme="error" colorScheme="error"
fontSize={16} fontSize={16}
isDisabled={images.length === 0}
/> />
</ButtonGroup> </ButtonGroup>
</> </>

View File

@ -43,7 +43,7 @@ export const ToolChooser: React.FC = () => {
const { t } = useTranslation(); const { t } = useTranslation();
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const selectedEntityIdentifier = useAppSelector(selectSelectedEntityIdentifier); const selectedEntityIdentifier = useAppSelector(selectSelectedEntityIdentifier);
const isStaging = useAppSelector((s) => s.canvasV2.stagingArea !== null); const isStaging = useAppSelector((s) => s.canvasV2.stagingArea.isStaging);
const isDrawingToolDisabled = useMemo( const isDrawingToolDisabled = useMemo(
() => !getIsDrawingToolEnabled(selectedEntityIdentifier), () => !getIsDrawingToolEnabled(selectedEntityIdentifier),
[selectedEntityIdentifier] [selectedEntityIdentifier]

View File

@ -76,6 +76,7 @@ export type StateApi = {
getInpaintMaskState: () => CanvasV2State['inpaintMask']; getInpaintMaskState: () => CanvasV2State['inpaintMask'];
getStagingAreaState: () => CanvasV2State['stagingArea']; getStagingAreaState: () => CanvasV2State['stagingArea'];
getLastProgressEvent: () => InvocationDenoiseProgressEvent | null; getLastProgressEvent: () => InvocationDenoiseProgressEvent | null;
resetLastProgressEvent: () => void;
onInpaintMaskImageCached: (imageDTO: ImageDTO) => void; onInpaintMaskImageCached: (imageDTO: ImageDTO) => void;
onRegionMaskImageCached: (id: string, imageDTO: ImageDTO) => void; onRegionMaskImageCached: (id: string, imageDTO: ImageDTO) => void;
onLayerImageCached: (imageDTO: ImageDTO) => void; onLayerImageCached: (imageDTO: ImageDTO) => void;
@ -280,8 +281,10 @@ export class KonvaNodeManager {
renderStagingArea() { renderStagingArea() {
this.preview.stagingArea.render( this.preview.stagingArea.render(
this.stateApi.getStagingAreaState(), this.stateApi.getStagingAreaState(),
this.stateApi.getBbox(),
this.stateApi.getShouldShowStagedImage(), this.stateApi.getShouldShowStagedImage(),
this.stateApi.getLastProgressEvent() this.stateApi.getLastProgressEvent(),
this.stateApi.resetLastProgressEvent
); );
} }

View File

@ -18,18 +18,18 @@ export class CanvasPreview {
documentSizeOverlay: CanvasDocumentSizeOverlay, documentSizeOverlay: CanvasDocumentSizeOverlay,
stagingArea: CanvasStagingArea stagingArea: CanvasStagingArea
) { ) {
this.layer = new Konva.Layer({ listening: true }); this.layer = new Konva.Layer({ listening: true, imageSmoothingEnabled: false });
this.bbox = bbox;
this.layer.add(this.bbox.group);
this.tool = tool;
this.layer.add(this.tool.group);
this.documentSizeOverlay = documentSizeOverlay; this.documentSizeOverlay = documentSizeOverlay;
this.layer.add(this.documentSizeOverlay.group); this.layer.add(this.documentSizeOverlay.group);
this.stagingArea = stagingArea; this.stagingArea = stagingArea;
this.layer.add(this.stagingArea.group); this.layer.add(this.stagingArea.group);
this.bbox = bbox;
this.layer.add(this.bbox.group);
this.tool = tool;
this.layer.add(this.tool.group);
} }
} }

View File

@ -307,6 +307,9 @@ export const initializeRenderer = (
getStagingAreaState, getStagingAreaState,
getShouldShowStagedImage: $shouldShowStagedImage.get, getShouldShowStagedImage: $shouldShowStagedImage.get,
getLastProgressEvent: $lastProgressEvent.get, getLastProgressEvent: $lastProgressEvent.get,
resetLastProgressEvent: () => {
$lastProgressEvent.set(null);
},
// Read-write state // Read-write state
setTool, setTool,

View File

@ -2,7 +2,6 @@ import { KonvaImage, KonvaProgressImage } from 'features/controlLayers/konva/ren
import type { CanvasV2State } from 'features/controlLayers/store/types'; import type { CanvasV2State } from 'features/controlLayers/store/types';
import Konva from 'konva'; import Konva from 'konva';
import type { InvocationDenoiseProgressEvent } from 'services/events/types'; import type { InvocationDenoiseProgressEvent } from 'services/events/types';
import { assert } from 'tsafe';
export class CanvasStagingArea { export class CanvasStagingArea {
group: Konva.Group; group: Konva.Group;
@ -17,13 +16,54 @@ export class CanvasStagingArea {
async render( async render(
stagingArea: CanvasV2State['stagingArea'], stagingArea: CanvasV2State['stagingArea'],
bbox: CanvasV2State['bbox'],
shouldShowStagedImage: boolean, shouldShowStagedImage: boolean,
lastProgressEvent: InvocationDenoiseProgressEvent | null lastProgressEvent: InvocationDenoiseProgressEvent | null,
resetLastProgressEvent: () => void
) { ) {
if (stagingArea && lastProgressEvent) { const imageDTO = stagingArea.images[stagingArea.selectedImageIndex];
if (imageDTO) {
if (this.image) {
if (!this.image.isLoading && !this.image.isError && this.image.imageName !== imageDTO.image_name) {
await this.image.updateImageSource(imageDTO.image_name);
}
this.image.konvaImageGroup.x(bbox.x);
this.image.konvaImageGroup.y(bbox.y);
this.image.konvaImageGroup.visible(shouldShowStagedImage);
this.progressImage?.konvaImageGroup.visible(false);
} else {
const { image_name, width, height } = imageDTO;
this.image = new KonvaImage({
imageObject: {
id: 'staging-area-image',
type: 'image',
x: bbox.x,
y: bbox.y,
width,
height,
filters: [],
image: {
name: image_name,
width,
height,
},
},
onLoad: () => {
resetLastProgressEvent();
},
});
this.group.add(this.image.konvaImageGroup);
await this.image.updateImageSource(imageDTO.image_name);
this.image.konvaImageGroup.visible(shouldShowStagedImage);
this.progressImage?.konvaImageGroup.visible(false);
}
}
if (stagingArea.isStaging && lastProgressEvent) {
const { invocation, step, progress_image } = lastProgressEvent; const { invocation, step, progress_image } = lastProgressEvent;
const { dataURL } = progress_image; const { dataURL } = progress_image;
const { x, y, width, height } = stagingArea.bbox; const { x, y, width, height } = bbox;
const progressImageId = `${invocation.id}_${step}`; const progressImageId = `${invocation.id}_${step}`;
if (this.progressImage) { if (this.progressImage) {
if ( if (
@ -42,47 +82,16 @@ export class CanvasStagingArea {
this.image?.konvaImageGroup.visible(false); this.image?.konvaImageGroup.visible(false);
this.progressImage.konvaImageGroup.visible(true); this.progressImage.konvaImageGroup.visible(true);
} }
} else if (stagingArea && stagingArea.selectedImageIndex !== null) {
const imageDTO = stagingArea.images[stagingArea.selectedImageIndex];
assert(imageDTO, 'Image must exist');
if (this.image) {
if (!this.image.isLoading && !this.image.isError && this.image.imageName !== imageDTO.image_name) {
await this.image.updateImageSource(imageDTO.image_name);
} }
this.image.konvaImageGroup.x(stagingArea.bbox.x);
this.image.konvaImageGroup.y(stagingArea.bbox.y); if (!imageDTO && !lastProgressEvent) {
this.image.konvaImageGroup.visible(shouldShowStagedImage);
this.progressImage?.konvaImageGroup.visible(false);
} else {
const { image_name, width, height } = imageDTO;
this.image = new KonvaImage({
imageObject: {
id: 'staging-area-image',
type: 'image',
x: stagingArea.bbox.x,
y: stagingArea.bbox.y,
width,
height,
filters: [],
image: {
name: image_name,
width,
height,
},
},
});
this.group.add(this.image.konvaImageGroup);
await this.image.updateImageSource(imageDTO.image_name);
this.image.konvaImageGroup.visible(shouldShowStagedImage);
this.progressImage?.konvaImageGroup.visible(false);
}
} else {
if (this.image) { if (this.image) {
this.image.konvaImageGroup.visible(false); this.image.konvaImageGroup.visible(false);
} }
if (this.progressImage) { if (this.progressImage) {
this.progressImage.konvaImageGroup.visible(false); this.progressImage.konvaImageGroup.visible(false);
} }
resetLastProgressEvent();
} }
} }
} }

View File

@ -121,7 +121,11 @@ const initialState: CanvasV2State = {
refinerNegativeAestheticScore: 2.5, refinerNegativeAestheticScore: 2.5,
refinerStart: 0.8, refinerStart: 0.8,
}, },
stagingArea: null, stagingArea: {
isStaging: false,
images: [],
selectedImageIndex: 0,
},
}; };
export const canvasV2Slice = createSlice({ export const canvasV2Slice = createSlice({
@ -332,12 +336,11 @@ export const {
imLinePointAdded, imLinePointAdded,
imRectAdded, imRectAdded,
// Staging // Staging
stagingAreaInitialized, stagingAreaStartedStaging,
stagingAreaImageAdded, stagingAreaImageAdded,
stagingAreaBatchIdAdded,
stagingAreaImageDiscarded, stagingAreaImageDiscarded,
stagingAreaImageAccepted, stagingAreaImageAccepted,
stagingAreaReset, stagingAreaCanceledStaging,
stagingAreaNextImageSelected, stagingAreaNextImageSelected,
stagingAreaPreviousImageSelected, stagingAreaPreviousImageSelected,
} = canvasV2Slice.actions; } = canvasV2Slice.actions;

View File

@ -1,16 +1,11 @@
import type { PayloadAction, SliceCaseReducers } from '@reduxjs/toolkit'; import type { PayloadAction, SliceCaseReducers } from '@reduxjs/toolkit';
import type { CanvasV2State, Rect } from 'features/controlLayers/store/types'; import type { CanvasV2State } from 'features/controlLayers/store/types';
import type { ImageDTO } from 'services/api/types'; import type { ImageDTO } from 'services/api/types';
export const stagingAreaReducers = { export const stagingAreaReducers = {
stagingAreaInitialized: (state, action: PayloadAction<{ bbox: Rect; batchIds: string[] }>) => { stagingAreaStartedStaging: (state) => {
const { bbox, batchIds } = action.payload; state.stagingArea.isStaging = true;
state.stagingArea = { state.stagingArea.selectedImageIndex = 0;
bbox,
batchIds,
selectedImageIndex: null,
images: [],
};
// When we start staging, the user should not be interacting with the stage except to move it around. Set the tool // When we start staging, the user should not be interacting with the stage except to move it around. Set the tool
// to view. // to view.
state.tool.selectedBuffer = state.tool.selected; state.tool.selectedBuffer = state.tool.selected;
@ -18,67 +13,41 @@ export const stagingAreaReducers = {
}, },
stagingAreaImageAdded: (state, action: PayloadAction<{ imageDTO: ImageDTO }>) => { stagingAreaImageAdded: (state, action: PayloadAction<{ imageDTO: ImageDTO }>) => {
const { imageDTO } = action.payload; const { imageDTO } = action.payload;
if (!state.stagingArea) {
// Should not happen
return;
}
state.stagingArea.images.push(imageDTO); state.stagingArea.images.push(imageDTO);
if (!state.stagingArea.selectedImageIndex) {
state.stagingArea.selectedImageIndex = state.stagingArea.images.length - 1; state.stagingArea.selectedImageIndex = state.stagingArea.images.length - 1;
}
}, },
stagingAreaNextImageSelected: (state) => { stagingAreaNextImageSelected: (state) => {
if (!state.stagingArea) {
// Should not happen
return;
}
if (state.stagingArea.selectedImageIndex === null) {
if (state.stagingArea.images.length > 0) {
state.stagingArea.selectedImageIndex = 0;
}
return;
}
state.stagingArea.selectedImageIndex = (state.stagingArea.selectedImageIndex + 1) % state.stagingArea.images.length; state.stagingArea.selectedImageIndex = (state.stagingArea.selectedImageIndex + 1) % state.stagingArea.images.length;
}, },
stagingAreaPreviousImageSelected: (state) => { stagingAreaPreviousImageSelected: (state) => {
if (!state.stagingArea) {
// Should not happen
return;
}
if (state.stagingArea.selectedImageIndex === null) {
if (state.stagingArea.images.length > 0) {
state.stagingArea.selectedImageIndex = 0;
}
return;
}
state.stagingArea.selectedImageIndex = state.stagingArea.selectedImageIndex =
(state.stagingArea.selectedImageIndex - 1 + state.stagingArea.images.length) % state.stagingArea.images.length; (state.stagingArea.selectedImageIndex - 1 + state.stagingArea.images.length) % state.stagingArea.images.length;
}, },
stagingAreaBatchIdAdded: (state, action: PayloadAction<{ batchId: string }>) => {
const { batchId } = action.payload;
if (!state.stagingArea) {
// Should not happen
return;
}
state.stagingArea.batchIds.push(batchId);
},
stagingAreaImageDiscarded: (state, action: PayloadAction<{ imageDTO: ImageDTO }>) => { stagingAreaImageDiscarded: (state, action: PayloadAction<{ imageDTO: ImageDTO }>) => {
const { imageDTO } = action.payload; const { imageDTO } = action.payload;
if (!state.stagingArea) {
// Should not happen
return;
}
state.stagingArea.images = state.stagingArea.images.filter((image) => image.image_name !== imageDTO.image_name); state.stagingArea.images = state.stagingArea.images.filter((image) => image.image_name !== imageDTO.image_name);
state.stagingArea.selectedImageIndex = Math.min(
state.stagingArea.selectedImageIndex,
state.stagingArea.images.length - 1
);
if (state.stagingArea.images.length === 0) {
state.stagingArea.isStaging = false;
}
}, },
stagingAreaImageAccepted: (state, _: PayloadAction<{ imageDTO: ImageDTO }>) => { stagingAreaImageAccepted: (state, _: PayloadAction<{ imageDTO: ImageDTO }>) => {
// When we finish staging, reset the tool back to the previous selection. // When we finish staging, reset the tool back to the previous selection.
state.stagingArea.isStaging = false;
state.stagingArea.images = [];
state.stagingArea.selectedImageIndex = 0;
if (state.tool.selectedBuffer) { if (state.tool.selectedBuffer) {
state.tool.selected = state.tool.selectedBuffer; state.tool.selected = state.tool.selectedBuffer;
state.tool.selectedBuffer = null; state.tool.selectedBuffer = null;
} }
}, },
stagingAreaReset: (state) => { stagingAreaCanceledStaging: (state) => {
state.stagingArea = null; state.stagingArea.isStaging = false;
state.stagingArea.images = [];
state.stagingArea.selectedImageIndex = 0;
// When we finish staging, reset the tool back to the previous selection. // When we finish staging, reset the tool back to the previous selection.
if (state.tool.selectedBuffer) { if (state.tool.selectedBuffer) {
state.tool.selected = state.tool.selectedBuffer; state.tool.selected = state.tool.selectedBuffer;

View File

@ -883,11 +883,10 @@ export type CanvasV2State = {
refinerStart: number; refinerStart: number;
}; };
stagingArea: { stagingArea: {
bbox: Rect; isStaging: boolean;
images: ImageDTO[]; images: ImageDTO[];
selectedImageIndex: number | null; selectedImageIndex: number;
batchIds: string[]; };
} | null;
}; };
export type StageAttrs = { x: number; y: number; width: number; height: number; scale: number }; export type StageAttrs = { x: number; y: number; width: number; height: number; scale: number };

View File

@ -107,6 +107,7 @@ export const prepareLinearUIBatch = (state: RootState, g: Graph, prepend: boolea
graph: g.getGraph(), graph: g.getGraph(),
runs: 1, runs: 1,
data, data,
origin: 'canvas',
}, },
}; };

View File

@ -276,6 +276,26 @@ export const queueApi = api.injectEndpoints({
}, },
invalidatesTags: ['SessionQueueStatus', 'BatchStatus'], invalidatesTags: ['SessionQueueStatus', 'BatchStatus'],
}), }),
cancelByBatchOrigin: build.mutation<
paths['/api/v1/queue/{queue_id}/cancel_by_origin']['put']['responses']['200']['content']['application/json'],
paths['/api/v1/queue/{queue_id}/cancel_by_origin']['put']['parameters']['query']
>({
query: (params) => ({
url: buildQueueUrl('cancel_by_origin'),
method: 'PUT',
params,
}),
onQueryStarted: async (arg, api) => {
const { dispatch, queryFulfilled } = api;
try {
await queryFulfilled;
resetListQueryData(dispatch);
} catch {
// no-op
}
},
invalidatesTags: ['SessionQueueStatus', 'BatchStatus'],
}),
listQueueItems: build.query< listQueueItems: build.query<
EntityState<components['schemas']['SessionQueueItemDTO'], string> & { EntityState<components['schemas']['SessionQueueItemDTO'], string> & {
has_more: boolean; has_more: boolean;