make LoRAs editable

This commit is contained in:
Lincoln Stein 2023-07-25 23:20:14 -04:00 committed by psychedelicious
parent 4c79350300
commit aa2c94be9e
2 changed files with 95 additions and 17 deletions

View File

@ -1,12 +1,21 @@
import { Divider, Flex, Text } from '@chakra-ui/react'; import { Divider, Flex, Text } from '@chakra-ui/react';
import { useForm } from '@mantine/form'; import { useForm } from '@mantine/form';
import { makeToast } from 'features/system/util/makeToast';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import IAIMantineTextInput from 'common/components/IAIMantineInput'; import IAIMantineTextInput from 'common/components/IAIMantineInput';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { addToast } from 'features/system/store/systemSlice';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { import {
LORA_MODEL_FORMAT_MAP, LORA_MODEL_FORMAT_MAP,
MODEL_TYPE_MAP, MODEL_TYPE_MAP,
} from 'features/parameters/types/constants'; } from 'features/parameters/types/constants';
import { useTranslation } from 'react-i18next'; import {
import { LoRAModelConfigEntity } from 'services/api/endpoints/models'; LoRAModelConfigEntity,
useUpdateLoRAModelsMutation,
} from 'services/api/endpoints/models';
import { LoRAModelConfig } from 'services/api/types'; import { LoRAModelConfig } from 'services/api/types';
import BaseModelSelect from '../shared/BaseModelSelect'; import BaseModelSelect from '../shared/BaseModelSelect';
@ -15,8 +24,13 @@ type LoRAModelEditProps = {
}; };
export default function LoRAModelEdit(props: LoRAModelEditProps) { export default function LoRAModelEdit(props: LoRAModelEditProps) {
const isBusy = useAppSelector(selectIsBusy);
const { model } = props; const { model } = props;
const [updateLoRAModel, { isLoading }] = useUpdateLoRAModelsMutation();
const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const loraEditForm = useForm<LoRAModelConfig>({ const loraEditForm = useForm<LoRAModelConfig>({
@ -34,6 +48,49 @@ export default function LoRAModelEdit(props: LoRAModelEditProps) {
}, },
}); });
const editModelFormSubmitHandler = useCallback(
(values: LoRAModelConfig) => {
const responseBody = {
base_model: model.base_model,
model_name: model.model_name,
body: values,
};
updateLoRAModel(responseBody)
.unwrap()
.then((payload) => {
loraEditForm.setValues(payload as LoRAModelConfig);
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdated'),
status: 'success',
})
)
);
})
.catch((_) => {
loraEditForm.reset();
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdateFailed'),
status: 'error',
})
)
);
});
},
[
dispatch,
loraEditForm,
model.base_model,
model.model_name,
t,
updateLoRAModel,
]
);
return ( return (
<Flex flexDirection="column" rowGap={4} width="100%"> <Flex flexDirection="column" rowGap={4} width="100%">
<Flex flexDirection="column"> <Flex flexDirection="column">
@ -47,34 +104,32 @@ export default function LoRAModelEdit(props: LoRAModelEditProps) {
</Flex> </Flex>
<Divider /> <Divider />
<form> <form
onSubmit={loraEditForm.onSubmit((values) =>
editModelFormSubmitHandler(values)
)}
>
<Flex flexDirection="column" overflowY="scroll" gap={4}> <Flex flexDirection="column" overflowY="scroll" gap={4}>
<IAIMantineTextInput <IAIMantineTextInput
label={t('modelManager.name')} label={t('modelManager.name')}
readOnly={true}
disabled={true}
{...loraEditForm.getInputProps('model_name')} {...loraEditForm.getInputProps('model_name')}
/> />
<IAIMantineTextInput <IAIMantineTextInput
label={t('modelManager.description')} label={t('modelManager.description')}
readOnly={true}
disabled={true}
{...loraEditForm.getInputProps('description')} {...loraEditForm.getInputProps('description')}
/> />
<BaseModelSelect <BaseModelSelect {...loraEditForm.getInputProps('base_model')} />
readOnly={true}
disabled={true}
{...loraEditForm.getInputProps('base_model')}
/>
<IAIMantineTextInput <IAIMantineTextInput
readOnly={true}
disabled={true}
label={t('modelManager.modelLocation')} label={t('modelManager.modelLocation')}
{...loraEditForm.getInputProps('path')} {...loraEditForm.getInputProps('path')}
/> />
<Text color="base.400"> <IAIButton
{t('Editing LoRA model metadata is not yet supported.')} type="submit"
</Text> isDisabled={isBusy || isLoading}
isLoading={isLoading}
>
{t('modelManager.updateModel')}
</IAIButton>
</Flex> </Flex>
</form> </form>
</Flex> </Flex>

View File

@ -52,9 +52,18 @@ type UpdateMainModelArg = {
body: MainModelConfig; body: MainModelConfig;
}; };
type UpdateLoRAModelArg = {
base_model: BaseModelType;
model_name: string;
body: LoRAModelConfig;
};
type UpdateMainModelResponse = type UpdateMainModelResponse =
paths['/api/v1/models/{base_model}/{model_type}/{model_name}']['patch']['responses']['200']['content']['application/json']; paths['/api/v1/models/{base_model}/{model_type}/{model_name}']['patch']['responses']['200']['content']['application/json'];
type UpdateLoRAModelResponse =
paths['/api/v1/models/{base_model}/{model_type}/{model_name}']['patch']['responses']['200']['content']['application/json'];
type DeleteMainModelArg = { type DeleteMainModelArg = {
base_model: BaseModelType; base_model: BaseModelType;
model_name: string; model_name: string;
@ -324,6 +333,19 @@ export const modelsApi = api.injectEndpoints({
); );
}, },
}), }),
updateLoRAModels: build.mutation<
UpdateLoRAModelResponse,
UpdateLoRAModelArg
>({
query: ({ base_model, model_name, body }) => {
return {
url: `models/${base_model}/lora/${model_name}`,
method: 'PATCH',
body: body,
};
},
invalidatesTags: [{ type: 'LoRAModel', id: LIST_TAG }],
}),
deleteLoRAModels: build.mutation< deleteLoRAModels: build.mutation<
DeleteLoRAModelResponse, DeleteLoRAModelResponse,
DeleteLoRAModelArg DeleteLoRAModelArg
@ -484,6 +506,7 @@ export const {
useConvertMainModelsMutation, useConvertMainModelsMutation,
useMergeMainModelsMutation, useMergeMainModelsMutation,
useDeleteLoRAModelsMutation, useDeleteLoRAModelsMutation,
useUpdateLoRAModelsMutation,
useSyncModelsMutation, useSyncModelsMutation,
useGetModelsInFolderQuery, useGetModelsInFolderQuery,
useGetCheckpointConfigsQuery, useGetCheckpointConfigsQuery,