feat(ui): move model save/close buttons to model header

This commit is contained in:
psychedelicious 2024-03-07 16:59:17 +11:00
parent c008704bc8
commit 6386109fc5
6 changed files with 114 additions and 113 deletions

View File

@ -21,6 +21,7 @@ export const SyncModelsButton = memo((props: Omit<ButtonProps, 'aria-label'>) =>
leftIcon={<PiArrowsClockwiseBold />}
isLoading={isLoading}
onClick={syncModels}
size="sm"
variant="ghost"
{...props}
>

View File

@ -21,8 +21,8 @@ export const ModelManager = () => {
<Flex w="full" gap={4} justifyContent="space-between" alignItems="center">
<Heading fontSize="xl">{t('common.modelManager')}</Heading>
<Spacer />
<SyncModelsButton />
<Button colorScheme="invokeYellow" leftIcon={<PiPlusBold />} onClick={handleClickAddModel}>
<SyncModelsButton size="sm" />
<Button size="sm" colorScheme="invokeYellow" leftIcon={<PiPlusBold />} onClick={handleClickAddModel}>
{t('modelManager.addModels')}
</Button>
</Flex>

View File

@ -34,7 +34,7 @@ export const ModelTypeFilter = () => {
return (
<Menu>
<MenuButton as={Button} leftIcon={<IoFilter />}>
<MenuButton as={Button} size="sm" leftIcon={<IoFilter />}>
{filteredModelType ? MODEL_TYPE_LABELS[filteredModelType] : t('modelManager.allModels')}
</MenuButton>
<MenuList>

View File

@ -1,10 +1,18 @@
import { Flex, Heading, Spacer, Text } from '@invoke-ai/ui-library';
import { Button, Flex, Heading, Spacer, Text } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppSelector } from 'app/store/storeHooks';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { setSelectedModelMode } from 'features/modelManagerV2/store/modelManagerV2Slice';
import { ModelConvertButton } from 'features/modelManagerV2/subpanels/ModelPanel/ModelConvertButton';
import { ModelEditButton } from 'features/modelManagerV2/subpanels/ModelPanel/ModelEditButton';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { useCallback } from 'react';
import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
import { PiCheckBold, PiXBold } from 'react-icons/pi';
import type { UpdateModelArg } from 'services/api/endpoints/models';
import { useGetModelConfigQuery, useUpdateModelMutation } from 'services/api/endpoints/models';
import ModelImageUpload from './Fields/ModelImageUpload';
import { ModelEdit } from './ModelEdit';
@ -15,6 +23,57 @@ export const Model = () => {
const selectedModelMode = useAppSelector((s) => s.modelmanagerV2.selectedModelMode);
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
const [updateModel, { isLoading: isSubmitting }] = useUpdateModelMutation();
const dispatch = useAppDispatch();
const form = useForm<UpdateModelArg['body']>({
defaultValues: data,
mode: 'onChange',
});
const onSubmit = useCallback<SubmitHandler<UpdateModelArg['body']>>(
(values) => {
if (!data?.key) {
return;
}
const responseBody: UpdateModelArg = {
key: data.key,
body: values,
};
updateModel(responseBody)
.unwrap()
.then((payload) => {
form.reset(payload, { keepDefaultValues: true });
dispatch(setSelectedModelMode('view'));
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdated'),
status: 'success',
})
)
);
})
.catch((_) => {
form.reset();
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdateFailed'),
status: 'error',
})
)
);
});
},
[dispatch, data?.key, form, t, updateModel]
);
const handleClickCancel = useCallback(() => {
dispatch(setSelectedModelMode('view'));
}, [dispatch]);
if (isLoading) {
return <Text>{t('common.loading')}</Text>;
@ -30,12 +89,29 @@ export const Model = () => {
<ModelImageUpload model_key={selectedModelKey} model_image={data.cover_image} />
<Flex flexDir="column" gap={1} flexGrow={1}>
<Flex gap={2}>
<Heading as="h2" fontSize="lg" w="full">
<Heading as="h2" fontSize="lg">
{data.name}
</Heading>
<Spacer />
<ModelEditButton />
<ModelConvertButton modelKey={selectedModelKey} />
{selectedModelMode === 'view' && <ModelConvertButton modelKey={selectedModelKey} />}
{selectedModelMode === 'view' && <ModelEditButton />}
{selectedModelMode === 'edit' && (
<Button size="sm" onClick={handleClickCancel} leftIcon={<PiXBold />}>
{t('common.cancel')}
</Button>
)}
{selectedModelMode === 'edit' && (
<Button
size="sm"
colorScheme="invokeYellow"
leftIcon={<PiCheckBold />}
onClick={form.handleSubmit(onSubmit)}
isLoading={isSubmitting}
isDisabled={Boolean(Object.keys(form.formState.errors).length)}
>
{t('common.save')}
</Button>
)}
</Flex>
{data.source && (
<Text variant="subtext">
@ -45,7 +121,7 @@ export const Model = () => {
<Text noOfLines={3}>{data.description}</Text>
</Flex>
</Flex>
{selectedModelMode === 'view' ? <ModelView /> : <ModelEdit />}
{selectedModelMode === 'view' ? <ModelView /> : <ModelEdit form={form} onSubmit={onSubmit} />}
</Flex>
);
};

View File

@ -76,7 +76,7 @@ export const ModelConvertButton = (props: ModelConvertProps) => {
isLoading={isLoading}
flexShrink={0}
>
🧨 {t('modelManager.convertToDiffusers')}
🧨 {t('modelManager.convert')}
</Button>
<ConfirmationAlertDialog
title={`${t('modelManager.convert')} ${data?.name}`}

View File

@ -1,5 +1,4 @@
import {
Button,
Checkbox,
Flex,
FormControl,
@ -11,87 +10,30 @@ import {
Textarea,
} from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { setSelectedModelMode } from 'features/modelManagerV2/store/modelManagerV2Slice';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { useCallback } from 'react';
import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form';
import { useAppSelector } from 'app/store/storeHooks';
import type { SubmitHandler, UseFormReturn } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import type { UpdateModelArg } from 'services/api/endpoints/models';
import { useGetModelConfigQuery, useUpdateModelMutation } from 'services/api/endpoints/models';
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
import BaseModelSelect from './Fields/BaseModelSelect';
import ModelVariantSelect from './Fields/ModelVariantSelect';
import PredictionTypeSelect from './Fields/PredictionTypeSelect';
export const ModelEdit = () => {
const dispatch = useAppDispatch();
type Props = {
form: UseFormReturn<UpdateModelArg['body']>;
onSubmit: SubmitHandler<UpdateModelArg['body']>;
};
const stringFieldOptions = {
validate: (value?: string | null) => (value && value.trim().length > 3) || 'Must be at least 3 characters',
};
export const ModelEdit = ({ form }: Props) => {
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
const [updateModel, { isLoading: isSubmitting }] = useUpdateModelMutation();
const { t } = useTranslation();
const {
register,
handleSubmit,
control,
formState: { errors },
reset,
} = useForm<UpdateModelArg['body']>({
defaultValues: {
...data,
},
mode: 'onChange',
});
const onSubmit = useCallback<SubmitHandler<UpdateModelArg['body']>>(
(values) => {
if (!data?.key) {
return;
}
const responseBody: UpdateModelArg = {
key: data.key,
body: values,
};
updateModel(responseBody)
.unwrap()
.then((payload) => {
reset(payload, { keepDefaultValues: true });
dispatch(setSelectedModelMode('view'));
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdated'),
status: 'success',
})
)
);
})
.catch((_) => {
reset();
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdateFailed'),
status: 'error',
})
)
);
});
},
[dispatch, data?.key, reset, t, updateModel]
);
const handleClickCancel = useCallback(() => {
dispatch(setSelectedModelMode('view'));
}, [dispatch]);
if (isLoading) {
return <Text>{t('common.loading')}</Text>;
}
@ -99,40 +41,26 @@ export const ModelEdit = () => {
if (!data) {
return <Text>{t('common.somethingWentWrong')}</Text>;
}
return (
<Flex flexDir="column" h="full">
<form onSubmit={handleSubmit(onSubmit)}>
<form>
<Flex w="full" justifyContent="space-between" gap={4} alignItems="center">
<FormControl flexDir="column" alignItems="flex-start" gap={1} isInvalid={Boolean(errors.name)}>
<FormControl flexDir="column" alignItems="flex-start" gap={1} isInvalid={Boolean(form.formState.errors.name)}>
<FormLabel>{t('modelManager.modelName')}</FormLabel>
<Input
{...register('name', {
validate: (value) => (value && value.trim().length > 3) || 'Must be at least 3 characters',
})}
size="lg"
/>
<Input {...form.register('name', stringFieldOptions)} size="lg" />
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
{form.formState.errors.name?.message && (
<FormErrorMessage>{form.formState.errors.name?.message}</FormErrorMessage>
)}
</FormControl>
<Button size="sm" onClick={handleClickCancel}>
{t('common.cancel')}
</Button>
<Button
size="sm"
colorScheme="invokeYellow"
onClick={handleSubmit(onSubmit)}
isLoading={isSubmitting}
isDisabled={Boolean(Object.keys(errors).length)}
>
{t('common.save')}
</Button>
</Flex>
<Flex flexDir="column" gap={3} mt="4">
<Flex gap="4" alignItems="center">
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.description')}</FormLabel>
<Textarea fontSize="md" {...register('description')} />
<Textarea fontSize="md" {...form.register('description')} />
</FormControl>
</Flex>
<Heading as="h3" fontSize="md" mt="4">
@ -141,7 +69,7 @@ export const ModelEdit = () => {
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.baseModel')}</FormLabel>
<BaseModelSelect control={control} />
<BaseModelSelect control={form.control} />
</FormControl>
</Flex>
{data.type === 'main' && data.format === 'checkpoint' && (
@ -149,25 +77,21 @@ export const ModelEdit = () => {
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.pathToConfig')}</FormLabel>
<Input
{...register('config_path', {
validate: (value) => (value && value.trim().length > 3) || 'Must be at least 3 characters',
})}
/>
<Input {...form.register('config_path', stringFieldOptions)} />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.variant')}</FormLabel>
<ModelVariantSelect control={control} />
<ModelVariantSelect control={form.control} />
</FormControl>
</Flex>
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.predictionType')}</FormLabel>
<PredictionTypeSelect control={control} />
<PredictionTypeSelect control={form.control} />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.upcastAttention')}</FormLabel>
<Checkbox {...register('upcast_attention')} />
<Checkbox {...form.register('upcast_attention')} />
</FormControl>
</Flex>
</>