diff --git a/invokeai/frontend/web/src/features/controlLayers/store/canvasV2Slice.ts b/invokeai/frontend/web/src/features/controlLayers/store/canvasV2Slice.ts index 3b32084905..801ccbe4aa 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/canvasV2Slice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/canvasV2Slice.ts @@ -3,7 +3,6 @@ import { createAction, createSlice } from '@reduxjs/toolkit'; import type { PersistConfig, RootState } from 'app/store/store'; import { moveOneToEnd, moveOneToStart, moveToEnd, moveToStart } from 'common/util/arrayUtils'; import { deepClone } from 'common/util/deepClone'; -import { exhaustiveCheck } from 'features/controlLayers/konva/util'; import { bboxReducers } from 'features/controlLayers/store/bboxReducers'; import { compositingReducers } from 'features/controlLayers/store/compositingReducers'; import { controlLayersReducers } from 'features/controlLayers/store/controlLayersReducers'; @@ -157,21 +156,31 @@ export function selectEntity(state: CanvasV2State, { id, type }: CanvasEntityIde } function selectAllEntitiesOfType(state: CanvasV2State, type: CanvasEntityState['type']): CanvasEntityState[] { - if (type === 'raster_layer') { - return state.rasterLayers.entities; - } else if (type === 'control_layer') { - return state.controlLayers.entities; - } else if (type === 'inpaint_mask') { - return state.inpaintMasks.entities; - } else if (type === 'regional_guidance') { - return state.regions.entities; - } else if (type === 'ip_adapter') { - return state.ipAdapters.entities; - } else { - assert(false, 'Not implemented'); + switch (type) { + case 'raster_layer': + return state.rasterLayers.entities; + case 'control_layer': + return state.controlLayers.entities; + case 'inpaint_mask': + return state.inpaintMasks.entities; + case 'regional_guidance': + return state.regions.entities; + case 'ip_adapter': + return state.ipAdapters.entities; } } +function selectAllEntities(state: CanvasV2State): CanvasEntityState[] { + // These are in the same order as they are displayed in the list! + return [ + ...state.inpaintMasks.entities.toReversed(), + ...state.regions.entities.toReversed(), + ...state.ipAdapters.entities.toReversed(), + ...state.controlLayers.entities.toReversed(), + ...state.rasterLayers.entities.toReversed(), + ]; +} + export const canvasV2Slice = createSlice({ name: 'canvasV2', initialState, @@ -288,49 +297,33 @@ export const canvasV2Slice = createSlice({ entityDeleted: (state, action: PayloadAction) => { const { entityIdentifier } = action.payload; - const firstInpaintMaskEntity = state.inpaintMasks.entities[0]; + let selectedEntityIdentifier: CanvasV2State['selectedEntityIdentifier'] = null; + const allEntities = selectAllEntities(state); + const index = allEntities.findIndex((entity) => entity.id === entityIdentifier.id); + const nextIndex = allEntities.length > 1 ? (index + 1) % allEntities.length : -1; + if (nextIndex !== -1) { + const nextEntity = allEntities[nextIndex]; + if (nextEntity) { + selectedEntityIdentifier = getEntityIdentifier(nextEntity); + } + } - let selectedEntityIdentifier: CanvasV2State['selectedEntityIdentifier'] = firstInpaintMaskEntity - ? getEntityIdentifier(firstInpaintMaskEntity) - : null; - - if (entityIdentifier.type === 'raster_layer') { - const index = state.rasterLayers.entities.findIndex((layer) => layer.id === entityIdentifier.id); - state.rasterLayers.entities = state.rasterLayers.entities.filter((layer) => layer.id !== entityIdentifier.id); - const nextRasterLayer = state.rasterLayers.entities[index]; - if (nextRasterLayer) { - selectedEntityIdentifier = { type: nextRasterLayer.type, id: nextRasterLayer.id }; - } - } else if (entityIdentifier.type === 'control_layer') { - const index = state.controlLayers.entities.findIndex((layer) => layer.id === entityIdentifier.id); - state.controlLayers.entities = state.controlLayers.entities.filter((rg) => rg.id !== entityIdentifier.id); - const nextControlLayer = state.controlLayers.entities[index]; - if (nextControlLayer) { - selectedEntityIdentifier = { type: nextControlLayer.type, id: nextControlLayer.id }; - } - } else if (entityIdentifier.type === 'regional_guidance') { - const index = state.regions.entities.findIndex((layer) => layer.id === entityIdentifier.id); - state.regions.entities = state.regions.entities.filter((rg) => rg.id !== entityIdentifier.id); - const region = state.regions.entities[index]; - if (region) { - selectedEntityIdentifier = { type: region.type, id: region.id }; - } - } else if (entityIdentifier.type === 'ip_adapter') { - const index = state.ipAdapters.entities.findIndex((layer) => layer.id === entityIdentifier.id); - state.ipAdapters.entities = state.ipAdapters.entities.filter((rg) => rg.id !== entityIdentifier.id); - const entity = state.ipAdapters.entities[index]; - if (entity) { - selectedEntityIdentifier = { type: entity.type, id: entity.id }; - } - } else if (entityIdentifier.type === 'inpaint_mask') { - const index = state.inpaintMasks.entities.findIndex((layer) => layer.id === entityIdentifier.id); - state.inpaintMasks.entities = state.inpaintMasks.entities.filter((rg) => rg.id !== entityIdentifier.id); - const entity = state.inpaintMasks.entities[index]; - if (entity) { - selectedEntityIdentifier = { type: entity.type, id: entity.id }; - } - } else { - assert(false, 'Not implemented'); + switch (entityIdentifier.type) { + case 'raster_layer': + state.rasterLayers.entities = state.rasterLayers.entities.filter((layer) => layer.id !== entityIdentifier.id); + break; + case 'control_layer': + state.controlLayers.entities = state.controlLayers.entities.filter((rg) => rg.id !== entityIdentifier.id); + break; + case 'regional_guidance': + state.regions.entities = state.regions.entities.filter((rg) => rg.id !== entityIdentifier.id); + break; + case 'ip_adapter': + state.ipAdapters.entities = state.ipAdapters.entities.filter((rg) => rg.id !== entityIdentifier.id); + break; + case 'inpaint_mask': + state.inpaintMasks.entities = state.inpaintMasks.entities.filter((rg) => rg.id !== entityIdentifier.id); + break; } state.selectedEntityIdentifier = selectedEntityIdentifier; @@ -397,9 +390,6 @@ export const canvasV2Slice = createSlice({ case 'ip_adapter': // no-op break; - default: { - exhaustiveCheck(type); - } } }, allEntitiesDeleted: (state) => {