feat: disable/enable LoRas with a switch (#5591)

* feat:  disable/enable LorRas with a switch

* feat:  visually display previous weight when disabled

* style: 🚨 linting

* feat:  lora badge count reflects active loras

* style: 🚨 linting

* feat:  track disabled lora on state instead of weight

* style: 🚨 linting

* feat:  it all works now

tracking isEnabled on lora state, disabled slider when disabled, removed disabled loras from graph, updated badge counting and renamed lora add function

* style: 🚨 linting

* fix: 🐛 enabledLoRAs filter nullish coalescing

* refactor: 🎨 minor changes

renamed lora toggle action, removed errent comment, removed extraneous type annotation

* style: 🚨 linting
This commit is contained in:
Josh Corbett 2024-01-30 22:50:03 -07:00 committed by GitHub
parent 0d4de4cc63
commit f70c0936ca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 51 additions and 20 deletions

View File

@ -4,12 +4,14 @@ import {
CardHeader,
CompositeNumberInput,
CompositeSlider,
Flex,
IconButton,
Switch,
Text,
} from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import type { LoRA } from 'features/lora/store/loraSlice';
import { loraRemoved, loraWeightChanged } from 'features/lora/store/loraSlice';
import { loraIsEnabledChanged, loraRemoved, loraWeightChanged } from 'features/lora/store/loraSlice';
import { memo, useCallback } from 'react';
import { PiTrashSimpleBold } from 'react-icons/pi';
@ -28,6 +30,10 @@ export const LoRACard = memo((props: LoRACardProps) => {
[dispatch, lora.id]
);
const handleSetLoraToggle = useCallback(() => {
dispatch(loraIsEnabledChanged({ id: lora.id, isEnabled: !lora.isEnabled }));
}, [dispatch, lora.id, lora.isEnabled]);
const handleRemoveLora = useCallback(() => {
dispatch(loraRemoved(lora.id));
}, [dispatch, lora.id]);
@ -35,16 +41,21 @@ export const LoRACard = memo((props: LoRACardProps) => {
return (
<Card variant="lora">
<CardHeader>
<Text noOfLines={1} wordBreak="break-all" color="base.200">
{lora.model_name}
</Text>
<IconButton
aria-label="Remove LoRA"
variant="ghost"
size="sm"
onClick={handleRemoveLora}
icon={<PiTrashSimpleBold />}
/>
<Flex alignItems="center" justifyContent="space-between" width="100%" gap={2}>
<Text noOfLines={1} wordBreak="break-all" color={lora.isEnabled ? 'base.200' : 'base.500'}>
{lora.model_name}
</Text>
<Flex alignItems="center" gap={2}>
<Switch size="sm" onChange={handleSetLoraToggle} isChecked={lora.isEnabled} />
<IconButton
aria-label="Remove LoRA"
variant="ghost"
size="sm"
onClick={handleRemoveLora}
icon={<PiTrashSimpleBold />}
/>
</Flex>
</Flex>
</CardHeader>
<CardBody>
<CompositeSlider
@ -55,6 +66,7 @@ export const LoRACard = memo((props: LoRACardProps) => {
step={0.01}
marks={marks}
defaultValue={0.75}
isDisabled={!lora.isEnabled}
/>
<CompositeNumberInput
value={lora.weight}
@ -65,6 +77,7 @@ export const LoRACard = memo((props: LoRACardProps) => {
w={20}
flexShrink={0}
defaultValue={0.75}
isDisabled={!lora.isEnabled}
/>
</CardBody>
</Card>

View File

@ -7,10 +7,12 @@ import type { LoRAModelConfigEntity } from 'services/api/endpoints/models';
export type LoRA = ParameterLoRAModel & {
id: string;
weight: number;
isEnabled?: boolean;
};
export const defaultLoRAConfig = {
export const defaultLoRAConfig: Pick<LoRA, 'weight' | 'isEnabled'> = {
weight: 0.75,
isEnabled: true,
};
export type LoraState = {
@ -58,11 +60,26 @@ export const loraSlice = createSlice({
}
lora.weight = defaultLoRAConfig.weight;
},
loraIsEnabledChanged: (state, action: PayloadAction<Pick<LoRA, 'id' | 'isEnabled'>>) => {
const { id, isEnabled } = action.payload;
const lora = state.loras[id];
if (!lora) {
return;
}
lora.isEnabled = isEnabled;
},
},
});
export const { loraAdded, loraRemoved, loraWeightChanged, loraWeightReset, lorasCleared, loraRecalled } =
loraSlice.actions;
export const {
loraAdded,
loraRemoved,
loraWeightChanged,
loraWeightReset,
loraIsEnabledChanged,
lorasCleared,
loraRecalled,
} = loraSlice.actions;
export default loraSlice.reducer;

View File

@ -1,5 +1,5 @@
import type { RootState } from 'app/store/store';
import { forEach, size } from 'lodash-es';
import { filter, size } from 'lodash-es';
import type { CoreMetadataInvocation, LoraLoaderInvocation, NonNullableGraph } from 'services/api/types';
import {
@ -28,8 +28,8 @@ export const addLoRAsToGraph = (
* So we need to inject a LoRA chain into the graph.
*/
const { loras } = state.lora;
const loraCount = size(loras);
const enabledLoRAs = filter(state.lora.loras, (l) => l.isEnabled ?? false);
const loraCount = size(enabledLoRAs);
if (loraCount === 0) {
return;
@ -47,7 +47,7 @@ export const addLoRAsToGraph = (
let currentLoraIndex = 0;
const loraMetadata: CoreMetadataInvocation['loras'] = [];
forEach(loras, (lora) => {
enabledLoRAs.forEach((lora) => {
const { model_name, base_model, weight } = lora;
const currentLoraNodeId = `${LORA_LOADER}_${model_name.replace('.', '_')}`;

View File

@ -23,7 +23,7 @@ import ParamMainModelSelect from 'features/parameters/components/MainModel/Param
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
import { useExpanderToggle } from 'features/settingsAccordions/hooks/useExpanderToggle';
import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle';
import { size } from 'lodash-es';
import { filter, size } from 'lodash-es';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
@ -32,7 +32,8 @@ const formLabelProps: FormLabelProps = {
};
const badgesSelector = createMemoizedSelector(selectLoraSlice, selectGenerationSlice, (lora, generation) => {
const loraTabBadges = size(lora.loras) ? [size(lora.loras)] : [];
const enabledLoRAsCount = filter(lora.loras, (l) => !!l.isEnabled).length;
const loraTabBadges = size(lora.loras) ? [enabledLoRAsCount] : [];
const accordionBadges: (string | number)[] = [];
if (generation.model) {
accordionBadges.push(generation.model.model_name);