From de7b059e670e02f49743ecea439b58e8c40b3452 Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Fri, 30 Jun 2023 08:51:27 +1200 Subject: [PATCH] feat: Port Checkpoint Edit to Mantine Form --- .../subpanels/ModelManagerPanel.tsx | 1 + .../ModelManagerPanel/CheckpointModelEdit.tsx | 363 +++++------------- .../ModelManagerPanel/ModelConvert.tsx | 29 +- .../ModelManagerPanel/ModelListItem.tsx | 25 +- 4 files changed, 102 insertions(+), 316 deletions(-) diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel.tsx index cdf39579ed..228fb79c2e 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel.tsx @@ -32,6 +32,7 @@ export default function ModelManagerPanel() { ); } diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx index 187268abbb..34a6d6885e 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx @@ -1,37 +1,37 @@ -import IAIButton from 'common/components/IAIButton'; -import IAIInput from 'common/components/IAIInput'; -import IAINumberInput from 'common/components/IAINumberInput'; -import { useEffect, useState } from 'react'; - import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { - Flex, - FormControl, - FormLabel, - HStack, - Text, - VStack, -} from '@chakra-ui/react'; +import { Divider, Flex, Text } from '@chakra-ui/react'; // import { addNewModel } from 'app/socketio/actions'; -import { Field, Formik } from 'formik'; +import { useForm } from '@mantine/form'; import { useTranslation } from 'react-i18next'; import type { RootState } from 'app/store/store'; -import type { InvokeModelConfigProps } from 'app/types/invokeai'; -import IAIForm from 'common/components/IAIForm'; -import IAIFormErrorMessage from 'common/components/IAIForms/IAIFormErrorMessage'; -import IAIFormHelperText from 'common/components/IAIForms/IAIFormHelperText'; -import type { FieldInputProps, FormikProps } from 'formik'; +import IAIButton from 'common/components/IAIButton'; +import IAIInput from 'common/components/IAIInput'; +import IAIMantineSelect from 'common/components/IAIMantineSelect'; +import { MODEL_TYPE_MAP } from 'features/system/components/ModelSelect'; +import { S } from 'services/api/types'; import ModelConvert from './ModelConvert'; -const MIN_MODEL_SIZE = 64; -const MAX_MODEL_SIZE = 2048; +const baseModelSelectData = [ + { value: 'sd-1', label: MODEL_TYPE_MAP['sd-1'] }, + { value: 'sd-2', label: MODEL_TYPE_MAP['sd-2'] }, +]; + +const variantSelectData = [ + { value: 'normal', label: 'Normal' }, + { value: 'inpaint', label: 'Inpaint' }, + { value: 'depth', label: 'Depth' }, +]; + +export type CheckpointModel = + | S<'StableDiffusion1ModelCheckpointConfig'> + | S<'StableDiffusion2ModelCheckpointConfig'>; type CheckpointModelEditProps = { modelToEdit: string; - retrievedModel: any; + retrievedModel: CheckpointModel; }; export default function CheckpointModelEdit(props: CheckpointModelEditProps) { @@ -42,268 +42,93 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) { const { modelToEdit, retrievedModel } = props; const dispatch = useAppDispatch(); - const { t } = useTranslation(); - const [editModelFormValues, setEditModelFormValues] = - useState({ - name: '', - description: '', - config: 'configs/stable-diffusion/v1-inference.yaml', - weights: '', - vae: '', - width: 512, - height: 512, - default: false, - model_format: 'ckpt', - }); + const checkpointEditForm = useForm({ + initialValues: { + name: retrievedModel.name, + base_model: retrievedModel.base_model, + type: 'main', + path: retrievedModel.path, + description: retrievedModel.description, + model_format: 'checkpoint', + vae: retrievedModel.vae, + config: retrievedModel.config, + variant: retrievedModel.variant, + }, + }); - useEffect(() => { - if (modelToEdit) { - setEditModelFormValues({ - name: modelToEdit, - description: retrievedModel?.description, - config: retrievedModel?.config, - weights: retrievedModel?.weights, - vae: retrievedModel?.vae, - width: retrievedModel?.width, - height: retrievedModel?.height, - default: retrievedModel?.default, - model_format: 'ckpt', - }); - } - }, [retrievedModel, modelToEdit]); - - const editModelFormSubmitHandler = (values: InvokeModelConfigProps) => { - dispatch( - addNewModel({ - ...values, - width: Number(values.width), - height: Number(values.height), - }) - ); + const editModelFormSubmitHandler = (values) => { + console.log(values); }; return modelToEdit ? ( - - - {modelToEdit} - - + + + + {retrievedModel.name} + + + {MODEL_TYPE_MAP[retrievedModel.base_model]} Model + + + + + - - {({ handleSubmit, errors, touched }) => ( - - - {/* Description */} - - - {t('modelManager.description')} - - - - {!!errors.description && touched.description ? ( - - {errors.description} - - ) : ( - - {t('modelManager.descriptionValidationMsg')} - - )} - - - - {/* Config */} - - - {t('modelManager.config')} - - - - {!!errors.config && touched.config ? ( - {errors.config} - ) : ( - - {t('modelManager.configValidationMsg')} - - )} - - - - {/* Weights */} - - - {t('modelManager.modelLocation')} - - - - {!!errors.weights && touched.weights ? ( - - {errors.weights} - - ) : ( - - {t('modelManager.modelLocationValidationMsg')} - - )} - - - - {/* VAE */} - - - {t('modelManager.vaeLocation')} - - - - {!!errors.vae && touched.vae ? ( - {errors.vae} - ) : ( - - {t('modelManager.vaeLocationValidationMsg')} - - )} - - - - - {/* Width */} - - - {t('modelManager.width')} - - - - {({ - field, - form, - }: { - field: FieldInputProps; - form: FormikProps; - }) => ( - - form.setFieldValue(field.name, Number(value)) - } - /> - )} - - - {!!errors.width && touched.width ? ( - - {errors.width} - - ) : ( - - {t('modelManager.widthValidationMsg')} - - )} - - - - {/* Height */} - - - {t('modelManager.height')} - - - - {({ - field, - form, - }: { - field: FieldInputProps; - form: FormikProps; - }) => ( - - form.setFieldValue(field.name, Number(value)) - } - /> - )} - - - {!!errors.height && touched.height ? ( - - {errors.height} - - ) : ( - - {t('modelManager.heightValidationMsg')} - - )} - - - - - - {t('modelManager.updateModel')} - - - +
+ editModelFormSubmitHandler(values) )} - + > + + + + + + + + + + {t('modelManager.updateModel')} + + +
) : ( diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelConvert.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelConvert.tsx index 820ad546b3..56a668d0e2 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelConvert.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelConvert.tsx @@ -4,42 +4,28 @@ import { Radio, RadioGroup, Text, - UnorderedList, Tooltip, + UnorderedList, } from '@chakra-ui/react'; // import { convertToDiffusers } from 'app/socketio/actions'; -import { RootState } from 'app/store/store'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { useAppDispatch } from 'app/store/storeHooks'; import IAIAlertDialog from 'common/components/IAIAlertDialog'; import IAIButton from 'common/components/IAIButton'; import IAIInput from 'common/components/IAIInput'; -import { useState, useEffect } from 'react'; +import { useEffect, useState } from 'react'; import { useTranslation } from 'react-i18next'; +import { CheckpointModel } from './CheckpointModelEdit'; interface ModelConvertProps { - model: string; + model: CheckpointModel; } export default function ModelConvert(props: ModelConvertProps) { const { model } = props; - const model_list = useAppSelector( - (state: RootState) => state.system.model_list - ); - - const retrievedModel = model_list[model]; - const dispatch = useAppDispatch(); const { t } = useTranslation(); - const isProcessing = useAppSelector( - (state: RootState) => state.system.isProcessing - ); - - const isConnected = useAppSelector( - (state: RootState) => state.system.isConnected - ); - const [saveLocation, setSaveLocation] = useState('same'); const [customSaveLocation, setCustomSaveLocation] = useState(''); @@ -65,7 +51,7 @@ export default function ModelConvert(props: ModelConvertProps) { return ( diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelListItem.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelListItem.tsx index e1b3bbab1e..ab5fddd5ea 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelListItem.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelListItem.tsx @@ -1,5 +1,5 @@ import { DeleteIcon, EditIcon } from '@chakra-ui/icons'; -import { Box, Button, Flex, Spacer, Text, Tooltip } from '@chakra-ui/react'; +import { Box, Flex, Spacer, Text, Tooltip } from '@chakra-ui/react'; // import { deleteModel, requestModelChange } from 'app/socketio/actions'; import { RootState } from 'app/store/store'; @@ -30,10 +30,6 @@ export default function ModelListItem(props: ModelListItemProps) { const { modelKey, name, description } = props; - const handleChangeModel = () => { - dispatch(requestModelChange(modelKey)); - }; - const openModelHandler = () => { dispatch(setOpenModel(modelKey)); }; @@ -43,17 +39,6 @@ export default function ModelListItem(props: ModelListItemProps) { dispatch(setOpenModel(null)); }; - const statusTextColor = () => { - switch (status) { - case 'active': - return 'ok.500'; - case 'cached': - return 'warning.500'; - case 'not loaded': - return 'inherit'; - } - }; - return ( - - } size="sm"