From 52202e45def16fea601d1cd7a3df4ceb706d7e4f Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 26 Aug 2024 22:29:28 +1000 Subject: [PATCH] feat(ui): split out loras state from canvas rendering state --- .../listeners/modelSelected.ts | 4 +- .../listeners/modelsLoaded.ts | 4 +- invokeai/frontend/web/src/app/store/store.ts | 3 + .../controlLayers/store/canvasV2Slice.ts | 11 --- .../controlLayers/store/lorasReducers.ts | 50 ------------ .../controlLayers/store/lorasSlice.ts | 80 +++++++++++++++++++ .../src/features/controlLayers/store/types.ts | 1 - .../src/features/lora/components/LoRACard.tsx | 2 +- .../src/features/lora/components/LoRAList.tsx | 4 +- .../features/lora/components/LoRASelect.tsx | 5 +- .../web/src/features/metadata/util/parsers.ts | 2 +- .../src/features/metadata/util/recallers.ts | 8 +- .../nodes/util/graph/generation/addLoRAs.ts | 2 +- .../util/graph/generation/addSDXLLoRAs.ts | 2 +- .../features/prompt/PromptTriggerSelect.tsx | 2 +- .../GenerationSettingsAccordion.tsx | 6 +- 16 files changed, 101 insertions(+), 85 deletions(-) delete mode 100644 invokeai/frontend/web/src/features/controlLayers/store/lorasReducers.ts create mode 100644 invokeai/frontend/web/src/features/controlLayers/store/lorasSlice.ts diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts index 41bd9d6712..fefd88800e 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts @@ -1,6 +1,6 @@ import { logger } from 'app/logging/logger'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; -import { loraDeleted } from 'features/controlLayers/store/canvasV2Slice'; +import { loraDeleted } from 'features/controlLayers/store/lorasSlice'; import { modelChanged, vaeSelected } from 'features/controlLayers/store/paramsSlice'; import { modelSelected } from 'features/parameters/store/actions'; import { zParameterModel } from 'features/parameters/types/parameterSchemas'; @@ -31,7 +31,7 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) = let modelsCleared = 0; // handle incompatible loras - state.canvasV2.loras.forEach((lora) => { + state.loras.loras.forEach((lora) => { if (lora.model.base !== newBaseModel) { dispatch(loraDeleted({ id: lora.id })); modelsCleared += 1; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts index f6b8818e13..cc452a2152 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts @@ -7,9 +7,9 @@ import { bboxWidthChanged, controlLayerModelChanged, ipaModelChanged, - loraDeleted, rgIPAdapterModelChanged, } from 'features/controlLayers/store/canvasV2Slice'; +import { loraDeleted } from 'features/controlLayers/store/lorasSlice'; import { modelChanged, refinerModelChanged, vaeSelected } from 'features/controlLayers/store/paramsSlice'; import { getEntityIdentifier } from 'features/controlLayers/store/types'; import { calculateNewSize } from 'features/parameters/components/DocumentSize/calculateNewSize'; @@ -161,7 +161,7 @@ const handleVAEModels: ModelHandler = (models, state, dispatch, log) => { const handleLoRAModels: ModelHandler = (models, state, dispatch, _log) => { const loraModels = models.filter(isLoRAModelConfig); - state.canvasV2.loras.forEach((lora) => { + state.loras.loras.forEach((lora) => { const isLoRAAvailable = loraModels.some((m) => m.key === lora.model.key); if (isLoRAAvailable) { return; diff --git a/invokeai/frontend/web/src/app/store/store.ts b/invokeai/frontend/web/src/app/store/store.ts index aabc77829a..8c8dbb9681 100644 --- a/invokeai/frontend/web/src/app/store/store.ts +++ b/invokeai/frontend/web/src/app/store/store.ts @@ -9,6 +9,7 @@ import { changeBoardModalSlice } from 'features/changeBoardModal/store/slice'; import { canvasSessionPersistConfig, canvasSessionSlice } from 'features/controlLayers/store/canvasSessionSlice'; import { canvasSettingsPersistConfig, canvasSettingsSlice } from 'features/controlLayers/store/canvasSettingsSlice'; import { canvasV2PersistConfig, canvasV2Slice } from 'features/controlLayers/store/canvasV2Slice'; +import { lorasPersistConfig, lorasSlice } from 'features/controlLayers/store/lorasSlice'; import { paramsPersistConfig, paramsSlice } from 'features/controlLayers/store/paramsSlice'; import { toolPersistConfig, toolSlice } from 'features/controlLayers/store/toolSlice'; import { deleteImageModalSlice } from 'features/deleteImageModal/store/slice'; @@ -65,6 +66,7 @@ const allReducers = { [toolSlice.name]: toolSlice.reducer, [canvasSettingsSlice.name]: canvasSettingsSlice.reducer, [canvasSessionSlice.name]: canvasSessionSlice.reducer, + [lorasSlice.name]: lorasSlice.reducer, }; const rootReducer = combineReducers(allReducers); @@ -110,6 +112,7 @@ const persistConfigs: { [key in keyof typeof allReducers]?: PersistConfig } = { [toolPersistConfig.name]: toolPersistConfig, [canvasSettingsPersistConfig.name]: canvasSettingsPersistConfig, [canvasSessionPersistConfig.name]: canvasSessionPersistConfig, + [lorasPersistConfig.name]: lorasPersistConfig, }; const unserialize: UnserializeFunction = (data, key) => { diff --git a/invokeai/frontend/web/src/features/controlLayers/store/canvasV2Slice.ts b/invokeai/frontend/web/src/features/controlLayers/store/canvasV2Slice.ts index 8951797f77..92ac8162d2 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/canvasV2Slice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/canvasV2Slice.ts @@ -8,7 +8,6 @@ import { bboxReducers } from 'features/controlLayers/store/bboxReducers'; import { controlLayersReducers } from 'features/controlLayers/store/controlLayersReducers'; import { inpaintMaskReducers } from 'features/controlLayers/store/inpaintMaskReducers'; import { ipAdaptersReducers } from 'features/controlLayers/store/ipAdaptersReducers'; -import { lorasReducers } from 'features/controlLayers/store/lorasReducers'; import { modelChanged } from 'features/controlLayers/store/paramsSlice'; import { rasterLayersReducers } from 'features/controlLayers/store/rasterLayersReducers'; import { regionsReducers } from 'features/controlLayers/store/regionsReducers'; @@ -52,7 +51,6 @@ const initialState: CanvasV2State = { isHidden: false, entities: [], }, - loras: [], ipAdapters: { entities: [] }, bbox: { rect: { x: 0, y: 0, width: 512, height: 512 }, @@ -77,8 +75,6 @@ export const canvasV2Slice = createSlice({ ...regionsReducers, ...inpaintMaskReducers, ...bboxReducers, - // move out - ...lorasReducers, entitySelected: (state, action: PayloadAction) => { const { entityIdentifier } = action.payload; state.selectedEntityIdentifier = entityIdentifier; @@ -437,13 +433,6 @@ export const { rgIPAdapterMethodChanged, rgIPAdapterModelChanged, rgIPAdapterCLIPVisionModelChanged, - // LoRAs - loraAdded, - loraRecalled, - loraDeleted, - loraWeightChanged, - loraIsEnabledChanged, - loraAllDeleted, // Inpaint mask inpaintMaskAdded, // inpaintMaskRecalled, diff --git a/invokeai/frontend/web/src/features/controlLayers/store/lorasReducers.ts b/invokeai/frontend/web/src/features/controlLayers/store/lorasReducers.ts deleted file mode 100644 index d43f608346..0000000000 --- a/invokeai/frontend/web/src/features/controlLayers/store/lorasReducers.ts +++ /dev/null @@ -1,50 +0,0 @@ -import type { PayloadAction, SliceCaseReducers } from '@reduxjs/toolkit'; -import type { CanvasV2State, LoRA } from 'features/controlLayers/store/types'; -import { zModelIdentifierField } from 'features/nodes/types/common'; -import type { LoRAModelConfig } from 'services/api/types'; -import { v4 as uuidv4 } from 'uuid'; - -export const defaultLoRAConfig: Pick = { - weight: 0.75, - isEnabled: true, -}; - -const selectLoRA = (state: CanvasV2State, id: string) => state.loras.find((lora) => lora.id === id); - -export const lorasReducers = { - loraAdded: { - reducer: (state, action: PayloadAction<{ model: LoRAModelConfig; id: string }>) => { - const { model, id } = action.payload; - const parsedModel = zModelIdentifierField.parse(model); - state.loras.push({ ...defaultLoRAConfig, model: parsedModel, id }); - }, - prepare: (payload: { model: LoRAModelConfig }) => ({ payload: { ...payload, id: uuidv4() } }), - }, - loraRecalled: (state, action: PayloadAction<{ lora: LoRA }>) => { - const { lora } = action.payload; - state.loras.push(lora); - }, - loraDeleted: (state, action: PayloadAction<{ id: string }>) => { - const { id } = action.payload; - state.loras = state.loras.filter((lora) => lora.id !== id); - }, - loraWeightChanged: (state, action: PayloadAction<{ id: string; weight: number }>) => { - const { id, weight } = action.payload; - const lora = selectLoRA(state, id); - if (!lora) { - return; - } - lora.weight = weight; - }, - loraIsEnabledChanged: (state, action: PayloadAction<{ id: string; isEnabled: boolean }>) => { - const { id, isEnabled } = action.payload; - const lora = selectLoRA(state, id); - if (!lora) { - return; - } - lora.isEnabled = isEnabled; - }, - loraAllDeleted: (state) => { - state.loras = []; - }, -} satisfies SliceCaseReducers; diff --git a/invokeai/frontend/web/src/features/controlLayers/store/lorasSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/lorasSlice.ts new file mode 100644 index 0000000000..66d5521113 --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/store/lorasSlice.ts @@ -0,0 +1,80 @@ +import { createSlice, type PayloadAction } from '@reduxjs/toolkit'; +import type { PersistConfig, RootState } from 'app/store/store'; +import type { LoRA } from 'features/controlLayers/store/types'; +import { zModelIdentifierField } from 'features/nodes/types/common'; +import type { LoRAModelConfig } from 'services/api/types'; +import { v4 as uuidv4 } from 'uuid'; + +type LoRAsState = { + loras: LoRA[]; +}; + +export const defaultLoRAConfig: Pick = { + weight: 0.75, + isEnabled: true, +}; + +const initialState: LoRAsState = { + loras: [], +}; + +const selectLoRA = (state: LoRAsState, id: string) => state.loras.find((lora) => lora.id === id); + +export const lorasSlice = createSlice({ + name: 'loras', + initialState, + reducers: { + loraAdded: { + reducer: (state, action: PayloadAction<{ model: LoRAModelConfig; id: string }>) => { + const { model, id } = action.payload; + const parsedModel = zModelIdentifierField.parse(model); + state.loras.push({ ...defaultLoRAConfig, model: parsedModel, id }); + }, + prepare: (payload: { model: LoRAModelConfig }) => ({ payload: { ...payload, id: uuidv4() } }), + }, + loraRecalled: (state, action: PayloadAction<{ lora: LoRA }>) => { + const { lora } = action.payload; + state.loras.push(lora); + }, + loraDeleted: (state, action: PayloadAction<{ id: string }>) => { + const { id } = action.payload; + state.loras = state.loras.filter((lora) => lora.id !== id); + }, + loraWeightChanged: (state, action: PayloadAction<{ id: string; weight: number }>) => { + const { id, weight } = action.payload; + const lora = selectLoRA(state, id); + if (!lora) { + return; + } + lora.weight = weight; + }, + loraIsEnabledChanged: (state, action: PayloadAction<{ id: string; isEnabled: boolean }>) => { + const { id, isEnabled } = action.payload; + const lora = selectLoRA(state, id); + if (!lora) { + return; + } + lora.isEnabled = isEnabled; + }, + loraAllDeleted: (state) => { + state.loras = []; + }, + }, +}); + +export const { loraAdded, loraRecalled, loraDeleted, loraWeightChanged, loraIsEnabledChanged, loraAllDeleted } = + lorasSlice.actions; + +export const selectLoRAsSlice = (state: RootState) => state.loras; + +/* eslint-disable-next-line @typescript-eslint/no-explicit-any */ +const migrate = (state: any): any => { + return state; +}; + +export const lorasPersistConfig: PersistConfig = { + name: lorasSlice.name, + initialState, + migrate, + persistDenylist: [], +}; diff --git a/invokeai/frontend/web/src/features/controlLayers/store/types.ts b/invokeai/frontend/web/src/features/controlLayers/store/types.ts index 74231cdb9f..cb504d7768 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/types.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/types.ts @@ -714,7 +714,6 @@ export type CanvasV2State = { ipAdapters: { entities: CanvasIPAdapterState[]; }; - loras: LoRA[]; bbox: { rect: { x: number; diff --git a/invokeai/frontend/web/src/features/lora/components/LoRACard.tsx b/invokeai/frontend/web/src/features/lora/components/LoRACard.tsx index 54ad2cf987..28c0ea8198 100644 --- a/invokeai/frontend/web/src/features/lora/components/LoRACard.tsx +++ b/invokeai/frontend/web/src/features/lora/components/LoRACard.tsx @@ -11,7 +11,7 @@ import { } from '@invoke-ai/ui-library'; import { useAppDispatch } from 'app/store/storeHooks'; import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover'; -import { loraDeleted, loraIsEnabledChanged, loraWeightChanged } from 'features/controlLayers/store/canvasV2Slice'; +import { loraDeleted, loraIsEnabledChanged, loraWeightChanged } from 'features/controlLayers/store/lorasSlice'; import type { LoRA } from 'features/controlLayers/store/types'; import { memo, useCallback } from 'react'; import { PiTrashSimpleBold } from 'react-icons/pi'; diff --git a/invokeai/frontend/web/src/features/lora/components/LoRAList.tsx b/invokeai/frontend/web/src/features/lora/components/LoRAList.tsx index e96e38797d..6d05f1ea6f 100644 --- a/invokeai/frontend/web/src/features/lora/components/LoRAList.tsx +++ b/invokeai/frontend/web/src/features/lora/components/LoRAList.tsx @@ -1,11 +1,11 @@ import { Flex } from '@invoke-ai/ui-library'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; -import { selectCanvasV2Slice } from 'features/controlLayers/store/selectors'; +import { selectLoRAsSlice } from 'features/controlLayers/store/lorasSlice'; import { LoRACard } from 'features/lora/components/LoRACard'; import { memo } from 'react'; -const selectLoRAsArray = createMemoizedSelector(selectCanvasV2Slice, (canvasV2) => canvasV2.loras); +const selectLoRAsArray = createMemoizedSelector(selectLoRAsSlice, (loras) => loras.loras); export const LoRAList = memo(() => { const lorasArray = useAppSelector(selectLoRAsArray); diff --git a/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx b/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx index 2f552ba863..8786e6c1b3 100644 --- a/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx +++ b/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx @@ -4,14 +4,13 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover'; import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox'; -import { loraAdded } from 'features/controlLayers/store/canvasV2Slice'; -import { selectCanvasV2Slice } from 'features/controlLayers/store/selectors'; +import { loraAdded, selectLoRAsSlice } from 'features/controlLayers/store/lorasSlice'; import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { useLoRAModels } from 'services/api/hooks/modelsByType'; import type { LoRAModelConfig } from 'services/api/types'; -const selectLoRAs = createMemoizedSelector(selectCanvasV2Slice, (canvasV2) => canvasV2.loras); +const selectLoRAs = createMemoizedSelector(selectLoRAsSlice, (loras) => loras.loras); const LoRASelect = () => { const dispatch = useAppDispatch(); diff --git a/invokeai/frontend/web/src/features/metadata/util/parsers.ts b/invokeai/frontend/web/src/features/metadata/util/parsers.ts index 60f9eea833..efcfffd091 100644 --- a/invokeai/frontend/web/src/features/metadata/util/parsers.ts +++ b/invokeai/frontend/web/src/features/metadata/util/parsers.ts @@ -1,5 +1,5 @@ import { getPrefixedId } from 'features/controlLayers/konva/util'; -import { defaultLoRAConfig } from 'features/controlLayers/store/lorasReducers'; +import { defaultLoRAConfig } from 'features/controlLayers/store/lorasSlice'; import type { CanvasControlLayerState, CanvasInpaintMaskState, diff --git a/invokeai/frontend/web/src/features/metadata/util/recallers.ts b/invokeai/frontend/web/src/features/metadata/util/recallers.ts index 6333167768..71278e0d9c 100644 --- a/invokeai/frontend/web/src/features/metadata/util/recallers.ts +++ b/invokeai/frontend/web/src/features/metadata/util/recallers.ts @@ -1,10 +1,6 @@ import { getStore } from 'app/store/nanostores/store'; -import { - bboxHeightChanged, - bboxWidthChanged, - loraAllDeleted, - loraRecalled, -} from 'features/controlLayers/store/canvasV2Slice'; +import { bboxHeightChanged, bboxWidthChanged } from 'features/controlLayers/store/canvasV2Slice'; +import { loraAllDeleted, loraRecalled } from 'features/controlLayers/store/lorasSlice'; import { negativePrompt2Changed, negativePromptChanged, diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addLoRAs.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addLoRAs.ts index 92bf0cbeaa..79a8521efb 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addLoRAs.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addLoRAs.ts @@ -14,7 +14,7 @@ export const addLoRAs = ( posCond: Invocation<'compel'>, negCond: Invocation<'compel'> ): void => { - const enabledLoRAs = state.canvasV2.loras.filter( + const enabledLoRAs = state.loras.loras.filter( (l) => l.isEnabled && (l.model.base === 'sd-1' || l.model.base === 'sd-2') ); const loraCount = enabledLoRAs.length; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addSDXLLoRAs.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addSDXLLoRAs.ts index ffb5268520..a38c9757ce 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addSDXLLoRAs.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addSDXLLoRAs.ts @@ -13,7 +13,7 @@ export const addSDXLLoRAs = ( posCond: Invocation<'sdxl_compel_prompt'>, negCond: Invocation<'sdxl_compel_prompt'> ): void => { - const enabledLoRAs = state.canvasV2.loras.filter((l) => l.isEnabled && l.model.base === 'sdxl'); + const enabledLoRAs = state.loras.loras.filter((l) => l.isEnabled && l.model.base === 'sdxl'); const loraCount = enabledLoRAs.length; if (loraCount === 0) { diff --git a/invokeai/frontend/web/src/features/prompt/PromptTriggerSelect.tsx b/invokeai/frontend/web/src/features/prompt/PromptTriggerSelect.tsx index ad89f6872d..b4b14f4e52 100644 --- a/invokeai/frontend/web/src/features/prompt/PromptTriggerSelect.tsx +++ b/invokeai/frontend/web/src/features/prompt/PromptTriggerSelect.tsx @@ -18,7 +18,7 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel const { t } = useTranslation(); const mainModel = useAppSelector((s) => s.params.model); - const addedLoRAs = useAppSelector((s) => s.canvasV2.loras); + const addedLoRAs = useAppSelector((s) => s.loras.loras); const { data: mainModelConfig, isLoading: isLoadingMainModelConfig } = useGetModelConfigQuery( mainModel?.key ?? skipToken ); diff --git a/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/GenerationSettingsAccordion.tsx b/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/GenerationSettingsAccordion.tsx index 2d69401369..9edab01c37 100644 --- a/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/GenerationSettingsAccordion.tsx +++ b/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/GenerationSettingsAccordion.tsx @@ -3,7 +3,7 @@ import { Box, Expander, Flex, FormControlGroup, StandaloneAccordion } from '@inv import { EMPTY_ARRAY } from 'app/store/constants'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; -import { selectCanvasV2Slice } from 'features/controlLayers/store/selectors'; +import { selectLoRAsSlice } from 'features/controlLayers/store/lorasSlice'; import { LoRAList } from 'features/lora/components/LoRAList'; import LoRASelect from 'features/lora/components/LoRASelect'; import ParamCFGScale from 'features/parameters/components/Core/ParamCFGScale'; @@ -29,8 +29,8 @@ export const GenerationSettingsAccordion = memo(() => { const activeTabName = useAppSelector(selectActiveTab); const selectBadges = useMemo( () => - createMemoizedSelector(selectCanvasV2Slice, (canvasV2) => { - const enabledLoRAsCount = canvasV2.loras.filter((l) => l.isEnabled).length; + createMemoizedSelector(selectLoRAsSlice, (loras) => { + const enabledLoRAsCount = loras.loras.filter((l) => l.isEnabled).length; const loraTabBadges = enabledLoRAsCount ? [`${enabledLoRAsCount} ${t('models.concepts')}`] : EMPTY_ARRAY; const accordionBadges = modelConfig ? [modelConfig.name, modelConfig.base] : EMPTY_ARRAY; return { loraTabBadges, accordionBadges };