mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat: ✨ disable/enable LoRas with a switch (#5591)
* feat: ✨ disable/enable LorRas with a switch * feat: ✨ visually display previous weight when disabled * style: 🚨 linting * feat: ✨ lora badge count reflects active loras * style: 🚨 linting * feat: ✨ track disabled lora on state instead of weight * style: 🚨 linting * feat: ✨ it all works now tracking isEnabled on lora state, disabled slider when disabled, removed disabled loras from graph, updated badge counting and renamed lora add function * style: 🚨 linting * fix: 🐛 enabledLoRAs filter nullish coalescing * refactor: 🎨 minor changes renamed lora toggle action, removed errent comment, removed extraneous type annotation * style: 🚨 linting
This commit is contained in:
parent
0d4de4cc63
commit
f70c0936ca
@ -4,12 +4,14 @@ import {
|
|||||||
CardHeader,
|
CardHeader,
|
||||||
CompositeNumberInput,
|
CompositeNumberInput,
|
||||||
CompositeSlider,
|
CompositeSlider,
|
||||||
|
Flex,
|
||||||
IconButton,
|
IconButton,
|
||||||
|
Switch,
|
||||||
Text,
|
Text,
|
||||||
} from '@invoke-ai/ui-library';
|
} from '@invoke-ai/ui-library';
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
import type { LoRA } from 'features/lora/store/loraSlice';
|
import type { LoRA } from 'features/lora/store/loraSlice';
|
||||||
import { loraRemoved, loraWeightChanged } from 'features/lora/store/loraSlice';
|
import { loraIsEnabledChanged, loraRemoved, loraWeightChanged } from 'features/lora/store/loraSlice';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { PiTrashSimpleBold } from 'react-icons/pi';
|
import { PiTrashSimpleBold } from 'react-icons/pi';
|
||||||
|
|
||||||
@ -28,6 +30,10 @@ export const LoRACard = memo((props: LoRACardProps) => {
|
|||||||
[dispatch, lora.id]
|
[dispatch, lora.id]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const handleSetLoraToggle = useCallback(() => {
|
||||||
|
dispatch(loraIsEnabledChanged({ id: lora.id, isEnabled: !lora.isEnabled }));
|
||||||
|
}, [dispatch, lora.id, lora.isEnabled]);
|
||||||
|
|
||||||
const handleRemoveLora = useCallback(() => {
|
const handleRemoveLora = useCallback(() => {
|
||||||
dispatch(loraRemoved(lora.id));
|
dispatch(loraRemoved(lora.id));
|
||||||
}, [dispatch, lora.id]);
|
}, [dispatch, lora.id]);
|
||||||
@ -35,16 +41,21 @@ export const LoRACard = memo((props: LoRACardProps) => {
|
|||||||
return (
|
return (
|
||||||
<Card variant="lora">
|
<Card variant="lora">
|
||||||
<CardHeader>
|
<CardHeader>
|
||||||
<Text noOfLines={1} wordBreak="break-all" color="base.200">
|
<Flex alignItems="center" justifyContent="space-between" width="100%" gap={2}>
|
||||||
{lora.model_name}
|
<Text noOfLines={1} wordBreak="break-all" color={lora.isEnabled ? 'base.200' : 'base.500'}>
|
||||||
</Text>
|
{lora.model_name}
|
||||||
<IconButton
|
</Text>
|
||||||
aria-label="Remove LoRA"
|
<Flex alignItems="center" gap={2}>
|
||||||
variant="ghost"
|
<Switch size="sm" onChange={handleSetLoraToggle} isChecked={lora.isEnabled} />
|
||||||
size="sm"
|
<IconButton
|
||||||
onClick={handleRemoveLora}
|
aria-label="Remove LoRA"
|
||||||
icon={<PiTrashSimpleBold />}
|
variant="ghost"
|
||||||
/>
|
size="sm"
|
||||||
|
onClick={handleRemoveLora}
|
||||||
|
icon={<PiTrashSimpleBold />}
|
||||||
|
/>
|
||||||
|
</Flex>
|
||||||
|
</Flex>
|
||||||
</CardHeader>
|
</CardHeader>
|
||||||
<CardBody>
|
<CardBody>
|
||||||
<CompositeSlider
|
<CompositeSlider
|
||||||
@ -55,6 +66,7 @@ export const LoRACard = memo((props: LoRACardProps) => {
|
|||||||
step={0.01}
|
step={0.01}
|
||||||
marks={marks}
|
marks={marks}
|
||||||
defaultValue={0.75}
|
defaultValue={0.75}
|
||||||
|
isDisabled={!lora.isEnabled}
|
||||||
/>
|
/>
|
||||||
<CompositeNumberInput
|
<CompositeNumberInput
|
||||||
value={lora.weight}
|
value={lora.weight}
|
||||||
@ -65,6 +77,7 @@ export const LoRACard = memo((props: LoRACardProps) => {
|
|||||||
w={20}
|
w={20}
|
||||||
flexShrink={0}
|
flexShrink={0}
|
||||||
defaultValue={0.75}
|
defaultValue={0.75}
|
||||||
|
isDisabled={!lora.isEnabled}
|
||||||
/>
|
/>
|
||||||
</CardBody>
|
</CardBody>
|
||||||
</Card>
|
</Card>
|
||||||
|
@ -7,10 +7,12 @@ import type { LoRAModelConfigEntity } from 'services/api/endpoints/models';
|
|||||||
export type LoRA = ParameterLoRAModel & {
|
export type LoRA = ParameterLoRAModel & {
|
||||||
id: string;
|
id: string;
|
||||||
weight: number;
|
weight: number;
|
||||||
|
isEnabled?: boolean;
|
||||||
};
|
};
|
||||||
|
|
||||||
export const defaultLoRAConfig = {
|
export const defaultLoRAConfig: Pick<LoRA, 'weight' | 'isEnabled'> = {
|
||||||
weight: 0.75,
|
weight: 0.75,
|
||||||
|
isEnabled: true,
|
||||||
};
|
};
|
||||||
|
|
||||||
export type LoraState = {
|
export type LoraState = {
|
||||||
@ -58,11 +60,26 @@ export const loraSlice = createSlice({
|
|||||||
}
|
}
|
||||||
lora.weight = defaultLoRAConfig.weight;
|
lora.weight = defaultLoRAConfig.weight;
|
||||||
},
|
},
|
||||||
|
loraIsEnabledChanged: (state, action: PayloadAction<Pick<LoRA, 'id' | 'isEnabled'>>) => {
|
||||||
|
const { id, isEnabled } = action.payload;
|
||||||
|
const lora = state.loras[id];
|
||||||
|
if (!lora) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
lora.isEnabled = isEnabled;
|
||||||
|
},
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
export const { loraAdded, loraRemoved, loraWeightChanged, loraWeightReset, lorasCleared, loraRecalled } =
|
export const {
|
||||||
loraSlice.actions;
|
loraAdded,
|
||||||
|
loraRemoved,
|
||||||
|
loraWeightChanged,
|
||||||
|
loraWeightReset,
|
||||||
|
loraIsEnabledChanged,
|
||||||
|
lorasCleared,
|
||||||
|
loraRecalled,
|
||||||
|
} = loraSlice.actions;
|
||||||
|
|
||||||
export default loraSlice.reducer;
|
export default loraSlice.reducer;
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import type { RootState } from 'app/store/store';
|
import type { RootState } from 'app/store/store';
|
||||||
import { forEach, size } from 'lodash-es';
|
import { filter, size } from 'lodash-es';
|
||||||
import type { CoreMetadataInvocation, LoraLoaderInvocation, NonNullableGraph } from 'services/api/types';
|
import type { CoreMetadataInvocation, LoraLoaderInvocation, NonNullableGraph } from 'services/api/types';
|
||||||
|
|
||||||
import {
|
import {
|
||||||
@ -28,8 +28,8 @@ export const addLoRAsToGraph = (
|
|||||||
* So we need to inject a LoRA chain into the graph.
|
* So we need to inject a LoRA chain into the graph.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
const { loras } = state.lora;
|
const enabledLoRAs = filter(state.lora.loras, (l) => l.isEnabled ?? false);
|
||||||
const loraCount = size(loras);
|
const loraCount = size(enabledLoRAs);
|
||||||
|
|
||||||
if (loraCount === 0) {
|
if (loraCount === 0) {
|
||||||
return;
|
return;
|
||||||
@ -47,7 +47,7 @@ export const addLoRAsToGraph = (
|
|||||||
let currentLoraIndex = 0;
|
let currentLoraIndex = 0;
|
||||||
const loraMetadata: CoreMetadataInvocation['loras'] = [];
|
const loraMetadata: CoreMetadataInvocation['loras'] = [];
|
||||||
|
|
||||||
forEach(loras, (lora) => {
|
enabledLoRAs.forEach((lora) => {
|
||||||
const { model_name, base_model, weight } = lora;
|
const { model_name, base_model, weight } = lora;
|
||||||
const currentLoraNodeId = `${LORA_LOADER}_${model_name.replace('.', '_')}`;
|
const currentLoraNodeId = `${LORA_LOADER}_${model_name.replace('.', '_')}`;
|
||||||
|
|
||||||
|
@ -23,7 +23,7 @@ import ParamMainModelSelect from 'features/parameters/components/MainModel/Param
|
|||||||
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
|
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
|
||||||
import { useExpanderToggle } from 'features/settingsAccordions/hooks/useExpanderToggle';
|
import { useExpanderToggle } from 'features/settingsAccordions/hooks/useExpanderToggle';
|
||||||
import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle';
|
import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle';
|
||||||
import { size } from 'lodash-es';
|
import { filter, size } from 'lodash-es';
|
||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
@ -32,7 +32,8 @@ const formLabelProps: FormLabelProps = {
|
|||||||
};
|
};
|
||||||
|
|
||||||
const badgesSelector = createMemoizedSelector(selectLoraSlice, selectGenerationSlice, (lora, generation) => {
|
const badgesSelector = createMemoizedSelector(selectLoraSlice, selectGenerationSlice, (lora, generation) => {
|
||||||
const loraTabBadges = size(lora.loras) ? [size(lora.loras)] : [];
|
const enabledLoRAsCount = filter(lora.loras, (l) => !!l.isEnabled).length;
|
||||||
|
const loraTabBadges = size(lora.loras) ? [enabledLoRAsCount] : [];
|
||||||
const accordionBadges: (string | number)[] = [];
|
const accordionBadges: (string | number)[] = [];
|
||||||
if (generation.model) {
|
if (generation.model) {
|
||||||
accordionBadges.push(generation.model.model_name);
|
accordionBadges.push(generation.model.model_name);
|
||||||
|
Loading…
Reference in New Issue
Block a user