feat: hook up model edit forms

This commit is contained in:
blessedcoolant 2023-06-26 18:40:35 +12:00 committed by psychedelicious
parent e73f774920
commit 0bb668b8a8
3 changed files with 83 additions and 111 deletions

View File

@ -1,12 +1,9 @@
import { createSelector } from '@reduxjs/toolkit';
import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput';
import IAINumberInput from 'common/components/IAINumberInput';
import { useEffect, useState } from 'react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { systemSelector } from 'features/system/store/systemSelectors';
import {
Flex,
@ -21,40 +18,29 @@ import {
import { Field, Formik } from 'formik';
import { useTranslation } from 'react-i18next';
import type { InvokeModelConfigProps } from 'app/types/invokeai';
import type { RootState } from 'app/store/store';
import type { FieldInputProps, FormikProps } from 'formik';
import { isEqual, pickBy } from 'lodash-es';
import ModelConvert from './ModelConvert';
import IAIFormHelperText from 'common/components/IAIForms/IAIFormHelperText';
import IAIFormErrorMessage from 'common/components/IAIForms/IAIFormErrorMessage';
import type { InvokeModelConfigProps } from 'app/types/invokeai';
import IAIForm from 'common/components/IAIForm';
const selector = createSelector(
[systemSelector],
(system) => {
const { openModel, model_list } = system;
return {
model_list,
openModel,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
import IAIFormErrorMessage from 'common/components/IAIForms/IAIFormErrorMessage';
import IAIFormHelperText from 'common/components/IAIForms/IAIFormHelperText';
import type { FieldInputProps, FormikProps } from 'formik';
import ModelConvert from './ModelConvert';
const MIN_MODEL_SIZE = 64;
const MAX_MODEL_SIZE = 2048;
export default function CheckpointModelEdit() {
const { openModel, model_list } = useAppSelector(selector);
type CheckpointModelEditProps = {
modelToEdit: string;
retrievedModel: any;
};
export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
const isProcessing = useAppSelector(
(state: RootState) => state.system.isProcessing
);
const { modelToEdit, retrievedModel } = props;
const dispatch = useAppDispatch();
const { t } = useTranslation();
@ -69,27 +55,24 @@ export default function CheckpointModelEdit() {
width: 512,
height: 512,
default: false,
format: 'ckpt',
model_format: 'ckpt',
});
useEffect(() => {
if (openModel) {
const retrievedModel = pickBy(model_list, (_val, key) => {
return isEqual(key, openModel);
});
if (modelToEdit) {
setEditModelFormValues({
name: openModel,
description: retrievedModel[openModel]?.description,
config: retrievedModel[openModel]?.config,
weights: retrievedModel[openModel]?.weights,
vae: retrievedModel[openModel]?.vae,
width: retrievedModel[openModel]?.width,
height: retrievedModel[openModel]?.height,
default: retrievedModel[openModel]?.default,
format: 'ckpt',
name: modelToEdit,
description: retrievedModel?.description,
config: retrievedModel?.config,
weights: retrievedModel?.weights,
vae: retrievedModel?.vae,
width: retrievedModel?.width,
height: retrievedModel?.height,
default: retrievedModel?.default,
model_format: 'ckpt',
});
}
}, [model_list, openModel]);
}, [retrievedModel, modelToEdit]);
const editModelFormSubmitHandler = (values: InvokeModelConfigProps) => {
dispatch(
@ -101,13 +84,13 @@ export default function CheckpointModelEdit() {
);
};
return openModel ? (
return modelToEdit ? (
<Flex flexDirection="column" rowGap={4} width="100%">
<Flex alignItems="center" gap={4} justifyContent="space-between">
<Text fontSize="lg" fontWeight="bold">
{openModel}
{modelToEdit}
</Text>
<ModelConvert model={openModel} />
<ModelConvert model={modelToEdit} />
</Flex>
<Flex
flexDirection="column"

View File

@ -1,11 +1,8 @@
import { createSelector } from '@reduxjs/toolkit';
import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput';
import { useEffect, useState } from 'react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { systemSelector } from 'features/system/store/systemSelectors';
import { Flex, FormControl, FormLabel, Text, VStack } from '@chakra-ui/react';
@ -13,35 +10,24 @@ import { Flex, FormControl, FormLabel, Text, VStack } from '@chakra-ui/react';
import { Field, Formik } from 'formik';
import { useTranslation } from 'react-i18next';
import type { InvokeDiffusersModelConfigProps } from 'app/types/invokeai';
import type { RootState } from 'app/store/store';
import { isEqual, pickBy } from 'lodash-es';
import IAIFormHelperText from 'common/components/IAIForms/IAIFormHelperText';
import IAIFormErrorMessage from 'common/components/IAIForms/IAIFormErrorMessage';
import type { InvokeDiffusersModelConfigProps } from 'app/types/invokeai';
import IAIForm from 'common/components/IAIForm';
import IAIFormErrorMessage from 'common/components/IAIForms/IAIFormErrorMessage';
import IAIFormHelperText from 'common/components/IAIForms/IAIFormHelperText';
const selector = createSelector(
[systemSelector],
(system) => {
const { openModel, model_list } = system;
return {
model_list,
openModel,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
type DiffusersModelEditProps = {
modelToEdit: string;
retrievedModel: any;
};
export default function DiffusersModelEdit() {
const { openModel, model_list } = useAppSelector(selector);
export default function DiffusersModelEdit(props: DiffusersModelEditProps) {
const isProcessing = useAppSelector(
(state: RootState) => state.system.isProcessing
);
const { retrievedModel, modelToEdit } = props;
const dispatch = useAppDispatch();
const { t } = useTranslation();
@ -54,41 +40,31 @@ export default function DiffusersModelEdit() {
path: '',
vae: { repo_id: '', path: '' },
default: false,
format: 'diffusers',
model_format: 'diffusers',
});
useEffect(() => {
if (openModel) {
const retrievedModel = pickBy(model_list, (_val, key) => {
return isEqual(key, openModel);
});
setEditModelFormValues({
name: openModel,
description: retrievedModel[openModel]?.description,
path:
retrievedModel[openModel]?.path &&
retrievedModel[openModel]?.path !== 'None'
? retrievedModel[openModel]?.path
: '',
repo_id:
retrievedModel[openModel]?.repo_id &&
retrievedModel[openModel]?.repo_id !== 'None'
? retrievedModel[openModel]?.repo_id
: '',
vae: {
repo_id: retrievedModel[openModel]?.vae?.repo_id
? retrievedModel[openModel]?.vae?.repo_id
: '',
path: retrievedModel[openModel]?.vae?.path
? retrievedModel[openModel]?.vae?.path
: '',
},
default: retrievedModel[openModel]?.default,
format: 'diffusers',
});
}
}, [model_list, openModel]);
setEditModelFormValues({
name: modelToEdit,
description: retrievedModel?.description,
path:
retrievedModel?.path && retrievedModel?.path !== 'None'
? retrievedModel?.path
: '',
repo_id:
retrievedModel?.repo_id && retrievedModel?.repo_id !== 'None'
? retrievedModel?.repo_id
: '',
vae: {
repo_id: retrievedModel?.vae?.repo_id
? retrievedModel?.vae?.repo_id
: '',
path: retrievedModel?.vae?.path ? retrievedModel?.vae?.path : '',
},
default: retrievedModel?.default,
model_format: 'diffusers',
});
}, [retrievedModel, modelToEdit]);
const editModelFormSubmitHandler = (
values: InvokeDiffusersModelConfigProps
@ -103,11 +79,11 @@ export default function DiffusersModelEdit() {
dispatch(addNewModel(values));
};
return openModel ? (
return modelToEdit ? (
<Flex flexDirection="column" rowGap={4} width="100%">
<Flex alignItems="center">
<Text fontSize="lg" fontWeight="bold">
{openModel}
{retrievedModel.name}
</Text>
</Flex>
<Flex flexDirection="column" overflowY="scroll" paddingInlineEnd={8}>

View File

@ -45,6 +45,26 @@ export default function ModelManagerModal({
const { t } = useTranslation();
const renderModelEditTabs = () => {
if (!openModel || !pipelineModels) return;
if (pipelineModels['entities'][openModel]['model_format'] === 'diffusers') {
return (
<DiffusersModelEdit
modelToEdit={openModel}
retrievedModel={pipelineModels['entities'][openModel]}
/>
);
} else {
return (
<CheckpointModelEdit
modelToEdit={openModel}
retrievedModel={pipelineModels['entities'][openModel]}
/>
);
}
};
return (
<>
{cloneElement(children, {
@ -62,14 +82,7 @@ export default function ModelManagerModal({
<ModalBody>
<Flex width="100%" columnGap={8}>
<ModelList />
{openModel &&
pipelineModels &&
pipelineModels['entities'][openModel]['model_format'] ===
'diffusers' ? (
<DiffusersModelEdit />
) : (
<CheckpointModelEdit />
)}
{renderModelEditTabs()}
</Flex>
</ModalBody>
<ModalFooter />