feat: Port Checkpoint Edit to Mantine Form

This commit is contained in:
blessedcoolant 2023-06-30 08:51:27 +12:00 committed by psychedelicious
parent 33db4e27a0
commit de7b059e67
4 changed files with 102 additions and 316 deletions

View File

@ -32,6 +32,7 @@ export default function ModelManagerPanel() {
<CheckpointModelEdit <CheckpointModelEdit
modelToEdit={openModel} modelToEdit={openModel}
retrievedModel={mainModels['entities'][openModel]} retrievedModel={mainModels['entities'][openModel]}
key={openModel}
/> />
); );
} }

View File

@ -1,37 +1,37 @@
import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput';
import IAINumberInput from 'common/components/IAINumberInput';
import { useEffect, useState } from 'react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { import { Divider, Flex, Text } from '@chakra-ui/react';
Flex,
FormControl,
FormLabel,
HStack,
Text,
VStack,
} from '@chakra-ui/react';
// import { addNewModel } from 'app/socketio/actions'; // import { addNewModel } from 'app/socketio/actions';
import { Field, Formik } from 'formik'; import { useForm } from '@mantine/form';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import type { RootState } from 'app/store/store'; import type { RootState } from 'app/store/store';
import type { InvokeModelConfigProps } from 'app/types/invokeai'; import IAIButton from 'common/components/IAIButton';
import IAIForm from 'common/components/IAIForm'; import IAIInput from 'common/components/IAIInput';
import IAIFormErrorMessage from 'common/components/IAIForms/IAIFormErrorMessage'; import IAIMantineSelect from 'common/components/IAIMantineSelect';
import IAIFormHelperText from 'common/components/IAIForms/IAIFormHelperText'; import { MODEL_TYPE_MAP } from 'features/system/components/ModelSelect';
import type { FieldInputProps, FormikProps } from 'formik'; import { S } from 'services/api/types';
import ModelConvert from './ModelConvert'; import ModelConvert from './ModelConvert';
const MIN_MODEL_SIZE = 64; const baseModelSelectData = [
const MAX_MODEL_SIZE = 2048; { value: 'sd-1', label: MODEL_TYPE_MAP['sd-1'] },
{ value: 'sd-2', label: MODEL_TYPE_MAP['sd-2'] },
];
const variantSelectData = [
{ value: 'normal', label: 'Normal' },
{ value: 'inpaint', label: 'Inpaint' },
{ value: 'depth', label: 'Depth' },
];
export type CheckpointModel =
| S<'StableDiffusion1ModelCheckpointConfig'>
| S<'StableDiffusion2ModelCheckpointConfig'>;
type CheckpointModelEditProps = { type CheckpointModelEditProps = {
modelToEdit: string; modelToEdit: string;
retrievedModel: any; retrievedModel: CheckpointModel;
}; };
export default function CheckpointModelEdit(props: CheckpointModelEditProps) { export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
@ -42,268 +42,93 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
const { modelToEdit, retrievedModel } = props; const { modelToEdit, retrievedModel } = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const [editModelFormValues, setEditModelFormValues] = const checkpointEditForm = useForm({
useState<InvokeModelConfigProps>({ initialValues: {
name: '', name: retrievedModel.name,
description: '', base_model: retrievedModel.base_model,
config: 'configs/stable-diffusion/v1-inference.yaml', type: 'main',
weights: '', path: retrievedModel.path,
vae: '', description: retrievedModel.description,
width: 512, model_format: 'checkpoint',
height: 512, vae: retrievedModel.vae,
default: false, config: retrievedModel.config,
model_format: 'ckpt', variant: retrievedModel.variant,
},
}); });
useEffect(() => { const editModelFormSubmitHandler = (values) => {
if (modelToEdit) { console.log(values);
setEditModelFormValues({
name: modelToEdit,
description: retrievedModel?.description,
config: retrievedModel?.config,
weights: retrievedModel?.weights,
vae: retrievedModel?.vae,
width: retrievedModel?.width,
height: retrievedModel?.height,
default: retrievedModel?.default,
model_format: 'ckpt',
});
}
}, [retrievedModel, modelToEdit]);
const editModelFormSubmitHandler = (values: InvokeModelConfigProps) => {
dispatch(
addNewModel({
...values,
width: Number(values.width),
height: Number(values.height),
})
);
}; };
return modelToEdit ? ( return modelToEdit ? (
<Flex flexDirection="column" rowGap={4} width="100%"> <Flex flexDirection="column" rowGap={4} width="100%">
<Flex alignItems="center" gap={4} justifyContent="space-between"> <Flex justifyContent="space-between" alignItems="center">
<Flex flexDirection="column">
<Text fontSize="lg" fontWeight="bold"> <Text fontSize="lg" fontWeight="bold">
{modelToEdit} {retrievedModel.name}
</Text>
<Text fontSize="sm" color="base.400">
{MODEL_TYPE_MAP[retrievedModel.base_model]} Model
</Text> </Text>
<ModelConvert model={modelToEdit} />
</Flex> </Flex>
<ModelConvert model={retrievedModel} />
</Flex>
<Divider />
<Flex <Flex
flexDirection="column" flexDirection="column"
maxHeight={window.innerHeight - 270} maxHeight={window.innerHeight - 270}
overflowY="scroll" overflowY="scroll"
paddingInlineEnd={8} paddingInlineEnd={8}
> >
<Formik <form
enableReinitialize={true} onSubmit={checkpointEditForm.onSubmit((values) =>
initialValues={editModelFormValues} editModelFormSubmitHandler(values)
onSubmit={editModelFormSubmitHandler} )}
> >
{({ handleSubmit, errors, touched }) => ( <Flex
<IAIForm onSubmit={handleSubmit}> flexDirection="column"
<VStack rowGap={2} alignItems="start"> overflowY="scroll"
{/* Description */} gap={4}
<FormControl paddingInlineEnd={8}
isInvalid={!!errors.description && touched.description}
isRequired
> >
<FormLabel htmlFor="description" fontSize="sm"> <IAIInput
{t('modelManager.description')} label={t('modelManager.name')}
</FormLabel> {...checkpointEditForm.getInputProps('name')}
<VStack alignItems="start">
<Field
as={IAIInput}
id="description"
name="description"
type="text"
width="full"
/> />
{!!errors.description && touched.description ? ( <IAIInput
<IAIFormErrorMessage> label={t('modelManager.description')}
{errors.description} {...checkpointEditForm.getInputProps('description')}
</IAIFormErrorMessage>
) : (
<IAIFormHelperText>
{t('modelManager.descriptionValidationMsg')}
</IAIFormHelperText>
)}
</VStack>
</FormControl>
{/* Config */}
<FormControl
isInvalid={!!errors.config && touched.config}
isRequired
>
<FormLabel htmlFor="config" fontSize="sm">
{t('modelManager.config')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="config"
name="config"
type="text"
width="full"
/> />
{!!errors.config && touched.config ? ( <IAIMantineSelect
<IAIFormErrorMessage>{errors.config}</IAIFormErrorMessage> label={t('modelManager.baseModel')}
) : ( data={baseModelSelectData}
<IAIFormHelperText> {...checkpointEditForm.getInputProps('base_model')}
{t('modelManager.configValidationMsg')}
</IAIFormHelperText>
)}
</VStack>
</FormControl>
{/* Weights */}
<FormControl
isInvalid={!!errors.weights && touched.weights}
isRequired
>
<FormLabel htmlFor="config" fontSize="sm">
{t('modelManager.modelLocation')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="weights"
name="weights"
type="text"
width="full"
/> />
{!!errors.weights && touched.weights ? ( <IAIMantineSelect
<IAIFormErrorMessage> label={t('modelManager.variant')}
{errors.weights} data={variantSelectData}
</IAIFormErrorMessage> {...checkpointEditForm.getInputProps('variant')}
) : (
<IAIFormHelperText>
{t('modelManager.modelLocationValidationMsg')}
</IAIFormHelperText>
)}
</VStack>
</FormControl>
{/* VAE */}
<FormControl isInvalid={!!errors.vae && touched.vae}>
<FormLabel htmlFor="vae" fontSize="sm">
{t('modelManager.vaeLocation')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="vae"
name="vae"
type="text"
width="full"
/> />
{!!errors.vae && touched.vae ? ( <IAIInput
<IAIFormErrorMessage>{errors.vae}</IAIFormErrorMessage> label={t('modelManager.modelLocation')}
) : ( {...checkpointEditForm.getInputProps('path')}
<IAIFormHelperText>
{t('modelManager.vaeLocationValidationMsg')}
</IAIFormHelperText>
)}
</VStack>
</FormControl>
<HStack width="100%">
{/* Width */}
<FormControl isInvalid={!!errors.width && touched.width}>
<FormLabel htmlFor="width" fontSize="sm">
{t('modelManager.width')}
</FormLabel>
<VStack alignItems="start">
<Field id="width" name="width">
{({
field,
form,
}: {
field: FieldInputProps<number>;
form: FormikProps<InvokeModelConfigProps>;
}) => (
<IAINumberInput
id="width"
name="width"
min={MIN_MODEL_SIZE}
max={MAX_MODEL_SIZE}
step={64}
value={form.values.width}
onChange={(value) =>
form.setFieldValue(field.name, Number(value))
}
/> />
)} <IAIInput
</Field> label={t('modelManager.vaeLocation')}
{...checkpointEditForm.getInputProps('vae')}
{!!errors.width && touched.width ? (
<IAIFormErrorMessage>
{errors.width}
</IAIFormErrorMessage>
) : (
<IAIFormHelperText>
{t('modelManager.widthValidationMsg')}
</IAIFormHelperText>
)}
</VStack>
</FormControl>
{/* Height */}
<FormControl isInvalid={!!errors.height && touched.height}>
<FormLabel htmlFor="height" fontSize="sm">
{t('modelManager.height')}
</FormLabel>
<VStack alignItems="start">
<Field id="height" name="height">
{({
field,
form,
}: {
field: FieldInputProps<number>;
form: FormikProps<InvokeModelConfigProps>;
}) => (
<IAINumberInput
id="height"
name="height"
min={MIN_MODEL_SIZE}
max={MAX_MODEL_SIZE}
step={64}
value={form.values.height}
onChange={(value) =>
form.setFieldValue(field.name, Number(value))
}
/> />
)} <IAIInput
</Field> label={t('modelManager.config')}
{...checkpointEditForm.getInputProps('config')}
{!!errors.height && touched.height ? ( />
<IAIFormErrorMessage> <IAIButton disabled={isProcessing} type="submit">
{errors.height}
</IAIFormErrorMessage>
) : (
<IAIFormHelperText>
{t('modelManager.heightValidationMsg')}
</IAIFormHelperText>
)}
</VStack>
</FormControl>
</HStack>
<IAIButton
type="submit"
className="modal-close-btn"
isLoading={isProcessing}
>
{t('modelManager.updateModel')} {t('modelManager.updateModel')}
</IAIButton> </IAIButton>
</VStack> </Flex>
</IAIForm> </form>
)}
</Formik>
</Flex> </Flex>
</Flex> </Flex>
) : ( ) : (

View File

@ -4,42 +4,28 @@ import {
Radio, Radio,
RadioGroup, RadioGroup,
Text, Text,
UnorderedList,
Tooltip, Tooltip,
UnorderedList,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
// import { convertToDiffusers } from 'app/socketio/actions'; // import { convertToDiffusers } from 'app/socketio/actions';
import { RootState } from 'app/store/store'; import { useAppDispatch } from 'app/store/storeHooks';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIAlertDialog from 'common/components/IAIAlertDialog'; import IAIAlertDialog from 'common/components/IAIAlertDialog';
import IAIButton from 'common/components/IAIButton'; import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput'; import IAIInput from 'common/components/IAIInput';
import { useState, useEffect } from 'react'; import { useEffect, useState } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { CheckpointModel } from './CheckpointModelEdit';
interface ModelConvertProps { interface ModelConvertProps {
model: string; model: CheckpointModel;
} }
export default function ModelConvert(props: ModelConvertProps) { export default function ModelConvert(props: ModelConvertProps) {
const { model } = props; const { model } = props;
const model_list = useAppSelector(
(state: RootState) => state.system.model_list
);
const retrievedModel = model_list[model];
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const isProcessing = useAppSelector(
(state: RootState) => state.system.isProcessing
);
const isConnected = useAppSelector(
(state: RootState) => state.system.isConnected
);
const [saveLocation, setSaveLocation] = useState<string>('same'); const [saveLocation, setSaveLocation] = useState<string>('same');
const [customSaveLocation, setCustomSaveLocation] = useState<string>(''); const [customSaveLocation, setCustomSaveLocation] = useState<string>('');
@ -65,7 +51,7 @@ export default function ModelConvert(props: ModelConvertProps) {
return ( return (
<IAIAlertDialog <IAIAlertDialog
title={`${t('modelManager.convert')} ${model}`} title={`${t('modelManager.convert')} ${model.name}`}
acceptCallback={modelConvertHandler} acceptCallback={modelConvertHandler}
cancelCallback={modelConvertCancelHandler} cancelCallback={modelConvertCancelHandler}
acceptButtonText={`${t('modelManager.convert')}`} acceptButtonText={`${t('modelManager.convert')}`}
@ -73,9 +59,6 @@ export default function ModelConvert(props: ModelConvertProps) {
<IAIButton <IAIButton
size={'sm'} size={'sm'}
aria-label={t('modelManager.convertToDiffusers')} aria-label={t('modelManager.convertToDiffusers')}
isDisabled={
retrievedModel.status === 'active' || isProcessing || !isConnected
}
className=" modal-close-btn" className=" modal-close-btn"
marginInlineEnd={8} marginInlineEnd={8}
> >

View File

@ -1,5 +1,5 @@
import { DeleteIcon, EditIcon } from '@chakra-ui/icons'; import { DeleteIcon, EditIcon } from '@chakra-ui/icons';
import { Box, Button, Flex, Spacer, Text, Tooltip } from '@chakra-ui/react'; import { Box, Flex, Spacer, Text, Tooltip } from '@chakra-ui/react';
// import { deleteModel, requestModelChange } from 'app/socketio/actions'; // import { deleteModel, requestModelChange } from 'app/socketio/actions';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
@ -30,10 +30,6 @@ export default function ModelListItem(props: ModelListItemProps) {
const { modelKey, name, description } = props; const { modelKey, name, description } = props;
const handleChangeModel = () => {
dispatch(requestModelChange(modelKey));
};
const openModelHandler = () => { const openModelHandler = () => {
dispatch(setOpenModel(modelKey)); dispatch(setOpenModel(modelKey));
}; };
@ -43,17 +39,6 @@ export default function ModelListItem(props: ModelListItemProps) {
dispatch(setOpenModel(null)); dispatch(setOpenModel(null));
}; };
const statusTextColor = () => {
switch (status) {
case 'active':
return 'ok.500';
case 'cached':
return 'warning.500';
case 'not loaded':
return 'inherit';
}
};
return ( return (
<Flex <Flex
alignItems="center" alignItems="center"
@ -81,14 +66,6 @@ export default function ModelListItem(props: ModelListItemProps) {
</Box> </Box>
<Spacer onClick={openModelHandler} cursor="pointer" /> <Spacer onClick={openModelHandler} cursor="pointer" />
<Flex gap={2} alignItems="center"> <Flex gap={2} alignItems="center">
<Button
size="sm"
onClick={handleChangeModel}
isDisabled={status === 'active' || isProcessing || !isConnected}
>
{t('modelManager.load')}
</Button>
<IAIIconButton <IAIIconButton
icon={<EditIcon />} icon={<EditIcon />}
size="sm" size="sm"