finish model update

This commit is contained in:
Mary Hipp 2024-02-21 14:16:00 -05:00 committed by psychedelicious
parent 20576deae8
commit 86e2b39f0d
2 changed files with 182 additions and 131 deletions

View File

@ -1,6 +1,6 @@
import { skipToken } from '@reduxjs/toolkit/query'; import { skipToken } from '@reduxjs/toolkit/query';
import { useAppDispatch, useAppSelector } from '../../../../app/store/storeHooks'; import { useAppDispatch, useAppSelector } from '../../../../app/store/storeHooks';
import { useGetModelQuery } from '../../../../services/api/endpoints/models'; import { useGetModelQuery, useUpdateModelsMutation } from '../../../../services/api/endpoints/models';
import { Flex, Text, Heading, Button, Input, FormControl, FormLabel, Textarea } from '@invoke-ai/ui-library'; import { Flex, Text, Heading, Button, Input, FormControl, FormLabel, Textarea } from '@invoke-ai/ui-library';
import { useCallback, useMemo } from 'react'; import { useCallback, useMemo } from 'react';
import { import {
@ -16,19 +16,26 @@ import {
} from '../../../../services/api/types'; } from '../../../../services/api/types';
import { setSelectedModelMode } from '../../store/modelManagerV2Slice'; import { setSelectedModelMode } from '../../store/modelManagerV2Slice';
import BaseModelSelect from './Fields/BaseModelSelect'; import BaseModelSelect from './Fields/BaseModelSelect';
import { useForm } from 'react-hook-form'; import { SubmitHandler, useForm } from 'react-hook-form';
import ModelTypeSelect from './Fields/ModelTypeSelect'; import ModelTypeSelect from './Fields/ModelTypeSelect';
import ModelVariantSelect from './Fields/ModelVariantSelect'; import ModelVariantSelect from './Fields/ModelVariantSelect';
import RepoVariantSelect from './Fields/RepoVariantSelect'; import RepoVariantSelect from './Fields/RepoVariantSelect';
import PredictionTypeSelect from './Fields/PredictionTypeSelect'; import PredictionTypeSelect from './Fields/PredictionTypeSelect';
import BooleanSelect from './Fields/BooleanSelect'; import BooleanSelect from './Fields/BooleanSelect';
import ModelFormatSelect from './Fields/ModelFormatSelect'; import ModelFormatSelect from './Fields/ModelFormatSelect';
import { useTranslation } from 'react-i18next';
import { addToast } from '../../../system/store/systemSlice';
import { makeToast } from '../../../system/util/makeToast';
export const ModelEdit = () => { export const ModelEdit = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey); const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
const { data, isLoading } = useGetModelQuery(selectedModelKey ?? skipToken); const { data, isLoading } = useGetModelQuery(selectedModelKey ?? skipToken);
const [updateModel, { isLoading: isSubmitting }] = useUpdateModelsMutation();
const { t } = useTranslation();
const modelData = useMemo(() => { const modelData = useMemo(() => {
if (!data) { if (!data) {
return null; return null;
@ -75,6 +82,46 @@ export const ModelEdit = () => {
mode: 'onChange', mode: 'onChange',
}); });
const onSubmit = useCallback<SubmitHandler<AnyModelConfig>>(
(values) => {
if (!modelData?.key) {
return;
}
const responseBody = {
key: modelData.key,
body: values,
};
updateModel(responseBody)
.unwrap()
.then((payload) => {
reset(payload as AnyModelConfig, { keepDefaultValues: true });
dispatch(setSelectedModelMode('view'));
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdated'),
status: 'success',
})
)
);
})
.catch((_) => {
reset();
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdateFailed'),
status: 'error',
})
)
);
});
},
[dispatch, modelData?.key, reset, t, updateModel]
);
const handleClickCancel = useCallback(() => { const handleClickCancel = useCallback(() => {
dispatch(setSelectedModelMode('view')); dispatch(setSelectedModelMode('view'));
}, [dispatch]); }, [dispatch]);
@ -88,109 +135,111 @@ export const ModelEdit = () => {
} }
return ( return (
<Flex flexDir="column" h="full"> <Flex flexDir="column" h="full">
<Flex w="full" justifyContent="space-between" gap={4} alignItems="center"> <form onSubmit={handleSubmit(onSubmit)}>
<Input <Flex w="full" justifyContent="space-between" gap={4} alignItems="center">
{...register('name', { <Input
validate: (value) => value.trim().length > 3 || 'Must be at least 3 characters', {...register('name', {
})} validate: (value) => value.trim().length > 3 || 'Must be at least 3 characters',
size="lg" })}
/> size="lg"
<Flex gap={2}> />
<Button size="sm" onClick={handleClickCancel}> <Flex gap={2}>
Cancel <Button size="sm" onClick={handleClickCancel}>
</Button> Cancel
<Button size="sm" colorScheme="invokeYellow"> </Button>
Save <Button size="sm" colorScheme="invokeYellow" onClick={handleSubmit(onSubmit)}>
</Button> Save
</Button>
</Flex>
</Flex> </Flex>
</Flex>
<Flex flexDir="column" gap={3} mt="4"> <Flex flexDir="column" gap={3} mt="4">
<Flex> <Flex>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>Description</FormLabel>
<Textarea fontSize="md" resize="none" {...register('description')} />
</FormControl>
</Flex>
<Heading as="h3" fontSize="md" mt="4">
Model Settings
</Heading>
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>Base Model</FormLabel>
<BaseModelSelect<AnyModelConfig> control={control} name="base" />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>Model Type</FormLabel>
<ModelTypeSelect<AnyModelConfig> control={control} name="type" />
</FormControl>
</Flex>
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>Format</FormLabel>
<ModelFormatSelect<AnyModelConfig> control={control} name="format" />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>Path</FormLabel>
<Input
{...register('path', {
validate: (value) => value.trim().length > 0 || 'Must provide a path',
})}
/>
</FormControl>
</Flex>
{modelData.type === 'main' && (
<>
<Flex gap={4}>
{modelData.format === 'diffusers' && (
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>Repo Variant</FormLabel>
<RepoVariantSelect<AnyModelConfig> control={control} name="repo_variant" />
</FormControl>
)}
{modelData.format === 'checkpoint' && (
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>Config Path</FormLabel>
<Input {...register('config')} />
</FormControl>
)}
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>Variant</FormLabel>
<ModelVariantSelect<AnyModelConfig> control={control} name="variant" />
</FormControl>
</Flex>
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>Prediction Type</FormLabel>
<PredictionTypeSelect<AnyModelConfig> control={control} name="prediction_type" />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>Upcast Attention</FormLabel>
<BooleanSelect<AnyModelConfig> control={control} name="upcast_attention" />
</FormControl>
</Flex>
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>ZTSNR Training</FormLabel>
<BooleanSelect<AnyModelConfig> control={control} name="ztsnr_training" />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>VAE Path</FormLabel>
<Input {...register('vae')} />
</FormControl>
</Flex>
</>
)}
{modelData.type === 'ip_adapter' && (
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}> <FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>Image Encoder Model ID</FormLabel> <FormLabel>Description</FormLabel>
<Input {...register('image_encoder_model_id')} /> <Textarea fontSize="md" resize="none" {...register('description')} />
</FormControl> </FormControl>
</Flex> </Flex>
)} <Heading as="h3" fontSize="md" mt="4">
</Flex> Model Settings
</Heading>
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>Base Model</FormLabel>
<BaseModelSelect<AnyModelConfig> control={control} name="base" />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>Model Type</FormLabel>
<ModelTypeSelect<AnyModelConfig> control={control} name="type" />
</FormControl>
</Flex>
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>Format</FormLabel>
<ModelFormatSelect<AnyModelConfig> control={control} name="format" />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>Path</FormLabel>
<Input
{...register('path', {
validate: (value) => value.trim().length > 0 || 'Must provide a path',
})}
/>
</FormControl>
</Flex>
{modelData.type === 'main' && (
<>
<Flex gap={4}>
{modelData.format === 'diffusers' && (
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>Repo Variant</FormLabel>
<RepoVariantSelect<AnyModelConfig> control={control} name="repo_variant" />
</FormControl>
)}
{modelData.format === 'checkpoint' && (
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>Config Path</FormLabel>
<Input {...register('config')} />
</FormControl>
)}
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>Variant</FormLabel>
<ModelVariantSelect<AnyModelConfig> control={control} name="variant" />
</FormControl>
</Flex>
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>Prediction Type</FormLabel>
<PredictionTypeSelect<AnyModelConfig> control={control} name="prediction_type" />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>Upcast Attention</FormLabel>
<BooleanSelect<AnyModelConfig> control={control} name="upcast_attention" />
</FormControl>
</Flex>
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>ZTSNR Training</FormLabel>
<BooleanSelect<AnyModelConfig> control={control} name="ztsnr_training" />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>VAE Path</FormLabel>
<Input {...register('vae')} />
</FormControl>
</Flex>
</>
)}
{modelData.type === 'ip_adapter' && (
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>Image Encoder Model ID</FormLabel>
<Input {...register('image_encoder_model_id')} />
</FormControl>
</Flex>
)}
</Flex>
</form>
</Flex> </Flex>
); );
}; };

View File

@ -91,39 +91,41 @@ export const ModelView = () => {
Model Settings Model Settings
</Heading> </Heading>
<Box layerStyle="second" borderRadius="base" p={3}> <Box layerStyle="second" borderRadius="base" p={3}>
<Flex gap={2}> <Flex flexDir="column" gap={3}>
<ModelAttrView label="Base Model" value={modelData.base} />
<ModelAttrView label="Model Type" value={modelData.type} />
</Flex>
<Flex gap={2}>
<ModelAttrView label="Format" value={modelData.format} />
<ModelAttrView label="Path" value={modelData.path} />
</Flex>
{modelData.type === 'main' && (
<>
<Flex gap={2}>
{modelData.format === 'diffusers' && (
<ModelAttrView label="Repo Variant" value={modelData.repo_variant} />
)}
{modelData.format === 'checkpoint' && <ModelAttrView label="Config Path" value={modelData.config} />}
<ModelAttrView label="Variant" value={modelData.variant} />
</Flex>
<Flex gap={2}>
<ModelAttrView label="Prediction Type" value={modelData.prediction_type} />
<ModelAttrView label="Upcast Attention" value={`${modelData.upcast_attention}`} />
</Flex>
<Flex gap={2}>
<ModelAttrView label="ZTSNR Training" value={`${modelData.ztsnr_training}`} />
<ModelAttrView label="VAE" value={modelData.vae} />
</Flex>
</>
)}
{modelData.type === 'ip_adapter' && (
<Flex gap={2}> <Flex gap={2}>
<ModelAttrView label="Image Encoder Model ID" value={modelData.image_encoder_model_id} /> <ModelAttrView label="Base Model" value={modelData.base} />
<ModelAttrView label="Model Type" value={modelData.type} />
</Flex> </Flex>
)} <Flex gap={2}>
<ModelAttrView label="Format" value={modelData.format} />
<ModelAttrView label="Path" value={modelData.path} />
</Flex>
{modelData.type === 'main' && (
<>
<Flex gap={2}>
{modelData.format === 'diffusers' && (
<ModelAttrView label="Repo Variant" value={modelData.repo_variant} />
)}
{modelData.format === 'checkpoint' && <ModelAttrView label="Config Path" value={modelData.config} />}
<ModelAttrView label="Variant" value={modelData.variant} />
</Flex>
<Flex gap={2}>
<ModelAttrView label="Prediction Type" value={modelData.prediction_type} />
<ModelAttrView label="Upcast Attention" value={`${modelData.upcast_attention}`} />
</Flex>
<Flex gap={2}>
<ModelAttrView label="ZTSNR Training" value={`${modelData.ztsnr_training}`} />
<ModelAttrView label="VAE" value={modelData.vae} />
</Flex>
</>
)}
{modelData.type === 'ip_adapter' && (
<Flex gap={2}>
<ModelAttrView label="Image Encoder Model ID" value={modelData.image_encoder_model_id} />
</Flex>
)}
</Flex>
</Box> </Box>
</Flex> </Flex>