mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat: hook up model edit forms
This commit is contained in:
parent
e73f774920
commit
0bb668b8a8
@ -1,12 +1,9 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
|
||||||
|
|
||||||
import IAIButton from 'common/components/IAIButton';
|
import IAIButton from 'common/components/IAIButton';
|
||||||
import IAIInput from 'common/components/IAIInput';
|
import IAIInput from 'common/components/IAIInput';
|
||||||
import IAINumberInput from 'common/components/IAINumberInput';
|
import IAINumberInput from 'common/components/IAINumberInput';
|
||||||
import { useEffect, useState } from 'react';
|
import { useEffect, useState } from 'react';
|
||||||
|
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
|
||||||
|
|
||||||
import {
|
import {
|
||||||
Flex,
|
Flex,
|
||||||
@ -21,40 +18,29 @@ import {
|
|||||||
import { Field, Formik } from 'formik';
|
import { Field, Formik } from 'formik';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
import type { InvokeModelConfigProps } from 'app/types/invokeai';
|
|
||||||
import type { RootState } from 'app/store/store';
|
import type { RootState } from 'app/store/store';
|
||||||
import type { FieldInputProps, FormikProps } from 'formik';
|
import type { InvokeModelConfigProps } from 'app/types/invokeai';
|
||||||
import { isEqual, pickBy } from 'lodash-es';
|
|
||||||
import ModelConvert from './ModelConvert';
|
|
||||||
import IAIFormHelperText from 'common/components/IAIForms/IAIFormHelperText';
|
|
||||||
import IAIFormErrorMessage from 'common/components/IAIForms/IAIFormErrorMessage';
|
|
||||||
import IAIForm from 'common/components/IAIForm';
|
import IAIForm from 'common/components/IAIForm';
|
||||||
|
import IAIFormErrorMessage from 'common/components/IAIForms/IAIFormErrorMessage';
|
||||||
const selector = createSelector(
|
import IAIFormHelperText from 'common/components/IAIForms/IAIFormHelperText';
|
||||||
[systemSelector],
|
import type { FieldInputProps, FormikProps } from 'formik';
|
||||||
(system) => {
|
import ModelConvert from './ModelConvert';
|
||||||
const { openModel, model_list } = system;
|
|
||||||
return {
|
|
||||||
model_list,
|
|
||||||
openModel,
|
|
||||||
};
|
|
||||||
},
|
|
||||||
{
|
|
||||||
memoizeOptions: {
|
|
||||||
resultEqualityCheck: isEqual,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
const MIN_MODEL_SIZE = 64;
|
const MIN_MODEL_SIZE = 64;
|
||||||
const MAX_MODEL_SIZE = 2048;
|
const MAX_MODEL_SIZE = 2048;
|
||||||
|
|
||||||
export default function CheckpointModelEdit() {
|
type CheckpointModelEditProps = {
|
||||||
const { openModel, model_list } = useAppSelector(selector);
|
modelToEdit: string;
|
||||||
|
retrievedModel: any;
|
||||||
|
};
|
||||||
|
|
||||||
|
export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
|
||||||
const isProcessing = useAppSelector(
|
const isProcessing = useAppSelector(
|
||||||
(state: RootState) => state.system.isProcessing
|
(state: RootState) => state.system.isProcessing
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const { modelToEdit, retrievedModel } = props;
|
||||||
|
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
@ -69,27 +55,24 @@ export default function CheckpointModelEdit() {
|
|||||||
width: 512,
|
width: 512,
|
||||||
height: 512,
|
height: 512,
|
||||||
default: false,
|
default: false,
|
||||||
format: 'ckpt',
|
model_format: 'ckpt',
|
||||||
});
|
});
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (openModel) {
|
if (modelToEdit) {
|
||||||
const retrievedModel = pickBy(model_list, (_val, key) => {
|
|
||||||
return isEqual(key, openModel);
|
|
||||||
});
|
|
||||||
setEditModelFormValues({
|
setEditModelFormValues({
|
||||||
name: openModel,
|
name: modelToEdit,
|
||||||
description: retrievedModel[openModel]?.description,
|
description: retrievedModel?.description,
|
||||||
config: retrievedModel[openModel]?.config,
|
config: retrievedModel?.config,
|
||||||
weights: retrievedModel[openModel]?.weights,
|
weights: retrievedModel?.weights,
|
||||||
vae: retrievedModel[openModel]?.vae,
|
vae: retrievedModel?.vae,
|
||||||
width: retrievedModel[openModel]?.width,
|
width: retrievedModel?.width,
|
||||||
height: retrievedModel[openModel]?.height,
|
height: retrievedModel?.height,
|
||||||
default: retrievedModel[openModel]?.default,
|
default: retrievedModel?.default,
|
||||||
format: 'ckpt',
|
model_format: 'ckpt',
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}, [model_list, openModel]);
|
}, [retrievedModel, modelToEdit]);
|
||||||
|
|
||||||
const editModelFormSubmitHandler = (values: InvokeModelConfigProps) => {
|
const editModelFormSubmitHandler = (values: InvokeModelConfigProps) => {
|
||||||
dispatch(
|
dispatch(
|
||||||
@ -101,13 +84,13 @@ export default function CheckpointModelEdit() {
|
|||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
return openModel ? (
|
return modelToEdit ? (
|
||||||
<Flex flexDirection="column" rowGap={4} width="100%">
|
<Flex flexDirection="column" rowGap={4} width="100%">
|
||||||
<Flex alignItems="center" gap={4} justifyContent="space-between">
|
<Flex alignItems="center" gap={4} justifyContent="space-between">
|
||||||
<Text fontSize="lg" fontWeight="bold">
|
<Text fontSize="lg" fontWeight="bold">
|
||||||
{openModel}
|
{modelToEdit}
|
||||||
</Text>
|
</Text>
|
||||||
<ModelConvert model={openModel} />
|
<ModelConvert model={modelToEdit} />
|
||||||
</Flex>
|
</Flex>
|
||||||
<Flex
|
<Flex
|
||||||
flexDirection="column"
|
flexDirection="column"
|
||||||
|
@ -1,11 +1,8 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
|
||||||
|
|
||||||
import IAIButton from 'common/components/IAIButton';
|
import IAIButton from 'common/components/IAIButton';
|
||||||
import IAIInput from 'common/components/IAIInput';
|
import IAIInput from 'common/components/IAIInput';
|
||||||
import { useEffect, useState } from 'react';
|
import { useEffect, useState } from 'react';
|
||||||
|
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
|
||||||
|
|
||||||
import { Flex, FormControl, FormLabel, Text, VStack } from '@chakra-ui/react';
|
import { Flex, FormControl, FormLabel, Text, VStack } from '@chakra-ui/react';
|
||||||
|
|
||||||
@ -13,35 +10,24 @@ import { Flex, FormControl, FormLabel, Text, VStack } from '@chakra-ui/react';
|
|||||||
import { Field, Formik } from 'formik';
|
import { Field, Formik } from 'formik';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
import type { InvokeDiffusersModelConfigProps } from 'app/types/invokeai';
|
|
||||||
import type { RootState } from 'app/store/store';
|
import type { RootState } from 'app/store/store';
|
||||||
import { isEqual, pickBy } from 'lodash-es';
|
import type { InvokeDiffusersModelConfigProps } from 'app/types/invokeai';
|
||||||
import IAIFormHelperText from 'common/components/IAIForms/IAIFormHelperText';
|
|
||||||
import IAIFormErrorMessage from 'common/components/IAIForms/IAIFormErrorMessage';
|
|
||||||
import IAIForm from 'common/components/IAIForm';
|
import IAIForm from 'common/components/IAIForm';
|
||||||
|
import IAIFormErrorMessage from 'common/components/IAIForms/IAIFormErrorMessage';
|
||||||
|
import IAIFormHelperText from 'common/components/IAIForms/IAIFormHelperText';
|
||||||
|
|
||||||
const selector = createSelector(
|
type DiffusersModelEditProps = {
|
||||||
[systemSelector],
|
modelToEdit: string;
|
||||||
(system) => {
|
retrievedModel: any;
|
||||||
const { openModel, model_list } = system;
|
};
|
||||||
return {
|
|
||||||
model_list,
|
|
||||||
openModel,
|
|
||||||
};
|
|
||||||
},
|
|
||||||
{
|
|
||||||
memoizeOptions: {
|
|
||||||
resultEqualityCheck: isEqual,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
export default function DiffusersModelEdit() {
|
export default function DiffusersModelEdit(props: DiffusersModelEditProps) {
|
||||||
const { openModel, model_list } = useAppSelector(selector);
|
|
||||||
const isProcessing = useAppSelector(
|
const isProcessing = useAppSelector(
|
||||||
(state: RootState) => state.system.isProcessing
|
(state: RootState) => state.system.isProcessing
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const { retrievedModel, modelToEdit } = props;
|
||||||
|
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
@ -54,41 +40,31 @@ export default function DiffusersModelEdit() {
|
|||||||
path: '',
|
path: '',
|
||||||
vae: { repo_id: '', path: '' },
|
vae: { repo_id: '', path: '' },
|
||||||
default: false,
|
default: false,
|
||||||
format: 'diffusers',
|
model_format: 'diffusers',
|
||||||
});
|
});
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (openModel) {
|
setEditModelFormValues({
|
||||||
const retrievedModel = pickBy(model_list, (_val, key) => {
|
name: modelToEdit,
|
||||||
return isEqual(key, openModel);
|
description: retrievedModel?.description,
|
||||||
});
|
path:
|
||||||
|
retrievedModel?.path && retrievedModel?.path !== 'None'
|
||||||
setEditModelFormValues({
|
? retrievedModel?.path
|
||||||
name: openModel,
|
: '',
|
||||||
description: retrievedModel[openModel]?.description,
|
repo_id:
|
||||||
path:
|
retrievedModel?.repo_id && retrievedModel?.repo_id !== 'None'
|
||||||
retrievedModel[openModel]?.path &&
|
? retrievedModel?.repo_id
|
||||||
retrievedModel[openModel]?.path !== 'None'
|
: '',
|
||||||
? retrievedModel[openModel]?.path
|
vae: {
|
||||||
: '',
|
repo_id: retrievedModel?.vae?.repo_id
|
||||||
repo_id:
|
? retrievedModel?.vae?.repo_id
|
||||||
retrievedModel[openModel]?.repo_id &&
|
: '',
|
||||||
retrievedModel[openModel]?.repo_id !== 'None'
|
path: retrievedModel?.vae?.path ? retrievedModel?.vae?.path : '',
|
||||||
? retrievedModel[openModel]?.repo_id
|
},
|
||||||
: '',
|
default: retrievedModel?.default,
|
||||||
vae: {
|
model_format: 'diffusers',
|
||||||
repo_id: retrievedModel[openModel]?.vae?.repo_id
|
});
|
||||||
? retrievedModel[openModel]?.vae?.repo_id
|
}, [retrievedModel, modelToEdit]);
|
||||||
: '',
|
|
||||||
path: retrievedModel[openModel]?.vae?.path
|
|
||||||
? retrievedModel[openModel]?.vae?.path
|
|
||||||
: '',
|
|
||||||
},
|
|
||||||
default: retrievedModel[openModel]?.default,
|
|
||||||
format: 'diffusers',
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}, [model_list, openModel]);
|
|
||||||
|
|
||||||
const editModelFormSubmitHandler = (
|
const editModelFormSubmitHandler = (
|
||||||
values: InvokeDiffusersModelConfigProps
|
values: InvokeDiffusersModelConfigProps
|
||||||
@ -103,11 +79,11 @@ export default function DiffusersModelEdit() {
|
|||||||
dispatch(addNewModel(values));
|
dispatch(addNewModel(values));
|
||||||
};
|
};
|
||||||
|
|
||||||
return openModel ? (
|
return modelToEdit ? (
|
||||||
<Flex flexDirection="column" rowGap={4} width="100%">
|
<Flex flexDirection="column" rowGap={4} width="100%">
|
||||||
<Flex alignItems="center">
|
<Flex alignItems="center">
|
||||||
<Text fontSize="lg" fontWeight="bold">
|
<Text fontSize="lg" fontWeight="bold">
|
||||||
{openModel}
|
{retrievedModel.name}
|
||||||
</Text>
|
</Text>
|
||||||
</Flex>
|
</Flex>
|
||||||
<Flex flexDirection="column" overflowY="scroll" paddingInlineEnd={8}>
|
<Flex flexDirection="column" overflowY="scroll" paddingInlineEnd={8}>
|
||||||
|
@ -45,6 +45,26 @@ export default function ModelManagerModal({
|
|||||||
|
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
const renderModelEditTabs = () => {
|
||||||
|
if (!openModel || !pipelineModels) return;
|
||||||
|
|
||||||
|
if (pipelineModels['entities'][openModel]['model_format'] === 'diffusers') {
|
||||||
|
return (
|
||||||
|
<DiffusersModelEdit
|
||||||
|
modelToEdit={openModel}
|
||||||
|
retrievedModel={pipelineModels['entities'][openModel]}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
return (
|
||||||
|
<CheckpointModelEdit
|
||||||
|
modelToEdit={openModel}
|
||||||
|
retrievedModel={pipelineModels['entities'][openModel]}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
{cloneElement(children, {
|
{cloneElement(children, {
|
||||||
@ -62,14 +82,7 @@ export default function ModelManagerModal({
|
|||||||
<ModalBody>
|
<ModalBody>
|
||||||
<Flex width="100%" columnGap={8}>
|
<Flex width="100%" columnGap={8}>
|
||||||
<ModelList />
|
<ModelList />
|
||||||
{openModel &&
|
{renderModelEditTabs()}
|
||||||
pipelineModels &&
|
|
||||||
pipelineModels['entities'][openModel]['model_format'] ===
|
|
||||||
'diffusers' ? (
|
|
||||||
<DiffusersModelEdit />
|
|
||||||
) : (
|
|
||||||
<CheckpointModelEdit />
|
|
||||||
)}
|
|
||||||
</Flex>
|
</Flex>
|
||||||
</ModalBody>
|
</ModalBody>
|
||||||
<ModalFooter />
|
<ModalFooter />
|
||||||
|
Loading…
Reference in New Issue
Block a user