mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat: ✨ disable/enable LoRas with a switch (#5591)
* feat: ✨ disable/enable LorRas with a switch * feat: ✨ visually display previous weight when disabled * style: 🚨 linting * feat: ✨ lora badge count reflects active loras * style: 🚨 linting * feat: ✨ track disabled lora on state instead of weight * style: 🚨 linting * feat: ✨ it all works now tracking isEnabled on lora state, disabled slider when disabled, removed disabled loras from graph, updated badge counting and renamed lora add function * style: 🚨 linting * fix: 🐛 enabledLoRAs filter nullish coalescing * refactor: 🎨 minor changes renamed lora toggle action, removed errent comment, removed extraneous type annotation * style: 🚨 linting
This commit is contained in:
parent
0d4de4cc63
commit
f70c0936ca
@ -4,12 +4,14 @@ import {
|
||||
CardHeader,
|
||||
CompositeNumberInput,
|
||||
CompositeSlider,
|
||||
Flex,
|
||||
IconButton,
|
||||
Switch,
|
||||
Text,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import type { LoRA } from 'features/lora/store/loraSlice';
|
||||
import { loraRemoved, loraWeightChanged } from 'features/lora/store/loraSlice';
|
||||
import { loraIsEnabledChanged, loraRemoved, loraWeightChanged } from 'features/lora/store/loraSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { PiTrashSimpleBold } from 'react-icons/pi';
|
||||
|
||||
@ -28,6 +30,10 @@ export const LoRACard = memo((props: LoRACardProps) => {
|
||||
[dispatch, lora.id]
|
||||
);
|
||||
|
||||
const handleSetLoraToggle = useCallback(() => {
|
||||
dispatch(loraIsEnabledChanged({ id: lora.id, isEnabled: !lora.isEnabled }));
|
||||
}, [dispatch, lora.id, lora.isEnabled]);
|
||||
|
||||
const handleRemoveLora = useCallback(() => {
|
||||
dispatch(loraRemoved(lora.id));
|
||||
}, [dispatch, lora.id]);
|
||||
@ -35,16 +41,21 @@ export const LoRACard = memo((props: LoRACardProps) => {
|
||||
return (
|
||||
<Card variant="lora">
|
||||
<CardHeader>
|
||||
<Text noOfLines={1} wordBreak="break-all" color="base.200">
|
||||
{lora.model_name}
|
||||
</Text>
|
||||
<IconButton
|
||||
aria-label="Remove LoRA"
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
onClick={handleRemoveLora}
|
||||
icon={<PiTrashSimpleBold />}
|
||||
/>
|
||||
<Flex alignItems="center" justifyContent="space-between" width="100%" gap={2}>
|
||||
<Text noOfLines={1} wordBreak="break-all" color={lora.isEnabled ? 'base.200' : 'base.500'}>
|
||||
{lora.model_name}
|
||||
</Text>
|
||||
<Flex alignItems="center" gap={2}>
|
||||
<Switch size="sm" onChange={handleSetLoraToggle} isChecked={lora.isEnabled} />
|
||||
<IconButton
|
||||
aria-label="Remove LoRA"
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
onClick={handleRemoveLora}
|
||||
icon={<PiTrashSimpleBold />}
|
||||
/>
|
||||
</Flex>
|
||||
</Flex>
|
||||
</CardHeader>
|
||||
<CardBody>
|
||||
<CompositeSlider
|
||||
@ -55,6 +66,7 @@ export const LoRACard = memo((props: LoRACardProps) => {
|
||||
step={0.01}
|
||||
marks={marks}
|
||||
defaultValue={0.75}
|
||||
isDisabled={!lora.isEnabled}
|
||||
/>
|
||||
<CompositeNumberInput
|
||||
value={lora.weight}
|
||||
@ -65,6 +77,7 @@ export const LoRACard = memo((props: LoRACardProps) => {
|
||||
w={20}
|
||||
flexShrink={0}
|
||||
defaultValue={0.75}
|
||||
isDisabled={!lora.isEnabled}
|
||||
/>
|
||||
</CardBody>
|
||||
</Card>
|
||||
|
@ -7,10 +7,12 @@ import type { LoRAModelConfigEntity } from 'services/api/endpoints/models';
|
||||
export type LoRA = ParameterLoRAModel & {
|
||||
id: string;
|
||||
weight: number;
|
||||
isEnabled?: boolean;
|
||||
};
|
||||
|
||||
export const defaultLoRAConfig = {
|
||||
export const defaultLoRAConfig: Pick<LoRA, 'weight' | 'isEnabled'> = {
|
||||
weight: 0.75,
|
||||
isEnabled: true,
|
||||
};
|
||||
|
||||
export type LoraState = {
|
||||
@ -58,11 +60,26 @@ export const loraSlice = createSlice({
|
||||
}
|
||||
lora.weight = defaultLoRAConfig.weight;
|
||||
},
|
||||
loraIsEnabledChanged: (state, action: PayloadAction<Pick<LoRA, 'id' | 'isEnabled'>>) => {
|
||||
const { id, isEnabled } = action.payload;
|
||||
const lora = state.loras[id];
|
||||
if (!lora) {
|
||||
return;
|
||||
}
|
||||
lora.isEnabled = isEnabled;
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
export const { loraAdded, loraRemoved, loraWeightChanged, loraWeightReset, lorasCleared, loraRecalled } =
|
||||
loraSlice.actions;
|
||||
export const {
|
||||
loraAdded,
|
||||
loraRemoved,
|
||||
loraWeightChanged,
|
||||
loraWeightReset,
|
||||
loraIsEnabledChanged,
|
||||
lorasCleared,
|
||||
loraRecalled,
|
||||
} = loraSlice.actions;
|
||||
|
||||
export default loraSlice.reducer;
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { forEach, size } from 'lodash-es';
|
||||
import { filter, size } from 'lodash-es';
|
||||
import type { CoreMetadataInvocation, LoraLoaderInvocation, NonNullableGraph } from 'services/api/types';
|
||||
|
||||
import {
|
||||
@ -28,8 +28,8 @@ export const addLoRAsToGraph = (
|
||||
* So we need to inject a LoRA chain into the graph.
|
||||
*/
|
||||
|
||||
const { loras } = state.lora;
|
||||
const loraCount = size(loras);
|
||||
const enabledLoRAs = filter(state.lora.loras, (l) => l.isEnabled ?? false);
|
||||
const loraCount = size(enabledLoRAs);
|
||||
|
||||
if (loraCount === 0) {
|
||||
return;
|
||||
@ -47,7 +47,7 @@ export const addLoRAsToGraph = (
|
||||
let currentLoraIndex = 0;
|
||||
const loraMetadata: CoreMetadataInvocation['loras'] = [];
|
||||
|
||||
forEach(loras, (lora) => {
|
||||
enabledLoRAs.forEach((lora) => {
|
||||
const { model_name, base_model, weight } = lora;
|
||||
const currentLoraNodeId = `${LORA_LOADER}_${model_name.replace('.', '_')}`;
|
||||
|
||||
|
@ -23,7 +23,7 @@ import ParamMainModelSelect from 'features/parameters/components/MainModel/Param
|
||||
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
|
||||
import { useExpanderToggle } from 'features/settingsAccordions/hooks/useExpanderToggle';
|
||||
import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle';
|
||||
import { size } from 'lodash-es';
|
||||
import { filter, size } from 'lodash-es';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
@ -32,7 +32,8 @@ const formLabelProps: FormLabelProps = {
|
||||
};
|
||||
|
||||
const badgesSelector = createMemoizedSelector(selectLoraSlice, selectGenerationSlice, (lora, generation) => {
|
||||
const loraTabBadges = size(lora.loras) ? [size(lora.loras)] : [];
|
||||
const enabledLoRAsCount = filter(lora.loras, (l) => !!l.isEnabled).length;
|
||||
const loraTabBadges = size(lora.loras) ? [enabledLoRAsCount] : [];
|
||||
const accordionBadges: (string | number)[] = [];
|
||||
if (generation.model) {
|
||||
accordionBadges.push(generation.model.model_name);
|
||||
|
Loading…
Reference in New Issue
Block a user