diff --git a/invokeai/frontend/web/src/features/lora/components/LoRACard.tsx b/invokeai/frontend/web/src/features/lora/components/LoRACard.tsx index f0c8e3fcd3..caedde875a 100644 --- a/invokeai/frontend/web/src/features/lora/components/LoRACard.tsx +++ b/invokeai/frontend/web/src/features/lora/components/LoRACard.tsx @@ -4,12 +4,14 @@ import { CardHeader, CompositeNumberInput, CompositeSlider, + Flex, IconButton, + Switch, Text, } from '@invoke-ai/ui-library'; import { useAppDispatch } from 'app/store/storeHooks'; import type { LoRA } from 'features/lora/store/loraSlice'; -import { loraRemoved, loraWeightChanged } from 'features/lora/store/loraSlice'; +import { loraIsEnabledChanged, loraRemoved, loraWeightChanged } from 'features/lora/store/loraSlice'; import { memo, useCallback } from 'react'; import { PiTrashSimpleBold } from 'react-icons/pi'; @@ -28,6 +30,10 @@ export const LoRACard = memo((props: LoRACardProps) => { [dispatch, lora.id] ); + const handleSetLoraToggle = useCallback(() => { + dispatch(loraIsEnabledChanged({ id: lora.id, isEnabled: !lora.isEnabled })); + }, [dispatch, lora.id, lora.isEnabled]); + const handleRemoveLora = useCallback(() => { dispatch(loraRemoved(lora.id)); }, [dispatch, lora.id]); @@ -35,16 +41,21 @@ export const LoRACard = memo((props: LoRACardProps) => { return ( - - {lora.model_name} - - } - /> + + + {lora.model_name} + + + + } + /> + + { step={0.01} marks={marks} defaultValue={0.75} + isDisabled={!lora.isEnabled} /> { w={20} flexShrink={0} defaultValue={0.75} + isDisabled={!lora.isEnabled} /> diff --git a/invokeai/frontend/web/src/features/lora/store/loraSlice.ts b/invokeai/frontend/web/src/features/lora/store/loraSlice.ts index 8562090fbc..7906443d7f 100644 --- a/invokeai/frontend/web/src/features/lora/store/loraSlice.ts +++ b/invokeai/frontend/web/src/features/lora/store/loraSlice.ts @@ -7,10 +7,12 @@ import type { LoRAModelConfigEntity } from 'services/api/endpoints/models'; export type LoRA = ParameterLoRAModel & { id: string; weight: number; + isEnabled?: boolean; }; -export const defaultLoRAConfig = { +export const defaultLoRAConfig: Pick = { weight: 0.75, + isEnabled: true, }; export type LoraState = { @@ -58,11 +60,26 @@ export const loraSlice = createSlice({ } lora.weight = defaultLoRAConfig.weight; }, + loraIsEnabledChanged: (state, action: PayloadAction>) => { + const { id, isEnabled } = action.payload; + const lora = state.loras[id]; + if (!lora) { + return; + } + lora.isEnabled = isEnabled; + }, }, }); -export const { loraAdded, loraRemoved, loraWeightChanged, loraWeightReset, lorasCleared, loraRecalled } = - loraSlice.actions; +export const { + loraAdded, + loraRemoved, + loraWeightChanged, + loraWeightReset, + loraIsEnabledChanged, + lorasCleared, + loraRecalled, +} = loraSlice.actions; export default loraSlice.reducer; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addLoRAsToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addLoRAsToGraph.ts index d1b5ddde84..3ed71b7529 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addLoRAsToGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addLoRAsToGraph.ts @@ -1,5 +1,5 @@ import type { RootState } from 'app/store/store'; -import { forEach, size } from 'lodash-es'; +import { filter, size } from 'lodash-es'; import type { CoreMetadataInvocation, LoraLoaderInvocation, NonNullableGraph } from 'services/api/types'; import { @@ -28,8 +28,8 @@ export const addLoRAsToGraph = ( * So we need to inject a LoRA chain into the graph. */ - const { loras } = state.lora; - const loraCount = size(loras); + const enabledLoRAs = filter(state.lora.loras, (l) => l.isEnabled ?? false); + const loraCount = size(enabledLoRAs); if (loraCount === 0) { return; @@ -47,7 +47,7 @@ export const addLoRAsToGraph = ( let currentLoraIndex = 0; const loraMetadata: CoreMetadataInvocation['loras'] = []; - forEach(loras, (lora) => { + enabledLoRAs.forEach((lora) => { const { model_name, base_model, weight } = lora; const currentLoraNodeId = `${LORA_LOADER}_${model_name.replace('.', '_')}`; diff --git a/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/GenerationSettingsAccordion.tsx b/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/GenerationSettingsAccordion.tsx index d641ea9929..ea6fd3563d 100644 --- a/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/GenerationSettingsAccordion.tsx +++ b/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/GenerationSettingsAccordion.tsx @@ -23,7 +23,7 @@ import ParamMainModelSelect from 'features/parameters/components/MainModel/Param import { selectGenerationSlice } from 'features/parameters/store/generationSlice'; import { useExpanderToggle } from 'features/settingsAccordions/hooks/useExpanderToggle'; import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle'; -import { size } from 'lodash-es'; +import { filter, size } from 'lodash-es'; import { memo } from 'react'; import { useTranslation } from 'react-i18next'; @@ -32,7 +32,8 @@ const formLabelProps: FormLabelProps = { }; const badgesSelector = createMemoizedSelector(selectLoraSlice, selectGenerationSlice, (lora, generation) => { - const loraTabBadges = size(lora.loras) ? [size(lora.loras)] : []; + const enabledLoRAsCount = filter(lora.loras, (l) => !!l.isEnabled).length; + const loraTabBadges = size(lora.loras) ? [enabledLoRAsCount] : []; const accordionBadges: (string | number)[] = []; if (generation.model) { accordionBadges.push(generation.model.model_name);