From aa2c94be9ec15324c57426195a9c754f91494b67 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Tue, 25 Jul 2023 23:20:14 -0400 Subject: [PATCH] make LoRAs editable --- .../ModelManagerPanel/LoRAModelEdit.tsx | 89 +++++++++++++++---- .../web/src/services/api/endpoints/models.ts | 23 +++++ 2 files changed, 95 insertions(+), 17 deletions(-) diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/LoRAModelEdit.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/LoRAModelEdit.tsx index b1c6900f74..c87550c7d1 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/LoRAModelEdit.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/LoRAModelEdit.tsx @@ -1,12 +1,21 @@ import { Divider, Flex, Text } from '@chakra-ui/react'; import { useForm } from '@mantine/form'; +import { makeToast } from 'features/system/util/makeToast'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import IAIButton from 'common/components/IAIButton'; import IAIMantineTextInput from 'common/components/IAIMantineInput'; +import { selectIsBusy } from 'features/system/store/systemSelectors'; +import { addToast } from 'features/system/store/systemSlice'; +import { useCallback } from 'react'; +import { useTranslation } from 'react-i18next'; import { LORA_MODEL_FORMAT_MAP, MODEL_TYPE_MAP, } from 'features/parameters/types/constants'; -import { useTranslation } from 'react-i18next'; -import { LoRAModelConfigEntity } from 'services/api/endpoints/models'; +import { + LoRAModelConfigEntity, + useUpdateLoRAModelsMutation, +} from 'services/api/endpoints/models'; import { LoRAModelConfig } from 'services/api/types'; import BaseModelSelect from '../shared/BaseModelSelect'; @@ -15,8 +24,13 @@ type LoRAModelEditProps = { }; export default function LoRAModelEdit(props: LoRAModelEditProps) { + const isBusy = useAppSelector(selectIsBusy); + const { model } = props; + const [updateLoRAModel, { isLoading }] = useUpdateLoRAModelsMutation(); + + const dispatch = useAppDispatch(); const { t } = useTranslation(); const loraEditForm = useForm({ @@ -34,6 +48,49 @@ export default function LoRAModelEdit(props: LoRAModelEditProps) { }, }); + const editModelFormSubmitHandler = useCallback( + (values: LoRAModelConfig) => { + const responseBody = { + base_model: model.base_model, + model_name: model.model_name, + body: values, + }; + + updateLoRAModel(responseBody) + .unwrap() + .then((payload) => { + loraEditForm.setValues(payload as LoRAModelConfig); + dispatch( + addToast( + makeToast({ + title: t('modelManager.modelUpdated'), + status: 'success', + }) + ) + ); + }) + .catch((_) => { + loraEditForm.reset(); + dispatch( + addToast( + makeToast({ + title: t('modelManager.modelUpdateFailed'), + status: 'error', + }) + ) + ); + }); + }, + [ + dispatch, + loraEditForm, + model.base_model, + model.model_name, + t, + updateLoRAModel, + ] + ); + return ( @@ -47,34 +104,32 @@ export default function LoRAModelEdit(props: LoRAModelEditProps) { -
+ + editModelFormSubmitHandler(values) + )} + > - + - - {t('Editing LoRA model metadata is not yet supported.')} - + + {t('modelManager.updateModel')} +
diff --git a/invokeai/frontend/web/src/services/api/endpoints/models.ts b/invokeai/frontend/web/src/services/api/endpoints/models.ts index aa93be62b5..6fa7f60d08 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/models.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/models.ts @@ -52,9 +52,18 @@ type UpdateMainModelArg = { body: MainModelConfig; }; +type UpdateLoRAModelArg = { + base_model: BaseModelType; + model_name: string; + body: LoRAModelConfig; +}; + type UpdateMainModelResponse = paths['/api/v1/models/{base_model}/{model_type}/{model_name}']['patch']['responses']['200']['content']['application/json']; +type UpdateLoRAModelResponse = + paths['/api/v1/models/{base_model}/{model_type}/{model_name}']['patch']['responses']['200']['content']['application/json']; + type DeleteMainModelArg = { base_model: BaseModelType; model_name: string; @@ -324,6 +333,19 @@ export const modelsApi = api.injectEndpoints({ ); }, }), + updateLoRAModels: build.mutation< + UpdateLoRAModelResponse, + UpdateLoRAModelArg + >({ + query: ({ base_model, model_name, body }) => { + return { + url: `models/${base_model}/lora/${model_name}`, + method: 'PATCH', + body: body, + }; + }, + invalidatesTags: [{ type: 'LoRAModel', id: LIST_TAG }], + }), deleteLoRAModels: build.mutation< DeleteLoRAModelResponse, DeleteLoRAModelArg @@ -484,6 +506,7 @@ export const { useConvertMainModelsMutation, useMergeMainModelsMutation, useDeleteLoRAModelsMutation, + useUpdateLoRAModelsMutation, useSyncModelsMutation, useGetModelsInFolderQuery, useGetCheckpointConfigsQuery,