fix(ui): select next entity in the list when deleting

This commit is contained in:
psychedelicious 2024-08-23 17:30:41 +10:00
parent 3d87adea60
commit 14cc5e2453

View File

@ -3,7 +3,6 @@ import { createAction, createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store'; import type { PersistConfig, RootState } from 'app/store/store';
import { moveOneToEnd, moveOneToStart, moveToEnd, moveToStart } from 'common/util/arrayUtils'; import { moveOneToEnd, moveOneToStart, moveToEnd, moveToStart } from 'common/util/arrayUtils';
import { deepClone } from 'common/util/deepClone'; import { deepClone } from 'common/util/deepClone';
import { exhaustiveCheck } from 'features/controlLayers/konva/util';
import { bboxReducers } from 'features/controlLayers/store/bboxReducers'; import { bboxReducers } from 'features/controlLayers/store/bboxReducers';
import { compositingReducers } from 'features/controlLayers/store/compositingReducers'; import { compositingReducers } from 'features/controlLayers/store/compositingReducers';
import { controlLayersReducers } from 'features/controlLayers/store/controlLayersReducers'; import { controlLayersReducers } from 'features/controlLayers/store/controlLayersReducers';
@ -157,21 +156,31 @@ export function selectEntity(state: CanvasV2State, { id, type }: CanvasEntityIde
} }
function selectAllEntitiesOfType(state: CanvasV2State, type: CanvasEntityState['type']): CanvasEntityState[] { function selectAllEntitiesOfType(state: CanvasV2State, type: CanvasEntityState['type']): CanvasEntityState[] {
if (type === 'raster_layer') { switch (type) {
return state.rasterLayers.entities; case 'raster_layer':
} else if (type === 'control_layer') { return state.rasterLayers.entities;
return state.controlLayers.entities; case 'control_layer':
} else if (type === 'inpaint_mask') { return state.controlLayers.entities;
return state.inpaintMasks.entities; case 'inpaint_mask':
} else if (type === 'regional_guidance') { return state.inpaintMasks.entities;
return state.regions.entities; case 'regional_guidance':
} else if (type === 'ip_adapter') { return state.regions.entities;
return state.ipAdapters.entities; case 'ip_adapter':
} else { return state.ipAdapters.entities;
assert(false, 'Not implemented');
} }
} }
function selectAllEntities(state: CanvasV2State): CanvasEntityState[] {
// These are in the same order as they are displayed in the list!
return [
...state.inpaintMasks.entities.toReversed(),
...state.regions.entities.toReversed(),
...state.ipAdapters.entities.toReversed(),
...state.controlLayers.entities.toReversed(),
...state.rasterLayers.entities.toReversed(),
];
}
export const canvasV2Slice = createSlice({ export const canvasV2Slice = createSlice({
name: 'canvasV2', name: 'canvasV2',
initialState, initialState,
@ -288,49 +297,33 @@ export const canvasV2Slice = createSlice({
entityDeleted: (state, action: PayloadAction<EntityIdentifierPayload>) => { entityDeleted: (state, action: PayloadAction<EntityIdentifierPayload>) => {
const { entityIdentifier } = action.payload; const { entityIdentifier } = action.payload;
const firstInpaintMaskEntity = state.inpaintMasks.entities[0]; let selectedEntityIdentifier: CanvasV2State['selectedEntityIdentifier'] = null;
const allEntities = selectAllEntities(state);
const index = allEntities.findIndex((entity) => entity.id === entityIdentifier.id);
const nextIndex = allEntities.length > 1 ? (index + 1) % allEntities.length : -1;
if (nextIndex !== -1) {
const nextEntity = allEntities[nextIndex];
if (nextEntity) {
selectedEntityIdentifier = getEntityIdentifier(nextEntity);
}
}
let selectedEntityIdentifier: CanvasV2State['selectedEntityIdentifier'] = firstInpaintMaskEntity switch (entityIdentifier.type) {
? getEntityIdentifier(firstInpaintMaskEntity) case 'raster_layer':
: null; state.rasterLayers.entities = state.rasterLayers.entities.filter((layer) => layer.id !== entityIdentifier.id);
break;
if (entityIdentifier.type === 'raster_layer') { case 'control_layer':
const index = state.rasterLayers.entities.findIndex((layer) => layer.id === entityIdentifier.id); state.controlLayers.entities = state.controlLayers.entities.filter((rg) => rg.id !== entityIdentifier.id);
state.rasterLayers.entities = state.rasterLayers.entities.filter((layer) => layer.id !== entityIdentifier.id); break;
const nextRasterLayer = state.rasterLayers.entities[index]; case 'regional_guidance':
if (nextRasterLayer) { state.regions.entities = state.regions.entities.filter((rg) => rg.id !== entityIdentifier.id);
selectedEntityIdentifier = { type: nextRasterLayer.type, id: nextRasterLayer.id }; break;
} case 'ip_adapter':
} else if (entityIdentifier.type === 'control_layer') { state.ipAdapters.entities = state.ipAdapters.entities.filter((rg) => rg.id !== entityIdentifier.id);
const index = state.controlLayers.entities.findIndex((layer) => layer.id === entityIdentifier.id); break;
state.controlLayers.entities = state.controlLayers.entities.filter((rg) => rg.id !== entityIdentifier.id); case 'inpaint_mask':
const nextControlLayer = state.controlLayers.entities[index]; state.inpaintMasks.entities = state.inpaintMasks.entities.filter((rg) => rg.id !== entityIdentifier.id);
if (nextControlLayer) { break;
selectedEntityIdentifier = { type: nextControlLayer.type, id: nextControlLayer.id };
}
} else if (entityIdentifier.type === 'regional_guidance') {
const index = state.regions.entities.findIndex((layer) => layer.id === entityIdentifier.id);
state.regions.entities = state.regions.entities.filter((rg) => rg.id !== entityIdentifier.id);
const region = state.regions.entities[index];
if (region) {
selectedEntityIdentifier = { type: region.type, id: region.id };
}
} else if (entityIdentifier.type === 'ip_adapter') {
const index = state.ipAdapters.entities.findIndex((layer) => layer.id === entityIdentifier.id);
state.ipAdapters.entities = state.ipAdapters.entities.filter((rg) => rg.id !== entityIdentifier.id);
const entity = state.ipAdapters.entities[index];
if (entity) {
selectedEntityIdentifier = { type: entity.type, id: entity.id };
}
} else if (entityIdentifier.type === 'inpaint_mask') {
const index = state.inpaintMasks.entities.findIndex((layer) => layer.id === entityIdentifier.id);
state.inpaintMasks.entities = state.inpaintMasks.entities.filter((rg) => rg.id !== entityIdentifier.id);
const entity = state.inpaintMasks.entities[index];
if (entity) {
selectedEntityIdentifier = { type: entity.type, id: entity.id };
}
} else {
assert(false, 'Not implemented');
} }
state.selectedEntityIdentifier = selectedEntityIdentifier; state.selectedEntityIdentifier = selectedEntityIdentifier;
@ -397,9 +390,6 @@ export const canvasV2Slice = createSlice({
case 'ip_adapter': case 'ip_adapter':
// no-op // no-op
break; break;
default: {
exhaustiveCheck(type);
}
} }
}, },
allEntitiesDeleted: (state) => { allEntitiesDeleted: (state) => {