make LoRAs editable

This commit is contained in:
Lincoln Stein 2023-07-25 23:20:14 -04:00 committed by psychedelicious
parent 4c79350300
commit aa2c94be9e
2 changed files with 95 additions and 17 deletions

View File

@ -1,12 +1,21 @@
import { Divider, Flex, Text } from '@chakra-ui/react';
import { useForm } from '@mantine/form';
import { makeToast } from 'features/system/util/makeToast';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import IAIMantineTextInput from 'common/components/IAIMantineInput';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { addToast } from 'features/system/store/systemSlice';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import {
LORA_MODEL_FORMAT_MAP,
MODEL_TYPE_MAP,
} from 'features/parameters/types/constants';
import { useTranslation } from 'react-i18next';
import { LoRAModelConfigEntity } from 'services/api/endpoints/models';
import {
LoRAModelConfigEntity,
useUpdateLoRAModelsMutation,
} from 'services/api/endpoints/models';
import { LoRAModelConfig } from 'services/api/types';
import BaseModelSelect from '../shared/BaseModelSelect';
@ -15,8 +24,13 @@ type LoRAModelEditProps = {
};
export default function LoRAModelEdit(props: LoRAModelEditProps) {
const isBusy = useAppSelector(selectIsBusy);
const { model } = props;
const [updateLoRAModel, { isLoading }] = useUpdateLoRAModelsMutation();
const dispatch = useAppDispatch();
const { t } = useTranslation();
const loraEditForm = useForm<LoRAModelConfig>({
@ -34,6 +48,49 @@ export default function LoRAModelEdit(props: LoRAModelEditProps) {
},
});
const editModelFormSubmitHandler = useCallback(
(values: LoRAModelConfig) => {
const responseBody = {
base_model: model.base_model,
model_name: model.model_name,
body: values,
};
updateLoRAModel(responseBody)
.unwrap()
.then((payload) => {
loraEditForm.setValues(payload as LoRAModelConfig);
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdated'),
status: 'success',
})
)
);
})
.catch((_) => {
loraEditForm.reset();
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdateFailed'),
status: 'error',
})
)
);
});
},
[
dispatch,
loraEditForm,
model.base_model,
model.model_name,
t,
updateLoRAModel,
]
);
return (
<Flex flexDirection="column" rowGap={4} width="100%">
<Flex flexDirection="column">
@ -47,34 +104,32 @@ export default function LoRAModelEdit(props: LoRAModelEditProps) {
</Flex>
<Divider />
<form>
<form
onSubmit={loraEditForm.onSubmit((values) =>
editModelFormSubmitHandler(values)
)}
>
<Flex flexDirection="column" overflowY="scroll" gap={4}>
<IAIMantineTextInput
label={t('modelManager.name')}
readOnly={true}
disabled={true}
{...loraEditForm.getInputProps('model_name')}
/>
<IAIMantineTextInput
label={t('modelManager.description')}
readOnly={true}
disabled={true}
{...loraEditForm.getInputProps('description')}
/>
<BaseModelSelect
readOnly={true}
disabled={true}
{...loraEditForm.getInputProps('base_model')}
/>
<BaseModelSelect {...loraEditForm.getInputProps('base_model')} />
<IAIMantineTextInput
readOnly={true}
disabled={true}
label={t('modelManager.modelLocation')}
{...loraEditForm.getInputProps('path')}
/>
<Text color="base.400">
{t('Editing LoRA model metadata is not yet supported.')}
</Text>
<IAIButton
type="submit"
isDisabled={isBusy || isLoading}
isLoading={isLoading}
>
{t('modelManager.updateModel')}
</IAIButton>
</Flex>
</form>
</Flex>

View File

@ -52,9 +52,18 @@ type UpdateMainModelArg = {
body: MainModelConfig;
};
type UpdateLoRAModelArg = {
base_model: BaseModelType;
model_name: string;
body: LoRAModelConfig;
};
type UpdateMainModelResponse =
paths['/api/v1/models/{base_model}/{model_type}/{model_name}']['patch']['responses']['200']['content']['application/json'];
type UpdateLoRAModelResponse =
paths['/api/v1/models/{base_model}/{model_type}/{model_name}']['patch']['responses']['200']['content']['application/json'];
type DeleteMainModelArg = {
base_model: BaseModelType;
model_name: string;
@ -324,6 +333,19 @@ export const modelsApi = api.injectEndpoints({
);
},
}),
updateLoRAModels: build.mutation<
UpdateLoRAModelResponse,
UpdateLoRAModelArg
>({
query: ({ base_model, model_name, body }) => {
return {
url: `models/${base_model}/lora/${model_name}`,
method: 'PATCH',
body: body,
};
},
invalidatesTags: [{ type: 'LoRAModel', id: LIST_TAG }],
}),
deleteLoRAModels: build.mutation<
DeleteLoRAModelResponse,
DeleteLoRAModelArg
@ -484,6 +506,7 @@ export const {
useConvertMainModelsMutation,
useMergeMainModelsMutation,
useDeleteLoRAModelsMutation,
useUpdateLoRAModelsMutation,
useSyncModelsMutation,
useGetModelsInFolderQuery,
useGetCheckpointConfigsQuery,