mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
finish model update
This commit is contained in:
parent
20576deae8
commit
86e2b39f0d
@ -1,6 +1,6 @@
|
|||||||
import { skipToken } from '@reduxjs/toolkit/query';
|
import { skipToken } from '@reduxjs/toolkit/query';
|
||||||
import { useAppDispatch, useAppSelector } from '../../../../app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from '../../../../app/store/storeHooks';
|
||||||
import { useGetModelQuery } from '../../../../services/api/endpoints/models';
|
import { useGetModelQuery, useUpdateModelsMutation } from '../../../../services/api/endpoints/models';
|
||||||
import { Flex, Text, Heading, Button, Input, FormControl, FormLabel, Textarea } from '@invoke-ai/ui-library';
|
import { Flex, Text, Heading, Button, Input, FormControl, FormLabel, Textarea } from '@invoke-ai/ui-library';
|
||||||
import { useCallback, useMemo } from 'react';
|
import { useCallback, useMemo } from 'react';
|
||||||
import {
|
import {
|
||||||
@ -16,19 +16,26 @@ import {
|
|||||||
} from '../../../../services/api/types';
|
} from '../../../../services/api/types';
|
||||||
import { setSelectedModelMode } from '../../store/modelManagerV2Slice';
|
import { setSelectedModelMode } from '../../store/modelManagerV2Slice';
|
||||||
import BaseModelSelect from './Fields/BaseModelSelect';
|
import BaseModelSelect from './Fields/BaseModelSelect';
|
||||||
import { useForm } from 'react-hook-form';
|
import { SubmitHandler, useForm } from 'react-hook-form';
|
||||||
import ModelTypeSelect from './Fields/ModelTypeSelect';
|
import ModelTypeSelect from './Fields/ModelTypeSelect';
|
||||||
import ModelVariantSelect from './Fields/ModelVariantSelect';
|
import ModelVariantSelect from './Fields/ModelVariantSelect';
|
||||||
import RepoVariantSelect from './Fields/RepoVariantSelect';
|
import RepoVariantSelect from './Fields/RepoVariantSelect';
|
||||||
import PredictionTypeSelect from './Fields/PredictionTypeSelect';
|
import PredictionTypeSelect from './Fields/PredictionTypeSelect';
|
||||||
import BooleanSelect from './Fields/BooleanSelect';
|
import BooleanSelect from './Fields/BooleanSelect';
|
||||||
import ModelFormatSelect from './Fields/ModelFormatSelect';
|
import ModelFormatSelect from './Fields/ModelFormatSelect';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { addToast } from '../../../system/store/systemSlice';
|
||||||
|
import { makeToast } from '../../../system/util/makeToast';
|
||||||
|
|
||||||
export const ModelEdit = () => {
|
export const ModelEdit = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||||
const { data, isLoading } = useGetModelQuery(selectedModelKey ?? skipToken);
|
const { data, isLoading } = useGetModelQuery(selectedModelKey ?? skipToken);
|
||||||
|
|
||||||
|
const [updateModel, { isLoading: isSubmitting }] = useUpdateModelsMutation();
|
||||||
|
|
||||||
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const modelData = useMemo(() => {
|
const modelData = useMemo(() => {
|
||||||
if (!data) {
|
if (!data) {
|
||||||
return null;
|
return null;
|
||||||
@ -75,6 +82,46 @@ export const ModelEdit = () => {
|
|||||||
mode: 'onChange',
|
mode: 'onChange',
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const onSubmit = useCallback<SubmitHandler<AnyModelConfig>>(
|
||||||
|
(values) => {
|
||||||
|
if (!modelData?.key) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const responseBody = {
|
||||||
|
key: modelData.key,
|
||||||
|
body: values,
|
||||||
|
};
|
||||||
|
|
||||||
|
updateModel(responseBody)
|
||||||
|
.unwrap()
|
||||||
|
.then((payload) => {
|
||||||
|
reset(payload as AnyModelConfig, { keepDefaultValues: true });
|
||||||
|
dispatch(setSelectedModelMode('view'));
|
||||||
|
dispatch(
|
||||||
|
addToast(
|
||||||
|
makeToast({
|
||||||
|
title: t('modelManager.modelUpdated'),
|
||||||
|
status: 'success',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
);
|
||||||
|
})
|
||||||
|
.catch((_) => {
|
||||||
|
reset();
|
||||||
|
dispatch(
|
||||||
|
addToast(
|
||||||
|
makeToast({
|
||||||
|
title: t('modelManager.modelUpdateFailed'),
|
||||||
|
status: 'error',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
);
|
||||||
|
});
|
||||||
|
},
|
||||||
|
[dispatch, modelData?.key, reset, t, updateModel]
|
||||||
|
);
|
||||||
|
|
||||||
const handleClickCancel = useCallback(() => {
|
const handleClickCancel = useCallback(() => {
|
||||||
dispatch(setSelectedModelMode('view'));
|
dispatch(setSelectedModelMode('view'));
|
||||||
}, [dispatch]);
|
}, [dispatch]);
|
||||||
@ -88,109 +135,111 @@ export const ModelEdit = () => {
|
|||||||
}
|
}
|
||||||
return (
|
return (
|
||||||
<Flex flexDir="column" h="full">
|
<Flex flexDir="column" h="full">
|
||||||
<Flex w="full" justifyContent="space-between" gap={4} alignItems="center">
|
<form onSubmit={handleSubmit(onSubmit)}>
|
||||||
<Input
|
<Flex w="full" justifyContent="space-between" gap={4} alignItems="center">
|
||||||
{...register('name', {
|
<Input
|
||||||
validate: (value) => value.trim().length > 3 || 'Must be at least 3 characters',
|
{...register('name', {
|
||||||
})}
|
validate: (value) => value.trim().length > 3 || 'Must be at least 3 characters',
|
||||||
size="lg"
|
})}
|
||||||
/>
|
size="lg"
|
||||||
<Flex gap={2}>
|
/>
|
||||||
<Button size="sm" onClick={handleClickCancel}>
|
<Flex gap={2}>
|
||||||
Cancel
|
<Button size="sm" onClick={handleClickCancel}>
|
||||||
</Button>
|
Cancel
|
||||||
<Button size="sm" colorScheme="invokeYellow">
|
</Button>
|
||||||
Save
|
<Button size="sm" colorScheme="invokeYellow" onClick={handleSubmit(onSubmit)}>
|
||||||
</Button>
|
Save
|
||||||
|
</Button>
|
||||||
|
</Flex>
|
||||||
</Flex>
|
</Flex>
|
||||||
</Flex>
|
|
||||||
|
|
||||||
<Flex flexDir="column" gap={3} mt="4">
|
<Flex flexDir="column" gap={3} mt="4">
|
||||||
<Flex>
|
<Flex>
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
|
||||||
<FormLabel>Description</FormLabel>
|
|
||||||
<Textarea fontSize="md" resize="none" {...register('description')} />
|
|
||||||
</FormControl>
|
|
||||||
</Flex>
|
|
||||||
<Heading as="h3" fontSize="md" mt="4">
|
|
||||||
Model Settings
|
|
||||||
</Heading>
|
|
||||||
<Flex gap={4}>
|
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
|
||||||
<FormLabel>Base Model</FormLabel>
|
|
||||||
<BaseModelSelect<AnyModelConfig> control={control} name="base" />
|
|
||||||
</FormControl>
|
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
|
||||||
<FormLabel>Model Type</FormLabel>
|
|
||||||
<ModelTypeSelect<AnyModelConfig> control={control} name="type" />
|
|
||||||
</FormControl>
|
|
||||||
</Flex>
|
|
||||||
<Flex gap={4}>
|
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
|
||||||
<FormLabel>Format</FormLabel>
|
|
||||||
<ModelFormatSelect<AnyModelConfig> control={control} name="format" />
|
|
||||||
</FormControl>
|
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
|
||||||
<FormLabel>Path</FormLabel>
|
|
||||||
<Input
|
|
||||||
{...register('path', {
|
|
||||||
validate: (value) => value.trim().length > 0 || 'Must provide a path',
|
|
||||||
})}
|
|
||||||
/>
|
|
||||||
</FormControl>
|
|
||||||
</Flex>
|
|
||||||
{modelData.type === 'main' && (
|
|
||||||
<>
|
|
||||||
<Flex gap={4}>
|
|
||||||
{modelData.format === 'diffusers' && (
|
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
|
||||||
<FormLabel>Repo Variant</FormLabel>
|
|
||||||
<RepoVariantSelect<AnyModelConfig> control={control} name="repo_variant" />
|
|
||||||
</FormControl>
|
|
||||||
)}
|
|
||||||
{modelData.format === 'checkpoint' && (
|
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
|
||||||
<FormLabel>Config Path</FormLabel>
|
|
||||||
<Input {...register('config')} />
|
|
||||||
</FormControl>
|
|
||||||
)}
|
|
||||||
|
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
|
||||||
<FormLabel>Variant</FormLabel>
|
|
||||||
<ModelVariantSelect<AnyModelConfig> control={control} name="variant" />
|
|
||||||
</FormControl>
|
|
||||||
</Flex>
|
|
||||||
<Flex gap={4}>
|
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
|
||||||
<FormLabel>Prediction Type</FormLabel>
|
|
||||||
<PredictionTypeSelect<AnyModelConfig> control={control} name="prediction_type" />
|
|
||||||
</FormControl>
|
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
|
||||||
<FormLabel>Upcast Attention</FormLabel>
|
|
||||||
<BooleanSelect<AnyModelConfig> control={control} name="upcast_attention" />
|
|
||||||
</FormControl>
|
|
||||||
</Flex>
|
|
||||||
<Flex gap={4}>
|
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
|
||||||
<FormLabel>ZTSNR Training</FormLabel>
|
|
||||||
<BooleanSelect<AnyModelConfig> control={control} name="ztsnr_training" />
|
|
||||||
</FormControl>
|
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
|
||||||
<FormLabel>VAE Path</FormLabel>
|
|
||||||
<Input {...register('vae')} />
|
|
||||||
</FormControl>
|
|
||||||
</Flex>
|
|
||||||
</>
|
|
||||||
)}
|
|
||||||
{modelData.type === 'ip_adapter' && (
|
|
||||||
<Flex gap={4}>
|
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
<FormLabel>Image Encoder Model ID</FormLabel>
|
<FormLabel>Description</FormLabel>
|
||||||
<Input {...register('image_encoder_model_id')} />
|
<Textarea fontSize="md" resize="none" {...register('description')} />
|
||||||
</FormControl>
|
</FormControl>
|
||||||
</Flex>
|
</Flex>
|
||||||
)}
|
<Heading as="h3" fontSize="md" mt="4">
|
||||||
</Flex>
|
Model Settings
|
||||||
|
</Heading>
|
||||||
|
<Flex gap={4}>
|
||||||
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
|
<FormLabel>Base Model</FormLabel>
|
||||||
|
<BaseModelSelect<AnyModelConfig> control={control} name="base" />
|
||||||
|
</FormControl>
|
||||||
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
|
<FormLabel>Model Type</FormLabel>
|
||||||
|
<ModelTypeSelect<AnyModelConfig> control={control} name="type" />
|
||||||
|
</FormControl>
|
||||||
|
</Flex>
|
||||||
|
<Flex gap={4}>
|
||||||
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
|
<FormLabel>Format</FormLabel>
|
||||||
|
<ModelFormatSelect<AnyModelConfig> control={control} name="format" />
|
||||||
|
</FormControl>
|
||||||
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
|
<FormLabel>Path</FormLabel>
|
||||||
|
<Input
|
||||||
|
{...register('path', {
|
||||||
|
validate: (value) => value.trim().length > 0 || 'Must provide a path',
|
||||||
|
})}
|
||||||
|
/>
|
||||||
|
</FormControl>
|
||||||
|
</Flex>
|
||||||
|
{modelData.type === 'main' && (
|
||||||
|
<>
|
||||||
|
<Flex gap={4}>
|
||||||
|
{modelData.format === 'diffusers' && (
|
||||||
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
|
<FormLabel>Repo Variant</FormLabel>
|
||||||
|
<RepoVariantSelect<AnyModelConfig> control={control} name="repo_variant" />
|
||||||
|
</FormControl>
|
||||||
|
)}
|
||||||
|
{modelData.format === 'checkpoint' && (
|
||||||
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
|
<FormLabel>Config Path</FormLabel>
|
||||||
|
<Input {...register('config')} />
|
||||||
|
</FormControl>
|
||||||
|
)}
|
||||||
|
|
||||||
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
|
<FormLabel>Variant</FormLabel>
|
||||||
|
<ModelVariantSelect<AnyModelConfig> control={control} name="variant" />
|
||||||
|
</FormControl>
|
||||||
|
</Flex>
|
||||||
|
<Flex gap={4}>
|
||||||
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
|
<FormLabel>Prediction Type</FormLabel>
|
||||||
|
<PredictionTypeSelect<AnyModelConfig> control={control} name="prediction_type" />
|
||||||
|
</FormControl>
|
||||||
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
|
<FormLabel>Upcast Attention</FormLabel>
|
||||||
|
<BooleanSelect<AnyModelConfig> control={control} name="upcast_attention" />
|
||||||
|
</FormControl>
|
||||||
|
</Flex>
|
||||||
|
<Flex gap={4}>
|
||||||
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
|
<FormLabel>ZTSNR Training</FormLabel>
|
||||||
|
<BooleanSelect<AnyModelConfig> control={control} name="ztsnr_training" />
|
||||||
|
</FormControl>
|
||||||
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
|
<FormLabel>VAE Path</FormLabel>
|
||||||
|
<Input {...register('vae')} />
|
||||||
|
</FormControl>
|
||||||
|
</Flex>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
{modelData.type === 'ip_adapter' && (
|
||||||
|
<Flex gap={4}>
|
||||||
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
|
<FormLabel>Image Encoder Model ID</FormLabel>
|
||||||
|
<Input {...register('image_encoder_model_id')} />
|
||||||
|
</FormControl>
|
||||||
|
</Flex>
|
||||||
|
)}
|
||||||
|
</Flex>
|
||||||
|
</form>
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -91,39 +91,41 @@ export const ModelView = () => {
|
|||||||
Model Settings
|
Model Settings
|
||||||
</Heading>
|
</Heading>
|
||||||
<Box layerStyle="second" borderRadius="base" p={3}>
|
<Box layerStyle="second" borderRadius="base" p={3}>
|
||||||
<Flex gap={2}>
|
<Flex flexDir="column" gap={3}>
|
||||||
<ModelAttrView label="Base Model" value={modelData.base} />
|
|
||||||
<ModelAttrView label="Model Type" value={modelData.type} />
|
|
||||||
</Flex>
|
|
||||||
<Flex gap={2}>
|
|
||||||
<ModelAttrView label="Format" value={modelData.format} />
|
|
||||||
<ModelAttrView label="Path" value={modelData.path} />
|
|
||||||
</Flex>
|
|
||||||
{modelData.type === 'main' && (
|
|
||||||
<>
|
|
||||||
<Flex gap={2}>
|
|
||||||
{modelData.format === 'diffusers' && (
|
|
||||||
<ModelAttrView label="Repo Variant" value={modelData.repo_variant} />
|
|
||||||
)}
|
|
||||||
{modelData.format === 'checkpoint' && <ModelAttrView label="Config Path" value={modelData.config} />}
|
|
||||||
|
|
||||||
<ModelAttrView label="Variant" value={modelData.variant} />
|
|
||||||
</Flex>
|
|
||||||
<Flex gap={2}>
|
|
||||||
<ModelAttrView label="Prediction Type" value={modelData.prediction_type} />
|
|
||||||
<ModelAttrView label="Upcast Attention" value={`${modelData.upcast_attention}`} />
|
|
||||||
</Flex>
|
|
||||||
<Flex gap={2}>
|
|
||||||
<ModelAttrView label="ZTSNR Training" value={`${modelData.ztsnr_training}`} />
|
|
||||||
<ModelAttrView label="VAE" value={modelData.vae} />
|
|
||||||
</Flex>
|
|
||||||
</>
|
|
||||||
)}
|
|
||||||
{modelData.type === 'ip_adapter' && (
|
|
||||||
<Flex gap={2}>
|
<Flex gap={2}>
|
||||||
<ModelAttrView label="Image Encoder Model ID" value={modelData.image_encoder_model_id} />
|
<ModelAttrView label="Base Model" value={modelData.base} />
|
||||||
|
<ModelAttrView label="Model Type" value={modelData.type} />
|
||||||
</Flex>
|
</Flex>
|
||||||
)}
|
<Flex gap={2}>
|
||||||
|
<ModelAttrView label="Format" value={modelData.format} />
|
||||||
|
<ModelAttrView label="Path" value={modelData.path} />
|
||||||
|
</Flex>
|
||||||
|
{modelData.type === 'main' && (
|
||||||
|
<>
|
||||||
|
<Flex gap={2}>
|
||||||
|
{modelData.format === 'diffusers' && (
|
||||||
|
<ModelAttrView label="Repo Variant" value={modelData.repo_variant} />
|
||||||
|
)}
|
||||||
|
{modelData.format === 'checkpoint' && <ModelAttrView label="Config Path" value={modelData.config} />}
|
||||||
|
|
||||||
|
<ModelAttrView label="Variant" value={modelData.variant} />
|
||||||
|
</Flex>
|
||||||
|
<Flex gap={2}>
|
||||||
|
<ModelAttrView label="Prediction Type" value={modelData.prediction_type} />
|
||||||
|
<ModelAttrView label="Upcast Attention" value={`${modelData.upcast_attention}`} />
|
||||||
|
</Flex>
|
||||||
|
<Flex gap={2}>
|
||||||
|
<ModelAttrView label="ZTSNR Training" value={`${modelData.ztsnr_training}`} />
|
||||||
|
<ModelAttrView label="VAE" value={modelData.vae} />
|
||||||
|
</Flex>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
{modelData.type === 'ip_adapter' && (
|
||||||
|
<Flex gap={2}>
|
||||||
|
<ModelAttrView label="Image Encoder Model ID" value={modelData.image_encoder_model_id} />
|
||||||
|
</Flex>
|
||||||
|
)}
|
||||||
|
</Flex>
|
||||||
</Box>
|
</Box>
|
||||||
</Flex>
|
</Flex>
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user