finish model update

This commit is contained in:
Mary Hipp 2024-02-21 14:16:00 -05:00 committed by psychedelicious
parent 20576deae8
commit 86e2b39f0d
2 changed files with 182 additions and 131 deletions

View File

@ -1,6 +1,6 @@
import { skipToken } from '@reduxjs/toolkit/query'; import { skipToken } from '@reduxjs/toolkit/query';
import { useAppDispatch, useAppSelector } from '../../../../app/store/storeHooks'; import { useAppDispatch, useAppSelector } from '../../../../app/store/storeHooks';
import { useGetModelQuery } from '../../../../services/api/endpoints/models'; import { useGetModelQuery, useUpdateModelsMutation } from '../../../../services/api/endpoints/models';
import { Flex, Text, Heading, Button, Input, FormControl, FormLabel, Textarea } from '@invoke-ai/ui-library'; import { Flex, Text, Heading, Button, Input, FormControl, FormLabel, Textarea } from '@invoke-ai/ui-library';
import { useCallback, useMemo } from 'react'; import { useCallback, useMemo } from 'react';
import { import {
@ -16,19 +16,26 @@ import {
} from '../../../../services/api/types'; } from '../../../../services/api/types';
import { setSelectedModelMode } from '../../store/modelManagerV2Slice'; import { setSelectedModelMode } from '../../store/modelManagerV2Slice';
import BaseModelSelect from './Fields/BaseModelSelect'; import BaseModelSelect from './Fields/BaseModelSelect';
import { useForm } from 'react-hook-form'; import { SubmitHandler, useForm } from 'react-hook-form';
import ModelTypeSelect from './Fields/ModelTypeSelect'; import ModelTypeSelect from './Fields/ModelTypeSelect';
import ModelVariantSelect from './Fields/ModelVariantSelect'; import ModelVariantSelect from './Fields/ModelVariantSelect';
import RepoVariantSelect from './Fields/RepoVariantSelect'; import RepoVariantSelect from './Fields/RepoVariantSelect';
import PredictionTypeSelect from './Fields/PredictionTypeSelect'; import PredictionTypeSelect from './Fields/PredictionTypeSelect';
import BooleanSelect from './Fields/BooleanSelect'; import BooleanSelect from './Fields/BooleanSelect';
import ModelFormatSelect from './Fields/ModelFormatSelect'; import ModelFormatSelect from './Fields/ModelFormatSelect';
import { useTranslation } from 'react-i18next';
import { addToast } from '../../../system/store/systemSlice';
import { makeToast } from '../../../system/util/makeToast';
export const ModelEdit = () => { export const ModelEdit = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey); const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
const { data, isLoading } = useGetModelQuery(selectedModelKey ?? skipToken); const { data, isLoading } = useGetModelQuery(selectedModelKey ?? skipToken);
const [updateModel, { isLoading: isSubmitting }] = useUpdateModelsMutation();
const { t } = useTranslation();
const modelData = useMemo(() => { const modelData = useMemo(() => {
if (!data) { if (!data) {
return null; return null;
@ -75,6 +82,46 @@ export const ModelEdit = () => {
mode: 'onChange', mode: 'onChange',
}); });
const onSubmit = useCallback<SubmitHandler<AnyModelConfig>>(
(values) => {
if (!modelData?.key) {
return;
}
const responseBody = {
key: modelData.key,
body: values,
};
updateModel(responseBody)
.unwrap()
.then((payload) => {
reset(payload as AnyModelConfig, { keepDefaultValues: true });
dispatch(setSelectedModelMode('view'));
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdated'),
status: 'success',
})
)
);
})
.catch((_) => {
reset();
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdateFailed'),
status: 'error',
})
)
);
});
},
[dispatch, modelData?.key, reset, t, updateModel]
);
const handleClickCancel = useCallback(() => { const handleClickCancel = useCallback(() => {
dispatch(setSelectedModelMode('view')); dispatch(setSelectedModelMode('view'));
}, [dispatch]); }, [dispatch]);
@ -88,6 +135,7 @@ export const ModelEdit = () => {
} }
return ( return (
<Flex flexDir="column" h="full"> <Flex flexDir="column" h="full">
<form onSubmit={handleSubmit(onSubmit)}>
<Flex w="full" justifyContent="space-between" gap={4} alignItems="center"> <Flex w="full" justifyContent="space-between" gap={4} alignItems="center">
<Input <Input
{...register('name', { {...register('name', {
@ -99,7 +147,7 @@ export const ModelEdit = () => {
<Button size="sm" onClick={handleClickCancel}> <Button size="sm" onClick={handleClickCancel}>
Cancel Cancel
</Button> </Button>
<Button size="sm" colorScheme="invokeYellow"> <Button size="sm" colorScheme="invokeYellow" onClick={handleSubmit(onSubmit)}>
Save Save
</Button> </Button>
</Flex> </Flex>
@ -191,6 +239,7 @@ export const ModelEdit = () => {
</Flex> </Flex>
)} )}
</Flex> </Flex>
</form>
</Flex> </Flex>
); );
}; };

View File

@ -91,6 +91,7 @@ export const ModelView = () => {
Model Settings Model Settings
</Heading> </Heading>
<Box layerStyle="second" borderRadius="base" p={3}> <Box layerStyle="second" borderRadius="base" p={3}>
<Flex flexDir="column" gap={3}>
<Flex gap={2}> <Flex gap={2}>
<ModelAttrView label="Base Model" value={modelData.base} /> <ModelAttrView label="Base Model" value={modelData.base} />
<ModelAttrView label="Model Type" value={modelData.type} /> <ModelAttrView label="Model Type" value={modelData.type} />
@ -124,6 +125,7 @@ export const ModelView = () => {
<ModelAttrView label="Image Encoder Model ID" value={modelData.image_encoder_model_id} /> <ModelAttrView label="Image Encoder Model ID" value={modelData.image_encoder_model_id} />
</Flex> </Flex>
)} )}
</Flex>
</Box> </Box>
</Flex> </Flex>