fix(ui): newly-added entities are selected

This commit is contained in:
psychedelicious 2024-08-24 11:14:58 +10:00
parent 3f597a1c60
commit c3f7554053
6 changed files with 60 additions and 43 deletions

View File

@ -1,6 +1,5 @@
import { IconButton, Menu, MenuButton, MenuDivider, MenuItem, MenuList } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useDefaultControlAdapter, useDefaultIPAdapter } from 'features/controlLayers/hooks/useLayerControlAdapter';
import {
allEntitiesDeleted,
controlLayerAdded,
@ -21,23 +20,21 @@ export const CanvasEntityListMenu = memo(() => {
const count = selectEntityCount(s);
return count > 0;
});
const defaultControlAdapter = useDefaultControlAdapter();
const defaultIPAdapter = useDefaultIPAdapter();
const addInpaintMask = useCallback(() => {
dispatch(inpaintMaskAdded());
dispatch(inpaintMaskAdded({ isSelected: true }));
}, [dispatch]);
const addRegionalGuidance = useCallback(() => {
dispatch(rgAdded());
dispatch(rgAdded({ isSelected: true }));
}, [dispatch]);
const addRasterLayer = useCallback(() => {
dispatch(rasterLayerAdded({ isSelected: true }));
}, [dispatch]);
const addControlLayer = useCallback(() => {
dispatch(controlLayerAdded({ isSelected: true, overrides: { controlAdapter: defaultControlAdapter } }));
}, [defaultControlAdapter, dispatch]);
dispatch(controlLayerAdded({ isSelected: true }));
}, [dispatch]);
const addIPAdapter = useCallback(() => {
dispatch(ipaAdded({ ipAdapter: defaultIPAdapter }));
}, [defaultIPAdapter, dispatch]);
dispatch(ipaAdded({ isSelected: true }));
}, [dispatch]);
const deleteAll = useCallback(() => {
dispatch(allEntitiesDeleted());
}, [dispatch]);

View File

@ -14,7 +14,7 @@ import type {
ControlNetConfig,
T2IAdapterConfig,
} from './types';
import { initialControlNet } from './types';
import { getEntityIdentifier, initialControlNet } from './types';
const selectControlLayerEntity = (state: CanvasV2State, id: string) =>
state.controlLayers.entities.find((entity) => entity.id === id);
@ -31,7 +31,7 @@ export const controlLayersReducers = {
action: PayloadAction<{ id: string; overrides?: Partial<CanvasControlLayerState>; isSelected?: boolean }>
) => {
const { id, overrides, isSelected } = action.payload;
const layer: CanvasControlLayerState = {
const entity: CanvasControlLayerState = {
id,
name: null,
type: 'control_layer',
@ -42,10 +42,10 @@ export const controlLayersReducers = {
position: { x: 0, y: 0 },
controlAdapter: deepClone(initialControlNet),
};
merge(layer, overrides);
state.controlLayers.entities.push(layer);
merge(entity, overrides);
state.controlLayers.entities.push(entity);
if (isSelected) {
state.selectedEntityIdentifier = { type: 'control_layer', id };
state.selectedEntityIdentifier = getEntityIdentifier(entity);
}
},
prepare: (payload: { overrides?: Partial<CanvasControlLayerState>; isSelected?: boolean }) => ({

View File

@ -1,11 +1,12 @@
import type { PayloadAction, SliceCaseReducers } from '@reduxjs/toolkit';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import type {
CanvasInpaintMaskState,
CanvasV2State,
EntityIdentifierPayload,
FillStyle,
RgbColor,
import {
type CanvasInpaintMaskState,
type CanvasV2State,
type EntityIdentifierPayload,
type FillStyle,
getEntityIdentifier,
type RgbColor,
} from 'features/controlLayers/store/types';
import { merge } from 'lodash-es';
import { assert } from 'tsafe';
@ -41,7 +42,7 @@ export const inpaintMaskReducers = {
merge(entity, overrides);
state.inpaintMasks.entities.push(entity);
if (isSelected) {
state.selectedEntityIdentifier = { type: 'inpaint_mask', id };
state.selectedEntityIdentifier = getEntityIdentifier(entity);
}
},
prepare: (payload?: { overrides?: Partial<CanvasInpaintMaskState>; isSelected?: boolean }) => ({

View File

@ -1,11 +1,14 @@
import type { PayloadAction, SliceCaseReducers } from '@reduxjs/toolkit';
import { deepClone } from 'common/util/deepClone';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { merge } from 'lodash-es';
import type { ImageDTO, IPAdapterModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
import { v4 as uuidv4 } from 'uuid';
import type { CanvasIPAdapterState, CanvasV2State, CLIPVisionModelV2, IPAdapterConfig, IPMethodV2 } from './types';
import { imageDTOToImageWithDims } from './types';
import type { CanvasIPAdapterState, CanvasV2State, CLIPVisionModelV2, IPMethodV2 } from './types';
import { getEntityIdentifier, imageDTOToImageWithDims, initialIPAdapter } from './types';
const selectIPAdapterEntity = (state: CanvasV2State, id: string) =>
state.ipAdapters.entities.find((ipa) => ipa.id === id);
@ -17,19 +20,27 @@ export const selectIPAdapterEntityOrThrow = (state: CanvasV2State, id: string) =
export const ipAdaptersReducers = {
ipaAdded: {
reducer: (state, action: PayloadAction<{ id: string; ipAdapter: IPAdapterConfig }>) => {
const { id, ipAdapter } = action.payload;
const layer: CanvasIPAdapterState = {
reducer: (
state,
action: PayloadAction<{ id: string; overrides?: Partial<CanvasIPAdapterState>; isSelected?: boolean }>
) => {
const { id, overrides, isSelected } = action.payload;
const entity: CanvasIPAdapterState = {
id,
type: 'ip_adapter',
name: null,
isEnabled: true,
ipAdapter,
ipAdapter: deepClone(initialIPAdapter),
};
state.ipAdapters.entities.push(layer);
state.selectedEntityIdentifier = { type: 'ip_adapter', id };
merge(entity, overrides);
state.ipAdapters.entities.push(entity);
if (isSelected) {
state.selectedEntityIdentifier = getEntityIdentifier(entity);
}
},
prepare: (payload: { ipAdapter: IPAdapterConfig }) => ({ payload: { id: uuidv4(), ...payload } }),
prepare: (payload?: { overrides?: Partial<CanvasIPAdapterState>; isSelected?: boolean }) => ({
payload: { ...payload, id: getPrefixedId('ip_adapter') },
}),
},
ipaRecalled: (state, action: PayloadAction<{ data: CanvasIPAdapterState }>) => {
const { data } = action.payload;

View File

@ -4,7 +4,7 @@ import { getPrefixedId } from 'features/controlLayers/konva/util';
import { merge } from 'lodash-es';
import type { CanvasControlLayerState, CanvasRasterLayerState, CanvasV2State } from './types';
import { initialControlNet } from './types';
import { getEntityIdentifier, initialControlNet } from './types';
const selectRasterLayerEntity = (state: CanvasV2State, id: string) =>
state.rasterLayers.entities.find((layer) => layer.id === id);
@ -16,7 +16,7 @@ export const rasterLayersReducers = {
action: PayloadAction<{ id: string; overrides?: Partial<CanvasRasterLayerState>; isSelected?: boolean }>
) => {
const { id, overrides, isSelected } = action.payload;
const layer: CanvasRasterLayerState = {
const entity: CanvasRasterLayerState = {
id,
name: null,
type: 'raster_layer',
@ -25,10 +25,10 @@ export const rasterLayersReducers = {
opacity: 1,
position: { x: 0, y: 0 },
};
merge(layer, overrides);
state.rasterLayers.entities.push(layer);
merge(entity, overrides);
state.rasterLayers.entities.push(entity);
if (isSelected) {
state.selectedEntityIdentifier = { type: 'raster_layer', id };
state.selectedEntityIdentifier = getEntityIdentifier(entity);
}
},
prepare: (payload: { overrides?: Partial<CanvasRasterLayerState>; isSelected?: boolean }) => ({

View File

@ -8,9 +8,9 @@ import type {
RegionalGuidanceIPAdapterConfig,
RgbColor,
} from 'features/controlLayers/store/types';
import { imageDTOToImageWithDims } from 'features/controlLayers/store/types';
import { getEntityIdentifier, imageDTOToImageWithDims } from 'features/controlLayers/store/types';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { isEqual } from 'lodash-es';
import { isEqual, merge } from 'lodash-es';
import type { ImageDTO, IPAdapterModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
@ -56,9 +56,12 @@ const getRGMaskFill = (state: CanvasV2State): RgbColor => {
export const regionsReducers = {
rgAdded: {
reducer: (state, action: PayloadAction<{ id: string }>) => {
const { id } = action.payload;
const rg: CanvasRegionalGuidanceState = {
reducer: (
state,
action: PayloadAction<{ id: string; overrides?: Partial<CanvasRegionalGuidanceState>; isSelected?: boolean }>
) => {
const { id, overrides, isSelected } = action.payload;
const entity: CanvasRegionalGuidanceState = {
id,
name: null,
type: 'regional_guidance',
@ -75,10 +78,15 @@ export const regionsReducers = {
negativePrompt: null,
ipAdapters: [],
};
state.regions.entities.push(rg);
state.selectedEntityIdentifier = { type: 'regional_guidance', id };
merge(entity, overrides);
state.regions.entities.push(entity);
if (isSelected) {
state.selectedEntityIdentifier = getEntityIdentifier(entity);
}
},
prepare: () => ({ payload: { id: getPrefixedId('regional_guidance') } }),
prepare: (payload?: { overrides?: Partial<CanvasRegionalGuidanceState>; isSelected?: boolean }) => ({
payload: { ...payload, id: getPrefixedId('regional_guidance') },
}),
},
rgRecalled: (state, action: PayloadAction<{ data: CanvasRegionalGuidanceState }>) => {
const { data } = action.payload;