From 4b848798e742dffa28f36741e701275a5fa44921 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 17 Jun 2024 12:23:32 +1000 Subject: [PATCH] refactor(ui): move loras to canvas slice --- .../listeners/modelSelected.ts | 10 +-- .../listeners/modelsLoaded.ts | 9 +- invokeai/frontend/web/src/app/store/store.ts | 7 -- .../controlLayers/store/canvasV2Slice.ts | 10 +++ .../controlLayers/store/lorasReducers.ts | 56 ++++++++++++ .../src/features/controlLayers/store/types.ts | 9 ++ .../src/features/lora/components/LoRACard.tsx | 20 ++--- .../src/features/lora/components/LoRAList.tsx | 7 +- .../features/lora/components/LoRASelect.tsx | 18 ++-- .../web/src/features/lora/store/loraSlice.ts | 87 ------------------- .../metadata/components/MetadataLoRAs.tsx | 2 +- .../nodes/util/graph/generation/addLoRAs.ts | 7 +- .../util/graph/generation/addSDXLLoRAs.ts | 5 +- .../GenerationSettingsAccordion.tsx | 7 +- 14 files changed, 114 insertions(+), 140 deletions(-) create mode 100644 invokeai/frontend/web/src/features/controlLayers/store/lorasReducers.ts delete mode 100644 invokeai/frontend/web/src/features/lora/store/loraSlice.ts diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts index e3ebd277ff..38260b23a2 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts @@ -1,12 +1,10 @@ import { logger } from 'app/logging/logger'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; -import { caIsEnabledToggled, modelChanged, vaeSelected } from 'features/controlLayers/store/canvasV2Slice'; -import { loraRemoved } from 'features/lora/store/loraSlice'; +import { caIsEnabledToggled, loraDeleted, modelChanged, vaeSelected } from 'features/controlLayers/store/canvasV2Slice'; import { modelSelected } from 'features/parameters/store/actions'; import { zParameterModel } from 'features/parameters/types/parameterSchemas'; import { toast } from 'features/toast/toast'; import { t } from 'i18next'; -import { forEach } from 'lodash-es'; export const addModelSelectedListener = (startAppListening: AppStartListening) => { startAppListening({ @@ -32,9 +30,9 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) = let modelsCleared = 0; // handle incompatible loras - forEach(state.lora.loras, (lora, id) => { + state.canvasV2.loras.forEach((lora) => { if (lora.model.base !== newBaseModel) { - dispatch(loraRemoved(id)); + dispatch(loraDeleted({ id: lora.id })); modelsCleared += 1; } }); @@ -68,7 +66,7 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) = } } - dispatch(modelChanged(newModel, state.canvasV2.params.model)); + dispatch(modelChanged({ model: newModel, previousModel: state.canvasV2.params.model })); }, }); }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts index b3b4dacaa5..47555a0ef1 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts @@ -6,18 +6,17 @@ import { caModelChanged, heightChanged, ipaModelChanged, + loraDeleted, modelChanged, refinerModelChanged, rgIPAdapterModelChanged, vaeSelected, widthChanged, } from 'features/controlLayers/store/canvasV2Slice'; -import { loraRemoved } from 'features/lora/store/loraSlice'; import { calculateNewSize } from 'features/parameters/components/ImageSize/calculateNewSize'; import { postProcessingModelChanged, upscaleModelChanged } from 'features/parameters/store/upscaleSlice'; import { zParameterModel, zParameterVAEModel } from 'features/parameters/types/parameterSchemas'; import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension'; -import { forEach } from 'lodash-es'; import type { Logger } from 'roarr'; import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/models'; import type { AnyModelConfig } from 'services/api/types'; @@ -161,15 +160,13 @@ const handleVAEModels: ModelHandler = (models, state, dispatch, log) => { }; const handleLoRAModels: ModelHandler = (models, state, dispatch, _log) => { - const loras = state.lora.loras; const loraModels = models.filter(isLoRAModelConfig); - - forEach(loras, (lora, id) => { + state.canvasV2.loras.forEach((lora) => { const isLoRAAvailable = loraModels.some((m) => m.key === lora.model.key); if (isLoRAAvailable) { return; } - dispatch(loraRemoved(id)); + dispatch(loraDeleted({ id: lora.id })); }); }; diff --git a/invokeai/frontend/web/src/app/store/store.ts b/invokeai/frontend/web/src/app/store/store.ts index 95de810cdd..f41d6273e9 100644 --- a/invokeai/frontend/web/src/app/store/store.ts +++ b/invokeai/frontend/web/src/app/store/store.ts @@ -10,7 +10,6 @@ import { deleteImageModalSlice } from 'features/deleteImageModal/store/slice'; import { dynamicPromptsPersistConfig, dynamicPromptsSlice } from 'features/dynamicPrompts/store/dynamicPromptsSlice'; import { galleryPersistConfig, gallerySlice } from 'features/gallery/store/gallerySlice'; import { hrfPersistConfig, hrfSlice } from 'features/hrf/store/hrfSlice'; -import { loraPersistConfig, loraSlice } from 'features/lora/store/loraSlice'; import { modelManagerV2PersistConfig, modelManagerV2Slice } from 'features/modelManagerV2/store/modelManagerV2Slice'; import { nodesPersistConfig, nodesSlice, nodesUndoableConfig } from 'features/nodes/store/nodesSlice'; import { workflowSettingsPersistConfig, workflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice'; @@ -40,7 +39,6 @@ import { listenerMiddleware } from './middleware/listenerMiddleware'; const allReducers = { [api.reducerPath]: api.reducer, [gallerySlice.name]: gallerySlice.reducer, - // [generationSlice.name]: generationSlice.reducer, [nodesSlice.name]: undoable(nodesSlice.reducer, nodesUndoableConfig), [systemSlice.name]: systemSlice.reducer, [configSlice.name]: configSlice.reducer, @@ -48,9 +46,7 @@ const allReducers = { [dynamicPromptsSlice.name]: dynamicPromptsSlice.reducer, [deleteImageModalSlice.name]: deleteImageModalSlice.reducer, [changeBoardModalSlice.name]: changeBoardModalSlice.reducer, - [loraSlice.name]: loraSlice.reducer, [modelManagerV2Slice.name]: modelManagerV2Slice.reducer, - // [sdxlSlice.name]: sdxlSlice.reducer, [queueSlice.name]: queueSlice.reducer, [workflowSlice.name]: workflowSlice.reducer, [hrfSlice.name]: hrfSlice.reducer, @@ -88,14 +84,11 @@ export type PersistConfig = { const persistConfigs: { [key in keyof typeof allReducers]?: PersistConfig } = { [galleryPersistConfig.name]: galleryPersistConfig, - // [generationPersistConfig.name]: generationPersistConfig, [nodesPersistConfig.name]: nodesPersistConfig, [systemPersistConfig.name]: systemPersistConfig, [workflowPersistConfig.name]: workflowPersistConfig, [uiPersistConfig.name]: uiPersistConfig, [dynamicPromptsPersistConfig.name]: dynamicPromptsPersistConfig, - // [sdxlPersistConfig.name]: sdxlPersistConfig, - [loraPersistConfig.name]: loraPersistConfig, [modelManagerV2PersistConfig.name]: modelManagerV2PersistConfig, [hrfPersistConfig.name]: hrfPersistConfig, [canvasV2PersistConfig.name]: canvasV2PersistConfig, diff --git a/invokeai/frontend/web/src/features/controlLayers/store/canvasV2Slice.ts b/invokeai/frontend/web/src/features/controlLayers/store/canvasV2Slice.ts index 631cab64b3..c6f6a2bd7a 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/canvasV2Slice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/canvasV2Slice.ts @@ -8,6 +8,7 @@ import { compositingReducers } from 'features/controlLayers/store/compositingRed import { controlAdaptersReducers } from 'features/controlLayers/store/controlAdaptersReducers'; import { ipAdaptersReducers } from 'features/controlLayers/store/ipAdaptersReducers'; import { layersReducers } from 'features/controlLayers/store/layersReducers'; +import { lorasReducers } from 'features/controlLayers/store/lorasReducers'; import { paramsReducers } from 'features/controlLayers/store/paramsReducers'; import { regionsReducers } from 'features/controlLayers/store/regionsReducers'; import { settingsReducers } from 'features/controlLayers/store/settingsReducers'; @@ -27,6 +28,7 @@ const initialState: CanvasV2State = { controlAdapters: [], ipAdapters: [], regions: [], + loras: [], tool: { selected: 'bbox', selectedBuffer: null, @@ -113,6 +115,7 @@ export const canvasV2Slice = createSlice({ ...ipAdaptersReducers, ...controlAdaptersReducers, ...regionsReducers, + ...lorasReducers, ...paramsReducers, ...compositingReducers, ...settingsReducers, @@ -287,6 +290,13 @@ export const { setRefinerNegativeAestheticScore, setRefinerStart, modelChanged, + // LoRAs + loraAdded, + loraRecalled, + loraDeleted, + loraWeightChanged, + loraIsEnabledChanged, + loraAllDeleted, } = canvasV2Slice.actions; export const selectCanvasV2Slice = (state: RootState) => state.canvasV2; diff --git a/invokeai/frontend/web/src/features/controlLayers/store/lorasReducers.ts b/invokeai/frontend/web/src/features/controlLayers/store/lorasReducers.ts new file mode 100644 index 0000000000..bca3ccdd9e --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/store/lorasReducers.ts @@ -0,0 +1,56 @@ +import type { PayloadAction, SliceCaseReducers } from '@reduxjs/toolkit'; +import type { CanvasV2State, LoRA } from 'features/controlLayers/store/types'; +import { zModelIdentifierField } from 'features/nodes/types/common'; +import type { LoRAModelConfig } from 'services/api/types'; +import { assert } from 'tsafe'; +import { v4 as uuidv4 } from 'uuid'; + +export const defaultLoRAConfig: Pick = { + weight: 0.75, + isEnabled: true, +}; + +export const selectLoRA = (state: CanvasV2State, id: string) => state.loras.find((lora) => lora.id === id); +export const selectLoRAOrThrow = (state: CanvasV2State, id: string) => { + const lora = selectLoRA(state, id); + assert(lora, `LoRA with id ${id} not found`); + return lora; +}; + +export const lorasReducers = { + loraAdded: { + reducer: (state, action: PayloadAction<{ model: LoRAModelConfig; id: string }>) => { + const { model, id } = action.payload; + const parsedModel = zModelIdentifierField.parse(model); + state.loras.push({ ...defaultLoRAConfig, model: parsedModel, id }); + }, + prepare: (payload: { model: LoRAModelConfig }) => ({ payload: { ...payload, id: uuidv4() } }), + }, + loraRecalled: (state, action: PayloadAction<{ lora: LoRA }>) => { + const { lora } = action.payload; + state.loras.push(lora); + }, + loraDeleted: (state, action: PayloadAction<{ id: string }>) => { + const { id } = action.payload; + state.loras = state.loras.filter((lora) => lora.id !== id); + }, + loraWeightChanged: (state, action: PayloadAction<{ id: string; weight: number }>) => { + const { id, weight } = action.payload; + const lora = selectLoRA(state, id); + if (!lora) { + return; + } + lora.weight = weight; + }, + loraIsEnabledChanged: (state, action: PayloadAction<{ id: string; isEnabled: boolean }>) => { + const { id, isEnabled } = action.payload; + const lora = selectLoRA(state, id); + if (!lora) { + return; + } + lora.isEnabled = isEnabled; + }, + loraAllDeleted: (state) => { + state.loras = []; + }, +} satisfies SliceCaseReducers; diff --git a/invokeai/frontend/web/src/features/controlLayers/store/types.ts b/invokeai/frontend/web/src/features/controlLayers/store/types.ts index 14d4725fb6..9038682c92 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/types.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/types.ts @@ -6,6 +6,7 @@ import type { ParameterCFGRescaleMultiplier, ParameterCFGScale, ParameterHeight, + ParameterLoRAModel, ParameterMaskBlurMethod, ParameterModel, ParameterNegativePrompt, @@ -799,6 +800,13 @@ export type Dimensions = { height: number; }; +export type LoRA = { + id: string; + isEnabled: boolean; + model: ParameterLoRAModel; + weight: number; +}; + export type CanvasV2State = { _version: 3; selectedEntityIdentifier: CanvasEntityIdentifier | null; @@ -806,6 +814,7 @@ export type CanvasV2State = { controlAdapters: ControlAdapterData[]; ipAdapters: IPAdapterData[]; regions: RegionalGuidanceData[]; + loras: LoRA[]; tool: { selected: Tool; selectedBuffer: Tool | null; diff --git a/invokeai/frontend/web/src/features/lora/components/LoRACard.tsx b/invokeai/frontend/web/src/features/lora/components/LoRACard.tsx index f7261b4608..54ad2cf987 100644 --- a/invokeai/frontend/web/src/features/lora/components/LoRACard.tsx +++ b/invokeai/frontend/web/src/features/lora/components/LoRACard.tsx @@ -11,8 +11,8 @@ import { } from '@invoke-ai/ui-library'; import { useAppDispatch } from 'app/store/storeHooks'; import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover'; -import type { LoRA } from 'features/lora/store/loraSlice'; -import { loraIsEnabledChanged, loraRemoved, loraWeightChanged } from 'features/lora/store/loraSlice'; +import { loraDeleted, loraIsEnabledChanged, loraWeightChanged } from 'features/controlLayers/store/canvasV2Slice'; +import type { LoRA } from 'features/controlLayers/store/types'; import { memo, useCallback } from 'react'; import { PiTrashSimpleBold } from 'react-icons/pi'; import { useGetModelConfigQuery } from 'services/api/endpoints/models'; @@ -21,6 +21,8 @@ type LoRACardProps = { lora: LoRA; }; +const marks = [-1, 0, 1, 2]; + export const LoRACard = memo((props: LoRACardProps) => { const { lora } = props; const dispatch = useAppDispatch(); @@ -28,18 +30,18 @@ export const LoRACard = memo((props: LoRACardProps) => { const handleChange = useCallback( (v: number) => { - dispatch(loraWeightChanged({ key: lora.model.key, weight: v })); + dispatch(loraWeightChanged({ id: lora.id, weight: v })); }, - [dispatch, lora.model.key] + [dispatch, lora.id] ); const handleSetLoraToggle = useCallback(() => { - dispatch(loraIsEnabledChanged({ key: lora.model.key, isEnabled: !lora.isEnabled })); - }, [dispatch, lora.model.key, lora.isEnabled]); + dispatch(loraIsEnabledChanged({ id: lora.id, isEnabled: !lora.isEnabled })); + }, [dispatch, lora.id, lora.isEnabled]); const handleRemoveLora = useCallback(() => { - dispatch(loraRemoved(lora.model.key)); - }, [dispatch, lora.model.key]); + dispatch(loraDeleted({ id: lora.id })); + }, [dispatch, lora.id]); return ( @@ -90,5 +92,3 @@ export const LoRACard = memo((props: LoRACardProps) => { }); LoRACard.displayName = 'LoRACard'; - -const marks = [-1, 0, 1, 2]; diff --git a/invokeai/frontend/web/src/features/lora/components/LoRAList.tsx b/invokeai/frontend/web/src/features/lora/components/LoRAList.tsx index 68d259a852..b5a2b1bba9 100644 --- a/invokeai/frontend/web/src/features/lora/components/LoRAList.tsx +++ b/invokeai/frontend/web/src/features/lora/components/LoRAList.tsx @@ -1,12 +1,11 @@ import { Flex } from '@invoke-ai/ui-library'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; +import { selectCanvasV2Slice } from 'features/controlLayers/store/canvasV2Slice'; import { LoRACard } from 'features/lora/components/LoRACard'; -import { selectLoraSlice } from 'features/lora/store/loraSlice'; -import { map } from 'lodash-es'; import { memo } from 'react'; -const selectLoRAsArray = createMemoizedSelector(selectLoraSlice, (lora) => map(lora.loras)); +const selectLoRAsArray = createMemoizedSelector(selectCanvasV2Slice, (canvasV2) => canvasV2.loras); export const LoRAList = memo(() => { const lorasArray = useAppSelector(selectLoRAsArray); @@ -18,7 +17,7 @@ export const LoRAList = memo(() => { return ( {lorasArray.map((lora) => ( - + ))} ); diff --git a/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx b/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx index 3c3e0375e2..8296031418 100644 --- a/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx +++ b/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx @@ -4,34 +4,34 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover'; import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox'; -import { loraAdded, selectLoraSlice } from 'features/lora/store/loraSlice'; +import { loraAdded, selectCanvasV2Slice } from 'features/controlLayers/store/canvasV2Slice'; import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { useLoRAModels } from 'services/api/hooks/modelsByType'; import type { LoRAModelConfig } from 'services/api/types'; -const selectAddedLoRAs = createMemoizedSelector(selectLoraSlice, (lora) => lora.loras); +const selectLoRAs = createMemoizedSelector(selectCanvasV2Slice, (canvasV2) => canvasV2.loras); const LoRASelect = () => { const dispatch = useAppDispatch(); const [modelConfigs, { isLoading }] = useLoRAModels(); const { t } = useTranslation(); - const addedLoRAs = useAppSelector(selectAddedLoRAs); + const addedLoRAs = useAppSelector(selectLoRAs); const currentBaseModel = useAppSelector((s) => s.canvasV2.params.model?.base); - const getIsDisabled = (lora: LoRAModelConfig): boolean => { - const isCompatible = currentBaseModel === lora.base; - const isAdded = Boolean(addedLoRAs[lora.key]); + const getIsDisabled = (model: LoRAModelConfig): boolean => { + const isCompatible = currentBaseModel === model.base; + const isAdded = Boolean(addedLoRAs.find((lora) => lora.model.key === model.key)); const hasMainModel = Boolean(currentBaseModel); return !hasMainModel || !isCompatible || isAdded; }; const _onChange = useCallback( - (lora: LoRAModelConfig | null) => { - if (!lora) { + (model: LoRAModelConfig | null) => { + if (!model) { return; } - dispatch(loraAdded(lora)); + dispatch(loraAdded({ model })); }, [dispatch] ); diff --git a/invokeai/frontend/web/src/features/lora/store/loraSlice.ts b/invokeai/frontend/web/src/features/lora/store/loraSlice.ts deleted file mode 100644 index 2382e9ffe4..0000000000 --- a/invokeai/frontend/web/src/features/lora/store/loraSlice.ts +++ /dev/null @@ -1,87 +0,0 @@ -import type { PayloadAction } from '@reduxjs/toolkit'; -import { createSlice } from '@reduxjs/toolkit'; -import type { PersistConfig, RootState } from 'app/store/store'; -import { deepClone } from 'common/util/deepClone'; -import { zModelIdentifierField } from 'features/nodes/types/common'; -import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas'; -import type { LoRAModelConfig } from 'services/api/types'; - -export type LoRA = { - model: ParameterLoRAModel; - weight: number; - isEnabled?: boolean; -}; - -export const defaultLoRAConfig: Pick = { - weight: 0.75, - isEnabled: true, -}; - -type LoraState = { - _version: 2; - loras: Record; -}; - -const initialLoraState: LoraState = { - _version: 2, - loras: {}, -}; - -export const loraSlice = createSlice({ - name: 'lora', - initialState: initialLoraState, - reducers: { - loraAdded: (state, action: PayloadAction) => { - const model = zModelIdentifierField.parse(action.payload); - state.loras[model.key] = { ...defaultLoRAConfig, model }; - }, - loraRecalled: (state, action: PayloadAction) => { - state.loras[action.payload.model.key] = action.payload; - }, - loraRemoved: (state, action: PayloadAction) => { - const key = action.payload; - delete state.loras[key]; - }, - loraWeightChanged: (state, action: PayloadAction<{ key: string; weight: number }>) => { - const { key, weight } = action.payload; - const lora = state.loras[key]; - if (!lora) { - return; - } - lora.weight = weight; - }, - loraIsEnabledChanged: (state, action: PayloadAction<{ key: string; isEnabled: boolean }>) => { - const { key, isEnabled } = action.payload; - const lora = state.loras[key]; - if (!lora) { - return; - } - lora.isEnabled = isEnabled; - }, - lorasReset: () => deepClone(initialLoraState), - }, -}); - -export const { loraAdded, loraRemoved, loraWeightChanged, loraIsEnabledChanged, loraRecalled, lorasReset } = - loraSlice.actions; - -export const selectLoraSlice = (state: RootState) => state.lora; - -/* eslint-disable-next-line @typescript-eslint/no-explicit-any */ -const migrateLoRAState = (state: any): any => { - if (!('_version' in state)) { - state._version = 1; - } - if (state._version === 1) { - // Model type has changed, so we need to reset the state - too risky to migrate - state = deepClone(initialLoraState); - } - return state; -}; - -export const loraPersistConfig: PersistConfig = { - name: loraSlice.name, - initialState: initialLoraState, - migrate: migrateLoRAState, - persistDenylist: [], -}; diff --git a/invokeai/frontend/web/src/features/metadata/components/MetadataLoRAs.tsx b/invokeai/frontend/web/src/features/metadata/components/MetadataLoRAs.tsx index 7e78985c49..921306fc93 100644 --- a/invokeai/frontend/web/src/features/metadata/components/MetadataLoRAs.tsx +++ b/invokeai/frontend/web/src/features/metadata/components/MetadataLoRAs.tsx @@ -1,4 +1,4 @@ -import type { LoRA } from 'features/lora/store/loraSlice'; +import type { LoRA } from 'features/controlLayers/store/types'; import { MetadataItemView } from 'features/metadata/components/MetadataItemView'; import type { MetadataHandlers } from 'features/metadata/types'; import { handlers } from 'features/metadata/util/handlers'; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addLoRAs.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addLoRAs.ts index 3335e0f80d..b078dfcdfc 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addLoRAs.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addLoRAs.ts @@ -2,7 +2,6 @@ import type { RootState } from 'app/store/store'; import { zModelIdentifierField } from 'features/nodes/types/common'; import { LORA_LOADER } from 'features/nodes/util/graph/constants'; import type { Graph } from 'features/nodes/util/graph/generation/Graph'; -import { filter, size } from 'lodash-es'; import type { Invocation, S } from 'services/api/types'; export const addLoRAs = ( @@ -15,8 +14,10 @@ export const addLoRAs = ( posCond: Invocation<'compel'>, negCond: Invocation<'compel'> ): void => { - const enabledLoRAs = filter(state.lora.loras, (l) => l.isEnabled ?? false); - const loraCount = size(enabledLoRAs); + const enabledLoRAs = state.canvasV2.loras.filter( + (l) => l.isEnabled && (l.model.base === 'sd-1' || l.model.base === 'sd-2') + ); + const loraCount = enabledLoRAs.length; if (loraCount === 0) { return; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addSDXLLoRAs.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addSDXLLoRAs.ts index 3125ab5ac3..d7377da4b0 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addSDXLLoRAs.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addSDXLLoRAs.ts @@ -2,7 +2,6 @@ import type { RootState } from 'app/store/store'; import { zModelIdentifierField } from 'features/nodes/types/common'; import { LORA_LOADER } from 'features/nodes/util/graph/constants'; import type { Graph } from 'features/nodes/util/graph/generation/Graph'; -import { filter, size } from 'lodash-es'; import type { Invocation, S } from 'services/api/types'; export const addSDXLLoRas = ( @@ -14,8 +13,8 @@ export const addSDXLLoRas = ( posCond: Invocation<'sdxl_compel_prompt'>, negCond: Invocation<'sdxl_compel_prompt'> ): void => { - const enabledLoRAs = filter(state.lora.loras, (l) => l.isEnabled ?? false); - const loraCount = size(enabledLoRAs); + const enabledLoRAs = state.canvasV2.loras.filter((l) => l.isEnabled && l.model.base === 'sdxl'); + const loraCount = enabledLoRAs.length; if (loraCount === 0) { return; diff --git a/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/GenerationSettingsAccordion.tsx b/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/GenerationSettingsAccordion.tsx index a809a35587..634b8a6bd3 100644 --- a/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/GenerationSettingsAccordion.tsx +++ b/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/GenerationSettingsAccordion.tsx @@ -3,9 +3,9 @@ import { Box, Expander, Flex, FormControlGroup, StandaloneAccordion } from '@inv import { EMPTY_ARRAY } from 'app/store/constants'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; +import { selectCanvasV2Slice } from 'features/controlLayers/store/canvasV2Slice'; import { LoRAList } from 'features/lora/components/LoRAList'; import LoRASelect from 'features/lora/components/LoRASelect'; -import { selectLoraSlice } from 'features/lora/store/loraSlice'; import ParamCFGScale from 'features/parameters/components/Core/ParamCFGScale'; import ParamScheduler from 'features/parameters/components/Core/ParamScheduler'; import ParamSteps from 'features/parameters/components/Core/ParamSteps'; @@ -15,7 +15,6 @@ import { UseDefaultSettingsButton } from 'features/parameters/components/MainMod import { useExpanderToggle } from 'features/settingsAccordions/hooks/useExpanderToggle'; import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; -import { filter } from 'lodash-es'; import { memo, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { useSelectedModelConfig } from 'services/api/hooks/useSelectedModelConfig'; @@ -30,8 +29,8 @@ export const GenerationSettingsAccordion = memo(() => { const activeTabName = useAppSelector(activeTabNameSelector); const selectBadges = useMemo( () => - createMemoizedSelector(selectLoraSlice, (lora) => { - const enabledLoRAsCount = filter(lora.loras, (l) => !!l.isEnabled).length; + createMemoizedSelector(selectCanvasV2Slice, (canvasV2) => { + const enabledLoRAsCount = canvasV2.loras.filter((l) => l.isEnabled).length; const loraTabBadges = enabledLoRAsCount ? [`${enabledLoRAsCount} ${t('models.concepts')}`] : EMPTY_ARRAY; const accordionBadges = modelConfig ? [modelConfig.name, modelConfig.base] : EMPTY_ARRAY; return { loraTabBadges, accordionBadges };