diff --git a/invokeai/frontend/web/src/features/lora/components/LoRACard.tsx b/invokeai/frontend/web/src/features/lora/components/LoRACard.tsx
index f0c8e3fcd3..caedde875a 100644
--- a/invokeai/frontend/web/src/features/lora/components/LoRACard.tsx
+++ b/invokeai/frontend/web/src/features/lora/components/LoRACard.tsx
@@ -4,12 +4,14 @@ import {
CardHeader,
CompositeNumberInput,
CompositeSlider,
+ Flex,
IconButton,
+ Switch,
Text,
} from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import type { LoRA } from 'features/lora/store/loraSlice';
-import { loraRemoved, loraWeightChanged } from 'features/lora/store/loraSlice';
+import { loraIsEnabledChanged, loraRemoved, loraWeightChanged } from 'features/lora/store/loraSlice';
import { memo, useCallback } from 'react';
import { PiTrashSimpleBold } from 'react-icons/pi';
@@ -28,6 +30,10 @@ export const LoRACard = memo((props: LoRACardProps) => {
[dispatch, lora.id]
);
+ const handleSetLoraToggle = useCallback(() => {
+ dispatch(loraIsEnabledChanged({ id: lora.id, isEnabled: !lora.isEnabled }));
+ }, [dispatch, lora.id, lora.isEnabled]);
+
const handleRemoveLora = useCallback(() => {
dispatch(loraRemoved(lora.id));
}, [dispatch, lora.id]);
@@ -35,16 +41,21 @@ export const LoRACard = memo((props: LoRACardProps) => {
return (
-
- {lora.model_name}
-
- }
- />
+
+
+ {lora.model_name}
+
+
+
+ }
+ />
+
+
{
step={0.01}
marks={marks}
defaultValue={0.75}
+ isDisabled={!lora.isEnabled}
/>
{
w={20}
flexShrink={0}
defaultValue={0.75}
+ isDisabled={!lora.isEnabled}
/>
diff --git a/invokeai/frontend/web/src/features/lora/store/loraSlice.ts b/invokeai/frontend/web/src/features/lora/store/loraSlice.ts
index 8562090fbc..7906443d7f 100644
--- a/invokeai/frontend/web/src/features/lora/store/loraSlice.ts
+++ b/invokeai/frontend/web/src/features/lora/store/loraSlice.ts
@@ -7,10 +7,12 @@ import type { LoRAModelConfigEntity } from 'services/api/endpoints/models';
export type LoRA = ParameterLoRAModel & {
id: string;
weight: number;
+ isEnabled?: boolean;
};
-export const defaultLoRAConfig = {
+export const defaultLoRAConfig: Pick = {
weight: 0.75,
+ isEnabled: true,
};
export type LoraState = {
@@ -58,11 +60,26 @@ export const loraSlice = createSlice({
}
lora.weight = defaultLoRAConfig.weight;
},
+ loraIsEnabledChanged: (state, action: PayloadAction>) => {
+ const { id, isEnabled } = action.payload;
+ const lora = state.loras[id];
+ if (!lora) {
+ return;
+ }
+ lora.isEnabled = isEnabled;
+ },
},
});
-export const { loraAdded, loraRemoved, loraWeightChanged, loraWeightReset, lorasCleared, loraRecalled } =
- loraSlice.actions;
+export const {
+ loraAdded,
+ loraRemoved,
+ loraWeightChanged,
+ loraWeightReset,
+ loraIsEnabledChanged,
+ lorasCleared,
+ loraRecalled,
+} = loraSlice.actions;
export default loraSlice.reducer;
diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addLoRAsToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addLoRAsToGraph.ts
index d1b5ddde84..3ed71b7529 100644
--- a/invokeai/frontend/web/src/features/nodes/util/graph/addLoRAsToGraph.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/graph/addLoRAsToGraph.ts
@@ -1,5 +1,5 @@
import type { RootState } from 'app/store/store';
-import { forEach, size } from 'lodash-es';
+import { filter, size } from 'lodash-es';
import type { CoreMetadataInvocation, LoraLoaderInvocation, NonNullableGraph } from 'services/api/types';
import {
@@ -28,8 +28,8 @@ export const addLoRAsToGraph = (
* So we need to inject a LoRA chain into the graph.
*/
- const { loras } = state.lora;
- const loraCount = size(loras);
+ const enabledLoRAs = filter(state.lora.loras, (l) => l.isEnabled ?? false);
+ const loraCount = size(enabledLoRAs);
if (loraCount === 0) {
return;
@@ -47,7 +47,7 @@ export const addLoRAsToGraph = (
let currentLoraIndex = 0;
const loraMetadata: CoreMetadataInvocation['loras'] = [];
- forEach(loras, (lora) => {
+ enabledLoRAs.forEach((lora) => {
const { model_name, base_model, weight } = lora;
const currentLoraNodeId = `${LORA_LOADER}_${model_name.replace('.', '_')}`;
diff --git a/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/GenerationSettingsAccordion.tsx b/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/GenerationSettingsAccordion.tsx
index d641ea9929..ea6fd3563d 100644
--- a/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/GenerationSettingsAccordion.tsx
+++ b/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/GenerationSettingsAccordion.tsx
@@ -23,7 +23,7 @@ import ParamMainModelSelect from 'features/parameters/components/MainModel/Param
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
import { useExpanderToggle } from 'features/settingsAccordions/hooks/useExpanderToggle';
import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle';
-import { size } from 'lodash-es';
+import { filter, size } from 'lodash-es';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
@@ -32,7 +32,8 @@ const formLabelProps: FormLabelProps = {
};
const badgesSelector = createMemoizedSelector(selectLoraSlice, selectGenerationSlice, (lora, generation) => {
- const loraTabBadges = size(lora.loras) ? [size(lora.loras)] : [];
+ const enabledLoRAsCount = filter(lora.loras, (l) => !!l.isEnabled).length;
+ const loraTabBadges = size(lora.loras) ? [enabledLoRAsCount] : [];
const accordionBadges: (string | number)[] = [];
if (generation.model) {
accordionBadges.push(generation.model.model_name);