mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactored and fixed issues with advanced import form
This commit is contained in:
parent
170d9bca98
commit
7f56e84a8d
@ -1,56 +1,238 @@
|
||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import { Combobox, Flex, FormLabel,Text } from '@invoke-ai/ui-library';
|
||||
import { Button, Flex, FormControl, FormErrorMessage, FormLabel, Input, Text, Textarea } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||
import { useCallback, useMemo, useState } from 'react';
|
||||
import BaseModelSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/BaseModelSelect';
|
||||
import BooleanSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/BooleanSelect';
|
||||
import ModelFormatSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/ModelFormatSelect';
|
||||
import ModelTypeSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/ModelTypeSelect';
|
||||
import ModelVariantSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/ModelVariantSelect';
|
||||
import PredictionTypeSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/PredictionTypeSelect';
|
||||
import RepoVariantSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/RepoVariantSelect';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { useCallback, useEffect } from 'react';
|
||||
import type { SubmitHandler} from 'react-hook-form';
|
||||
import {useForm } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useImportAdvancedModelMutation } from 'services/api/endpoints/models';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
import { z } from 'zod';
|
||||
|
||||
import { AdvancedImportCheckpoint } from './AdvancedImportCheckpoint';
|
||||
import { AdvancedImportDiffusers } from './AdvancedImportDiffusers';
|
||||
|
||||
export const zManualAddMode = z.enum(['diffusers', 'checkpoint']);
|
||||
export type ManualAddMode = z.infer<typeof zManualAddMode>;
|
||||
export const isManualAddMode = (v: unknown): v is ManualAddMode => zManualAddMode.safeParse(v).success;
|
||||
|
||||
export const AdvancedImport = () => {
|
||||
const [advancedAddMode, setAdvancedAddMode] = useState<ManualAddMode>('diffusers');
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const [importAdvancedModel] = useImportAdvancedModelMutation();
|
||||
|
||||
const { t } = useTranslation();
|
||||
const handleChange: ComboboxOnChange = useCallback((v) => {
|
||||
if (!isManualAddMode(v?.value)) {
|
||||
return;
|
||||
}
|
||||
setAdvancedAddMode(v.value);
|
||||
}, []);
|
||||
|
||||
const options: ComboboxOption[] = useMemo(
|
||||
() => [
|
||||
{ label: t('modelManager.diffusersModels'), value: 'diffusers' },
|
||||
{ label: t('modelManager.checkpointOrSafetensors'), value: 'checkpoint' },
|
||||
],
|
||||
[t]
|
||||
const {
|
||||
register,
|
||||
handleSubmit,
|
||||
control,
|
||||
formState: { errors },
|
||||
setValue,
|
||||
resetField,
|
||||
reset,
|
||||
watch,
|
||||
} = useForm<AnyModelConfig>({
|
||||
defaultValues: {
|
||||
name: '',
|
||||
base: 'sd-1',
|
||||
type: 'main',
|
||||
path: '',
|
||||
description: '',
|
||||
format: 'diffusers',
|
||||
vae: '',
|
||||
variant: 'normal',
|
||||
},
|
||||
mode: 'onChange',
|
||||
});
|
||||
|
||||
const onSubmit = useCallback<SubmitHandler<AnyModelConfig>>(
|
||||
(values) => {
|
||||
const cleanValues = Object.fromEntries(
|
||||
Object.entries(values).filter(([value]) => value !== null && value !== undefined)
|
||||
);
|
||||
importAdvancedModel({
|
||||
source: {
|
||||
path: cleanValues.path as string,
|
||||
type: 'local',
|
||||
},
|
||||
config: cleanValues,
|
||||
})
|
||||
.unwrap()
|
||||
.then((_) => {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('modelManager.modelAdded', {
|
||||
modelName: values.name,
|
||||
}),
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
reset();
|
||||
})
|
||||
.catch((error) => {
|
||||
if (error) {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('toast.modelAddFailed'),
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
});
|
||||
},
|
||||
[dispatch, reset, t, importAdvancedModel]
|
||||
);
|
||||
|
||||
const value = useMemo(() => options.find((o) => o.value === advancedAddMode), [options, advancedAddMode]);
|
||||
const watchedModelType = watch('type');
|
||||
const watchedModelFormat = watch('format');
|
||||
|
||||
useEffect(() => {
|
||||
if (watchedModelType === 'main') {
|
||||
setValue('format', 'diffusers');
|
||||
setValue('repo_variant', '');
|
||||
setValue('variant', 'normal');
|
||||
}
|
||||
if (watchedModelType === 'lora') {
|
||||
setValue('format', 'lycoris');
|
||||
} else if (watchedModelType === 'embedding') {
|
||||
setValue('format', 'embedding_file');
|
||||
} else if (watchedModelType === 'ip_adapter') {
|
||||
setValue('format', 'invokeai');
|
||||
} else {
|
||||
setValue('format', 'diffusers');
|
||||
}
|
||||
resetField('upcast_attention');
|
||||
resetField('ztsnr_training');
|
||||
resetField('vae');
|
||||
resetField('config');
|
||||
resetField('prediction_type');
|
||||
resetField('image_encoder_model_id');
|
||||
}, [watchedModelType, resetField, setValue]);
|
||||
|
||||
return (
|
||||
<ScrollableContent>
|
||||
<Flex flexDirection="column" gap={4} width="100%" pb={10}>
|
||||
<Flex alignItems="flex-end" gap="4">
|
||||
<Flex direction="column" gap="3" width="full">
|
||||
<FormLabel>{t('modelManager.modelType')}</FormLabel>
|
||||
<Combobox value={value} options={options} onChange={handleChange} />
|
||||
<form onSubmit={handleSubmit(onSubmit)}>
|
||||
<Flex flexDirection="column" gap={4} width="100%" pb={10}>
|
||||
<Flex alignItems="flex-end" gap="4">
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>Model Type</FormLabel>
|
||||
<ModelTypeSelect<AnyModelConfig> control={control} name="type" />
|
||||
</FormControl>
|
||||
<Text px="2" fontSize="xs" textAlign="center">
|
||||
{t('modelManager.advancedImportInfo')}
|
||||
</Text>
|
||||
</Flex>
|
||||
<Text px="2" fontSize="xs" textAlign="center">
|
||||
{t('modelManager.advancedImportInfo')}
|
||||
</Text>
|
||||
</Flex>
|
||||
|
||||
<Flex p={4} borderRadius={4} bg="base.850" height="100%">
|
||||
{advancedAddMode === 'diffusers' && <AdvancedImportDiffusers />}
|
||||
{advancedAddMode === 'checkpoint' && <AdvancedImportCheckpoint />}
|
||||
<Flex p={4} borderRadius={4} bg="base.850" height="100%" direction="column" gap="3">
|
||||
<FormControl isInvalid={Boolean(errors.name)}>
|
||||
<Flex direction="column" width="full">
|
||||
<FormLabel>{t('modelManager.name')}</FormLabel>
|
||||
<Input
|
||||
{...register('name', {
|
||||
validate: (value) => value.trim().length >= 3 || 'Must be at least 3 characters',
|
||||
})}
|
||||
/>
|
||||
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
|
||||
</Flex>
|
||||
</FormControl>
|
||||
<Flex>
|
||||
<FormControl>
|
||||
<Flex direction="column" width="full">
|
||||
<FormLabel>{t('modelManager.description')}</FormLabel>
|
||||
<Textarea size="sm" {...register('description')} />
|
||||
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
|
||||
</Flex>
|
||||
</FormControl>
|
||||
</Flex>
|
||||
<Flex gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>Base Model</FormLabel>
|
||||
<BaseModelSelect<AnyModelConfig> control={control} name="base" />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>Format</FormLabel>
|
||||
<ModelFormatSelect<AnyModelConfig> control={control} name="format" />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
<Flex gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1} isInvalid={Boolean(errors.path)}>
|
||||
<FormLabel>Path</FormLabel>
|
||||
<Input
|
||||
{...register('path', {
|
||||
validate: (value) => value.trim().length > 0 || 'Must provide a path',
|
||||
})}
|
||||
/>
|
||||
{errors.path?.message && <FormErrorMessage>{errors.path?.message}</FormErrorMessage>}
|
||||
</FormControl>
|
||||
</Flex>
|
||||
{watchedModelType === 'main' && (
|
||||
<>
|
||||
<Flex gap={4}>
|
||||
{watchedModelFormat === 'diffusers' && (
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>Repo Variant</FormLabel>
|
||||
<RepoVariantSelect<AnyModelConfig> control={control} name="repo_variant" />
|
||||
</FormControl>
|
||||
)}
|
||||
{watchedModelFormat === 'checkpoint' && (
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>Config Path</FormLabel>
|
||||
<Input {...register('config')} />
|
||||
</FormControl>
|
||||
)}
|
||||
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>Variant</FormLabel>
|
||||
<ModelVariantSelect<AnyModelConfig> control={control} name="variant" />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
<Flex gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>Prediction Type</FormLabel>
|
||||
<PredictionTypeSelect<AnyModelConfig> control={control} name="prediction_type" />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>Upcast Attention</FormLabel>
|
||||
<BooleanSelect<AnyModelConfig> control={control} name="upcast_attention" />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
<Flex gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>ZTSNR Training</FormLabel>
|
||||
<BooleanSelect<AnyModelConfig> control={control} name="ztsnr_training" />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>VAE Path</FormLabel>
|
||||
<Input {...register('vae')} />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
</>
|
||||
)}
|
||||
{watchedModelType === 'ip_adapter' && (
|
||||
<Flex gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>Image Encoder Model ID</FormLabel>
|
||||
<Input {...register('image_encoder_model_id')} />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
)}
|
||||
<Button mt={2} type="submit">
|
||||
{t('modelManager.addModel')}
|
||||
</Button>
|
||||
</Flex>
|
||||
</Flex>
|
||||
</Flex>
|
||||
</form>
|
||||
</ScrollableContent>
|
||||
);
|
||||
};
|
||||
|
@ -1,160 +0,0 @@
|
||||
import {
|
||||
Button,
|
||||
Checkbox,
|
||||
Flex,
|
||||
FormControl,
|
||||
FormErrorMessage,
|
||||
FormLabel,
|
||||
Input,
|
||||
Textarea,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import type { CSSProperties } from 'react';
|
||||
import { useCallback, useState } from 'react';
|
||||
import type { SubmitHandler } from 'react-hook-form';
|
||||
import { useForm } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useAddMainModelsMutation } from 'services/api/endpoints/models';
|
||||
import type { CheckpointModelConfig, DiffusersModelConfig } from 'services/api/types';
|
||||
|
||||
import BaseModelSelect from './BaseModelSelect';
|
||||
import CheckpointConfigsSelect from './CheckpointConfigsSelect';
|
||||
import ModelVariantSelect from './ModelVariantSelect';
|
||||
|
||||
export const AdvancedImportCheckpoint = () => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const [addMainModel] = useAddMainModelsMutation();
|
||||
|
||||
const [useCustomConfig, setUseCustomConfig] = useState<boolean>(false);
|
||||
|
||||
const {
|
||||
register,
|
||||
handleSubmit,
|
||||
control,
|
||||
formState: { errors },
|
||||
reset,
|
||||
} = useForm<CheckpointModelConfig>({
|
||||
defaultValues: {
|
||||
name: '',
|
||||
base: 'sd-1',
|
||||
type: 'main',
|
||||
path: '',
|
||||
description: '',
|
||||
format: 'checkpoint',
|
||||
vae: '',
|
||||
variant: 'normal',
|
||||
config: '',
|
||||
},
|
||||
mode: 'onChange',
|
||||
});
|
||||
|
||||
const onSubmit = useCallback<SubmitHandler<DiffusersModelConfig>>(
|
||||
(values) => {
|
||||
addMainModel({
|
||||
body: values,
|
||||
})
|
||||
.unwrap()
|
||||
.then((_) => {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('modelManager.modelAdded', {
|
||||
modelName: values.name,
|
||||
}),
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
reset();
|
||||
})
|
||||
.catch((error) => {
|
||||
if (error) {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('toast.modelAddFailed'),
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
});
|
||||
},
|
||||
[addMainModel, dispatch, reset, t]
|
||||
);
|
||||
|
||||
const handleChangeUseCustomConfig = useCallback(() => setUseCustomConfig((prev) => !prev), []);
|
||||
|
||||
return (
|
||||
<form onSubmit={handleSubmit(onSubmit)} style={formStyles}>
|
||||
<Flex flexDirection="column" gap={2}>
|
||||
<FormControl isInvalid={Boolean(errors.name)}>
|
||||
<Flex direction="column" width="full">
|
||||
<FormLabel>{t('modelManager.name')}</FormLabel>
|
||||
<Input
|
||||
{...register('name', {
|
||||
validate: (value) => value.trim().length > 3 || 'Must be at least 3 characters',
|
||||
})}
|
||||
/>
|
||||
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
|
||||
</Flex>
|
||||
</FormControl>
|
||||
<Flex>
|
||||
<FormControl>
|
||||
<Flex direction="column" width="full">
|
||||
<FormLabel>{t('modelManager.description')}</FormLabel>
|
||||
<Textarea size="sm" {...register('description')} />
|
||||
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
|
||||
</Flex>
|
||||
</FormControl>
|
||||
</Flex>
|
||||
<Flex gap="2">
|
||||
<BaseModelSelect<CheckpointModelConfig> control={control} name="base" />
|
||||
<ModelVariantSelect<CheckpointModelConfig> control={control} name="variant" />
|
||||
</Flex>
|
||||
<FormControl isInvalid={Boolean(errors.path)}>
|
||||
<Flex direction="column" width="full">
|
||||
<FormLabel>{t('modelManager.modelLocation')}</FormLabel>
|
||||
<Input
|
||||
{...register('path', {
|
||||
validate: (value) => value.trim().length > 0 || 'Must provide a path',
|
||||
})}
|
||||
/>
|
||||
{errors.path?.message && <FormErrorMessage>{errors.path?.message}</FormErrorMessage>}
|
||||
</Flex>
|
||||
</FormControl>
|
||||
<FormControl>
|
||||
<Flex direction="column" width="full">
|
||||
<FormLabel>{t('modelManager.vaeLocation')}</FormLabel>
|
||||
<Input {...register('vae')} />
|
||||
</Flex>
|
||||
</FormControl>
|
||||
<Flex flexDirection="column" width="100%" gap={2}>
|
||||
{!useCustomConfig ? (
|
||||
<CheckpointConfigsSelect control={control} name="config" />
|
||||
) : (
|
||||
<FormControl isRequired>
|
||||
<FormLabel>{t('modelManager.config')}</FormLabel>
|
||||
<Input {...register('config')} />
|
||||
</FormControl>
|
||||
)}
|
||||
<FormControl>
|
||||
<FormLabel>{t('modelManager.useCustomConfig')}</FormLabel>
|
||||
<Checkbox isChecked={useCustomConfig} onChange={handleChangeUseCustomConfig} />
|
||||
</FormControl>
|
||||
<Button mt={2} type="submit">
|
||||
{t('modelManager.addModel')}
|
||||
</Button>
|
||||
</Flex>
|
||||
</Flex>
|
||||
</form>
|
||||
);
|
||||
};
|
||||
|
||||
const formStyles: CSSProperties = {
|
||||
width: '100%',
|
||||
};
|
@ -1,132 +0,0 @@
|
||||
import { Button, Flex, FormControl, FormErrorMessage, FormLabel, Input, Textarea } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import type { CSSProperties } from 'react';
|
||||
import {useCallback } from 'react';
|
||||
import type { SubmitHandler } from 'react-hook-form';
|
||||
import { useForm } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useAddMainModelsMutation } from 'services/api/endpoints/models';
|
||||
import type { DiffusersModelConfig } from 'services/api/types';
|
||||
|
||||
import BaseModelSelect from './BaseModelSelect';
|
||||
import ModelVariantSelect from './ModelVariantSelect';
|
||||
|
||||
export const AdvancedImportDiffusers = () => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const [addMainModel] = useAddMainModelsMutation();
|
||||
|
||||
const {
|
||||
register,
|
||||
handleSubmit,
|
||||
control,
|
||||
formState: { errors },
|
||||
reset,
|
||||
} = useForm<DiffusersModelConfig>({
|
||||
defaultValues: {
|
||||
name: '',
|
||||
base: 'sd-1',
|
||||
type: 'main',
|
||||
path: '',
|
||||
description: '',
|
||||
format: 'diffusers',
|
||||
vae: '',
|
||||
variant: 'normal',
|
||||
},
|
||||
mode: 'onChange',
|
||||
});
|
||||
|
||||
const onSubmit = useCallback<SubmitHandler<DiffusersModelConfig>>(
|
||||
(values) => {
|
||||
addMainModel({
|
||||
body: values,
|
||||
})
|
||||
.unwrap()
|
||||
.then((_) => {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('modelManager.modelAdded', {
|
||||
modelName: values.name,
|
||||
}),
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
reset();
|
||||
})
|
||||
.catch((error) => {
|
||||
if (error) {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('toast.modelAddFailed'),
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
});
|
||||
},
|
||||
[addMainModel, dispatch, reset, t]
|
||||
);
|
||||
|
||||
return (
|
||||
<form onSubmit={handleSubmit(onSubmit)} style={formStyles}>
|
||||
<Flex flexDirection="column" gap={2}>
|
||||
<FormControl isInvalid={Boolean(errors.name)}>
|
||||
<Flex direction="column" width="full">
|
||||
<FormLabel>{t('modelManager.name')}</FormLabel>
|
||||
<Input
|
||||
{...register('name', {
|
||||
validate: (value) => value.trim().length > 3 || 'Must be at least 3 characters',
|
||||
})}
|
||||
/>
|
||||
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
|
||||
</Flex>
|
||||
</FormControl>
|
||||
<Flex>
|
||||
<FormControl>
|
||||
<Flex direction="column" width="full">
|
||||
<FormLabel>{t('modelManager.description')}</FormLabel>
|
||||
<Textarea size="sm" {...register('description')} />
|
||||
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
|
||||
</Flex>
|
||||
</FormControl>
|
||||
</Flex>
|
||||
<Flex gap="2">
|
||||
<BaseModelSelect<DiffusersModelConfig> control={control} name="base" />
|
||||
<ModelVariantSelect<DiffusersModelConfig> control={control} name="variant" />
|
||||
</Flex>
|
||||
<FormControl isInvalid={Boolean(errors.path)}>
|
||||
<Flex direction="column" width="full">
|
||||
<FormLabel>{t('modelManager.modelLocation')}</FormLabel>
|
||||
<Input
|
||||
{...register('path', {
|
||||
validate: (value) => value.trim().length > 0 || 'Must provide a path',
|
||||
})}
|
||||
/>
|
||||
{errors.path?.message && <FormErrorMessage>{errors.path?.message}</FormErrorMessage>}
|
||||
</Flex>
|
||||
</FormControl>
|
||||
<FormControl>
|
||||
<Flex direction="column" width="full">
|
||||
<FormLabel>{t('modelManager.vaeLocation')}</FormLabel>
|
||||
<Input {...register('vae')} />
|
||||
</Flex>
|
||||
</FormControl>
|
||||
|
||||
<Button mt={2} type="submit">
|
||||
{t('modelManager.addModel')}
|
||||
</Button>
|
||||
</Flex>
|
||||
</form>
|
||||
);
|
||||
};
|
||||
|
||||
const formStyles: CSSProperties = {
|
||||
width: '100%',
|
||||
};
|
@ -4,11 +4,12 @@ import { typedMemo } from 'common/util/typedMemo';
|
||||
import { LORA_MODEL_FORMAT_MAP } from 'features/parameters/types/constants';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import { useController, useWatch } from 'react-hook-form';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
const ModelFormatSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
|
||||
const { field, formState } = useController(props);
|
||||
const type = useWatch({ control: props.control, name: 'type' });
|
||||
|
||||
const onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
@ -18,17 +19,18 @@ const ModelFormatSelect = <T extends AnyModelConfig>(props: UseControllerProps<T
|
||||
);
|
||||
|
||||
const options: ComboboxOption[] = useMemo(() => {
|
||||
if (formState.defaultValues?.type === 'lora') {
|
||||
const modelType = type || formState.defaultValues?.type;
|
||||
if (modelType === 'lora') {
|
||||
return Object.keys(LORA_MODEL_FORMAT_MAP).map((format) => ({
|
||||
value: format,
|
||||
label: LORA_MODEL_FORMAT_MAP[format],
|
||||
})) as ComboboxOption[];
|
||||
} else if (formState.defaultValues?.type === 'embedding') {
|
||||
} else if (modelType === 'embedding') {
|
||||
return [
|
||||
{ value: 'embedding_file', label: 'Embedding File' },
|
||||
{ value: 'embedding_folder', label: 'Embedding Folder' },
|
||||
];
|
||||
} else if (formState.defaultValues?.type === 'ip_adapter') {
|
||||
} else if (modelType === 'ip_adapter') {
|
||||
return [{ value: 'invokeai', label: 'invokeai' }];
|
||||
} else {
|
||||
return [
|
||||
@ -36,7 +38,7 @@ const ModelFormatSelect = <T extends AnyModelConfig>(props: UseControllerProps<T
|
||||
{ value: 'checkpoint', label: 'Checkpoint' },
|
||||
];
|
||||
}
|
||||
}, [formState.defaultValues?.type]);
|
||||
}, [type, formState.defaultValues?.type]);
|
||||
|
||||
const value = useMemo(() => options.find((o) => o.value === field.value), [options, field.value]);
|
||||
|
||||
|
@ -72,11 +72,12 @@ type DeleteImportModelsResponse =
|
||||
type PruneModelImportsResponse =
|
||||
paths['/api/v2/models/import']['patch']['responses']['200']['content']['application/json'];
|
||||
|
||||
type AddMainModelArg = {
|
||||
body: MainModelConfig;
|
||||
type ImportAdvancedModelArg = {
|
||||
source: NonNullable<operations['import_model']['requestBody']['content']['application/json']['source']>;
|
||||
config: NonNullable<operations['import_model']['requestBody']['content']['application/json']['config']>;
|
||||
};
|
||||
|
||||
type AddMainModelResponse = paths['/api/v2/models/add']['post']['responses']['201']['content']['application/json'];
|
||||
type ImportAdvancedModelResponse = paths['/api/v2/models/import']['post']['responses']['201']['content']['application/json'];
|
||||
|
||||
export type ScanFolderResponse =
|
||||
paths['/api/v2/models/scan_folder']['get']['responses']['200']['content']['application/json'];
|
||||
@ -198,12 +199,12 @@ export const modelsApi = api.injectEndpoints({
|
||||
},
|
||||
invalidatesTags: ['Model', 'ModelImports'],
|
||||
}),
|
||||
addMainModels: build.mutation<AddMainModelResponse, AddMainModelArg>({
|
||||
query: ({ body }) => {
|
||||
importAdvancedModel: build.mutation<ImportAdvancedModelResponse, ImportAdvancedModelArg>({
|
||||
query: ({ source, config}) => {
|
||||
return {
|
||||
url: buildModelsUrl('add'),
|
||||
url: buildModelsUrl('install'),
|
||||
method: 'POST',
|
||||
body: body,
|
||||
body: { source, config },
|
||||
};
|
||||
},
|
||||
invalidatesTags: ['Model', 'ModelImports'],
|
||||
@ -344,7 +345,7 @@ export const {
|
||||
useDeleteModelsMutation,
|
||||
useUpdateModelsMutation,
|
||||
useImportMainModelsMutation,
|
||||
useAddMainModelsMutation,
|
||||
useImportAdvancedModelMutation,
|
||||
useConvertMainModelsMutation,
|
||||
useMergeMainModelsMutation,
|
||||
useSyncModelsMutation,
|
||||
|
Loading…
Reference in New Issue
Block a user