feat: Restore Update Model functionality

This commit is contained in:
blessedcoolant 2023-07-12 16:13:49 +12:00
parent af239fa122
commit afb46564e8
4 changed files with 108 additions and 33 deletions

View File

@ -342,6 +342,7 @@
"safetensorModels": "SafeTensors", "safetensorModels": "SafeTensors",
"modelAdded": "Model Added", "modelAdded": "Model Added",
"modelUpdated": "Model Updated", "modelUpdated": "Model Updated",
"modelUpdateFailed": "Model Update Failed",
"modelEntryDeleted": "Model Entry Deleted", "modelEntryDeleted": "Model Entry Deleted",
"cannotUseSpaces": "Cannot Use Spaces", "cannotUseSpaces": "Cannot Use Spaces",
"addNew": "Add New", "addNew": "Add New",

View File

@ -11,7 +11,11 @@ import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput'; import IAIInput from 'common/components/IAIInput';
import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { MODEL_TYPE_MAP } from 'features/system/components/ModelSelect'; import { MODEL_TYPE_MAP } from 'features/system/components/ModelSelect';
import { S } from 'services/api/types';
import { makeToast } from 'app/components/Toaster';
import { addToast } from 'features/system/store/systemSlice';
import { useUpdateMainModelsMutation } from 'services/api/endpoints/models';
import { components } from 'services/api/schema';
import ModelConvert from './ModelConvert'; import ModelConvert from './ModelConvert';
const baseModelSelectData = [ const baseModelSelectData = [
@ -25,13 +29,13 @@ const variantSelectData = [
{ value: 'depth', label: 'Depth' }, { value: 'depth', label: 'Depth' },
]; ];
export type CheckpointModel = export type CheckpointModelConfig =
| S<'StableDiffusion1ModelCheckpointConfig'> | components['schemas']['StableDiffusion1ModelCheckpointConfig']
| S<'StableDiffusion2ModelCheckpointConfig'>; | components['schemas']['StableDiffusion2ModelCheckpointConfig'];
type CheckpointModelEditProps = { type CheckpointModelEditProps = {
modelToEdit: string; modelToEdit: string;
retrievedModel: CheckpointModel; retrievedModel: CheckpointModelConfig;
}; };
export default function CheckpointModelEdit(props: CheckpointModelEditProps) { export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
@ -41,25 +45,52 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
const { modelToEdit, retrievedModel } = props; const { modelToEdit, retrievedModel } = props;
const [updateMainModel, { error }] = useUpdateMainModelsMutation();
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const checkpointEditForm = useForm({ const checkpointEditForm = useForm<CheckpointModelConfig>({
initialValues: { initialValues: {
name: retrievedModel.name, name: retrievedModel.name ? retrievedModel.name : '',
base_model: retrievedModel.base_model, base_model: retrievedModel.base_model,
type: 'main', type: 'main',
path: retrievedModel.path, path: retrievedModel.path ? retrievedModel.path : '',
description: retrievedModel.description, description: retrievedModel.description ? retrievedModel.description : '',
model_format: 'checkpoint', model_format: 'checkpoint',
vae: retrievedModel.vae, vae: retrievedModel.vae ? retrievedModel.vae : '',
config: retrievedModel.config, config: retrievedModel.config ? retrievedModel.config : '',
variant: retrievedModel.variant, variant: retrievedModel.variant,
}, },
}); });
const editModelFormSubmitHandler = (values) => { const editModelFormSubmitHandler = (values: CheckpointModelConfig) => {
console.log(values); const responseBody = {
base_model: retrievedModel.base_model,
model_name: retrievedModel.name,
body: values,
};
updateMainModel(responseBody);
if (error) {
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdateFailed'),
status: 'success',
})
)
);
}
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdated'),
status: 'success',
})
)
);
}; };
return modelToEdit ? ( return modelToEdit ? (
@ -88,10 +119,6 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
)} )}
> >
<Flex flexDirection="column" overflowY="scroll" gap={4}> <Flex flexDirection="column" overflowY="scroll" gap={4}>
<IAIInput
label={t('modelManager.name')}
{...checkpointEditForm.getInputProps('name')}
/>
<IAIInput <IAIInput
label={t('modelManager.description')} label={t('modelManager.description')}
{...checkpointEditForm.getInputProps('description')} {...checkpointEditForm.getInputProps('description')}

View File

@ -6,20 +6,23 @@ import { Divider, Flex, Text } from '@chakra-ui/react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useForm } from '@mantine/form'; import { useForm } from '@mantine/form';
import { makeToast } from 'app/components/Toaster';
import type { RootState } from 'app/store/store'; import type { RootState } from 'app/store/store';
import IAIButton from 'common/components/IAIButton'; import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput'; import IAIInput from 'common/components/IAIInput';
import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { MODEL_TYPE_MAP } from 'features/system/components/ModelSelect'; import { MODEL_TYPE_MAP } from 'features/system/components/ModelSelect';
import { S } from 'services/api/types'; import { addToast } from 'features/system/store/systemSlice';
import { useUpdateMainModelsMutation } from 'services/api/endpoints/models';
import { components } from 'services/api/schema';
type DiffusersModel = export type DiffusersModelConfig =
| S<'StableDiffusion1ModelDiffusersConfig'> | components['schemas']['StableDiffusion1ModelDiffusersConfig']
| S<'StableDiffusion2ModelDiffusersConfig'>; | components['schemas']['StableDiffusion2ModelDiffusersConfig'];
type DiffusersModelEditProps = { type DiffusersModelEditProps = {
modelToEdit: string; modelToEdit: string;
retrievedModel: DiffusersModel; retrievedModel: DiffusersModelConfig;
}; };
const baseModelSelectData = [ const baseModelSelectData = [
@ -39,24 +42,51 @@ export default function DiffusersModelEdit(props: DiffusersModelEditProps) {
); );
const { retrievedModel, modelToEdit } = props; const { retrievedModel, modelToEdit } = props;
const [updateMainModel, { error }] = useUpdateMainModelsMutation();
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const diffusersEditForm = useForm({ const diffusersEditForm = useForm<DiffusersModelConfig>({
initialValues: { initialValues: {
name: retrievedModel.name, name: retrievedModel.name ? retrievedModel.name : '',
base_model: retrievedModel.base_model, base_model: retrievedModel.base_model,
type: 'main', type: 'main',
path: retrievedModel.path, path: retrievedModel.path ? retrievedModel.path : '',
description: retrievedModel.description, description: retrievedModel.description ? retrievedModel.description : '',
model_format: 'diffusers', model_format: 'diffusers',
vae: retrievedModel.vae, vae: retrievedModel.vae ? retrievedModel.vae : '',
variant: retrievedModel.variant, variant: retrievedModel.variant,
}, },
}); });
const editModelFormSubmitHandler = (values) => { const editModelFormSubmitHandler = (values: DiffusersModelConfig) => {
console.log(values); const responseBody = {
base_model: retrievedModel.base_model,
model_name: retrievedModel.name,
body: values,
};
updateMainModel(responseBody);
if (error) {
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdateFailed'),
status: 'success',
})
)
);
}
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdated'),
status: 'success',
})
)
);
}; };
return modelToEdit ? ( return modelToEdit ? (
@ -77,10 +107,6 @@ export default function DiffusersModelEdit(props: DiffusersModelEditProps) {
)} )}
> >
<Flex flexDirection="column" overflowY="scroll" gap={4}> <Flex flexDirection="column" overflowY="scroll" gap={4}>
<IAIInput
label={t('modelManager.name')}
{...diffusersEditForm.getInputProps('name')}
/>
<IAIInput <IAIInput
label={t('modelManager.description')} label={t('modelManager.description')}
{...diffusersEditForm.getInputProps('description')} {...diffusersEditForm.getInputProps('description')}

View File

@ -2,6 +2,7 @@ import { EntityState, createEntityAdapter } from '@reduxjs/toolkit';
import { cloneDeep } from 'lodash-es'; import { cloneDeep } from 'lodash-es';
import { import {
AnyModelConfig, AnyModelConfig,
BaseModelType,
ControlNetModelConfig, ControlNetModelConfig,
LoRAModelConfig, LoRAModelConfig,
MainModelConfig, MainModelConfig,
@ -32,6 +33,12 @@ type AnyModelConfigEntity =
| TextualInversionModelConfigEntity | TextualInversionModelConfigEntity
| VaeModelConfigEntity; | VaeModelConfigEntity;
type UpdateMainModelQuery = {
base_model: BaseModelType;
model_name: string;
body: MainModelConfig;
};
const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({ const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
sortComparer: (a, b) => a.name.localeCompare(b.name), sortComparer: (a, b) => a.name.localeCompare(b.name),
}); });
@ -101,6 +108,19 @@ export const modelsApi = api.injectEndpoints({
); );
}, },
}), }),
updateMainModels: build.mutation<
EntityState<MainModelConfigEntity>,
UpdateMainModelQuery
>({
query: ({ base_model, model_name, body }) => {
return {
url: `models/${base_model}/main/${model_name}`,
method: 'PATCH',
body: body,
};
},
invalidatesTags: ['MainModel'],
}),
getLoRAModels: build.query<EntityState<LoRAModelConfigEntity>, void>({ getLoRAModels: build.query<EntityState<LoRAModelConfigEntity>, void>({
query: () => ({ url: 'models/', params: { model_type: 'lora' } }), query: () => ({ url: 'models/', params: { model_type: 'lora' } }),
providesTags: (result, error, arg) => { providesTags: (result, error, arg) => {
@ -244,4 +264,5 @@ export const {
useGetLoRAModelsQuery, useGetLoRAModelsQuery,
useGetTextualInversionModelsQuery, useGetTextualInversionModelsQuery,
useGetVaeModelsQuery, useGetVaeModelsQuery,
useUpdateMainModelsMutation,
} = modelsApi; } = modelsApi;