feat(ui): switch to view tool when staging

This commit is contained in:
psychedelicious 2024-06-28 19:39:04 +10:00
parent b55378c63c
commit 9c77023a11
6 changed files with 108 additions and 49 deletions

View File

@ -1,7 +1,11 @@
import { enqueueRequested } from 'app/store/actions'; import { enqueueRequested } from 'app/store/actions';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { getNodeManager } from 'features/controlLayers/konva/nodeManager'; import { getNodeManager } from 'features/controlLayers/konva/nodeManager';
import { stagingAreaBatchIdAdded, stagingAreaInitialized } from 'features/controlLayers/store/canvasV2Slice'; import {
stagingAreaBatchIdAdded,
stagingAreaInitialized,
stagingAreaReset,
} from 'features/controlLayers/store/canvasV2Slice';
import { isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice'; import { isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig'; import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph'; import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
@ -19,47 +23,58 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
const model = state.canvasV2.params.model; const model = state.canvasV2.params.model;
const { prepend } = action.payload; const { prepend } = action.payload;
let g; let didInitializeStagingArea = false;
const manager = getNodeManager(); if (state.canvasV2.stagingArea === null) {
assert(model, 'No model found in state'); dispatch(
const base = model.base; stagingAreaInitialized({
batchIds: [],
if (base === 'sdxl') { bbox: state.canvasV2.bbox,
g = await buildSDXLGraph(state, manager); })
} else if (base === 'sd-1' || base === 'sd-2') { );
g = await buildSD1Graph(state, manager); didInitializeStagingArea = true;
} else {
assert(false, `No graph builders for base ${base}`);
} }
const batchConfig = prepareLinearUIBatch(state, g, prepend);
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
fixedCacheKey: 'enqueueBatch',
})
);
try { try {
let g;
const manager = getNodeManager();
assert(model, 'No model found in state');
const base = model.base;
if (base === 'sdxl') {
g = await buildSDXLGraph(state, manager);
} else if (base === 'sd-1' || base === 'sd-2') {
g = await buildSD1Graph(state, manager);
} else {
assert(false, `No graph builders for base ${base}`);
}
const batchConfig = prepareLinearUIBatch(state, g, prepend);
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
fixedCacheKey: 'enqueueBatch',
})
);
const enqueueResult = await req.unwrap(); const enqueueResult = await req.unwrap();
req.reset();
if (shouldShowProgressInViewer) { if (shouldShowProgressInViewer) {
dispatch(isImageViewerOpenChanged(true)); dispatch(isImageViewerOpenChanged(true));
} }
// TODO(psyche): update the backend schema, this is always provided // TODO(psyche): update the backend schema, this is always provided
const batchId = enqueueResult.batch.batch_id; const batchId = enqueueResult.batch.batch_id;
assert(batchId, 'No batch ID found in enqueue result'); assert(batchId, 'No batch ID found in enqueue result');
if (!state.canvasV2.stagingArea) { dispatch(stagingAreaBatchIdAdded({ batchId }));
dispatch( } catch {
stagingAreaInitialized({ if (didInitializeStagingArea) {
batchIds: [batchId], // We initialized the staging area in this listener, and there was a problem at some point. This means
bbox: state.canvasV2.bbox, // there only possible canvas batch id is the one we just added, so we can reset the staging area without
}) // losing any data.
); dispatch(stagingAreaReset());
} else {
dispatch(stagingAreaBatchIdAdded({ batchId }));
} }
} finally {
req.reset();
} }
}, },
}); });

View File

@ -43,6 +43,7 @@ export const ToolChooser: React.FC = () => {
const { t } = useTranslation(); const { t } = useTranslation();
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const selectedEntityIdentifier = useAppSelector(selectSelectedEntityIdentifier); const selectedEntityIdentifier = useAppSelector(selectSelectedEntityIdentifier);
const isStaging = useAppSelector((s) => s.canvasV2.stagingArea !== null);
const isDrawingToolDisabled = useMemo( const isDrawingToolDisabled = useMemo(
() => !getIsDrawingToolEnabled(selectedEntityIdentifier), () => !getIsDrawingToolEnabled(selectedEntityIdentifier),
[selectedEntityIdentifier] [selectedEntityIdentifier]
@ -53,19 +54,35 @@ export const ToolChooser: React.FC = () => {
const setToolToBrush = useCallback(() => { const setToolToBrush = useCallback(() => {
dispatch(toolChanged('brush')); dispatch(toolChanged('brush'));
}, [dispatch]); }, [dispatch]);
useHotkeys('b', setToolToBrush, { enabled: !isDrawingToolDisabled }, [isDrawingToolDisabled, setToolToBrush]); useHotkeys('b', setToolToBrush, { enabled: !isDrawingToolDisabled && !isStaging }, [
isDrawingToolDisabled,
isStaging,
setToolToBrush,
]);
const setToolToEraser = useCallback(() => { const setToolToEraser = useCallback(() => {
dispatch(toolChanged('eraser')); dispatch(toolChanged('eraser'));
}, [dispatch]); }, [dispatch]);
useHotkeys('e', setToolToEraser, { enabled: !isDrawingToolDisabled }, [isDrawingToolDisabled, setToolToEraser]); useHotkeys('e', setToolToEraser, { enabled: !isDrawingToolDisabled && !isStaging }, [
isDrawingToolDisabled,
isStaging,
setToolToEraser,
]);
const setToolToRect = useCallback(() => { const setToolToRect = useCallback(() => {
dispatch(toolChanged('rect')); dispatch(toolChanged('rect'));
}, [dispatch]); }, [dispatch]);
useHotkeys('u', setToolToRect, { enabled: !isDrawingToolDisabled }, [isDrawingToolDisabled, setToolToRect]); useHotkeys('u', setToolToRect, { enabled: !isDrawingToolDisabled && !isStaging }, [
isDrawingToolDisabled,
isStaging,
setToolToRect,
]);
const setToolToMove = useCallback(() => { const setToolToMove = useCallback(() => {
dispatch(toolChanged('move')); dispatch(toolChanged('move'));
}, [dispatch]); }, [dispatch]);
useHotkeys('v', setToolToMove, { enabled: !isMoveToolDisabled }, [isMoveToolDisabled, setToolToMove]); useHotkeys('v', setToolToMove, { enabled: !isMoveToolDisabled && !isStaging }, [
isMoveToolDisabled,
isStaging,
setToolToMove,
]);
const setToolToView = useCallback(() => { const setToolToView = useCallback(() => {
dispatch(toolChanged('view')); dispatch(toolChanged('view'));
}, [dispatch]); }, [dispatch]);
@ -92,12 +109,16 @@ export const ToolChooser: React.FC = () => {
}, [dispatch, selectedEntityIdentifier]); }, [dispatch, selectedEntityIdentifier]);
const isResetEnabled = useMemo( const isResetEnabled = useMemo(
() => () =>
selectedEntityIdentifier?.type === 'layer' || (!isStaging && selectedEntityIdentifier?.type === 'layer') ||
selectedEntityIdentifier?.type === 'regional_guidance' || selectedEntityIdentifier?.type === 'regional_guidance' ||
selectedEntityIdentifier?.type === 'inpaint_mask', selectedEntityIdentifier?.type === 'inpaint_mask',
[selectedEntityIdentifier] [isStaging, selectedEntityIdentifier?.type]
); );
useHotkeys('shift+c', resetSelectedLayer, { enabled: isResetEnabled }, [isResetEnabled, resetSelectedLayer]); useHotkeys('shift+c', resetSelectedLayer, { enabled: isResetEnabled }, [
isResetEnabled,
isStaging,
resetSelectedLayer,
]);
const deleteSelectedLayer = useCallback(() => { const deleteSelectedLayer = useCallback(() => {
if (selectedEntityIdentifier === null) { if (selectedEntityIdentifier === null) {
@ -117,7 +138,10 @@ export const ToolChooser: React.FC = () => {
dispatch(ipaDeleted({ id })); dispatch(ipaDeleted({ id }));
} }
}, [dispatch, selectedEntityIdentifier]); }, [dispatch, selectedEntityIdentifier]);
const isDeleteEnabled = useMemo(() => selectedEntityIdentifier !== null, [selectedEntityIdentifier]); const isDeleteEnabled = useMemo(
() => selectedEntityIdentifier !== null && !isStaging,
[selectedEntityIdentifier, isStaging]
);
useHotkeys('shift+d', deleteSelectedLayer, { enabled: isDeleteEnabled }, [isDeleteEnabled, deleteSelectedLayer]); useHotkeys('shift+d', deleteSelectedLayer, { enabled: isDeleteEnabled }, [isDeleteEnabled, deleteSelectedLayer]);
return ( return (
@ -128,7 +152,7 @@ export const ToolChooser: React.FC = () => {
icon={<PiPaintBrushBold />} icon={<PiPaintBrushBold />}
variant={tool === 'brush' ? 'solid' : 'outline'} variant={tool === 'brush' ? 'solid' : 'outline'}
onClick={setToolToBrush} onClick={setToolToBrush}
isDisabled={isDrawingToolDisabled} isDisabled={isDrawingToolDisabled || isStaging}
/> />
<IconButton <IconButton
aria-label={`${t('unifiedCanvas.eraser')} (E)`} aria-label={`${t('unifiedCanvas.eraser')} (E)`}
@ -136,7 +160,7 @@ export const ToolChooser: React.FC = () => {
icon={<PiEraserBold />} icon={<PiEraserBold />}
variant={tool === 'eraser' ? 'solid' : 'outline'} variant={tool === 'eraser' ? 'solid' : 'outline'}
onClick={setToolToEraser} onClick={setToolToEraser}
isDisabled={isDrawingToolDisabled} isDisabled={isDrawingToolDisabled || isStaging}
/> />
<IconButton <IconButton
aria-label={`${t('controlLayers.rectangle')} (U)`} aria-label={`${t('controlLayers.rectangle')} (U)`}
@ -144,7 +168,7 @@ export const ToolChooser: React.FC = () => {
icon={<PiRectangleBold />} icon={<PiRectangleBold />}
variant={tool === 'rect' ? 'solid' : 'outline'} variant={tool === 'rect' ? 'solid' : 'outline'}
onClick={setToolToRect} onClick={setToolToRect}
isDisabled={isDrawingToolDisabled} isDisabled={isDrawingToolDisabled || isStaging}
/> />
<IconButton <IconButton
aria-label={`${t('unifiedCanvas.move')} (V)`} aria-label={`${t('unifiedCanvas.move')} (V)`}
@ -152,7 +176,7 @@ export const ToolChooser: React.FC = () => {
icon={<PiArrowsOutCardinalBold />} icon={<PiArrowsOutCardinalBold />}
variant={tool === 'move' ? 'solid' : 'outline'} variant={tool === 'move' ? 'solid' : 'outline'}
onClick={setToolToMove} onClick={setToolToMove}
isDisabled={isMoveToolDisabled} isDisabled={isMoveToolDisabled || isStaging}
/> />
<IconButton <IconButton
aria-label={`${t('unifiedCanvas.view')} (H)`} aria-label={`${t('unifiedCanvas.view')} (H)`}
@ -160,6 +184,7 @@ export const ToolChooser: React.FC = () => {
icon={<PiHandBold />} icon={<PiHandBold />}
variant={tool === 'view' ? 'solid' : 'outline'} variant={tool === 'view' ? 'solid' : 'outline'}
onClick={setToolToView} onClick={setToolToView}
isDisabled={isStaging}
/> />
<IconButton <IconButton
aria-label={`${t('controlLayers.bbox')} (Q)`} aria-label={`${t('controlLayers.bbox')} (Q)`}
@ -167,6 +192,7 @@ export const ToolChooser: React.FC = () => {
icon={<PiBoundingBoxBold />} icon={<PiBoundingBoxBold />}
variant={tool === 'bbox' ? 'solid' : 'outline'} variant={tool === 'bbox' ? 'solid' : 'outline'}
onClick={setToolToBbox} onClick={setToolToBbox}
isDisabled={isStaging}
/> />
</ButtonGroup> </ButtonGroup>
); );

View File

@ -128,8 +128,6 @@ export const setStageEventHandlers = (manager: KonvaNodeManager): (() => void) =
//#region mouseenter //#region mouseenter
stage.on('mouseenter', () => { stage.on('mouseenter', () => {
const tool = getToolState().selected;
stage.findOne<Konva.Layer>(`#${PREVIEW_TOOL_GROUP_ID}`)?.visible(tool === 'brush' || tool === 'eraser');
manager.renderToolPreview(); manager.renderToolPreview();
}); });

View File

@ -446,7 +446,12 @@ export const initializeRenderer = (
const unsubscribeRenderer = subscribe(renderCanvas); const unsubscribeRenderer = subscribe(renderCanvas);
// When we this flag, we need to render the staging area // When we this flag, we need to render the staging area
$shouldShowStagedImage.subscribe(manager.renderStagingArea.bind(manager)); $shouldShowStagedImage.subscribe((shouldShowStagedImage, prevShouldShowStagedImage) => {
logIfDebugging('Rendering staging area');
if (shouldShowStagedImage !== prevShouldShowStagedImage) {
manager.renderStagingArea();
}
});
logIfDebugging('First render of konva stage'); logIfDebugging('First render of konva stage');
// On first render, the document should be fit to the stage. // On first render, the document should be fit to the stage.

View File

@ -2,13 +2,12 @@ import { rgbaColorToString } from 'common/util/colorCodeTransformers';
import { import {
BRUSH_BORDER_INNER_COLOR, BRUSH_BORDER_INNER_COLOR,
BRUSH_BORDER_OUTER_COLOR, BRUSH_BORDER_OUTER_COLOR,
BRUSH_ERASER_BORDER_WIDTH BRUSH_ERASER_BORDER_WIDTH,
} from 'features/controlLayers/konva/constants'; } from 'features/controlLayers/konva/constants';
import { PREVIEW_RECT_ID } from 'features/controlLayers/konva/naming'; import { PREVIEW_RECT_ID } from 'features/controlLayers/konva/naming';
import type { CanvasEntity, CanvasV2State, Position, RgbaColor } from 'features/controlLayers/store/types'; import type { CanvasEntity, CanvasV2State, Position, RgbaColor } from 'features/controlLayers/store/types';
import Konva from 'konva'; import Konva from 'konva';
export class CanvasTool { export class CanvasTool {
group: Konva.Group; group: Konva.Group;
brush: { brush: {
@ -125,7 +124,8 @@ export class CanvasTool {
isMouseDown: boolean isMouseDown: boolean
) { ) {
const tool = toolState.selected; const tool = toolState.selected;
const isDrawableEntity = selectedEntity?.type === 'regional_guidance' || const isDrawableEntity =
selectedEntity?.type === 'regional_guidance' ||
selectedEntity?.type === 'layer' || selectedEntity?.type === 'layer' ||
selectedEntity?.type === 'inpaint_mask'; selectedEntity?.type === 'inpaint_mask';

View File

@ -11,6 +11,10 @@ export const stagingAreaReducers = {
selectedImageIndex: null, selectedImageIndex: null,
images: [], images: [],
}; };
// When we start staging, the user should not be interacting with the stage except to move it around. Set the tool
// to view.
state.tool.selectedBuffer = state.tool.selected;
state.tool.selected = 'view';
}, },
stagingAreaImageAdded: (state, action: PayloadAction<{ imageDTO: ImageDTO }>) => { stagingAreaImageAdded: (state, action: PayloadAction<{ imageDTO: ImageDTO }>) => {
const { imageDTO } = action.payload; const { imageDTO } = action.payload;
@ -66,8 +70,19 @@ export const stagingAreaReducers = {
} }
state.stagingArea.images = state.stagingArea.images.filter((image) => image.image_name !== imageDTO.image_name); state.stagingArea.images = state.stagingArea.images.filter((image) => image.image_name !== imageDTO.image_name);
}, },
stagingAreaImageAccepted: (state, _: PayloadAction<{ imageDTO: ImageDTO }>) => state, stagingAreaImageAccepted: (state, _: PayloadAction<{ imageDTO: ImageDTO }>) => {
// When we finish staging, reset the tool back to the previous selection.
if (state.tool.selectedBuffer) {
state.tool.selected = state.tool.selectedBuffer;
state.tool.selectedBuffer = null;
}
},
stagingAreaReset: (state) => { stagingAreaReset: (state) => {
state.stagingArea = null; state.stagingArea = null;
// When we finish staging, reset the tool back to the previous selection.
if (state.tool.selectedBuffer) {
state.tool.selected = state.tool.selectedBuffer;
state.tool.selectedBuffer = null;
}
}, },
} satisfies SliceCaseReducers<CanvasV2State>; } satisfies SliceCaseReducers<CanvasV2State>;