feat(ui): split out loras state from canvas rendering state

This commit is contained in:
psychedelicious 2024-08-26 22:29:28 +10:00
parent 100832c66d
commit 52202e45de
16 changed files with 101 additions and 85 deletions

View File

@ -1,6 +1,6 @@
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { loraDeleted } from 'features/controlLayers/store/canvasV2Slice'; import { loraDeleted } from 'features/controlLayers/store/lorasSlice';
import { modelChanged, vaeSelected } from 'features/controlLayers/store/paramsSlice'; import { modelChanged, vaeSelected } from 'features/controlLayers/store/paramsSlice';
import { modelSelected } from 'features/parameters/store/actions'; import { modelSelected } from 'features/parameters/store/actions';
import { zParameterModel } from 'features/parameters/types/parameterSchemas'; import { zParameterModel } from 'features/parameters/types/parameterSchemas';
@ -31,7 +31,7 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
let modelsCleared = 0; let modelsCleared = 0;
// handle incompatible loras // handle incompatible loras
state.canvasV2.loras.forEach((lora) => { state.loras.loras.forEach((lora) => {
if (lora.model.base !== newBaseModel) { if (lora.model.base !== newBaseModel) {
dispatch(loraDeleted({ id: lora.id })); dispatch(loraDeleted({ id: lora.id }));
modelsCleared += 1; modelsCleared += 1;

View File

@ -7,9 +7,9 @@ import {
bboxWidthChanged, bboxWidthChanged,
controlLayerModelChanged, controlLayerModelChanged,
ipaModelChanged, ipaModelChanged,
loraDeleted,
rgIPAdapterModelChanged, rgIPAdapterModelChanged,
} from 'features/controlLayers/store/canvasV2Slice'; } from 'features/controlLayers/store/canvasV2Slice';
import { loraDeleted } from 'features/controlLayers/store/lorasSlice';
import { modelChanged, refinerModelChanged, vaeSelected } from 'features/controlLayers/store/paramsSlice'; import { modelChanged, refinerModelChanged, vaeSelected } from 'features/controlLayers/store/paramsSlice';
import { getEntityIdentifier } from 'features/controlLayers/store/types'; import { getEntityIdentifier } from 'features/controlLayers/store/types';
import { calculateNewSize } from 'features/parameters/components/DocumentSize/calculateNewSize'; import { calculateNewSize } from 'features/parameters/components/DocumentSize/calculateNewSize';
@ -161,7 +161,7 @@ const handleVAEModels: ModelHandler = (models, state, dispatch, log) => {
const handleLoRAModels: ModelHandler = (models, state, dispatch, _log) => { const handleLoRAModels: ModelHandler = (models, state, dispatch, _log) => {
const loraModels = models.filter(isLoRAModelConfig); const loraModels = models.filter(isLoRAModelConfig);
state.canvasV2.loras.forEach((lora) => { state.loras.loras.forEach((lora) => {
const isLoRAAvailable = loraModels.some((m) => m.key === lora.model.key); const isLoRAAvailable = loraModels.some((m) => m.key === lora.model.key);
if (isLoRAAvailable) { if (isLoRAAvailable) {
return; return;

View File

@ -9,6 +9,7 @@ import { changeBoardModalSlice } from 'features/changeBoardModal/store/slice';
import { canvasSessionPersistConfig, canvasSessionSlice } from 'features/controlLayers/store/canvasSessionSlice'; import { canvasSessionPersistConfig, canvasSessionSlice } from 'features/controlLayers/store/canvasSessionSlice';
import { canvasSettingsPersistConfig, canvasSettingsSlice } from 'features/controlLayers/store/canvasSettingsSlice'; import { canvasSettingsPersistConfig, canvasSettingsSlice } from 'features/controlLayers/store/canvasSettingsSlice';
import { canvasV2PersistConfig, canvasV2Slice } from 'features/controlLayers/store/canvasV2Slice'; import { canvasV2PersistConfig, canvasV2Slice } from 'features/controlLayers/store/canvasV2Slice';
import { lorasPersistConfig, lorasSlice } from 'features/controlLayers/store/lorasSlice';
import { paramsPersistConfig, paramsSlice } from 'features/controlLayers/store/paramsSlice'; import { paramsPersistConfig, paramsSlice } from 'features/controlLayers/store/paramsSlice';
import { toolPersistConfig, toolSlice } from 'features/controlLayers/store/toolSlice'; import { toolPersistConfig, toolSlice } from 'features/controlLayers/store/toolSlice';
import { deleteImageModalSlice } from 'features/deleteImageModal/store/slice'; import { deleteImageModalSlice } from 'features/deleteImageModal/store/slice';
@ -65,6 +66,7 @@ const allReducers = {
[toolSlice.name]: toolSlice.reducer, [toolSlice.name]: toolSlice.reducer,
[canvasSettingsSlice.name]: canvasSettingsSlice.reducer, [canvasSettingsSlice.name]: canvasSettingsSlice.reducer,
[canvasSessionSlice.name]: canvasSessionSlice.reducer, [canvasSessionSlice.name]: canvasSessionSlice.reducer,
[lorasSlice.name]: lorasSlice.reducer,
}; };
const rootReducer = combineReducers(allReducers); const rootReducer = combineReducers(allReducers);
@ -110,6 +112,7 @@ const persistConfigs: { [key in keyof typeof allReducers]?: PersistConfig } = {
[toolPersistConfig.name]: toolPersistConfig, [toolPersistConfig.name]: toolPersistConfig,
[canvasSettingsPersistConfig.name]: canvasSettingsPersistConfig, [canvasSettingsPersistConfig.name]: canvasSettingsPersistConfig,
[canvasSessionPersistConfig.name]: canvasSessionPersistConfig, [canvasSessionPersistConfig.name]: canvasSessionPersistConfig,
[lorasPersistConfig.name]: lorasPersistConfig,
}; };
const unserialize: UnserializeFunction = (data, key) => { const unserialize: UnserializeFunction = (data, key) => {

View File

@ -8,7 +8,6 @@ import { bboxReducers } from 'features/controlLayers/store/bboxReducers';
import { controlLayersReducers } from 'features/controlLayers/store/controlLayersReducers'; import { controlLayersReducers } from 'features/controlLayers/store/controlLayersReducers';
import { inpaintMaskReducers } from 'features/controlLayers/store/inpaintMaskReducers'; import { inpaintMaskReducers } from 'features/controlLayers/store/inpaintMaskReducers';
import { ipAdaptersReducers } from 'features/controlLayers/store/ipAdaptersReducers'; import { ipAdaptersReducers } from 'features/controlLayers/store/ipAdaptersReducers';
import { lorasReducers } from 'features/controlLayers/store/lorasReducers';
import { modelChanged } from 'features/controlLayers/store/paramsSlice'; import { modelChanged } from 'features/controlLayers/store/paramsSlice';
import { rasterLayersReducers } from 'features/controlLayers/store/rasterLayersReducers'; import { rasterLayersReducers } from 'features/controlLayers/store/rasterLayersReducers';
import { regionsReducers } from 'features/controlLayers/store/regionsReducers'; import { regionsReducers } from 'features/controlLayers/store/regionsReducers';
@ -52,7 +51,6 @@ const initialState: CanvasV2State = {
isHidden: false, isHidden: false,
entities: [], entities: [],
}, },
loras: [],
ipAdapters: { entities: [] }, ipAdapters: { entities: [] },
bbox: { bbox: {
rect: { x: 0, y: 0, width: 512, height: 512 }, rect: { x: 0, y: 0, width: 512, height: 512 },
@ -77,8 +75,6 @@ export const canvasV2Slice = createSlice({
...regionsReducers, ...regionsReducers,
...inpaintMaskReducers, ...inpaintMaskReducers,
...bboxReducers, ...bboxReducers,
// move out
...lorasReducers,
entitySelected: (state, action: PayloadAction<EntityIdentifierPayload>) => { entitySelected: (state, action: PayloadAction<EntityIdentifierPayload>) => {
const { entityIdentifier } = action.payload; const { entityIdentifier } = action.payload;
state.selectedEntityIdentifier = entityIdentifier; state.selectedEntityIdentifier = entityIdentifier;
@ -437,13 +433,6 @@ export const {
rgIPAdapterMethodChanged, rgIPAdapterMethodChanged,
rgIPAdapterModelChanged, rgIPAdapterModelChanged,
rgIPAdapterCLIPVisionModelChanged, rgIPAdapterCLIPVisionModelChanged,
// LoRAs
loraAdded,
loraRecalled,
loraDeleted,
loraWeightChanged,
loraIsEnabledChanged,
loraAllDeleted,
// Inpaint mask // Inpaint mask
inpaintMaskAdded, inpaintMaskAdded,
// inpaintMaskRecalled, // inpaintMaskRecalled,

View File

@ -1,50 +0,0 @@
import type { PayloadAction, SliceCaseReducers } from '@reduxjs/toolkit';
import type { CanvasV2State, LoRA } from 'features/controlLayers/store/types';
import { zModelIdentifierField } from 'features/nodes/types/common';
import type { LoRAModelConfig } from 'services/api/types';
import { v4 as uuidv4 } from 'uuid';
export const defaultLoRAConfig: Pick<LoRA, 'weight' | 'isEnabled'> = {
weight: 0.75,
isEnabled: true,
};
const selectLoRA = (state: CanvasV2State, id: string) => state.loras.find((lora) => lora.id === id);
export const lorasReducers = {
loraAdded: {
reducer: (state, action: PayloadAction<{ model: LoRAModelConfig; id: string }>) => {
const { model, id } = action.payload;
const parsedModel = zModelIdentifierField.parse(model);
state.loras.push({ ...defaultLoRAConfig, model: parsedModel, id });
},
prepare: (payload: { model: LoRAModelConfig }) => ({ payload: { ...payload, id: uuidv4() } }),
},
loraRecalled: (state, action: PayloadAction<{ lora: LoRA }>) => {
const { lora } = action.payload;
state.loras.push(lora);
},
loraDeleted: (state, action: PayloadAction<{ id: string }>) => {
const { id } = action.payload;
state.loras = state.loras.filter((lora) => lora.id !== id);
},
loraWeightChanged: (state, action: PayloadAction<{ id: string; weight: number }>) => {
const { id, weight } = action.payload;
const lora = selectLoRA(state, id);
if (!lora) {
return;
}
lora.weight = weight;
},
loraIsEnabledChanged: (state, action: PayloadAction<{ id: string; isEnabled: boolean }>) => {
const { id, isEnabled } = action.payload;
const lora = selectLoRA(state, id);
if (!lora) {
return;
}
lora.isEnabled = isEnabled;
},
loraAllDeleted: (state) => {
state.loras = [];
},
} satisfies SliceCaseReducers<CanvasV2State>;

View File

@ -0,0 +1,80 @@
import { createSlice, type PayloadAction } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import type { LoRA } from 'features/controlLayers/store/types';
import { zModelIdentifierField } from 'features/nodes/types/common';
import type { LoRAModelConfig } from 'services/api/types';
import { v4 as uuidv4 } from 'uuid';
type LoRAsState = {
loras: LoRA[];
};
export const defaultLoRAConfig: Pick<LoRA, 'weight' | 'isEnabled'> = {
weight: 0.75,
isEnabled: true,
};
const initialState: LoRAsState = {
loras: [],
};
const selectLoRA = (state: LoRAsState, id: string) => state.loras.find((lora) => lora.id === id);
export const lorasSlice = createSlice({
name: 'loras',
initialState,
reducers: {
loraAdded: {
reducer: (state, action: PayloadAction<{ model: LoRAModelConfig; id: string }>) => {
const { model, id } = action.payload;
const parsedModel = zModelIdentifierField.parse(model);
state.loras.push({ ...defaultLoRAConfig, model: parsedModel, id });
},
prepare: (payload: { model: LoRAModelConfig }) => ({ payload: { ...payload, id: uuidv4() } }),
},
loraRecalled: (state, action: PayloadAction<{ lora: LoRA }>) => {
const { lora } = action.payload;
state.loras.push(lora);
},
loraDeleted: (state, action: PayloadAction<{ id: string }>) => {
const { id } = action.payload;
state.loras = state.loras.filter((lora) => lora.id !== id);
},
loraWeightChanged: (state, action: PayloadAction<{ id: string; weight: number }>) => {
const { id, weight } = action.payload;
const lora = selectLoRA(state, id);
if (!lora) {
return;
}
lora.weight = weight;
},
loraIsEnabledChanged: (state, action: PayloadAction<{ id: string; isEnabled: boolean }>) => {
const { id, isEnabled } = action.payload;
const lora = selectLoRA(state, id);
if (!lora) {
return;
}
lora.isEnabled = isEnabled;
},
loraAllDeleted: (state) => {
state.loras = [];
},
},
});
export const { loraAdded, loraRecalled, loraDeleted, loraWeightChanged, loraIsEnabledChanged, loraAllDeleted } =
lorasSlice.actions;
export const selectLoRAsSlice = (state: RootState) => state.loras;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrate = (state: any): any => {
return state;
};
export const lorasPersistConfig: PersistConfig<LoRAsState> = {
name: lorasSlice.name,
initialState,
migrate,
persistDenylist: [],
};

View File

@ -714,7 +714,6 @@ export type CanvasV2State = {
ipAdapters: { ipAdapters: {
entities: CanvasIPAdapterState[]; entities: CanvasIPAdapterState[];
}; };
loras: LoRA[];
bbox: { bbox: {
rect: { rect: {
x: number; x: number;

View File

@ -11,7 +11,7 @@ import {
} from '@invoke-ai/ui-library'; } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover'; import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { loraDeleted, loraIsEnabledChanged, loraWeightChanged } from 'features/controlLayers/store/canvasV2Slice'; import { loraDeleted, loraIsEnabledChanged, loraWeightChanged } from 'features/controlLayers/store/lorasSlice';
import type { LoRA } from 'features/controlLayers/store/types'; import type { LoRA } from 'features/controlLayers/store/types';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { PiTrashSimpleBold } from 'react-icons/pi'; import { PiTrashSimpleBold } from 'react-icons/pi';

View File

@ -1,11 +1,11 @@
import { Flex } from '@invoke-ai/ui-library'; import { Flex } from '@invoke-ai/ui-library';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { selectCanvasV2Slice } from 'features/controlLayers/store/selectors'; import { selectLoRAsSlice } from 'features/controlLayers/store/lorasSlice';
import { LoRACard } from 'features/lora/components/LoRACard'; import { LoRACard } from 'features/lora/components/LoRACard';
import { memo } from 'react'; import { memo } from 'react';
const selectLoRAsArray = createMemoizedSelector(selectCanvasV2Slice, (canvasV2) => canvasV2.loras); const selectLoRAsArray = createMemoizedSelector(selectLoRAsSlice, (loras) => loras.loras);
export const LoRAList = memo(() => { export const LoRAList = memo(() => {
const lorasArray = useAppSelector(selectLoRAsArray); const lorasArray = useAppSelector(selectLoRAsArray);

View File

@ -4,14 +4,13 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover'; import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox'; import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { loraAdded } from 'features/controlLayers/store/canvasV2Slice'; import { loraAdded, selectLoRAsSlice } from 'features/controlLayers/store/lorasSlice';
import { selectCanvasV2Slice } from 'features/controlLayers/store/selectors';
import { memo, useCallback, useMemo } from 'react'; import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useLoRAModels } from 'services/api/hooks/modelsByType'; import { useLoRAModels } from 'services/api/hooks/modelsByType';
import type { LoRAModelConfig } from 'services/api/types'; import type { LoRAModelConfig } from 'services/api/types';
const selectLoRAs = createMemoizedSelector(selectCanvasV2Slice, (canvasV2) => canvasV2.loras); const selectLoRAs = createMemoizedSelector(selectLoRAsSlice, (loras) => loras.loras);
const LoRASelect = () => { const LoRASelect = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();

View File

@ -1,5 +1,5 @@
import { getPrefixedId } from 'features/controlLayers/konva/util'; import { getPrefixedId } from 'features/controlLayers/konva/util';
import { defaultLoRAConfig } from 'features/controlLayers/store/lorasReducers'; import { defaultLoRAConfig } from 'features/controlLayers/store/lorasSlice';
import type { import type {
CanvasControlLayerState, CanvasControlLayerState,
CanvasInpaintMaskState, CanvasInpaintMaskState,

View File

@ -1,10 +1,6 @@
import { getStore } from 'app/store/nanostores/store'; import { getStore } from 'app/store/nanostores/store';
import { import { bboxHeightChanged, bboxWidthChanged } from 'features/controlLayers/store/canvasV2Slice';
bboxHeightChanged, import { loraAllDeleted, loraRecalled } from 'features/controlLayers/store/lorasSlice';
bboxWidthChanged,
loraAllDeleted,
loraRecalled,
} from 'features/controlLayers/store/canvasV2Slice';
import { import {
negativePrompt2Changed, negativePrompt2Changed,
negativePromptChanged, negativePromptChanged,

View File

@ -14,7 +14,7 @@ export const addLoRAs = (
posCond: Invocation<'compel'>, posCond: Invocation<'compel'>,
negCond: Invocation<'compel'> negCond: Invocation<'compel'>
): void => { ): void => {
const enabledLoRAs = state.canvasV2.loras.filter( const enabledLoRAs = state.loras.loras.filter(
(l) => l.isEnabled && (l.model.base === 'sd-1' || l.model.base === 'sd-2') (l) => l.isEnabled && (l.model.base === 'sd-1' || l.model.base === 'sd-2')
); );
const loraCount = enabledLoRAs.length; const loraCount = enabledLoRAs.length;

View File

@ -13,7 +13,7 @@ export const addSDXLLoRAs = (
posCond: Invocation<'sdxl_compel_prompt'>, posCond: Invocation<'sdxl_compel_prompt'>,
negCond: Invocation<'sdxl_compel_prompt'> negCond: Invocation<'sdxl_compel_prompt'>
): void => { ): void => {
const enabledLoRAs = state.canvasV2.loras.filter((l) => l.isEnabled && l.model.base === 'sdxl'); const enabledLoRAs = state.loras.loras.filter((l) => l.isEnabled && l.model.base === 'sdxl');
const loraCount = enabledLoRAs.length; const loraCount = enabledLoRAs.length;
if (loraCount === 0) { if (loraCount === 0) {

View File

@ -18,7 +18,7 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel
const { t } = useTranslation(); const { t } = useTranslation();
const mainModel = useAppSelector((s) => s.params.model); const mainModel = useAppSelector((s) => s.params.model);
const addedLoRAs = useAppSelector((s) => s.canvasV2.loras); const addedLoRAs = useAppSelector((s) => s.loras.loras);
const { data: mainModelConfig, isLoading: isLoadingMainModelConfig } = useGetModelConfigQuery( const { data: mainModelConfig, isLoading: isLoadingMainModelConfig } = useGetModelConfigQuery(
mainModel?.key ?? skipToken mainModel?.key ?? skipToken
); );

View File

@ -3,7 +3,7 @@ import { Box, Expander, Flex, FormControlGroup, StandaloneAccordion } from '@inv
import { EMPTY_ARRAY } from 'app/store/constants'; import { EMPTY_ARRAY } from 'app/store/constants';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { selectCanvasV2Slice } from 'features/controlLayers/store/selectors'; import { selectLoRAsSlice } from 'features/controlLayers/store/lorasSlice';
import { LoRAList } from 'features/lora/components/LoRAList'; import { LoRAList } from 'features/lora/components/LoRAList';
import LoRASelect from 'features/lora/components/LoRASelect'; import LoRASelect from 'features/lora/components/LoRASelect';
import ParamCFGScale from 'features/parameters/components/Core/ParamCFGScale'; import ParamCFGScale from 'features/parameters/components/Core/ParamCFGScale';
@ -29,8 +29,8 @@ export const GenerationSettingsAccordion = memo(() => {
const activeTabName = useAppSelector(selectActiveTab); const activeTabName = useAppSelector(selectActiveTab);
const selectBadges = useMemo( const selectBadges = useMemo(
() => () =>
createMemoizedSelector(selectCanvasV2Slice, (canvasV2) => { createMemoizedSelector(selectLoRAsSlice, (loras) => {
const enabledLoRAsCount = canvasV2.loras.filter((l) => l.isEnabled).length; const enabledLoRAsCount = loras.loras.filter((l) => l.isEnabled).length;
const loraTabBadges = enabledLoRAsCount ? [`${enabledLoRAsCount} ${t('models.concepts')}`] : EMPTY_ARRAY; const loraTabBadges = enabledLoRAsCount ? [`${enabledLoRAsCount} ${t('models.concepts')}`] : EMPTY_ARRAY;
const accordionBadges = modelConfig ? [modelConfig.name, modelConfig.base] : EMPTY_ARRAY; const accordionBadges = modelConfig ? [modelConfig.name, modelConfig.base] : EMPTY_ARRAY;
return { loraTabBadges, accordionBadges }; return { loraTabBadges, accordionBadges };