feat: Restore Update Model functionality

This commit is contained in:
blessedcoolant 2023-07-12 16:13:49 +12:00
parent af239fa122
commit afb46564e8
4 changed files with 108 additions and 33 deletions

View File

@ -342,6 +342,7 @@
"safetensorModels": "SafeTensors",
"modelAdded": "Model Added",
"modelUpdated": "Model Updated",
"modelUpdateFailed": "Model Update Failed",
"modelEntryDeleted": "Model Entry Deleted",
"cannotUseSpaces": "Cannot Use Spaces",
"addNew": "Add New",

View File

@ -11,7 +11,11 @@ import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { MODEL_TYPE_MAP } from 'features/system/components/ModelSelect';
import { S } from 'services/api/types';
import { makeToast } from 'app/components/Toaster';
import { addToast } from 'features/system/store/systemSlice';
import { useUpdateMainModelsMutation } from 'services/api/endpoints/models';
import { components } from 'services/api/schema';
import ModelConvert from './ModelConvert';
const baseModelSelectData = [
@ -25,13 +29,13 @@ const variantSelectData = [
{ value: 'depth', label: 'Depth' },
];
export type CheckpointModel =
| S<'StableDiffusion1ModelCheckpointConfig'>
| S<'StableDiffusion2ModelCheckpointConfig'>;
export type CheckpointModelConfig =
| components['schemas']['StableDiffusion1ModelCheckpointConfig']
| components['schemas']['StableDiffusion2ModelCheckpointConfig'];
type CheckpointModelEditProps = {
modelToEdit: string;
retrievedModel: CheckpointModel;
retrievedModel: CheckpointModelConfig;
};
export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
@ -41,25 +45,52 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
const { modelToEdit, retrievedModel } = props;
const [updateMainModel, { error }] = useUpdateMainModelsMutation();
const dispatch = useAppDispatch();
const { t } = useTranslation();
const checkpointEditForm = useForm({
const checkpointEditForm = useForm<CheckpointModelConfig>({
initialValues: {
name: retrievedModel.name,
name: retrievedModel.name ? retrievedModel.name : '',
base_model: retrievedModel.base_model,
type: 'main',
path: retrievedModel.path,
description: retrievedModel.description,
path: retrievedModel.path ? retrievedModel.path : '',
description: retrievedModel.description ? retrievedModel.description : '',
model_format: 'checkpoint',
vae: retrievedModel.vae,
config: retrievedModel.config,
vae: retrievedModel.vae ? retrievedModel.vae : '',
config: retrievedModel.config ? retrievedModel.config : '',
variant: retrievedModel.variant,
},
});
const editModelFormSubmitHandler = (values) => {
console.log(values);
const editModelFormSubmitHandler = (values: CheckpointModelConfig) => {
const responseBody = {
base_model: retrievedModel.base_model,
model_name: retrievedModel.name,
body: values,
};
updateMainModel(responseBody);
if (error) {
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdateFailed'),
status: 'success',
})
)
);
}
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdated'),
status: 'success',
})
)
);
};
return modelToEdit ? (
@ -88,10 +119,6 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
)}
>
<Flex flexDirection="column" overflowY="scroll" gap={4}>
<IAIInput
label={t('modelManager.name')}
{...checkpointEditForm.getInputProps('name')}
/>
<IAIInput
label={t('modelManager.description')}
{...checkpointEditForm.getInputProps('description')}

View File

@ -6,20 +6,23 @@ import { Divider, Flex, Text } from '@chakra-ui/react';
import { useTranslation } from 'react-i18next';
import { useForm } from '@mantine/form';
import { makeToast } from 'app/components/Toaster';
import type { RootState } from 'app/store/store';
import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { MODEL_TYPE_MAP } from 'features/system/components/ModelSelect';
import { S } from 'services/api/types';
import { addToast } from 'features/system/store/systemSlice';
import { useUpdateMainModelsMutation } from 'services/api/endpoints/models';
import { components } from 'services/api/schema';
type DiffusersModel =
| S<'StableDiffusion1ModelDiffusersConfig'>
| S<'StableDiffusion2ModelDiffusersConfig'>;
export type DiffusersModelConfig =
| components['schemas']['StableDiffusion1ModelDiffusersConfig']
| components['schemas']['StableDiffusion2ModelDiffusersConfig'];
type DiffusersModelEditProps = {
modelToEdit: string;
retrievedModel: DiffusersModel;
retrievedModel: DiffusersModelConfig;
};
const baseModelSelectData = [
@ -39,24 +42,51 @@ export default function DiffusersModelEdit(props: DiffusersModelEditProps) {
);
const { retrievedModel, modelToEdit } = props;
const [updateMainModel, { error }] = useUpdateMainModelsMutation();
const dispatch = useAppDispatch();
const { t } = useTranslation();
const diffusersEditForm = useForm({
const diffusersEditForm = useForm<DiffusersModelConfig>({
initialValues: {
name: retrievedModel.name,
name: retrievedModel.name ? retrievedModel.name : '',
base_model: retrievedModel.base_model,
type: 'main',
path: retrievedModel.path,
description: retrievedModel.description,
path: retrievedModel.path ? retrievedModel.path : '',
description: retrievedModel.description ? retrievedModel.description : '',
model_format: 'diffusers',
vae: retrievedModel.vae,
vae: retrievedModel.vae ? retrievedModel.vae : '',
variant: retrievedModel.variant,
},
});
const editModelFormSubmitHandler = (values) => {
console.log(values);
const editModelFormSubmitHandler = (values: DiffusersModelConfig) => {
const responseBody = {
base_model: retrievedModel.base_model,
model_name: retrievedModel.name,
body: values,
};
updateMainModel(responseBody);
if (error) {
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdateFailed'),
status: 'success',
})
)
);
}
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdated'),
status: 'success',
})
)
);
};
return modelToEdit ? (
@ -77,10 +107,6 @@ export default function DiffusersModelEdit(props: DiffusersModelEditProps) {
)}
>
<Flex flexDirection="column" overflowY="scroll" gap={4}>
<IAIInput
label={t('modelManager.name')}
{...diffusersEditForm.getInputProps('name')}
/>
<IAIInput
label={t('modelManager.description')}
{...diffusersEditForm.getInputProps('description')}

View File

@ -2,6 +2,7 @@ import { EntityState, createEntityAdapter } from '@reduxjs/toolkit';
import { cloneDeep } from 'lodash-es';
import {
AnyModelConfig,
BaseModelType,
ControlNetModelConfig,
LoRAModelConfig,
MainModelConfig,
@ -32,6 +33,12 @@ type AnyModelConfigEntity =
| TextualInversionModelConfigEntity
| VaeModelConfigEntity;
type UpdateMainModelQuery = {
base_model: BaseModelType;
model_name: string;
body: MainModelConfig;
};
const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
@ -101,6 +108,19 @@ export const modelsApi = api.injectEndpoints({
);
},
}),
updateMainModels: build.mutation<
EntityState<MainModelConfigEntity>,
UpdateMainModelQuery
>({
query: ({ base_model, model_name, body }) => {
return {
url: `models/${base_model}/main/${model_name}`,
method: 'PATCH',
body: body,
};
},
invalidatesTags: ['MainModel'],
}),
getLoRAModels: build.query<EntityState<LoRAModelConfigEntity>, void>({
query: () => ({ url: 'models/', params: { model_type: 'lora' } }),
providesTags: (result, error, arg) => {
@ -244,4 +264,5 @@ export const {
useGetLoRAModelsQuery,
useGetTextualInversionModelsQuery,
useGetVaeModelsQuery,
useUpdateMainModelsMutation,
} = modelsApi;