fix(ui): newly-added entities are selected

This commit is contained in:
psychedelicious 2024-08-24 11:14:58 +10:00
parent 3f597a1c60
commit c3f7554053
6 changed files with 60 additions and 43 deletions

View File

@ -1,6 +1,5 @@
import { IconButton, Menu, MenuButton, MenuDivider, MenuItem, MenuList } from '@invoke-ai/ui-library'; import { IconButton, Menu, MenuButton, MenuDivider, MenuItem, MenuList } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useDefaultControlAdapter, useDefaultIPAdapter } from 'features/controlLayers/hooks/useLayerControlAdapter';
import { import {
allEntitiesDeleted, allEntitiesDeleted,
controlLayerAdded, controlLayerAdded,
@ -21,23 +20,21 @@ export const CanvasEntityListMenu = memo(() => {
const count = selectEntityCount(s); const count = selectEntityCount(s);
return count > 0; return count > 0;
}); });
const defaultControlAdapter = useDefaultControlAdapter();
const defaultIPAdapter = useDefaultIPAdapter();
const addInpaintMask = useCallback(() => { const addInpaintMask = useCallback(() => {
dispatch(inpaintMaskAdded()); dispatch(inpaintMaskAdded({ isSelected: true }));
}, [dispatch]); }, [dispatch]);
const addRegionalGuidance = useCallback(() => { const addRegionalGuidance = useCallback(() => {
dispatch(rgAdded()); dispatch(rgAdded({ isSelected: true }));
}, [dispatch]); }, [dispatch]);
const addRasterLayer = useCallback(() => { const addRasterLayer = useCallback(() => {
dispatch(rasterLayerAdded({ isSelected: true })); dispatch(rasterLayerAdded({ isSelected: true }));
}, [dispatch]); }, [dispatch]);
const addControlLayer = useCallback(() => { const addControlLayer = useCallback(() => {
dispatch(controlLayerAdded({ isSelected: true, overrides: { controlAdapter: defaultControlAdapter } })); dispatch(controlLayerAdded({ isSelected: true }));
}, [defaultControlAdapter, dispatch]); }, [dispatch]);
const addIPAdapter = useCallback(() => { const addIPAdapter = useCallback(() => {
dispatch(ipaAdded({ ipAdapter: defaultIPAdapter })); dispatch(ipaAdded({ isSelected: true }));
}, [defaultIPAdapter, dispatch]); }, [dispatch]);
const deleteAll = useCallback(() => { const deleteAll = useCallback(() => {
dispatch(allEntitiesDeleted()); dispatch(allEntitiesDeleted());
}, [dispatch]); }, [dispatch]);

View File

@ -14,7 +14,7 @@ import type {
ControlNetConfig, ControlNetConfig,
T2IAdapterConfig, T2IAdapterConfig,
} from './types'; } from './types';
import { initialControlNet } from './types'; import { getEntityIdentifier, initialControlNet } from './types';
const selectControlLayerEntity = (state: CanvasV2State, id: string) => const selectControlLayerEntity = (state: CanvasV2State, id: string) =>
state.controlLayers.entities.find((entity) => entity.id === id); state.controlLayers.entities.find((entity) => entity.id === id);
@ -31,7 +31,7 @@ export const controlLayersReducers = {
action: PayloadAction<{ id: string; overrides?: Partial<CanvasControlLayerState>; isSelected?: boolean }> action: PayloadAction<{ id: string; overrides?: Partial<CanvasControlLayerState>; isSelected?: boolean }>
) => { ) => {
const { id, overrides, isSelected } = action.payload; const { id, overrides, isSelected } = action.payload;
const layer: CanvasControlLayerState = { const entity: CanvasControlLayerState = {
id, id,
name: null, name: null,
type: 'control_layer', type: 'control_layer',
@ -42,10 +42,10 @@ export const controlLayersReducers = {
position: { x: 0, y: 0 }, position: { x: 0, y: 0 },
controlAdapter: deepClone(initialControlNet), controlAdapter: deepClone(initialControlNet),
}; };
merge(layer, overrides); merge(entity, overrides);
state.controlLayers.entities.push(layer); state.controlLayers.entities.push(entity);
if (isSelected) { if (isSelected) {
state.selectedEntityIdentifier = { type: 'control_layer', id }; state.selectedEntityIdentifier = getEntityIdentifier(entity);
} }
}, },
prepare: (payload: { overrides?: Partial<CanvasControlLayerState>; isSelected?: boolean }) => ({ prepare: (payload: { overrides?: Partial<CanvasControlLayerState>; isSelected?: boolean }) => ({

View File

@ -1,11 +1,12 @@
import type { PayloadAction, SliceCaseReducers } from '@reduxjs/toolkit'; import type { PayloadAction, SliceCaseReducers } from '@reduxjs/toolkit';
import { getPrefixedId } from 'features/controlLayers/konva/util'; import { getPrefixedId } from 'features/controlLayers/konva/util';
import type { import {
CanvasInpaintMaskState, type CanvasInpaintMaskState,
CanvasV2State, type CanvasV2State,
EntityIdentifierPayload, type EntityIdentifierPayload,
FillStyle, type FillStyle,
RgbColor, getEntityIdentifier,
type RgbColor,
} from 'features/controlLayers/store/types'; } from 'features/controlLayers/store/types';
import { merge } from 'lodash-es'; import { merge } from 'lodash-es';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
@ -41,7 +42,7 @@ export const inpaintMaskReducers = {
merge(entity, overrides); merge(entity, overrides);
state.inpaintMasks.entities.push(entity); state.inpaintMasks.entities.push(entity);
if (isSelected) { if (isSelected) {
state.selectedEntityIdentifier = { type: 'inpaint_mask', id }; state.selectedEntityIdentifier = getEntityIdentifier(entity);
} }
}, },
prepare: (payload?: { overrides?: Partial<CanvasInpaintMaskState>; isSelected?: boolean }) => ({ prepare: (payload?: { overrides?: Partial<CanvasInpaintMaskState>; isSelected?: boolean }) => ({

View File

@ -1,11 +1,14 @@
import type { PayloadAction, SliceCaseReducers } from '@reduxjs/toolkit'; import type { PayloadAction, SliceCaseReducers } from '@reduxjs/toolkit';
import { deepClone } from 'common/util/deepClone';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { zModelIdentifierField } from 'features/nodes/types/common'; import { zModelIdentifierField } from 'features/nodes/types/common';
import { merge } from 'lodash-es';
import type { ImageDTO, IPAdapterModelConfig } from 'services/api/types'; import type { ImageDTO, IPAdapterModelConfig } from 'services/api/types';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
import { v4 as uuidv4 } from 'uuid'; import { v4 as uuidv4 } from 'uuid';
import type { CanvasIPAdapterState, CanvasV2State, CLIPVisionModelV2, IPAdapterConfig, IPMethodV2 } from './types'; import type { CanvasIPAdapterState, CanvasV2State, CLIPVisionModelV2, IPMethodV2 } from './types';
import { imageDTOToImageWithDims } from './types'; import { getEntityIdentifier, imageDTOToImageWithDims, initialIPAdapter } from './types';
const selectIPAdapterEntity = (state: CanvasV2State, id: string) => const selectIPAdapterEntity = (state: CanvasV2State, id: string) =>
state.ipAdapters.entities.find((ipa) => ipa.id === id); state.ipAdapters.entities.find((ipa) => ipa.id === id);
@ -17,19 +20,27 @@ export const selectIPAdapterEntityOrThrow = (state: CanvasV2State, id: string) =
export const ipAdaptersReducers = { export const ipAdaptersReducers = {
ipaAdded: { ipaAdded: {
reducer: (state, action: PayloadAction<{ id: string; ipAdapter: IPAdapterConfig }>) => { reducer: (
const { id, ipAdapter } = action.payload; state,
const layer: CanvasIPAdapterState = { action: PayloadAction<{ id: string; overrides?: Partial<CanvasIPAdapterState>; isSelected?: boolean }>
) => {
const { id, overrides, isSelected } = action.payload;
const entity: CanvasIPAdapterState = {
id, id,
type: 'ip_adapter', type: 'ip_adapter',
name: null, name: null,
isEnabled: true, isEnabled: true,
ipAdapter, ipAdapter: deepClone(initialIPAdapter),
}; };
state.ipAdapters.entities.push(layer); merge(entity, overrides);
state.selectedEntityIdentifier = { type: 'ip_adapter', id }; state.ipAdapters.entities.push(entity);
if (isSelected) {
state.selectedEntityIdentifier = getEntityIdentifier(entity);
}
}, },
prepare: (payload: { ipAdapter: IPAdapterConfig }) => ({ payload: { id: uuidv4(), ...payload } }), prepare: (payload?: { overrides?: Partial<CanvasIPAdapterState>; isSelected?: boolean }) => ({
payload: { ...payload, id: getPrefixedId('ip_adapter') },
}),
}, },
ipaRecalled: (state, action: PayloadAction<{ data: CanvasIPAdapterState }>) => { ipaRecalled: (state, action: PayloadAction<{ data: CanvasIPAdapterState }>) => {
const { data } = action.payload; const { data } = action.payload;

View File

@ -4,7 +4,7 @@ import { getPrefixedId } from 'features/controlLayers/konva/util';
import { merge } from 'lodash-es'; import { merge } from 'lodash-es';
import type { CanvasControlLayerState, CanvasRasterLayerState, CanvasV2State } from './types'; import type { CanvasControlLayerState, CanvasRasterLayerState, CanvasV2State } from './types';
import { initialControlNet } from './types'; import { getEntityIdentifier, initialControlNet } from './types';
const selectRasterLayerEntity = (state: CanvasV2State, id: string) => const selectRasterLayerEntity = (state: CanvasV2State, id: string) =>
state.rasterLayers.entities.find((layer) => layer.id === id); state.rasterLayers.entities.find((layer) => layer.id === id);
@ -16,7 +16,7 @@ export const rasterLayersReducers = {
action: PayloadAction<{ id: string; overrides?: Partial<CanvasRasterLayerState>; isSelected?: boolean }> action: PayloadAction<{ id: string; overrides?: Partial<CanvasRasterLayerState>; isSelected?: boolean }>
) => { ) => {
const { id, overrides, isSelected } = action.payload; const { id, overrides, isSelected } = action.payload;
const layer: CanvasRasterLayerState = { const entity: CanvasRasterLayerState = {
id, id,
name: null, name: null,
type: 'raster_layer', type: 'raster_layer',
@ -25,10 +25,10 @@ export const rasterLayersReducers = {
opacity: 1, opacity: 1,
position: { x: 0, y: 0 }, position: { x: 0, y: 0 },
}; };
merge(layer, overrides); merge(entity, overrides);
state.rasterLayers.entities.push(layer); state.rasterLayers.entities.push(entity);
if (isSelected) { if (isSelected) {
state.selectedEntityIdentifier = { type: 'raster_layer', id }; state.selectedEntityIdentifier = getEntityIdentifier(entity);
} }
}, },
prepare: (payload: { overrides?: Partial<CanvasRasterLayerState>; isSelected?: boolean }) => ({ prepare: (payload: { overrides?: Partial<CanvasRasterLayerState>; isSelected?: boolean }) => ({

View File

@ -8,9 +8,9 @@ import type {
RegionalGuidanceIPAdapterConfig, RegionalGuidanceIPAdapterConfig,
RgbColor, RgbColor,
} from 'features/controlLayers/store/types'; } from 'features/controlLayers/store/types';
import { imageDTOToImageWithDims } from 'features/controlLayers/store/types'; import { getEntityIdentifier, imageDTOToImageWithDims } from 'features/controlLayers/store/types';
import { zModelIdentifierField } from 'features/nodes/types/common'; import { zModelIdentifierField } from 'features/nodes/types/common';
import { isEqual } from 'lodash-es'; import { isEqual, merge } from 'lodash-es';
import type { ImageDTO, IPAdapterModelConfig } from 'services/api/types'; import type { ImageDTO, IPAdapterModelConfig } from 'services/api/types';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
@ -56,9 +56,12 @@ const getRGMaskFill = (state: CanvasV2State): RgbColor => {
export const regionsReducers = { export const regionsReducers = {
rgAdded: { rgAdded: {
reducer: (state, action: PayloadAction<{ id: string }>) => { reducer: (
const { id } = action.payload; state,
const rg: CanvasRegionalGuidanceState = { action: PayloadAction<{ id: string; overrides?: Partial<CanvasRegionalGuidanceState>; isSelected?: boolean }>
) => {
const { id, overrides, isSelected } = action.payload;
const entity: CanvasRegionalGuidanceState = {
id, id,
name: null, name: null,
type: 'regional_guidance', type: 'regional_guidance',
@ -75,10 +78,15 @@ export const regionsReducers = {
negativePrompt: null, negativePrompt: null,
ipAdapters: [], ipAdapters: [],
}; };
state.regions.entities.push(rg); merge(entity, overrides);
state.selectedEntityIdentifier = { type: 'regional_guidance', id }; state.regions.entities.push(entity);
if (isSelected) {
state.selectedEntityIdentifier = getEntityIdentifier(entity);
}
}, },
prepare: () => ({ payload: { id: getPrefixedId('regional_guidance') } }), prepare: (payload?: { overrides?: Partial<CanvasRegionalGuidanceState>; isSelected?: boolean }) => ({
payload: { ...payload, id: getPrefixedId('regional_guidance') },
}),
}, },
rgRecalled: (state, action: PayloadAction<{ data: CanvasRegionalGuidanceState }>) => { rgRecalled: (state, action: PayloadAction<{ data: CanvasRegionalGuidanceState }>) => {
const { data } = action.payload; const { data } = action.payload;