From afb46564e88c242fec9fcf52fb0c2ee5534122d7 Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Wed, 12 Jul 2023 16:13:49 +1200 Subject: [PATCH] feat: Restore Update Model functionality --- invokeai/frontend/web/public/locales/en.json | 1 + .../ModelManagerPanel/CheckpointModelEdit.tsx | 61 +++++++++++++------ .../ModelManagerPanel/DiffusersModelEdit.tsx | 58 +++++++++++++----- .../web/src/services/api/endpoints/models.ts | 21 +++++++ 4 files changed, 108 insertions(+), 33 deletions(-) diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 1a902a88b7..fc56f5a703 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -342,6 +342,7 @@ "safetensorModels": "SafeTensors", "modelAdded": "Model Added", "modelUpdated": "Model Updated", + "modelUpdateFailed": "Model Update Failed", "modelEntryDeleted": "Model Entry Deleted", "cannotUseSpaces": "Cannot Use Spaces", "addNew": "Add New", diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx index 0d5d21175a..5dbb64ca7d 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx @@ -11,7 +11,11 @@ import IAIButton from 'common/components/IAIButton'; import IAIInput from 'common/components/IAIInput'; import IAIMantineSelect from 'common/components/IAIMantineSelect'; import { MODEL_TYPE_MAP } from 'features/system/components/ModelSelect'; -import { S } from 'services/api/types'; + +import { makeToast } from 'app/components/Toaster'; +import { addToast } from 'features/system/store/systemSlice'; +import { useUpdateMainModelsMutation } from 'services/api/endpoints/models'; +import { components } from 'services/api/schema'; import ModelConvert from './ModelConvert'; const baseModelSelectData = [ @@ -25,13 +29,13 @@ const variantSelectData = [ { value: 'depth', label: 'Depth' }, ]; -export type CheckpointModel = - | S<'StableDiffusion1ModelCheckpointConfig'> - | S<'StableDiffusion2ModelCheckpointConfig'>; +export type CheckpointModelConfig = + | components['schemas']['StableDiffusion1ModelCheckpointConfig'] + | components['schemas']['StableDiffusion2ModelCheckpointConfig']; type CheckpointModelEditProps = { modelToEdit: string; - retrievedModel: CheckpointModel; + retrievedModel: CheckpointModelConfig; }; export default function CheckpointModelEdit(props: CheckpointModelEditProps) { @@ -41,25 +45,52 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) { const { modelToEdit, retrievedModel } = props; + const [updateMainModel, { error }] = useUpdateMainModelsMutation(); + const dispatch = useAppDispatch(); const { t } = useTranslation(); - const checkpointEditForm = useForm({ + const checkpointEditForm = useForm({ initialValues: { - name: retrievedModel.name, + name: retrievedModel.name ? retrievedModel.name : '', base_model: retrievedModel.base_model, type: 'main', - path: retrievedModel.path, - description: retrievedModel.description, + path: retrievedModel.path ? retrievedModel.path : '', + description: retrievedModel.description ? retrievedModel.description : '', model_format: 'checkpoint', - vae: retrievedModel.vae, - config: retrievedModel.config, + vae: retrievedModel.vae ? retrievedModel.vae : '', + config: retrievedModel.config ? retrievedModel.config : '', variant: retrievedModel.variant, }, }); - const editModelFormSubmitHandler = (values) => { - console.log(values); + const editModelFormSubmitHandler = (values: CheckpointModelConfig) => { + const responseBody = { + base_model: retrievedModel.base_model, + model_name: retrievedModel.name, + body: values, + }; + updateMainModel(responseBody); + + if (error) { + dispatch( + addToast( + makeToast({ + title: t('modelManager.modelUpdateFailed'), + status: 'success', + }) + ) + ); + } + + dispatch( + addToast( + makeToast({ + title: t('modelManager.modelUpdated'), + status: 'success', + }) + ) + ); }; return modelToEdit ? ( @@ -88,10 +119,6 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) { )} > - - | S<'StableDiffusion2ModelDiffusersConfig'>; +export type DiffusersModelConfig = + | components['schemas']['StableDiffusion1ModelDiffusersConfig'] + | components['schemas']['StableDiffusion2ModelDiffusersConfig']; type DiffusersModelEditProps = { modelToEdit: string; - retrievedModel: DiffusersModel; + retrievedModel: DiffusersModelConfig; }; const baseModelSelectData = [ @@ -39,24 +42,51 @@ export default function DiffusersModelEdit(props: DiffusersModelEditProps) { ); const { retrievedModel, modelToEdit } = props; + const [updateMainModel, { error }] = useUpdateMainModelsMutation(); + const dispatch = useAppDispatch(); const { t } = useTranslation(); - const diffusersEditForm = useForm({ + const diffusersEditForm = useForm({ initialValues: { - name: retrievedModel.name, + name: retrievedModel.name ? retrievedModel.name : '', base_model: retrievedModel.base_model, type: 'main', - path: retrievedModel.path, - description: retrievedModel.description, + path: retrievedModel.path ? retrievedModel.path : '', + description: retrievedModel.description ? retrievedModel.description : '', model_format: 'diffusers', - vae: retrievedModel.vae, + vae: retrievedModel.vae ? retrievedModel.vae : '', variant: retrievedModel.variant, }, }); - const editModelFormSubmitHandler = (values) => { - console.log(values); + const editModelFormSubmitHandler = (values: DiffusersModelConfig) => { + const responseBody = { + base_model: retrievedModel.base_model, + model_name: retrievedModel.name, + body: values, + }; + updateMainModel(responseBody); + + if (error) { + dispatch( + addToast( + makeToast({ + title: t('modelManager.modelUpdateFailed'), + status: 'success', + }) + ) + ); + } + + dispatch( + addToast( + makeToast({ + title: t('modelManager.modelUpdated'), + status: 'success', + }) + ) + ); }; return modelToEdit ? ( @@ -77,10 +107,6 @@ export default function DiffusersModelEdit(props: DiffusersModelEditProps) { )} > - ({ sortComparer: (a, b) => a.name.localeCompare(b.name), }); @@ -101,6 +108,19 @@ export const modelsApi = api.injectEndpoints({ ); }, }), + updateMainModels: build.mutation< + EntityState, + UpdateMainModelQuery + >({ + query: ({ base_model, model_name, body }) => { + return { + url: `models/${base_model}/main/${model_name}`, + method: 'PATCH', + body: body, + }; + }, + invalidatesTags: ['MainModel'], + }), getLoRAModels: build.query, void>({ query: () => ({ url: 'models/', params: { model_type: 'lora' } }), providesTags: (result, error, arg) => { @@ -244,4 +264,5 @@ export const { useGetLoRAModelsQuery, useGetTextualInversionModelsQuery, useGetVaeModelsQuery, + useUpdateMainModelsMutation, } = modelsApi;