feat(ui): partial rebuild of model manager internal logic

This commit is contained in:
psychedelicious
2023-12-29 20:43:20 +11:00
committed by Kent Keirsey
parent 2a661450c3
commit 52f9749bf5
16 changed files with 457 additions and 317 deletions

View File

@ -1,5 +1,4 @@
import { Flex } from '@chakra-ui/react';
import { useForm } from '@mantine/form';
import { useAppDispatch } from 'app/store/storeHooks';
import { InvButton } from 'common/components/InvButton/InvButton';
import { InvCheckbox } from 'common/components/InvCheckbox/wrapper';
@ -13,6 +12,8 @@ import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import type { CSSProperties, FocusEventHandler } from 'react';
import { memo, useCallback, useState } from 'react';
import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import { useAddMainModelsMutation } from 'services/api/endpoints/models';
import type { CheckpointModelConfig } from 'services/api/types';
@ -28,8 +29,16 @@ const AdvancedAddCheckpoint = (props: AdvancedAddCheckpointProps) => {
const dispatch = useAppDispatch();
const { model_path } = props;
const advancedAddCheckpointForm = useForm<CheckpointModelConfig>({
initialValues: {
const {
register,
handleSubmit,
control,
getValues,
setValue,
formState: { errors },
reset,
} = useForm<CheckpointModelConfig>({
defaultValues: {
model_name: model_path ? getModelName(model_path) : '',
base_model: 'sd-1',
model_type: 'main',
@ -41,64 +50,64 @@ const AdvancedAddCheckpoint = (props: AdvancedAddCheckpointProps) => {
variant: 'normal',
config: 'configs\\stable-diffusion\\v1-inference.yaml',
},
mode: 'onChange',
});
const [addMainModel] = useAddMainModelsMutation();
const [useCustomConfig, setUseCustomConfig] = useState<boolean>(false);
const advancedAddCheckpointFormHandler = (values: CheckpointModelConfig) => {
addMainModel({
body: values,
})
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: t('modelManager.modelAdded', {
modelName: values.model_name,
}),
status: 'success',
})
)
);
advancedAddCheckpointForm.reset();
// Close Advanced Panel in Scan Models tab
if (model_path) {
dispatch(setAdvancedAddScanModel(null));
}
const onSubmit = useCallback<SubmitHandler<CheckpointModelConfig>>(
(values) => {
addMainModel({
body: values,
})
.catch((error) => {
if (error) {
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: t('toast.modelAddFailed'),
status: 'error',
title: t('modelManager.modelAdded', {
modelName: values.model_name,
}),
status: 'success',
})
)
);
}
});
};
reset();
const handleBlurModelLocation: FocusEventHandler<HTMLInputElement> =
useCallback(
(e) => {
if (advancedAddCheckpointForm.values['model_name'] === '') {
const modelName = getModelName(e.currentTarget.value);
if (modelName) {
advancedAddCheckpointForm.setFieldValue(
'model_name',
modelName as string
// Close Advanced Panel in Scan Models tab
if (model_path) {
dispatch(setAdvancedAddScanModel(null));
}
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: t('toast.modelAddFailed'),
status: 'error',
})
)
);
}
});
},
[addMainModel, dispatch, model_path, reset, t]
);
const onBlur: FocusEventHandler<HTMLInputElement> = useCallback(
(e) => {
if (getValues().model_name === '') {
const modelName = getModelName(e.currentTarget.value);
if (modelName) {
setValue('model_name', modelName as string);
}
},
[advancedAddCheckpointForm]
);
}
},
[getValues, setValue]
);
const handleChangeUseCustomConfig = useCallback(
() => setUseCustomConfig((prev) => !prev),
@ -106,56 +115,53 @@ const AdvancedAddCheckpoint = (props: AdvancedAddCheckpointProps) => {
);
return (
<form
onSubmit={advancedAddCheckpointForm.onSubmit((v) =>
advancedAddCheckpointFormHandler(v)
)}
style={formStyles}
>
<form onSubmit={handleSubmit(onSubmit)} style={formStyles}>
<Flex flexDirection="column" gap={2}>
<InvControl label={t('modelManager.model')} isRequired>
<InvControl
label={t('modelManager.model')}
isInvalid={Boolean(errors.model_name)}
error={errors.model_name?.message}
>
<InvInput
{...advancedAddCheckpointForm.getInputProps('model_name')}
{...register('model_name', {
validate: (value) =>
value.trim().length > 3 || 'Must be at least 3 characters',
})}
/>
</InvControl>
<InvControl label={t('modelManager.baseModel')}>
<BaseModelSelect
{...advancedAddCheckpointForm.getInputProps('base_model')}
/>
</InvControl>
<InvControl label={t('modelManager.modelLocation')} isRequired>
<BaseModelSelect<CheckpointModelConfig>
control={control}
name="base_model"
/>
<InvControl
label={t('modelManager.modelLocation')}
isInvalid={Boolean(errors.path)}
error={errors.path?.message}
>
<InvInput
{...advancedAddCheckpointForm.getInputProps('path')}
onBlur={handleBlurModelLocation}
{...register('path', {
validate: (value) =>
value.trim().length > 0 || 'Must provide a path',
onBlur,
})}
/>
</InvControl>
<InvControl label={t('modelManager.description')}>
<InvInput
{...advancedAddCheckpointForm.getInputProps('description')}
/>
<InvInput {...register('description')} />
</InvControl>
<InvControl label={t('modelManager.vaeLocation')}>
<InvInput {...advancedAddCheckpointForm.getInputProps('vae')} />
</InvControl>
<InvControl label={t('modelManager.variant')}>
<ModelVariantSelect
{...advancedAddCheckpointForm.getInputProps('variant')}
/>
<InvInput {...register('vae')} />
</InvControl>
<ModelVariantSelect<CheckpointModelConfig>
control={control}
name="variant"
/>
<Flex flexDirection="column" width="100%" gap={2}>
{!useCustomConfig ? (
<CheckpointConfigsSelect
required
{...advancedAddCheckpointForm.getInputProps('config')}
/>
<CheckpointConfigsSelect control={control} name="config" />
) : (
<InvControl
label={t('modelManager.customConfigFileLocation')}
isRequired
>
<InvInput
{...advancedAddCheckpointForm.getInputProps('config')}
/>
<InvControl isRequired label={t('modelManager.config')}>
<InvInput {...register('config')} />
</InvControl>
)}
<InvControl label={t('modelManager.useCustomConfig')}>

View File

@ -1,5 +1,4 @@
import { Flex } from '@chakra-ui/react';
import { useForm } from '@mantine/form';
import { useAppDispatch } from 'app/store/storeHooks';
import { InvButton } from 'common/components/InvButton/InvButton';
import { InvControl } from 'common/components/InvControl/InvControl';
@ -11,6 +10,8 @@ import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import type { CSSProperties, FocusEventHandler } from 'react';
import { memo, useCallback } from 'react';
import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import { useAddMainModelsMutation } from 'services/api/endpoints/models';
import type { DiffusersModelConfig } from 'services/api/types';
@ -28,8 +29,16 @@ const AdvancedAddDiffusers = (props: AdvancedAddDiffusersProps) => {
const [addMainModel] = useAddMainModelsMutation();
const advancedAddDiffusersForm = useForm<DiffusersModelConfig>({
initialValues: {
const {
register,
handleSubmit,
control,
getValues,
setValue,
formState: { errors },
reset,
} = useForm<DiffusersModelConfig>({
defaultValues: {
model_name: model_path ? getModelName(model_path, false) : '',
base_model: 'sd-1',
model_type: 'main',
@ -40,96 +49,104 @@ const AdvancedAddDiffusers = (props: AdvancedAddDiffusersProps) => {
vae: '',
variant: 'normal',
},
mode: 'onChange',
});
const advancedAddDiffusersFormHandler = (values: DiffusersModelConfig) => {
addMainModel({
body: values,
})
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: t('modelManager.modelAdded', {
modelName: values.model_name,
}),
status: 'success',
})
)
);
advancedAddDiffusersForm.reset();
// Close Advanced Panel in Scan Models tab
if (model_path) {
dispatch(setAdvancedAddScanModel(null));
}
const onSubmit = useCallback<SubmitHandler<DiffusersModelConfig>>(
(values) => {
addMainModel({
body: values,
})
.catch((error) => {
if (error) {
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: t('toast.modelAddFailed'),
status: 'error',
title: t('modelManager.modelAdded', {
modelName: values.model_name,
}),
status: 'success',
})
)
);
}
});
};
const handleBlurModelLocation: FocusEventHandler<HTMLInputElement> =
useCallback(
(e) => {
if (advancedAddDiffusersForm.values['model_name'] === '') {
const modelName = getModelName(e.currentTarget.value, false);
if (modelName) {
advancedAddDiffusersForm.setFieldValue(
'model_name',
modelName as string
reset();
// Close Advanced Panel in Scan Models tab
if (model_path) {
dispatch(setAdvancedAddScanModel(null));
}
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: t('toast.modelAddFailed'),
status: 'error',
})
)
);
}
});
},
[addMainModel, dispatch, model_path, reset, t]
);
const onBlur: FocusEventHandler<HTMLInputElement> = useCallback(
(e) => {
if (getValues().model_name === '') {
const modelName = getModelName(e.currentTarget.value, false);
if (modelName) {
setValue('model_name', modelName as string);
}
},
[advancedAddDiffusersForm]
);
}
},
[getValues, setValue]
);
return (
<form
onSubmit={advancedAddDiffusersForm.onSubmit((v) =>
advancedAddDiffusersFormHandler(v)
)}
style={formStyles}
>
<form onSubmit={handleSubmit(onSubmit)} style={formStyles}>
<Flex flexDirection="column" gap={2}>
<InvControl isRequired label={t('modelManager.model')}>
<InvInput {...advancedAddDiffusersForm.getInputProps('model_name')} />
</InvControl>
<InvControl label={t('modelManager.baseModel')}>
<BaseModelSelect
{...advancedAddDiffusersForm.getInputProps('base_model')}
<InvControl
label={t('modelManager.name')}
isInvalid={Boolean(errors.model_name)}
error={errors.model_name?.message}
>
<InvInput
{...register('model_name', {
validate: (value) =>
value.trim().length > 3 || 'Must be at least 3 characters',
})}
/>
</InvControl>
<InvControl isRequired label={t('modelManager.modelLocation')}>
<InvControl label={t('modelManager.baseModel')}>
<BaseModelSelect<DiffusersModelConfig>
control={control}
name="base_model"
/>
</InvControl>
<InvControl
label={t('modelManager.modelLocation')}
isInvalid={Boolean(errors.path)}
error={errors.path?.message}
>
<InvInput
placeholder={t('modelManager.modelLocationValidationMsg')}
{...advancedAddDiffusersForm.getInputProps('path')}
onBlur={handleBlurModelLocation}
{...register('path', {
validate: (value) =>
value.trim().length > 0 || 'Must provide a path',
onBlur,
})}
/>
</InvControl>
<InvControl label={t('modelManager.description')}>
<InvInput
{...advancedAddDiffusersForm.getInputProps('description')}
/>
<InvInput {...register('description')} />
</InvControl>
<InvControl label={t('modelManager.vaeLocation')}>
<InvInput {...advancedAddDiffusersForm.getInputProps('vae')} />
</InvControl>
<InvControl label={t('modelManager.variant')}>
<ModelVariantSelect
{...advancedAddDiffusersForm.getInputProps('variant')}
/>
<InvInput {...register('vae')} />
</InvControl>
<ModelVariantSelect<DiffusersModelConfig>
control={control}
name="variant"
/>
<InvButton mt={2} type="submit">
{t('modelManager.addModel')}

View File

@ -7,15 +7,10 @@ import SearchFolderForm from './SearchFolderForm';
const ScanModels = () => {
return (
<Flex flexDirection="column" w="100%" gap={4}>
<Flex flexDirection="column" w="100%" h="full" gap={4}>
<SearchFolderForm />
<Flex gap={4}>
<Flex
maxHeight="calc(100vh - 300px)"
overflow="scroll"
gap={4}
w="100%"
>
<Flex overflow="scroll" gap={4} w="100%" h="full">
<FoundModelsList />
</Flex>
<ScanAdvancedAddModels />

View File

@ -17,7 +17,7 @@ const ImportModelsPanel = () => {
const handleClickScanTab = useCallback(() => setAddModelTab('scan'), []);
return (
<Flex flexDirection="column" gap={4}>
<Flex flexDirection="column" gap={4} h="full">
<InvButtonGroup>
<InvButton
onClick={handleClickAddTab}

View File

@ -1,5 +1,4 @@
import { Badge, Divider, Flex } from '@chakra-ui/react';
import { useForm } from '@mantine/form';
import { useAppDispatch } from 'app/store/storeHooks';
import { InvButton } from 'common/components/InvButton/InvButton';
import { InvCheckbox } from 'common/components/InvCheckbox/wrapper';
@ -13,6 +12,8 @@ import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { memo, useCallback, useEffect, useState } from 'react';
import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import type { CheckpointModelConfigEntity } from 'services/api/endpoints/models';
import {
@ -44,8 +45,14 @@ const CheckpointModelEdit = (props: CheckpointModelEditProps) => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const checkpointEditForm = useForm<CheckpointModelConfig>({
initialValues: {
const {
register,
handleSubmit,
control,
formState: { errors },
reset,
} = useForm<CheckpointModelConfig>({
defaultValues: {
model_name: model.model_name ? model.model_name : '',
base_model: model.base_model,
model_type: 'main',
@ -56,10 +63,7 @@ const CheckpointModelEdit = (props: CheckpointModelEditProps) => {
config: model.config ? model.config : '',
variant: model.variant,
},
validate: {
path: (value) =>
value.trim().length === 0 ? 'Must provide a path' : null,
},
mode: 'onChange',
});
const handleChangeUseCustomConfig = useCallback(
@ -67,8 +71,8 @@ const CheckpointModelEdit = (props: CheckpointModelEditProps) => {
[]
);
const editModelFormSubmitHandler = useCallback(
(values: CheckpointModelConfig) => {
const onSubmit = useCallback<SubmitHandler<CheckpointModelConfig>>(
(values) => {
const responseBody = {
base_model: model.base_model,
model_name: model.model_name,
@ -77,7 +81,7 @@ const CheckpointModelEdit = (props: CheckpointModelEditProps) => {
updateMainModel(responseBody)
.unwrap()
.then((payload) => {
checkpointEditForm.setValues(payload as CheckpointModelConfig);
reset(payload as CheckpointModelConfig, { keepDefaultValues: true });
dispatch(
addToast(
makeToast({
@ -88,7 +92,7 @@ const CheckpointModelEdit = (props: CheckpointModelEditProps) => {
);
})
.catch((_) => {
checkpointEditForm.reset();
reset();
dispatch(
addToast(
makeToast({
@ -99,14 +103,7 @@ const CheckpointModelEdit = (props: CheckpointModelEditProps) => {
);
});
},
[
checkpointEditForm,
dispatch,
model.base_model,
model.model_name,
t,
updateMainModel,
]
[dispatch, model.base_model, model.model_name, reset, t, updateMainModel]
);
return (
@ -135,42 +132,53 @@ const CheckpointModelEdit = (props: CheckpointModelEditProps) => {
maxHeight={window.innerHeight - 270}
overflowY="scroll"
>
<form
onSubmit={checkpointEditForm.onSubmit((values) =>
editModelFormSubmitHandler(values)
)}
>
<form onSubmit={handleSubmit(onSubmit)}>
<Flex flexDirection="column" overflowY="scroll" gap={4}>
<InvControl label={t('modelManager.name')}>
<InvInput {...checkpointEditForm.getInputProps('model_name')} />
<InvControl
label={t('modelManager.name')}
isInvalid={Boolean(errors.model_name)}
error={errors.model_name?.message}
>
<InvInput
{...register('model_name', {
validate: (value) =>
value.trim().length > 3 || 'Must be at least 3 characters',
})}
/>
</InvControl>
<InvControl label={t('modelManager.description')}>
<InvInput {...checkpointEditForm.getInputProps('description')} />
<InvInput {...register('description')} />
</InvControl>
<BaseModelSelect
required
{...checkpointEditForm.getInputProps('base_model')}
<BaseModelSelect<CheckpointModelConfig>
control={control}
name="base_model"
/>
<ModelVariantSelect
required
{...checkpointEditForm.getInputProps('variant')}
<ModelVariantSelect<CheckpointModelConfig>
control={control}
name="variant"
/>
<InvControl isRequired label={t('modelManager.modelLocation')}>
<InvInput {...checkpointEditForm.getInputProps('path')} />
<InvControl
label={t('modelManager.modelLocation')}
isInvalid={Boolean(errors.path)}
error={errors.path?.message}
>
<InvInput
{...register('path', {
validate: (value) =>
value.trim().length > 0 || 'Must provide a path',
})}
/>
</InvControl>
<InvControl label={t('modelManager.vaeLocation')}>
<InvInput {...checkpointEditForm.getInputProps('vae')} />
<InvInput {...register('vae')} />
</InvControl>
<Flex flexDirection="column" gap={2}>
{!useCustomConfig ? (
<CheckpointConfigsSelect
required
{...checkpointEditForm.getInputProps('config')}
/>
<CheckpointConfigsSelect control={control} name="config" />
) : (
<InvControl isRequired label={t('modelManager.config')}>
<InvInput {...checkpointEditForm.getInputProps('config')} />
<InvInput {...register('config')} />
</InvControl>
)}
<InvControl label="Use Custom Config">

View File

@ -1,5 +1,4 @@
import { Divider, Flex } from '@chakra-ui/react';
import { useForm } from '@mantine/form';
import { useAppDispatch } from 'app/store/storeHooks';
import { InvButton } from 'common/components/InvButton/InvButton';
import { InvControl } from 'common/components/InvControl/InvControl';
@ -11,6 +10,8 @@ import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { memo, useCallback } from 'react';
import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import type { DiffusersModelConfigEntity } from 'services/api/endpoints/models';
import { useUpdateMainModelsMutation } from 'services/api/endpoints/models';
@ -28,8 +29,14 @@ const DiffusersModelEdit = (props: DiffusersModelEditProps) => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const diffusersEditForm = useForm<DiffusersModelConfig>({
initialValues: {
const {
register,
handleSubmit,
control,
formState: { errors },
reset,
} = useForm<DiffusersModelConfig>({
defaultValues: {
model_name: model.model_name ? model.model_name : '',
base_model: model.base_model,
model_type: 'main',
@ -39,14 +46,11 @@ const DiffusersModelEdit = (props: DiffusersModelEditProps) => {
vae: model.vae ? model.vae : '',
variant: model.variant,
},
validate: {
path: (value) =>
value.trim().length === 0 ? 'Must provide a path' : null,
},
mode: 'onChange',
});
const editModelFormSubmitHandler = useCallback(
(values: DiffusersModelConfig) => {
const onSubmit = useCallback<SubmitHandler<DiffusersModelConfig>>(
(values) => {
const responseBody = {
base_model: model.base_model,
model_name: model.model_name,
@ -56,7 +60,7 @@ const DiffusersModelEdit = (props: DiffusersModelEditProps) => {
updateMainModel(responseBody)
.unwrap()
.then((payload) => {
diffusersEditForm.setValues(payload as DiffusersModelConfig);
reset(payload as DiffusersModelConfig, { keepDefaultValues: true });
dispatch(
addToast(
makeToast({
@ -67,7 +71,7 @@ const DiffusersModelEdit = (props: DiffusersModelEditProps) => {
);
})
.catch((_) => {
diffusersEditForm.reset();
reset();
dispatch(
addToast(
makeToast({
@ -78,14 +82,7 @@ const DiffusersModelEdit = (props: DiffusersModelEditProps) => {
);
});
},
[
diffusersEditForm,
dispatch,
model.base_model,
model.model_name,
t,
updateMainModel,
]
[dispatch, model.base_model, model.model_name, reset, t, updateMainModel]
);
return (
@ -100,31 +97,45 @@ const DiffusersModelEdit = (props: DiffusersModelEditProps) => {
</Flex>
<Divider />
<form
onSubmit={diffusersEditForm.onSubmit((values) =>
editModelFormSubmitHandler(values)
)}
>
<form onSubmit={handleSubmit(onSubmit)}>
<Flex flexDirection="column" overflowY="scroll" gap={4}>
<InvControl label={t('modelManager.name')}>
<InvInput {...diffusersEditForm.getInputProps('model_name')} />
<InvControl
label={t('modelManager.name')}
isInvalid={Boolean(errors.model_name)}
error={errors.model_name?.message}
>
<InvInput
{...register('model_name', {
validate: (value) =>
value.trim().length > 3 || 'Must be at least 3 characters',
})}
/>
</InvControl>
<InvControl label={t('modelManager.description')}>
<InvInput {...diffusersEditForm.getInputProps('description')} />
<InvInput {...register('description')} />
</InvControl>
<BaseModelSelect
required
{...diffusersEditForm.getInputProps('base_model')}
<BaseModelSelect<DiffusersModelConfig>
control={control}
name="base_model"
/>
<ModelVariantSelect
required
{...diffusersEditForm.getInputProps('variant')}
<ModelVariantSelect<DiffusersModelConfig>
control={control}
name="variant"
/>
<InvControl isRequired label={t('modelManager.modelLocation')}>
<InvInput {...diffusersEditForm.getInputProps('path')} />
<InvControl
label={t('modelManager.modelLocation')}
isInvalid={Boolean(errors.path)}
error={errors.path?.message}
>
<InvInput
{...register('path', {
validate: (value) =>
value.trim().length > 0 || 'Must provide a path',
})}
/>
</InvControl>
<InvControl label={t('modelManager.vaeLocation')}>
<InvInput {...diffusersEditForm.getInputProps('vae')} />
<InvInput {...register('vae')} />
</InvControl>
<InvButton type="submit" isLoading={isLoading}>
{t('modelManager.updateModel')}

View File

@ -1,5 +1,4 @@
import { Divider, Flex } from '@chakra-ui/react';
import { useForm } from '@mantine/form';
import { useAppDispatch } from 'app/store/storeHooks';
import { InvButton } from 'common/components/InvButton/InvButton';
import { InvControl } from 'common/components/InvControl/InvControl';
@ -13,6 +12,8 @@ import {
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { memo, useCallback } from 'react';
import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import type { LoRAModelConfigEntity } from 'services/api/endpoints/models';
import { useUpdateLoRAModelsMutation } from 'services/api/endpoints/models';
@ -30,8 +31,14 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const loraEditForm = useForm<LoRAModelConfig>({
initialValues: {
const {
register,
handleSubmit,
control,
formState: { errors },
reset,
} = useForm<LoRAModelConfig>({
defaultValues: {
model_name: model.model_name ? model.model_name : '',
base_model: model.base_model,
model_type: 'lora',
@ -39,14 +46,11 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => {
description: model.description ? model.description : '',
model_format: model.model_format,
},
validate: {
path: (value) =>
value.trim().length === 0 ? 'Must provide a path' : null,
},
mode: 'onChange',
});
const editModelFormSubmitHandler = useCallback(
(values: LoRAModelConfig) => {
const onSubmit = useCallback<SubmitHandler<LoRAModelConfig>>(
(values) => {
const responseBody = {
base_model: model.base_model,
model_name: model.model_name,
@ -56,7 +60,7 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => {
updateLoRAModel(responseBody)
.unwrap()
.then((payload) => {
loraEditForm.setValues(payload as LoRAModelConfig);
reset(payload as LoRAModelConfig, { keepDefaultValues: true });
dispatch(
addToast(
makeToast({
@ -67,7 +71,7 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => {
);
})
.catch((_) => {
loraEditForm.reset();
reset();
dispatch(
addToast(
makeToast({
@ -78,14 +82,7 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => {
);
});
},
[
dispatch,
loraEditForm,
model.base_model,
model.model_name,
t,
updateLoRAModel,
]
[dispatch, model.base_model, model.model_name, reset, t, updateLoRAModel]
);
return (
@ -101,21 +98,39 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => {
</Flex>
<Divider />
<form
onSubmit={loraEditForm.onSubmit((values) =>
editModelFormSubmitHandler(values)
)}
>
<form onSubmit={handleSubmit(onSubmit)}>
<Flex flexDirection="column" overflowY="scroll" gap={4}>
<InvControl label={t('modelManager.name')}>
<InvInput {...loraEditForm.getInputProps('model_name')} />
<InvControl
label={t('modelManager.name')}
isInvalid={Boolean(errors.model_name)}
error={errors.model_name?.message}
>
<InvInput
{...register('model_name', {
validate: (value) =>
value.trim().length > 3 || 'Must be at least 3 characters',
})}
/>
</InvControl>
<InvControl label={t('modelManager.description')}>
<InvInput {...loraEditForm.getInputProps('description')} />
<InvInput {...register('description')} />
</InvControl>
<BaseModelSelect {...loraEditForm.getInputProps('base_model')} />
<InvControl label={t('modelManager.modelLocation')}>
<InvInput {...loraEditForm.getInputProps('path')} />
<BaseModelSelect<LoRAModelConfig>
control={control}
name="base_model"
/>
<InvControl
label={t('modelManager.modelLocation')}
isInvalid={Boolean(errors.path)}
error={errors.path?.message}
>
<InvInput
{...register('path', {
validate: (value) =>
value.trim().length > 0 || 'Must provide a path',
})}
/>
</InvControl>
<InvButton type="submit" isLoading={isLoading}>
{t('modelManager.updateModel')}

View File

@ -1,12 +1,16 @@
import { InvControl } from 'common/components/InvControl/InvControl';
import { InvSelect } from 'common/components/InvSelect/InvSelect';
import type {
InvSelectOnChange,
InvSelectOption,
InvSelectProps,
} from 'common/components/InvSelect/types';
import { typedMemo } from 'common/util/typedMemo';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { memo } from 'react';
import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import type { AnyModelConfig } from 'services/api/types';
const options: InvSelectOption[] = [
{ value: 'sd-1', label: MODEL_TYPE_MAP['sd-1'] },
@ -15,15 +19,26 @@ const options: InvSelectOption[] = [
{ value: 'sdxl-refiner', label: MODEL_TYPE_MAP['sdxl-refiner'] },
];
type BaseModelSelectProps = Omit<InvSelectProps, 'options'>;
const BaseModelSelect = (props: BaseModelSelectProps) => {
const BaseModelSelect = <T extends AnyModelConfig>(
props: UseControllerProps<T>
) => {
const { t } = useTranslation();
const { field } = useController(props);
const value = useMemo(
() => options.find((o) => o.value === field.value),
[field.value]
);
const onChange = useCallback<InvSelectOnChange>(
(v) => {
field.onChange(v?.value);
},
[field]
);
return (
<InvControl label={t('modelManager.baseModel')}>
<InvSelect options={options} {...props} />
<InvSelect value={value} options={options} onChange={onChange} />
</InvControl>
);
};
export default memo(BaseModelSelect);
export default typedMemo(BaseModelSelect);

View File

@ -2,29 +2,44 @@ import type { ChakraProps } from '@chakra-ui/react';
import { InvControl } from 'common/components/InvControl/InvControl';
import { InvSelect } from 'common/components/InvSelect/InvSelect';
import type {
InvSelectOnChange,
InvSelectOption,
InvSelectProps,
} from 'common/components/InvSelect/types';
import { memo, useMemo } from 'react';
import { memo, useCallback, useMemo } from 'react';
import { useController, type UseControllerProps } from 'react-hook-form';
import { useGetCheckpointConfigsQuery } from 'services/api/endpoints/models';
type CheckpointConfigSelectProps = Omit<InvSelectProps, 'options'>;
import type { CheckpointModelConfig } from 'services/api/types';
const sx: ChakraProps['sx'] = { w: 'full' };
const CheckpointConfigsSelect = (props: CheckpointConfigSelectProps) => {
const CheckpointConfigsSelect = (
props: UseControllerProps<CheckpointModelConfig>
) => {
const { data } = useGetCheckpointConfigsQuery();
const options = useMemo<InvSelectOption[]>(
() => (data ? data.map((i) => ({ label: i, value: i })) : []),
[data]
);
const { field } = useController(props);
const value = useMemo(
() => options.find((o) => o.value === field.value),
[field.value, options]
);
const onChange = useCallback<InvSelectOnChange>(
(v) => {
field.onChange(v?.value);
},
[field]
);
return (
<InvControl label="Config File">
<InvSelect
placeholder="Select A Config File"
value={value}
options={options}
onChange={onChange}
sx={sx}
{...props}
/>
</InvControl>
);

View File

@ -1,11 +1,18 @@
import { InvControl } from 'common/components/InvControl/InvControl';
import { InvSelect } from 'common/components/InvSelect/InvSelect';
import type {
InvSelectOnChange,
InvSelectOption,
InvSelectProps,
} from 'common/components/InvSelect/types';
import { memo } from 'react';
import { typedMemo } from 'common/util/typedMemo';
import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import type {
CheckpointModelConfig,
DiffusersModelConfig,
} from 'services/api/types';
const options: InvSelectOption[] = [
{ value: 'normal', label: 'Normal' },
@ -13,15 +20,28 @@ const options: InvSelectOption[] = [
{ value: 'depth', label: 'Depth' },
];
type VariantSelectProps = Omit<InvSelectProps, 'options'>;
const ModelVariantSelect = (props: VariantSelectProps) => {
const ModelVariantSelect = <
T extends CheckpointModelConfig | DiffusersModelConfig,
>(
props: UseControllerProps<T>
) => {
const { t } = useTranslation();
const { field } = useController(props);
const value = useMemo(
() => options.find((o) => o.value === field.value),
[field.value]
);
const onChange = useCallback<InvSelectOnChange>(
(v) => {
field.onChange(v?.value);
},
[field]
);
return (
<InvControl label={t('modelManager.variant')}>
<InvSelect options={options} {...props} />
<InvSelect value={value} options={options} onChange={onChange} />
</InvControl>
);
};
export default memo(ModelVariantSelect);
export default typedMemo(ModelVariantSelect);