mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): move model save/close buttons to model header
This commit is contained in:
parent
c008704bc8
commit
6386109fc5
@ -21,6 +21,7 @@ export const SyncModelsButton = memo((props: Omit<ButtonProps, 'aria-label'>) =>
|
|||||||
leftIcon={<PiArrowsClockwiseBold />}
|
leftIcon={<PiArrowsClockwiseBold />}
|
||||||
isLoading={isLoading}
|
isLoading={isLoading}
|
||||||
onClick={syncModels}
|
onClick={syncModels}
|
||||||
|
size="sm"
|
||||||
variant="ghost"
|
variant="ghost"
|
||||||
{...props}
|
{...props}
|
||||||
>
|
>
|
||||||
|
@ -21,8 +21,8 @@ export const ModelManager = () => {
|
|||||||
<Flex w="full" gap={4} justifyContent="space-between" alignItems="center">
|
<Flex w="full" gap={4} justifyContent="space-between" alignItems="center">
|
||||||
<Heading fontSize="xl">{t('common.modelManager')}</Heading>
|
<Heading fontSize="xl">{t('common.modelManager')}</Heading>
|
||||||
<Spacer />
|
<Spacer />
|
||||||
<SyncModelsButton />
|
<SyncModelsButton size="sm" />
|
||||||
<Button colorScheme="invokeYellow" leftIcon={<PiPlusBold />} onClick={handleClickAddModel}>
|
<Button size="sm" colorScheme="invokeYellow" leftIcon={<PiPlusBold />} onClick={handleClickAddModel}>
|
||||||
{t('modelManager.addModels')}
|
{t('modelManager.addModels')}
|
||||||
</Button>
|
</Button>
|
||||||
</Flex>
|
</Flex>
|
||||||
|
@ -34,7 +34,7 @@ export const ModelTypeFilter = () => {
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<Menu>
|
<Menu>
|
||||||
<MenuButton as={Button} leftIcon={<IoFilter />}>
|
<MenuButton as={Button} size="sm" leftIcon={<IoFilter />}>
|
||||||
{filteredModelType ? MODEL_TYPE_LABELS[filteredModelType] : t('modelManager.allModels')}
|
{filteredModelType ? MODEL_TYPE_LABELS[filteredModelType] : t('modelManager.allModels')}
|
||||||
</MenuButton>
|
</MenuButton>
|
||||||
<MenuList>
|
<MenuList>
|
||||||
|
@ -1,10 +1,18 @@
|
|||||||
import { Flex, Heading, Spacer, Text } from '@invoke-ai/ui-library';
|
import { Button, Flex, Heading, Spacer, Text } from '@invoke-ai/ui-library';
|
||||||
import { skipToken } from '@reduxjs/toolkit/query';
|
import { skipToken } from '@reduxjs/toolkit/query';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { setSelectedModelMode } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||||
import { ModelConvertButton } from 'features/modelManagerV2/subpanels/ModelPanel/ModelConvertButton';
|
import { ModelConvertButton } from 'features/modelManagerV2/subpanels/ModelPanel/ModelConvertButton';
|
||||||
import { ModelEditButton } from 'features/modelManagerV2/subpanels/ModelPanel/ModelEditButton';
|
import { ModelEditButton } from 'features/modelManagerV2/subpanels/ModelPanel/ModelEditButton';
|
||||||
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
|
import { makeToast } from 'features/system/util/makeToast';
|
||||||
|
import { useCallback } from 'react';
|
||||||
|
import type { SubmitHandler } from 'react-hook-form';
|
||||||
|
import { useForm } from 'react-hook-form';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
|
import { PiCheckBold, PiXBold } from 'react-icons/pi';
|
||||||
|
import type { UpdateModelArg } from 'services/api/endpoints/models';
|
||||||
|
import { useGetModelConfigQuery, useUpdateModelMutation } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
import ModelImageUpload from './Fields/ModelImageUpload';
|
import ModelImageUpload from './Fields/ModelImageUpload';
|
||||||
import { ModelEdit } from './ModelEdit';
|
import { ModelEdit } from './ModelEdit';
|
||||||
@ -15,6 +23,57 @@ export const Model = () => {
|
|||||||
const selectedModelMode = useAppSelector((s) => s.modelmanagerV2.selectedModelMode);
|
const selectedModelMode = useAppSelector((s) => s.modelmanagerV2.selectedModelMode);
|
||||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||||
const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
||||||
|
const [updateModel, { isLoading: isSubmitting }] = useUpdateModelMutation();
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
const form = useForm<UpdateModelArg['body']>({
|
||||||
|
defaultValues: data,
|
||||||
|
mode: 'onChange',
|
||||||
|
});
|
||||||
|
|
||||||
|
const onSubmit = useCallback<SubmitHandler<UpdateModelArg['body']>>(
|
||||||
|
(values) => {
|
||||||
|
if (!data?.key) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const responseBody: UpdateModelArg = {
|
||||||
|
key: data.key,
|
||||||
|
body: values,
|
||||||
|
};
|
||||||
|
|
||||||
|
updateModel(responseBody)
|
||||||
|
.unwrap()
|
||||||
|
.then((payload) => {
|
||||||
|
form.reset(payload, { keepDefaultValues: true });
|
||||||
|
dispatch(setSelectedModelMode('view'));
|
||||||
|
dispatch(
|
||||||
|
addToast(
|
||||||
|
makeToast({
|
||||||
|
title: t('modelManager.modelUpdated'),
|
||||||
|
status: 'success',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
);
|
||||||
|
})
|
||||||
|
.catch((_) => {
|
||||||
|
form.reset();
|
||||||
|
dispatch(
|
||||||
|
addToast(
|
||||||
|
makeToast({
|
||||||
|
title: t('modelManager.modelUpdateFailed'),
|
||||||
|
status: 'error',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
);
|
||||||
|
});
|
||||||
|
},
|
||||||
|
[dispatch, data?.key, form, t, updateModel]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleClickCancel = useCallback(() => {
|
||||||
|
dispatch(setSelectedModelMode('view'));
|
||||||
|
}, [dispatch]);
|
||||||
|
|
||||||
if (isLoading) {
|
if (isLoading) {
|
||||||
return <Text>{t('common.loading')}</Text>;
|
return <Text>{t('common.loading')}</Text>;
|
||||||
@ -30,12 +89,29 @@ export const Model = () => {
|
|||||||
<ModelImageUpload model_key={selectedModelKey} model_image={data.cover_image} />
|
<ModelImageUpload model_key={selectedModelKey} model_image={data.cover_image} />
|
||||||
<Flex flexDir="column" gap={1} flexGrow={1}>
|
<Flex flexDir="column" gap={1} flexGrow={1}>
|
||||||
<Flex gap={2}>
|
<Flex gap={2}>
|
||||||
<Heading as="h2" fontSize="lg" w="full">
|
<Heading as="h2" fontSize="lg">
|
||||||
{data.name}
|
{data.name}
|
||||||
</Heading>
|
</Heading>
|
||||||
<Spacer />
|
<Spacer />
|
||||||
<ModelEditButton />
|
{selectedModelMode === 'view' && <ModelConvertButton modelKey={selectedModelKey} />}
|
||||||
<ModelConvertButton modelKey={selectedModelKey} />
|
{selectedModelMode === 'view' && <ModelEditButton />}
|
||||||
|
{selectedModelMode === 'edit' && (
|
||||||
|
<Button size="sm" onClick={handleClickCancel} leftIcon={<PiXBold />}>
|
||||||
|
{t('common.cancel')}
|
||||||
|
</Button>
|
||||||
|
)}
|
||||||
|
{selectedModelMode === 'edit' && (
|
||||||
|
<Button
|
||||||
|
size="sm"
|
||||||
|
colorScheme="invokeYellow"
|
||||||
|
leftIcon={<PiCheckBold />}
|
||||||
|
onClick={form.handleSubmit(onSubmit)}
|
||||||
|
isLoading={isSubmitting}
|
||||||
|
isDisabled={Boolean(Object.keys(form.formState.errors).length)}
|
||||||
|
>
|
||||||
|
{t('common.save')}
|
||||||
|
</Button>
|
||||||
|
)}
|
||||||
</Flex>
|
</Flex>
|
||||||
{data.source && (
|
{data.source && (
|
||||||
<Text variant="subtext">
|
<Text variant="subtext">
|
||||||
@ -45,7 +121,7 @@ export const Model = () => {
|
|||||||
<Text noOfLines={3}>{data.description}</Text>
|
<Text noOfLines={3}>{data.description}</Text>
|
||||||
</Flex>
|
</Flex>
|
||||||
</Flex>
|
</Flex>
|
||||||
{selectedModelMode === 'view' ? <ModelView /> : <ModelEdit />}
|
{selectedModelMode === 'view' ? <ModelView /> : <ModelEdit form={form} onSubmit={onSubmit} />}
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -76,7 +76,7 @@ export const ModelConvertButton = (props: ModelConvertProps) => {
|
|||||||
isLoading={isLoading}
|
isLoading={isLoading}
|
||||||
flexShrink={0}
|
flexShrink={0}
|
||||||
>
|
>
|
||||||
🧨 {t('modelManager.convertToDiffusers')}
|
🧨 {t('modelManager.convert')}
|
||||||
</Button>
|
</Button>
|
||||||
<ConfirmationAlertDialog
|
<ConfirmationAlertDialog
|
||||||
title={`${t('modelManager.convert')} ${data?.name}`}
|
title={`${t('modelManager.convert')} ${data?.name}`}
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import {
|
import {
|
||||||
Button,
|
|
||||||
Checkbox,
|
Checkbox,
|
||||||
Flex,
|
Flex,
|
||||||
FormControl,
|
FormControl,
|
||||||
@ -11,87 +10,30 @@ import {
|
|||||||
Textarea,
|
Textarea,
|
||||||
} from '@invoke-ai/ui-library';
|
} from '@invoke-ai/ui-library';
|
||||||
import { skipToken } from '@reduxjs/toolkit/query';
|
import { skipToken } from '@reduxjs/toolkit/query';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { setSelectedModelMode } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
import type { SubmitHandler, UseFormReturn } from 'react-hook-form';
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
|
||||||
import { makeToast } from 'features/system/util/makeToast';
|
|
||||||
import { useCallback } from 'react';
|
|
||||||
import type { SubmitHandler } from 'react-hook-form';
|
|
||||||
import { useForm } from 'react-hook-form';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import type { UpdateModelArg } from 'services/api/endpoints/models';
|
import type { UpdateModelArg } from 'services/api/endpoints/models';
|
||||||
import { useGetModelConfigQuery, useUpdateModelMutation } from 'services/api/endpoints/models';
|
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
import BaseModelSelect from './Fields/BaseModelSelect';
|
import BaseModelSelect from './Fields/BaseModelSelect';
|
||||||
import ModelVariantSelect from './Fields/ModelVariantSelect';
|
import ModelVariantSelect from './Fields/ModelVariantSelect';
|
||||||
import PredictionTypeSelect from './Fields/PredictionTypeSelect';
|
import PredictionTypeSelect from './Fields/PredictionTypeSelect';
|
||||||
|
|
||||||
export const ModelEdit = () => {
|
type Props = {
|
||||||
const dispatch = useAppDispatch();
|
form: UseFormReturn<UpdateModelArg['body']>;
|
||||||
|
onSubmit: SubmitHandler<UpdateModelArg['body']>;
|
||||||
|
};
|
||||||
|
|
||||||
|
const stringFieldOptions = {
|
||||||
|
validate: (value?: string | null) => (value && value.trim().length > 3) || 'Must be at least 3 characters',
|
||||||
|
};
|
||||||
|
|
||||||
|
export const ModelEdit = ({ form }: Props) => {
|
||||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||||
const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
||||||
|
|
||||||
const [updateModel, { isLoading: isSubmitting }] = useUpdateModelMutation();
|
|
||||||
|
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const {
|
|
||||||
register,
|
|
||||||
handleSubmit,
|
|
||||||
control,
|
|
||||||
formState: { errors },
|
|
||||||
reset,
|
|
||||||
} = useForm<UpdateModelArg['body']>({
|
|
||||||
defaultValues: {
|
|
||||||
...data,
|
|
||||||
},
|
|
||||||
mode: 'onChange',
|
|
||||||
});
|
|
||||||
|
|
||||||
const onSubmit = useCallback<SubmitHandler<UpdateModelArg['body']>>(
|
|
||||||
(values) => {
|
|
||||||
if (!data?.key) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const responseBody: UpdateModelArg = {
|
|
||||||
key: data.key,
|
|
||||||
body: values,
|
|
||||||
};
|
|
||||||
|
|
||||||
updateModel(responseBody)
|
|
||||||
.unwrap()
|
|
||||||
.then((payload) => {
|
|
||||||
reset(payload, { keepDefaultValues: true });
|
|
||||||
dispatch(setSelectedModelMode('view'));
|
|
||||||
dispatch(
|
|
||||||
addToast(
|
|
||||||
makeToast({
|
|
||||||
title: t('modelManager.modelUpdated'),
|
|
||||||
status: 'success',
|
|
||||||
})
|
|
||||||
)
|
|
||||||
);
|
|
||||||
})
|
|
||||||
.catch((_) => {
|
|
||||||
reset();
|
|
||||||
dispatch(
|
|
||||||
addToast(
|
|
||||||
makeToast({
|
|
||||||
title: t('modelManager.modelUpdateFailed'),
|
|
||||||
status: 'error',
|
|
||||||
})
|
|
||||||
)
|
|
||||||
);
|
|
||||||
});
|
|
||||||
},
|
|
||||||
[dispatch, data?.key, reset, t, updateModel]
|
|
||||||
);
|
|
||||||
|
|
||||||
const handleClickCancel = useCallback(() => {
|
|
||||||
dispatch(setSelectedModelMode('view'));
|
|
||||||
}, [dispatch]);
|
|
||||||
|
|
||||||
if (isLoading) {
|
if (isLoading) {
|
||||||
return <Text>{t('common.loading')}</Text>;
|
return <Text>{t('common.loading')}</Text>;
|
||||||
}
|
}
|
||||||
@ -99,40 +41,26 @@ export const ModelEdit = () => {
|
|||||||
if (!data) {
|
if (!data) {
|
||||||
return <Text>{t('common.somethingWentWrong')}</Text>;
|
return <Text>{t('common.somethingWentWrong')}</Text>;
|
||||||
}
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex flexDir="column" h="full">
|
<Flex flexDir="column" h="full">
|
||||||
<form onSubmit={handleSubmit(onSubmit)}>
|
<form>
|
||||||
<Flex w="full" justifyContent="space-between" gap={4} alignItems="center">
|
<Flex w="full" justifyContent="space-between" gap={4} alignItems="center">
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1} isInvalid={Boolean(errors.name)}>
|
<FormControl flexDir="column" alignItems="flex-start" gap={1} isInvalid={Boolean(form.formState.errors.name)}>
|
||||||
<FormLabel>{t('modelManager.modelName')}</FormLabel>
|
<FormLabel>{t('modelManager.modelName')}</FormLabel>
|
||||||
<Input
|
<Input {...form.register('name', stringFieldOptions)} size="lg" />
|
||||||
{...register('name', {
|
|
||||||
validate: (value) => (value && value.trim().length > 3) || 'Must be at least 3 characters',
|
|
||||||
})}
|
|
||||||
size="lg"
|
|
||||||
/>
|
|
||||||
|
|
||||||
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
|
{form.formState.errors.name?.message && (
|
||||||
|
<FormErrorMessage>{form.formState.errors.name?.message}</FormErrorMessage>
|
||||||
|
)}
|
||||||
</FormControl>
|
</FormControl>
|
||||||
<Button size="sm" onClick={handleClickCancel}>
|
|
||||||
{t('common.cancel')}
|
|
||||||
</Button>
|
|
||||||
<Button
|
|
||||||
size="sm"
|
|
||||||
colorScheme="invokeYellow"
|
|
||||||
onClick={handleSubmit(onSubmit)}
|
|
||||||
isLoading={isSubmitting}
|
|
||||||
isDisabled={Boolean(Object.keys(errors).length)}
|
|
||||||
>
|
|
||||||
{t('common.save')}
|
|
||||||
</Button>
|
|
||||||
</Flex>
|
</Flex>
|
||||||
|
|
||||||
<Flex flexDir="column" gap={3} mt="4">
|
<Flex flexDir="column" gap={3} mt="4">
|
||||||
<Flex gap="4" alignItems="center">
|
<Flex gap="4" alignItems="center">
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
<FormLabel>{t('modelManager.description')}</FormLabel>
|
<FormLabel>{t('modelManager.description')}</FormLabel>
|
||||||
<Textarea fontSize="md" {...register('description')} />
|
<Textarea fontSize="md" {...form.register('description')} />
|
||||||
</FormControl>
|
</FormControl>
|
||||||
</Flex>
|
</Flex>
|
||||||
<Heading as="h3" fontSize="md" mt="4">
|
<Heading as="h3" fontSize="md" mt="4">
|
||||||
@ -141,7 +69,7 @@ export const ModelEdit = () => {
|
|||||||
<Flex gap={4}>
|
<Flex gap={4}>
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
<FormLabel>{t('modelManager.baseModel')}</FormLabel>
|
<FormLabel>{t('modelManager.baseModel')}</FormLabel>
|
||||||
<BaseModelSelect control={control} />
|
<BaseModelSelect control={form.control} />
|
||||||
</FormControl>
|
</FormControl>
|
||||||
</Flex>
|
</Flex>
|
||||||
{data.type === 'main' && data.format === 'checkpoint' && (
|
{data.type === 'main' && data.format === 'checkpoint' && (
|
||||||
@ -149,25 +77,21 @@ export const ModelEdit = () => {
|
|||||||
<Flex gap={4}>
|
<Flex gap={4}>
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
<FormLabel>{t('modelManager.pathToConfig')}</FormLabel>
|
<FormLabel>{t('modelManager.pathToConfig')}</FormLabel>
|
||||||
<Input
|
<Input {...form.register('config_path', stringFieldOptions)} />
|
||||||
{...register('config_path', {
|
|
||||||
validate: (value) => (value && value.trim().length > 3) || 'Must be at least 3 characters',
|
|
||||||
})}
|
|
||||||
/>
|
|
||||||
</FormControl>
|
</FormControl>
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
<FormLabel>{t('modelManager.variant')}</FormLabel>
|
<FormLabel>{t('modelManager.variant')}</FormLabel>
|
||||||
<ModelVariantSelect control={control} />
|
<ModelVariantSelect control={form.control} />
|
||||||
</FormControl>
|
</FormControl>
|
||||||
</Flex>
|
</Flex>
|
||||||
<Flex gap={4}>
|
<Flex gap={4}>
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
<FormLabel>{t('modelManager.predictionType')}</FormLabel>
|
<FormLabel>{t('modelManager.predictionType')}</FormLabel>
|
||||||
<PredictionTypeSelect control={control} />
|
<PredictionTypeSelect control={form.control} />
|
||||||
</FormControl>
|
</FormControl>
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
<FormLabel>{t('modelManager.upcastAttention')}</FormLabel>
|
<FormLabel>{t('modelManager.upcastAttention')}</FormLabel>
|
||||||
<Checkbox {...register('upcast_attention')} />
|
<Checkbox {...form.register('upcast_attention')} />
|
||||||
</FormControl>
|
</FormControl>
|
||||||
</Flex>
|
</Flex>
|
||||||
</>
|
</>
|
||||||
|
Loading…
x
Reference in New Issue
Block a user