feat: disable/enable LoRas with a switch (#5591)

* feat:  disable/enable LorRas with a switch

* feat:  visually display previous weight when disabled

* style: 🚨 linting

* feat:  lora badge count reflects active loras

* style: 🚨 linting

* feat:  track disabled lora on state instead of weight

* style: 🚨 linting

* feat:  it all works now

tracking isEnabled on lora state, disabled slider when disabled, removed disabled loras from graph, updated badge counting and renamed lora add function

* style: 🚨 linting

* fix: 🐛 enabledLoRAs filter nullish coalescing

* refactor: 🎨 minor changes

renamed lora toggle action, removed errent comment, removed extraneous type annotation

* style: 🚨 linting
This commit is contained in:
Josh Corbett 2024-01-30 22:50:03 -07:00 committed by GitHub
parent 0d4de4cc63
commit f70c0936ca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 51 additions and 20 deletions

View File

@ -4,12 +4,14 @@ import {
CardHeader, CardHeader,
CompositeNumberInput, CompositeNumberInput,
CompositeSlider, CompositeSlider,
Flex,
IconButton, IconButton,
Switch,
Text, Text,
} from '@invoke-ai/ui-library'; } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import type { LoRA } from 'features/lora/store/loraSlice'; import type { LoRA } from 'features/lora/store/loraSlice';
import { loraRemoved, loraWeightChanged } from 'features/lora/store/loraSlice'; import { loraIsEnabledChanged, loraRemoved, loraWeightChanged } from 'features/lora/store/loraSlice';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { PiTrashSimpleBold } from 'react-icons/pi'; import { PiTrashSimpleBold } from 'react-icons/pi';
@ -28,6 +30,10 @@ export const LoRACard = memo((props: LoRACardProps) => {
[dispatch, lora.id] [dispatch, lora.id]
); );
const handleSetLoraToggle = useCallback(() => {
dispatch(loraIsEnabledChanged({ id: lora.id, isEnabled: !lora.isEnabled }));
}, [dispatch, lora.id, lora.isEnabled]);
const handleRemoveLora = useCallback(() => { const handleRemoveLora = useCallback(() => {
dispatch(loraRemoved(lora.id)); dispatch(loraRemoved(lora.id));
}, [dispatch, lora.id]); }, [dispatch, lora.id]);
@ -35,9 +41,12 @@ export const LoRACard = memo((props: LoRACardProps) => {
return ( return (
<Card variant="lora"> <Card variant="lora">
<CardHeader> <CardHeader>
<Text noOfLines={1} wordBreak="break-all" color="base.200"> <Flex alignItems="center" justifyContent="space-between" width="100%" gap={2}>
<Text noOfLines={1} wordBreak="break-all" color={lora.isEnabled ? 'base.200' : 'base.500'}>
{lora.model_name} {lora.model_name}
</Text> </Text>
<Flex alignItems="center" gap={2}>
<Switch size="sm" onChange={handleSetLoraToggle} isChecked={lora.isEnabled} />
<IconButton <IconButton
aria-label="Remove LoRA" aria-label="Remove LoRA"
variant="ghost" variant="ghost"
@ -45,6 +54,8 @@ export const LoRACard = memo((props: LoRACardProps) => {
onClick={handleRemoveLora} onClick={handleRemoveLora}
icon={<PiTrashSimpleBold />} icon={<PiTrashSimpleBold />}
/> />
</Flex>
</Flex>
</CardHeader> </CardHeader>
<CardBody> <CardBody>
<CompositeSlider <CompositeSlider
@ -55,6 +66,7 @@ export const LoRACard = memo((props: LoRACardProps) => {
step={0.01} step={0.01}
marks={marks} marks={marks}
defaultValue={0.75} defaultValue={0.75}
isDisabled={!lora.isEnabled}
/> />
<CompositeNumberInput <CompositeNumberInput
value={lora.weight} value={lora.weight}
@ -65,6 +77,7 @@ export const LoRACard = memo((props: LoRACardProps) => {
w={20} w={20}
flexShrink={0} flexShrink={0}
defaultValue={0.75} defaultValue={0.75}
isDisabled={!lora.isEnabled}
/> />
</CardBody> </CardBody>
</Card> </Card>

View File

@ -7,10 +7,12 @@ import type { LoRAModelConfigEntity } from 'services/api/endpoints/models';
export type LoRA = ParameterLoRAModel & { export type LoRA = ParameterLoRAModel & {
id: string; id: string;
weight: number; weight: number;
isEnabled?: boolean;
}; };
export const defaultLoRAConfig = { export const defaultLoRAConfig: Pick<LoRA, 'weight' | 'isEnabled'> = {
weight: 0.75, weight: 0.75,
isEnabled: true,
}; };
export type LoraState = { export type LoraState = {
@ -58,11 +60,26 @@ export const loraSlice = createSlice({
} }
lora.weight = defaultLoRAConfig.weight; lora.weight = defaultLoRAConfig.weight;
}, },
loraIsEnabledChanged: (state, action: PayloadAction<Pick<LoRA, 'id' | 'isEnabled'>>) => {
const { id, isEnabled } = action.payload;
const lora = state.loras[id];
if (!lora) {
return;
}
lora.isEnabled = isEnabled;
},
}, },
}); });
export const { loraAdded, loraRemoved, loraWeightChanged, loraWeightReset, lorasCleared, loraRecalled } = export const {
loraSlice.actions; loraAdded,
loraRemoved,
loraWeightChanged,
loraWeightReset,
loraIsEnabledChanged,
lorasCleared,
loraRecalled,
} = loraSlice.actions;
export default loraSlice.reducer; export default loraSlice.reducer;

View File

@ -1,5 +1,5 @@
import type { RootState } from 'app/store/store'; import type { RootState } from 'app/store/store';
import { forEach, size } from 'lodash-es'; import { filter, size } from 'lodash-es';
import type { CoreMetadataInvocation, LoraLoaderInvocation, NonNullableGraph } from 'services/api/types'; import type { CoreMetadataInvocation, LoraLoaderInvocation, NonNullableGraph } from 'services/api/types';
import { import {
@ -28,8 +28,8 @@ export const addLoRAsToGraph = (
* So we need to inject a LoRA chain into the graph. * So we need to inject a LoRA chain into the graph.
*/ */
const { loras } = state.lora; const enabledLoRAs = filter(state.lora.loras, (l) => l.isEnabled ?? false);
const loraCount = size(loras); const loraCount = size(enabledLoRAs);
if (loraCount === 0) { if (loraCount === 0) {
return; return;
@ -47,7 +47,7 @@ export const addLoRAsToGraph = (
let currentLoraIndex = 0; let currentLoraIndex = 0;
const loraMetadata: CoreMetadataInvocation['loras'] = []; const loraMetadata: CoreMetadataInvocation['loras'] = [];
forEach(loras, (lora) => { enabledLoRAs.forEach((lora) => {
const { model_name, base_model, weight } = lora; const { model_name, base_model, weight } = lora;
const currentLoraNodeId = `${LORA_LOADER}_${model_name.replace('.', '_')}`; const currentLoraNodeId = `${LORA_LOADER}_${model_name.replace('.', '_')}`;

View File

@ -23,7 +23,7 @@ import ParamMainModelSelect from 'features/parameters/components/MainModel/Param
import { selectGenerationSlice } from 'features/parameters/store/generationSlice'; import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
import { useExpanderToggle } from 'features/settingsAccordions/hooks/useExpanderToggle'; import { useExpanderToggle } from 'features/settingsAccordions/hooks/useExpanderToggle';
import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle'; import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle';
import { size } from 'lodash-es'; import { filter, size } from 'lodash-es';
import { memo } from 'react'; import { memo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
@ -32,7 +32,8 @@ const formLabelProps: FormLabelProps = {
}; };
const badgesSelector = createMemoizedSelector(selectLoraSlice, selectGenerationSlice, (lora, generation) => { const badgesSelector = createMemoizedSelector(selectLoraSlice, selectGenerationSlice, (lora, generation) => {
const loraTabBadges = size(lora.loras) ? [size(lora.loras)] : []; const enabledLoRAsCount = filter(lora.loras, (l) => !!l.isEnabled).length;
const loraTabBadges = size(lora.loras) ? [enabledLoRAsCount] : [];
const accordionBadges: (string | number)[] = []; const accordionBadges: (string | number)[] = [];
if (generation.model) { if (generation.model) {
accordionBadges.push(generation.model.model_name); accordionBadges.push(generation.model.model_name);