From de7b059e670e02f49743ecea439b58e8c40b3452 Mon Sep 17 00:00:00 2001
From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com>
Date: Fri, 30 Jun 2023 08:51:27 +1200
Subject: [PATCH] feat: Port Checkpoint Edit to Mantine Form
---
.../subpanels/ModelManagerPanel.tsx | 1 +
.../ModelManagerPanel/CheckpointModelEdit.tsx | 363 +++++-------------
.../ModelManagerPanel/ModelConvert.tsx | 29 +-
.../ModelManagerPanel/ModelListItem.tsx | 25 +-
4 files changed, 102 insertions(+), 316 deletions(-)
diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel.tsx
index cdf39579ed..228fb79c2e 100644
--- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel.tsx
+++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel.tsx
@@ -32,6 +32,7 @@ export default function ModelManagerPanel() {
);
}
diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx
index 187268abbb..34a6d6885e 100644
--- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx
+++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx
@@ -1,37 +1,37 @@
-import IAIButton from 'common/components/IAIButton';
-import IAIInput from 'common/components/IAIInput';
-import IAINumberInput from 'common/components/IAINumberInput';
-import { useEffect, useState } from 'react';
-
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
-import {
- Flex,
- FormControl,
- FormLabel,
- HStack,
- Text,
- VStack,
-} from '@chakra-ui/react';
+import { Divider, Flex, Text } from '@chakra-ui/react';
// import { addNewModel } from 'app/socketio/actions';
-import { Field, Formik } from 'formik';
+import { useForm } from '@mantine/form';
import { useTranslation } from 'react-i18next';
import type { RootState } from 'app/store/store';
-import type { InvokeModelConfigProps } from 'app/types/invokeai';
-import IAIForm from 'common/components/IAIForm';
-import IAIFormErrorMessage from 'common/components/IAIForms/IAIFormErrorMessage';
-import IAIFormHelperText from 'common/components/IAIForms/IAIFormHelperText';
-import type { FieldInputProps, FormikProps } from 'formik';
+import IAIButton from 'common/components/IAIButton';
+import IAIInput from 'common/components/IAIInput';
+import IAIMantineSelect from 'common/components/IAIMantineSelect';
+import { MODEL_TYPE_MAP } from 'features/system/components/ModelSelect';
+import { S } from 'services/api/types';
import ModelConvert from './ModelConvert';
-const MIN_MODEL_SIZE = 64;
-const MAX_MODEL_SIZE = 2048;
+const baseModelSelectData = [
+ { value: 'sd-1', label: MODEL_TYPE_MAP['sd-1'] },
+ { value: 'sd-2', label: MODEL_TYPE_MAP['sd-2'] },
+];
+
+const variantSelectData = [
+ { value: 'normal', label: 'Normal' },
+ { value: 'inpaint', label: 'Inpaint' },
+ { value: 'depth', label: 'Depth' },
+];
+
+export type CheckpointModel =
+ | S<'StableDiffusion1ModelCheckpointConfig'>
+ | S<'StableDiffusion2ModelCheckpointConfig'>;
type CheckpointModelEditProps = {
modelToEdit: string;
- retrievedModel: any;
+ retrievedModel: CheckpointModel;
};
export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
@@ -42,268 +42,93 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
const { modelToEdit, retrievedModel } = props;
const dispatch = useAppDispatch();
-
const { t } = useTranslation();
- const [editModelFormValues, setEditModelFormValues] =
- useState({
- name: '',
- description: '',
- config: 'configs/stable-diffusion/v1-inference.yaml',
- weights: '',
- vae: '',
- width: 512,
- height: 512,
- default: false,
- model_format: 'ckpt',
- });
+ const checkpointEditForm = useForm({
+ initialValues: {
+ name: retrievedModel.name,
+ base_model: retrievedModel.base_model,
+ type: 'main',
+ path: retrievedModel.path,
+ description: retrievedModel.description,
+ model_format: 'checkpoint',
+ vae: retrievedModel.vae,
+ config: retrievedModel.config,
+ variant: retrievedModel.variant,
+ },
+ });
- useEffect(() => {
- if (modelToEdit) {
- setEditModelFormValues({
- name: modelToEdit,
- description: retrievedModel?.description,
- config: retrievedModel?.config,
- weights: retrievedModel?.weights,
- vae: retrievedModel?.vae,
- width: retrievedModel?.width,
- height: retrievedModel?.height,
- default: retrievedModel?.default,
- model_format: 'ckpt',
- });
- }
- }, [retrievedModel, modelToEdit]);
-
- const editModelFormSubmitHandler = (values: InvokeModelConfigProps) => {
- dispatch(
- addNewModel({
- ...values,
- width: Number(values.width),
- height: Number(values.height),
- })
- );
+ const editModelFormSubmitHandler = (values) => {
+ console.log(values);
};
return modelToEdit ? (
-
-
- {modelToEdit}
-
-
+
+
+
+ {retrievedModel.name}
+
+
+ {MODEL_TYPE_MAP[retrievedModel.base_model]} Model
+
+
+
+
+
-
- {({ handleSubmit, errors, touched }) => (
-
-
- {/* Description */}
-
-
- {t('modelManager.description')}
-
-
-
- {!!errors.description && touched.description ? (
-
- {errors.description}
-
- ) : (
-
- {t('modelManager.descriptionValidationMsg')}
-
- )}
-
-
-
- {/* Config */}
-
-
- {t('modelManager.config')}
-
-
-
- {!!errors.config && touched.config ? (
- {errors.config}
- ) : (
-
- {t('modelManager.configValidationMsg')}
-
- )}
-
-
-
- {/* Weights */}
-
-
- {t('modelManager.modelLocation')}
-
-
-
- {!!errors.weights && touched.weights ? (
-
- {errors.weights}
-
- ) : (
-
- {t('modelManager.modelLocationValidationMsg')}
-
- )}
-
-
-
- {/* VAE */}
-
-
- {t('modelManager.vaeLocation')}
-
-
-
- {!!errors.vae && touched.vae ? (
- {errors.vae}
- ) : (
-
- {t('modelManager.vaeLocationValidationMsg')}
-
- )}
-
-
-
-
- {/* Width */}
-
-
- {t('modelManager.width')}
-
-
-
- {({
- field,
- form,
- }: {
- field: FieldInputProps;
- form: FormikProps;
- }) => (
-
- form.setFieldValue(field.name, Number(value))
- }
- />
- )}
-
-
- {!!errors.width && touched.width ? (
-
- {errors.width}
-
- ) : (
-
- {t('modelManager.widthValidationMsg')}
-
- )}
-
-
-
- {/* Height */}
-
-
- {t('modelManager.height')}
-
-
-
- {({
- field,
- form,
- }: {
- field: FieldInputProps;
- form: FormikProps;
- }) => (
-
- form.setFieldValue(field.name, Number(value))
- }
- />
- )}
-
-
- {!!errors.height && touched.height ? (
-
- {errors.height}
-
- ) : (
-
- {t('modelManager.heightValidationMsg')}
-
- )}
-
-
-
-
-
- {t('modelManager.updateModel')}
-
-
-
+
+ >
+
+
+
+
+
+
+
+
+
+ {t('modelManager.updateModel')}
+
+
+
) : (
diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelConvert.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelConvert.tsx
index 820ad546b3..56a668d0e2 100644
--- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelConvert.tsx
+++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelConvert.tsx
@@ -4,42 +4,28 @@ import {
Radio,
RadioGroup,
Text,
- UnorderedList,
Tooltip,
+ UnorderedList,
} from '@chakra-ui/react';
// import { convertToDiffusers } from 'app/socketio/actions';
-import { RootState } from 'app/store/store';
-import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
+import { useAppDispatch } from 'app/store/storeHooks';
import IAIAlertDialog from 'common/components/IAIAlertDialog';
import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput';
-import { useState, useEffect } from 'react';
+import { useEffect, useState } from 'react';
import { useTranslation } from 'react-i18next';
+import { CheckpointModel } from './CheckpointModelEdit';
interface ModelConvertProps {
- model: string;
+ model: CheckpointModel;
}
export default function ModelConvert(props: ModelConvertProps) {
const { model } = props;
- const model_list = useAppSelector(
- (state: RootState) => state.system.model_list
- );
-
- const retrievedModel = model_list[model];
-
const dispatch = useAppDispatch();
const { t } = useTranslation();
- const isProcessing = useAppSelector(
- (state: RootState) => state.system.isProcessing
- );
-
- const isConnected = useAppSelector(
- (state: RootState) => state.system.isConnected
- );
-
const [saveLocation, setSaveLocation] = useState('same');
const [customSaveLocation, setCustomSaveLocation] = useState('');
@@ -65,7 +51,7 @@ export default function ModelConvert(props: ModelConvertProps) {
return (
diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelListItem.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelListItem.tsx
index e1b3bbab1e..ab5fddd5ea 100644
--- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelListItem.tsx
+++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelListItem.tsx
@@ -1,5 +1,5 @@
import { DeleteIcon, EditIcon } from '@chakra-ui/icons';
-import { Box, Button, Flex, Spacer, Text, Tooltip } from '@chakra-ui/react';
+import { Box, Flex, Spacer, Text, Tooltip } from '@chakra-ui/react';
// import { deleteModel, requestModelChange } from 'app/socketio/actions';
import { RootState } from 'app/store/store';
@@ -30,10 +30,6 @@ export default function ModelListItem(props: ModelListItemProps) {
const { modelKey, name, description } = props;
- const handleChangeModel = () => {
- dispatch(requestModelChange(modelKey));
- };
-
const openModelHandler = () => {
dispatch(setOpenModel(modelKey));
};
@@ -43,17 +39,6 @@ export default function ModelListItem(props: ModelListItemProps) {
dispatch(setOpenModel(null));
};
- const statusTextColor = () => {
- switch (status) {
- case 'active':
- return 'ok.500';
- case 'cached':
- return 'warning.500';
- case 'not loaded':
- return 'inherit';
- }
- };
-
return (
-
-
}
size="sm"