feat: hook up model edit forms

This commit is contained in:
blessedcoolant 2023-06-26 18:40:35 +12:00 committed by psychedelicious
parent e73f774920
commit 0bb668b8a8
3 changed files with 83 additions and 111 deletions

View File

@ -1,12 +1,9 @@
import { createSelector } from '@reduxjs/toolkit';
import IAIButton from 'common/components/IAIButton'; import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput'; import IAIInput from 'common/components/IAIInput';
import IAINumberInput from 'common/components/IAINumberInput'; import IAINumberInput from 'common/components/IAINumberInput';
import { useEffect, useState } from 'react'; import { useEffect, useState } from 'react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { systemSelector } from 'features/system/store/systemSelectors';
import { import {
Flex, Flex,
@ -21,40 +18,29 @@ import {
import { Field, Formik } from 'formik'; import { Field, Formik } from 'formik';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import type { InvokeModelConfigProps } from 'app/types/invokeai';
import type { RootState } from 'app/store/store'; import type { RootState } from 'app/store/store';
import type { FieldInputProps, FormikProps } from 'formik'; import type { InvokeModelConfigProps } from 'app/types/invokeai';
import { isEqual, pickBy } from 'lodash-es';
import ModelConvert from './ModelConvert';
import IAIFormHelperText from 'common/components/IAIForms/IAIFormHelperText';
import IAIFormErrorMessage from 'common/components/IAIForms/IAIFormErrorMessage';
import IAIForm from 'common/components/IAIForm'; import IAIForm from 'common/components/IAIForm';
import IAIFormErrorMessage from 'common/components/IAIForms/IAIFormErrorMessage';
const selector = createSelector( import IAIFormHelperText from 'common/components/IAIForms/IAIFormHelperText';
[systemSelector], import type { FieldInputProps, FormikProps } from 'formik';
(system) => { import ModelConvert from './ModelConvert';
const { openModel, model_list } = system;
return {
model_list,
openModel,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const MIN_MODEL_SIZE = 64; const MIN_MODEL_SIZE = 64;
const MAX_MODEL_SIZE = 2048; const MAX_MODEL_SIZE = 2048;
export default function CheckpointModelEdit() { type CheckpointModelEditProps = {
const { openModel, model_list } = useAppSelector(selector); modelToEdit: string;
retrievedModel: any;
};
export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
const isProcessing = useAppSelector( const isProcessing = useAppSelector(
(state: RootState) => state.system.isProcessing (state: RootState) => state.system.isProcessing
); );
const { modelToEdit, retrievedModel } = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
@ -69,27 +55,24 @@ export default function CheckpointModelEdit() {
width: 512, width: 512,
height: 512, height: 512,
default: false, default: false,
format: 'ckpt', model_format: 'ckpt',
}); });
useEffect(() => { useEffect(() => {
if (openModel) { if (modelToEdit) {
const retrievedModel = pickBy(model_list, (_val, key) => {
return isEqual(key, openModel);
});
setEditModelFormValues({ setEditModelFormValues({
name: openModel, name: modelToEdit,
description: retrievedModel[openModel]?.description, description: retrievedModel?.description,
config: retrievedModel[openModel]?.config, config: retrievedModel?.config,
weights: retrievedModel[openModel]?.weights, weights: retrievedModel?.weights,
vae: retrievedModel[openModel]?.vae, vae: retrievedModel?.vae,
width: retrievedModel[openModel]?.width, width: retrievedModel?.width,
height: retrievedModel[openModel]?.height, height: retrievedModel?.height,
default: retrievedModel[openModel]?.default, default: retrievedModel?.default,
format: 'ckpt', model_format: 'ckpt',
}); });
} }
}, [model_list, openModel]); }, [retrievedModel, modelToEdit]);
const editModelFormSubmitHandler = (values: InvokeModelConfigProps) => { const editModelFormSubmitHandler = (values: InvokeModelConfigProps) => {
dispatch( dispatch(
@ -101,13 +84,13 @@ export default function CheckpointModelEdit() {
); );
}; };
return openModel ? ( return modelToEdit ? (
<Flex flexDirection="column" rowGap={4} width="100%"> <Flex flexDirection="column" rowGap={4} width="100%">
<Flex alignItems="center" gap={4} justifyContent="space-between"> <Flex alignItems="center" gap={4} justifyContent="space-between">
<Text fontSize="lg" fontWeight="bold"> <Text fontSize="lg" fontWeight="bold">
{openModel} {modelToEdit}
</Text> </Text>
<ModelConvert model={openModel} /> <ModelConvert model={modelToEdit} />
</Flex> </Flex>
<Flex <Flex
flexDirection="column" flexDirection="column"

View File

@ -1,11 +1,8 @@
import { createSelector } from '@reduxjs/toolkit';
import IAIButton from 'common/components/IAIButton'; import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput'; import IAIInput from 'common/components/IAIInput';
import { useEffect, useState } from 'react'; import { useEffect, useState } from 'react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { systemSelector } from 'features/system/store/systemSelectors';
import { Flex, FormControl, FormLabel, Text, VStack } from '@chakra-ui/react'; import { Flex, FormControl, FormLabel, Text, VStack } from '@chakra-ui/react';
@ -13,35 +10,24 @@ import { Flex, FormControl, FormLabel, Text, VStack } from '@chakra-ui/react';
import { Field, Formik } from 'formik'; import { Field, Formik } from 'formik';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import type { InvokeDiffusersModelConfigProps } from 'app/types/invokeai';
import type { RootState } from 'app/store/store'; import type { RootState } from 'app/store/store';
import { isEqual, pickBy } from 'lodash-es'; import type { InvokeDiffusersModelConfigProps } from 'app/types/invokeai';
import IAIFormHelperText from 'common/components/IAIForms/IAIFormHelperText';
import IAIFormErrorMessage from 'common/components/IAIForms/IAIFormErrorMessage';
import IAIForm from 'common/components/IAIForm'; import IAIForm from 'common/components/IAIForm';
import IAIFormErrorMessage from 'common/components/IAIForms/IAIFormErrorMessage';
import IAIFormHelperText from 'common/components/IAIForms/IAIFormHelperText';
const selector = createSelector( type DiffusersModelEditProps = {
[systemSelector], modelToEdit: string;
(system) => { retrievedModel: any;
const { openModel, model_list } = system; };
return {
model_list,
openModel,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
export default function DiffusersModelEdit() { export default function DiffusersModelEdit(props: DiffusersModelEditProps) {
const { openModel, model_list } = useAppSelector(selector);
const isProcessing = useAppSelector( const isProcessing = useAppSelector(
(state: RootState) => state.system.isProcessing (state: RootState) => state.system.isProcessing
); );
const { retrievedModel, modelToEdit } = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
@ -54,41 +40,31 @@ export default function DiffusersModelEdit() {
path: '', path: '',
vae: { repo_id: '', path: '' }, vae: { repo_id: '', path: '' },
default: false, default: false,
format: 'diffusers', model_format: 'diffusers',
}); });
useEffect(() => { useEffect(() => {
if (openModel) { setEditModelFormValues({
const retrievedModel = pickBy(model_list, (_val, key) => { name: modelToEdit,
return isEqual(key, openModel); description: retrievedModel?.description,
}); path:
retrievedModel?.path && retrievedModel?.path !== 'None'
setEditModelFormValues({ ? retrievedModel?.path
name: openModel, : '',
description: retrievedModel[openModel]?.description, repo_id:
path: retrievedModel?.repo_id && retrievedModel?.repo_id !== 'None'
retrievedModel[openModel]?.path && ? retrievedModel?.repo_id
retrievedModel[openModel]?.path !== 'None' : '',
? retrievedModel[openModel]?.path vae: {
: '', repo_id: retrievedModel?.vae?.repo_id
repo_id: ? retrievedModel?.vae?.repo_id
retrievedModel[openModel]?.repo_id && : '',
retrievedModel[openModel]?.repo_id !== 'None' path: retrievedModel?.vae?.path ? retrievedModel?.vae?.path : '',
? retrievedModel[openModel]?.repo_id },
: '', default: retrievedModel?.default,
vae: { model_format: 'diffusers',
repo_id: retrievedModel[openModel]?.vae?.repo_id });
? retrievedModel[openModel]?.vae?.repo_id }, [retrievedModel, modelToEdit]);
: '',
path: retrievedModel[openModel]?.vae?.path
? retrievedModel[openModel]?.vae?.path
: '',
},
default: retrievedModel[openModel]?.default,
format: 'diffusers',
});
}
}, [model_list, openModel]);
const editModelFormSubmitHandler = ( const editModelFormSubmitHandler = (
values: InvokeDiffusersModelConfigProps values: InvokeDiffusersModelConfigProps
@ -103,11 +79,11 @@ export default function DiffusersModelEdit() {
dispatch(addNewModel(values)); dispatch(addNewModel(values));
}; };
return openModel ? ( return modelToEdit ? (
<Flex flexDirection="column" rowGap={4} width="100%"> <Flex flexDirection="column" rowGap={4} width="100%">
<Flex alignItems="center"> <Flex alignItems="center">
<Text fontSize="lg" fontWeight="bold"> <Text fontSize="lg" fontWeight="bold">
{openModel} {retrievedModel.name}
</Text> </Text>
</Flex> </Flex>
<Flex flexDirection="column" overflowY="scroll" paddingInlineEnd={8}> <Flex flexDirection="column" overflowY="scroll" paddingInlineEnd={8}>

View File

@ -45,6 +45,26 @@ export default function ModelManagerModal({
const { t } = useTranslation(); const { t } = useTranslation();
const renderModelEditTabs = () => {
if (!openModel || !pipelineModels) return;
if (pipelineModels['entities'][openModel]['model_format'] === 'diffusers') {
return (
<DiffusersModelEdit
modelToEdit={openModel}
retrievedModel={pipelineModels['entities'][openModel]}
/>
);
} else {
return (
<CheckpointModelEdit
modelToEdit={openModel}
retrievedModel={pipelineModels['entities'][openModel]}
/>
);
}
};
return ( return (
<> <>
{cloneElement(children, { {cloneElement(children, {
@ -62,14 +82,7 @@ export default function ModelManagerModal({
<ModalBody> <ModalBody>
<Flex width="100%" columnGap={8}> <Flex width="100%" columnGap={8}>
<ModelList /> <ModelList />
{openModel && {renderModelEditTabs()}
pipelineModels &&
pipelineModels['entities'][openModel]['model_format'] ===
'diffusers' ? (
<DiffusersModelEdit />
) : (
<CheckpointModelEdit />
)}
</Flex> </Flex>
</ModalBody> </ModalBody>
<ModalFooter /> <ModalFooter />