mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactored and fixed issues with advanced import form
This commit is contained in:
parent
98be81354a
commit
b3beaefa04
@ -1,56 +1,238 @@
|
|||||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
import { Button, Flex, FormControl, FormErrorMessage, FormLabel, Input, Text, Textarea } from '@invoke-ai/ui-library';
|
||||||
import { Combobox, Flex, FormLabel,Text } from '@invoke-ai/ui-library';
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||||
import { useCallback, useMemo, useState } from 'react';
|
import BaseModelSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/BaseModelSelect';
|
||||||
|
import BooleanSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/BooleanSelect';
|
||||||
|
import ModelFormatSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/ModelFormatSelect';
|
||||||
|
import ModelTypeSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/ModelTypeSelect';
|
||||||
|
import ModelVariantSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/ModelVariantSelect';
|
||||||
|
import PredictionTypeSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/PredictionTypeSelect';
|
||||||
|
import RepoVariantSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/RepoVariantSelect';
|
||||||
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
|
import { makeToast } from 'features/system/util/makeToast';
|
||||||
|
import { useCallback, useEffect } from 'react';
|
||||||
|
import type { SubmitHandler} from 'react-hook-form';
|
||||||
|
import {useForm } from 'react-hook-form';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { useImportAdvancedModelMutation } from 'services/api/endpoints/models';
|
||||||
|
import type { AnyModelConfig } from 'services/api/types';
|
||||||
import { z } from 'zod';
|
import { z } from 'zod';
|
||||||
|
|
||||||
import { AdvancedImportCheckpoint } from './AdvancedImportCheckpoint';
|
|
||||||
import { AdvancedImportDiffusers } from './AdvancedImportDiffusers';
|
|
||||||
|
|
||||||
export const zManualAddMode = z.enum(['diffusers', 'checkpoint']);
|
export const zManualAddMode = z.enum(['diffusers', 'checkpoint']);
|
||||||
export type ManualAddMode = z.infer<typeof zManualAddMode>;
|
export type ManualAddMode = z.infer<typeof zManualAddMode>;
|
||||||
export const isManualAddMode = (v: unknown): v is ManualAddMode => zManualAddMode.safeParse(v).success;
|
export const isManualAddMode = (v: unknown): v is ManualAddMode => zManualAddMode.safeParse(v).success;
|
||||||
|
|
||||||
export const AdvancedImport = () => {
|
export const AdvancedImport = () => {
|
||||||
const [advancedAddMode, setAdvancedAddMode] = useState<ManualAddMode>('diffusers');
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
const [importAdvancedModel] = useImportAdvancedModelMutation();
|
||||||
|
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const handleChange: ComboboxOnChange = useCallback((v) => {
|
|
||||||
if (!isManualAddMode(v?.value)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
setAdvancedAddMode(v.value);
|
|
||||||
}, []);
|
|
||||||
|
|
||||||
const options: ComboboxOption[] = useMemo(
|
const {
|
||||||
() => [
|
register,
|
||||||
{ label: t('modelManager.diffusersModels'), value: 'diffusers' },
|
handleSubmit,
|
||||||
{ label: t('modelManager.checkpointOrSafetensors'), value: 'checkpoint' },
|
control,
|
||||||
],
|
formState: { errors },
|
||||||
[t]
|
setValue,
|
||||||
|
resetField,
|
||||||
|
reset,
|
||||||
|
watch,
|
||||||
|
} = useForm<AnyModelConfig>({
|
||||||
|
defaultValues: {
|
||||||
|
name: '',
|
||||||
|
base: 'sd-1',
|
||||||
|
type: 'main',
|
||||||
|
path: '',
|
||||||
|
description: '',
|
||||||
|
format: 'diffusers',
|
||||||
|
vae: '',
|
||||||
|
variant: 'normal',
|
||||||
|
},
|
||||||
|
mode: 'onChange',
|
||||||
|
});
|
||||||
|
|
||||||
|
const onSubmit = useCallback<SubmitHandler<AnyModelConfig>>(
|
||||||
|
(values) => {
|
||||||
|
const cleanValues = Object.fromEntries(
|
||||||
|
Object.entries(values).filter(([value]) => value !== null && value !== undefined)
|
||||||
|
);
|
||||||
|
importAdvancedModel({
|
||||||
|
source: {
|
||||||
|
path: cleanValues.path as string,
|
||||||
|
type: 'local',
|
||||||
|
},
|
||||||
|
config: cleanValues,
|
||||||
|
})
|
||||||
|
.unwrap()
|
||||||
|
.then((_) => {
|
||||||
|
dispatch(
|
||||||
|
addToast(
|
||||||
|
makeToast({
|
||||||
|
title: t('modelManager.modelAdded', {
|
||||||
|
modelName: values.name,
|
||||||
|
}),
|
||||||
|
status: 'success',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
);
|
||||||
|
reset();
|
||||||
|
})
|
||||||
|
.catch((error) => {
|
||||||
|
if (error) {
|
||||||
|
dispatch(
|
||||||
|
addToast(
|
||||||
|
makeToast({
|
||||||
|
title: t('toast.modelAddFailed'),
|
||||||
|
status: 'error',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
},
|
||||||
|
[dispatch, reset, t, importAdvancedModel]
|
||||||
);
|
);
|
||||||
|
|
||||||
const value = useMemo(() => options.find((o) => o.value === advancedAddMode), [options, advancedAddMode]);
|
const watchedModelType = watch('type');
|
||||||
|
const watchedModelFormat = watch('format');
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (watchedModelType === 'main') {
|
||||||
|
setValue('format', 'diffusers');
|
||||||
|
setValue('repo_variant', '');
|
||||||
|
setValue('variant', 'normal');
|
||||||
|
}
|
||||||
|
if (watchedModelType === 'lora') {
|
||||||
|
setValue('format', 'lycoris');
|
||||||
|
} else if (watchedModelType === 'embedding') {
|
||||||
|
setValue('format', 'embedding_file');
|
||||||
|
} else if (watchedModelType === 'ip_adapter') {
|
||||||
|
setValue('format', 'invokeai');
|
||||||
|
} else {
|
||||||
|
setValue('format', 'diffusers');
|
||||||
|
}
|
||||||
|
resetField('upcast_attention');
|
||||||
|
resetField('ztsnr_training');
|
||||||
|
resetField('vae');
|
||||||
|
resetField('config');
|
||||||
|
resetField('prediction_type');
|
||||||
|
resetField('image_encoder_model_id');
|
||||||
|
}, [watchedModelType, resetField, setValue]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<ScrollableContent>
|
<ScrollableContent>
|
||||||
<Flex flexDirection="column" gap={4} width="100%" pb={10}>
|
<form onSubmit={handleSubmit(onSubmit)}>
|
||||||
<Flex alignItems="flex-end" gap="4">
|
<Flex flexDirection="column" gap={4} width="100%" pb={10}>
|
||||||
<Flex direction="column" gap="3" width="full">
|
<Flex alignItems="flex-end" gap="4">
|
||||||
<FormLabel>{t('modelManager.modelType')}</FormLabel>
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
<Combobox value={value} options={options} onChange={handleChange} />
|
<FormLabel>Model Type</FormLabel>
|
||||||
|
<ModelTypeSelect<AnyModelConfig> control={control} name="type" />
|
||||||
|
</FormControl>
|
||||||
|
<Text px="2" fontSize="xs" textAlign="center">
|
||||||
|
{t('modelManager.advancedImportInfo')}
|
||||||
|
</Text>
|
||||||
</Flex>
|
</Flex>
|
||||||
<Text px="2" fontSize="xs" textAlign="center">
|
|
||||||
{t('modelManager.advancedImportInfo')}
|
|
||||||
</Text>
|
|
||||||
</Flex>
|
|
||||||
|
|
||||||
<Flex p={4} borderRadius={4} bg="base.850" height="100%">
|
<Flex p={4} borderRadius={4} bg="base.850" height="100%" direction="column" gap="3">
|
||||||
{advancedAddMode === 'diffusers' && <AdvancedImportDiffusers />}
|
<FormControl isInvalid={Boolean(errors.name)}>
|
||||||
{advancedAddMode === 'checkpoint' && <AdvancedImportCheckpoint />}
|
<Flex direction="column" width="full">
|
||||||
|
<FormLabel>{t('modelManager.name')}</FormLabel>
|
||||||
|
<Input
|
||||||
|
{...register('name', {
|
||||||
|
validate: (value) => value.trim().length >= 3 || 'Must be at least 3 characters',
|
||||||
|
})}
|
||||||
|
/>
|
||||||
|
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
|
||||||
|
</Flex>
|
||||||
|
</FormControl>
|
||||||
|
<Flex>
|
||||||
|
<FormControl>
|
||||||
|
<Flex direction="column" width="full">
|
||||||
|
<FormLabel>{t('modelManager.description')}</FormLabel>
|
||||||
|
<Textarea size="sm" {...register('description')} />
|
||||||
|
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
|
||||||
|
</Flex>
|
||||||
|
</FormControl>
|
||||||
|
</Flex>
|
||||||
|
<Flex gap={4}>
|
||||||
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
|
<FormLabel>Base Model</FormLabel>
|
||||||
|
<BaseModelSelect<AnyModelConfig> control={control} name="base" />
|
||||||
|
</FormControl>
|
||||||
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
|
<FormLabel>Format</FormLabel>
|
||||||
|
<ModelFormatSelect<AnyModelConfig> control={control} name="format" />
|
||||||
|
</FormControl>
|
||||||
|
</Flex>
|
||||||
|
<Flex gap={4}>
|
||||||
|
<FormControl flexDir="column" alignItems="flex-start" gap={1} isInvalid={Boolean(errors.path)}>
|
||||||
|
<FormLabel>Path</FormLabel>
|
||||||
|
<Input
|
||||||
|
{...register('path', {
|
||||||
|
validate: (value) => value.trim().length > 0 || 'Must provide a path',
|
||||||
|
})}
|
||||||
|
/>
|
||||||
|
{errors.path?.message && <FormErrorMessage>{errors.path?.message}</FormErrorMessage>}
|
||||||
|
</FormControl>
|
||||||
|
</Flex>
|
||||||
|
{watchedModelType === 'main' && (
|
||||||
|
<>
|
||||||
|
<Flex gap={4}>
|
||||||
|
{watchedModelFormat === 'diffusers' && (
|
||||||
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
|
<FormLabel>Repo Variant</FormLabel>
|
||||||
|
<RepoVariantSelect<AnyModelConfig> control={control} name="repo_variant" />
|
||||||
|
</FormControl>
|
||||||
|
)}
|
||||||
|
{watchedModelFormat === 'checkpoint' && (
|
||||||
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
|
<FormLabel>Config Path</FormLabel>
|
||||||
|
<Input {...register('config')} />
|
||||||
|
</FormControl>
|
||||||
|
)}
|
||||||
|
|
||||||
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
|
<FormLabel>Variant</FormLabel>
|
||||||
|
<ModelVariantSelect<AnyModelConfig> control={control} name="variant" />
|
||||||
|
</FormControl>
|
||||||
|
</Flex>
|
||||||
|
<Flex gap={4}>
|
||||||
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
|
<FormLabel>Prediction Type</FormLabel>
|
||||||
|
<PredictionTypeSelect<AnyModelConfig> control={control} name="prediction_type" />
|
||||||
|
</FormControl>
|
||||||
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
|
<FormLabel>Upcast Attention</FormLabel>
|
||||||
|
<BooleanSelect<AnyModelConfig> control={control} name="upcast_attention" />
|
||||||
|
</FormControl>
|
||||||
|
</Flex>
|
||||||
|
<Flex gap={4}>
|
||||||
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
|
<FormLabel>ZTSNR Training</FormLabel>
|
||||||
|
<BooleanSelect<AnyModelConfig> control={control} name="ztsnr_training" />
|
||||||
|
</FormControl>
|
||||||
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
|
<FormLabel>VAE Path</FormLabel>
|
||||||
|
<Input {...register('vae')} />
|
||||||
|
</FormControl>
|
||||||
|
</Flex>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
{watchedModelType === 'ip_adapter' && (
|
||||||
|
<Flex gap={4}>
|
||||||
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
|
<FormLabel>Image Encoder Model ID</FormLabel>
|
||||||
|
<Input {...register('image_encoder_model_id')} />
|
||||||
|
</FormControl>
|
||||||
|
</Flex>
|
||||||
|
)}
|
||||||
|
<Button mt={2} type="submit">
|
||||||
|
{t('modelManager.addModel')}
|
||||||
|
</Button>
|
||||||
|
</Flex>
|
||||||
</Flex>
|
</Flex>
|
||||||
</Flex>
|
</form>
|
||||||
</ScrollableContent>
|
</ScrollableContent>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -1,160 +0,0 @@
|
|||||||
import {
|
|
||||||
Button,
|
|
||||||
Checkbox,
|
|
||||||
Flex,
|
|
||||||
FormControl,
|
|
||||||
FormErrorMessage,
|
|
||||||
FormLabel,
|
|
||||||
Input,
|
|
||||||
Textarea,
|
|
||||||
} from '@invoke-ai/ui-library';
|
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
|
||||||
import { makeToast } from 'features/system/util/makeToast';
|
|
||||||
import type { CSSProperties } from 'react';
|
|
||||||
import { useCallback, useState } from 'react';
|
|
||||||
import type { SubmitHandler } from 'react-hook-form';
|
|
||||||
import { useForm } from 'react-hook-form';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
import { useAddMainModelsMutation } from 'services/api/endpoints/models';
|
|
||||||
import type { CheckpointModelConfig, DiffusersModelConfig } from 'services/api/types';
|
|
||||||
|
|
||||||
import BaseModelSelect from './BaseModelSelect';
|
|
||||||
import CheckpointConfigsSelect from './CheckpointConfigsSelect';
|
|
||||||
import ModelVariantSelect from './ModelVariantSelect';
|
|
||||||
|
|
||||||
export const AdvancedImportCheckpoint = () => {
|
|
||||||
const { t } = useTranslation();
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
|
|
||||||
const [addMainModel] = useAddMainModelsMutation();
|
|
||||||
|
|
||||||
const [useCustomConfig, setUseCustomConfig] = useState<boolean>(false);
|
|
||||||
|
|
||||||
const {
|
|
||||||
register,
|
|
||||||
handleSubmit,
|
|
||||||
control,
|
|
||||||
formState: { errors },
|
|
||||||
reset,
|
|
||||||
} = useForm<CheckpointModelConfig>({
|
|
||||||
defaultValues: {
|
|
||||||
name: '',
|
|
||||||
base: 'sd-1',
|
|
||||||
type: 'main',
|
|
||||||
path: '',
|
|
||||||
description: '',
|
|
||||||
format: 'checkpoint',
|
|
||||||
vae: '',
|
|
||||||
variant: 'normal',
|
|
||||||
config: '',
|
|
||||||
},
|
|
||||||
mode: 'onChange',
|
|
||||||
});
|
|
||||||
|
|
||||||
const onSubmit = useCallback<SubmitHandler<DiffusersModelConfig>>(
|
|
||||||
(values) => {
|
|
||||||
addMainModel({
|
|
||||||
body: values,
|
|
||||||
})
|
|
||||||
.unwrap()
|
|
||||||
.then((_) => {
|
|
||||||
dispatch(
|
|
||||||
addToast(
|
|
||||||
makeToast({
|
|
||||||
title: t('modelManager.modelAdded', {
|
|
||||||
modelName: values.name,
|
|
||||||
}),
|
|
||||||
status: 'success',
|
|
||||||
})
|
|
||||||
)
|
|
||||||
);
|
|
||||||
reset();
|
|
||||||
})
|
|
||||||
.catch((error) => {
|
|
||||||
if (error) {
|
|
||||||
dispatch(
|
|
||||||
addToast(
|
|
||||||
makeToast({
|
|
||||||
title: t('toast.modelAddFailed'),
|
|
||||||
status: 'error',
|
|
||||||
})
|
|
||||||
)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
},
|
|
||||||
[addMainModel, dispatch, reset, t]
|
|
||||||
);
|
|
||||||
|
|
||||||
const handleChangeUseCustomConfig = useCallback(() => setUseCustomConfig((prev) => !prev), []);
|
|
||||||
|
|
||||||
return (
|
|
||||||
<form onSubmit={handleSubmit(onSubmit)} style={formStyles}>
|
|
||||||
<Flex flexDirection="column" gap={2}>
|
|
||||||
<FormControl isInvalid={Boolean(errors.name)}>
|
|
||||||
<Flex direction="column" width="full">
|
|
||||||
<FormLabel>{t('modelManager.name')}</FormLabel>
|
|
||||||
<Input
|
|
||||||
{...register('name', {
|
|
||||||
validate: (value) => value.trim().length > 3 || 'Must be at least 3 characters',
|
|
||||||
})}
|
|
||||||
/>
|
|
||||||
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
|
|
||||||
</Flex>
|
|
||||||
</FormControl>
|
|
||||||
<Flex>
|
|
||||||
<FormControl>
|
|
||||||
<Flex direction="column" width="full">
|
|
||||||
<FormLabel>{t('modelManager.description')}</FormLabel>
|
|
||||||
<Textarea size="sm" {...register('description')} />
|
|
||||||
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
|
|
||||||
</Flex>
|
|
||||||
</FormControl>
|
|
||||||
</Flex>
|
|
||||||
<Flex gap="2">
|
|
||||||
<BaseModelSelect<CheckpointModelConfig> control={control} name="base" />
|
|
||||||
<ModelVariantSelect<CheckpointModelConfig> control={control} name="variant" />
|
|
||||||
</Flex>
|
|
||||||
<FormControl isInvalid={Boolean(errors.path)}>
|
|
||||||
<Flex direction="column" width="full">
|
|
||||||
<FormLabel>{t('modelManager.modelLocation')}</FormLabel>
|
|
||||||
<Input
|
|
||||||
{...register('path', {
|
|
||||||
validate: (value) => value.trim().length > 0 || 'Must provide a path',
|
|
||||||
})}
|
|
||||||
/>
|
|
||||||
{errors.path?.message && <FormErrorMessage>{errors.path?.message}</FormErrorMessage>}
|
|
||||||
</Flex>
|
|
||||||
</FormControl>
|
|
||||||
<FormControl>
|
|
||||||
<Flex direction="column" width="full">
|
|
||||||
<FormLabel>{t('modelManager.vaeLocation')}</FormLabel>
|
|
||||||
<Input {...register('vae')} />
|
|
||||||
</Flex>
|
|
||||||
</FormControl>
|
|
||||||
<Flex flexDirection="column" width="100%" gap={2}>
|
|
||||||
{!useCustomConfig ? (
|
|
||||||
<CheckpointConfigsSelect control={control} name="config" />
|
|
||||||
) : (
|
|
||||||
<FormControl isRequired>
|
|
||||||
<FormLabel>{t('modelManager.config')}</FormLabel>
|
|
||||||
<Input {...register('config')} />
|
|
||||||
</FormControl>
|
|
||||||
)}
|
|
||||||
<FormControl>
|
|
||||||
<FormLabel>{t('modelManager.useCustomConfig')}</FormLabel>
|
|
||||||
<Checkbox isChecked={useCustomConfig} onChange={handleChangeUseCustomConfig} />
|
|
||||||
</FormControl>
|
|
||||||
<Button mt={2} type="submit">
|
|
||||||
{t('modelManager.addModel')}
|
|
||||||
</Button>
|
|
||||||
</Flex>
|
|
||||||
</Flex>
|
|
||||||
</form>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
const formStyles: CSSProperties = {
|
|
||||||
width: '100%',
|
|
||||||
};
|
|
@ -1,132 +0,0 @@
|
|||||||
import { Button, Flex, FormControl, FormErrorMessage, FormLabel, Input, Textarea } from '@invoke-ai/ui-library';
|
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
|
||||||
import { makeToast } from 'features/system/util/makeToast';
|
|
||||||
import type { CSSProperties } from 'react';
|
|
||||||
import {useCallback } from 'react';
|
|
||||||
import type { SubmitHandler } from 'react-hook-form';
|
|
||||||
import { useForm } from 'react-hook-form';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
import { useAddMainModelsMutation } from 'services/api/endpoints/models';
|
|
||||||
import type { DiffusersModelConfig } from 'services/api/types';
|
|
||||||
|
|
||||||
import BaseModelSelect from './BaseModelSelect';
|
|
||||||
import ModelVariantSelect from './ModelVariantSelect';
|
|
||||||
|
|
||||||
export const AdvancedImportDiffusers = () => {
|
|
||||||
const { t } = useTranslation();
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
|
|
||||||
const [addMainModel] = useAddMainModelsMutation();
|
|
||||||
|
|
||||||
const {
|
|
||||||
register,
|
|
||||||
handleSubmit,
|
|
||||||
control,
|
|
||||||
formState: { errors },
|
|
||||||
reset,
|
|
||||||
} = useForm<DiffusersModelConfig>({
|
|
||||||
defaultValues: {
|
|
||||||
name: '',
|
|
||||||
base: 'sd-1',
|
|
||||||
type: 'main',
|
|
||||||
path: '',
|
|
||||||
description: '',
|
|
||||||
format: 'diffusers',
|
|
||||||
vae: '',
|
|
||||||
variant: 'normal',
|
|
||||||
},
|
|
||||||
mode: 'onChange',
|
|
||||||
});
|
|
||||||
|
|
||||||
const onSubmit = useCallback<SubmitHandler<DiffusersModelConfig>>(
|
|
||||||
(values) => {
|
|
||||||
addMainModel({
|
|
||||||
body: values,
|
|
||||||
})
|
|
||||||
.unwrap()
|
|
||||||
.then((_) => {
|
|
||||||
dispatch(
|
|
||||||
addToast(
|
|
||||||
makeToast({
|
|
||||||
title: t('modelManager.modelAdded', {
|
|
||||||
modelName: values.name,
|
|
||||||
}),
|
|
||||||
status: 'success',
|
|
||||||
})
|
|
||||||
)
|
|
||||||
);
|
|
||||||
reset();
|
|
||||||
})
|
|
||||||
.catch((error) => {
|
|
||||||
if (error) {
|
|
||||||
dispatch(
|
|
||||||
addToast(
|
|
||||||
makeToast({
|
|
||||||
title: t('toast.modelAddFailed'),
|
|
||||||
status: 'error',
|
|
||||||
})
|
|
||||||
)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
},
|
|
||||||
[addMainModel, dispatch, reset, t]
|
|
||||||
);
|
|
||||||
|
|
||||||
return (
|
|
||||||
<form onSubmit={handleSubmit(onSubmit)} style={formStyles}>
|
|
||||||
<Flex flexDirection="column" gap={2}>
|
|
||||||
<FormControl isInvalid={Boolean(errors.name)}>
|
|
||||||
<Flex direction="column" width="full">
|
|
||||||
<FormLabel>{t('modelManager.name')}</FormLabel>
|
|
||||||
<Input
|
|
||||||
{...register('name', {
|
|
||||||
validate: (value) => value.trim().length > 3 || 'Must be at least 3 characters',
|
|
||||||
})}
|
|
||||||
/>
|
|
||||||
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
|
|
||||||
</Flex>
|
|
||||||
</FormControl>
|
|
||||||
<Flex>
|
|
||||||
<FormControl>
|
|
||||||
<Flex direction="column" width="full">
|
|
||||||
<FormLabel>{t('modelManager.description')}</FormLabel>
|
|
||||||
<Textarea size="sm" {...register('description')} />
|
|
||||||
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
|
|
||||||
</Flex>
|
|
||||||
</FormControl>
|
|
||||||
</Flex>
|
|
||||||
<Flex gap="2">
|
|
||||||
<BaseModelSelect<DiffusersModelConfig> control={control} name="base" />
|
|
||||||
<ModelVariantSelect<DiffusersModelConfig> control={control} name="variant" />
|
|
||||||
</Flex>
|
|
||||||
<FormControl isInvalid={Boolean(errors.path)}>
|
|
||||||
<Flex direction="column" width="full">
|
|
||||||
<FormLabel>{t('modelManager.modelLocation')}</FormLabel>
|
|
||||||
<Input
|
|
||||||
{...register('path', {
|
|
||||||
validate: (value) => value.trim().length > 0 || 'Must provide a path',
|
|
||||||
})}
|
|
||||||
/>
|
|
||||||
{errors.path?.message && <FormErrorMessage>{errors.path?.message}</FormErrorMessage>}
|
|
||||||
</Flex>
|
|
||||||
</FormControl>
|
|
||||||
<FormControl>
|
|
||||||
<Flex direction="column" width="full">
|
|
||||||
<FormLabel>{t('modelManager.vaeLocation')}</FormLabel>
|
|
||||||
<Input {...register('vae')} />
|
|
||||||
</Flex>
|
|
||||||
</FormControl>
|
|
||||||
|
|
||||||
<Button mt={2} type="submit">
|
|
||||||
{t('modelManager.addModel')}
|
|
||||||
</Button>
|
|
||||||
</Flex>
|
|
||||||
</form>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
const formStyles: CSSProperties = {
|
|
||||||
width: '100%',
|
|
||||||
};
|
|
@ -4,11 +4,12 @@ import { typedMemo } from 'common/util/typedMemo';
|
|||||||
import { LORA_MODEL_FORMAT_MAP } from 'features/parameters/types/constants';
|
import { LORA_MODEL_FORMAT_MAP } from 'features/parameters/types/constants';
|
||||||
import { useCallback, useMemo } from 'react';
|
import { useCallback, useMemo } from 'react';
|
||||||
import type { UseControllerProps } from 'react-hook-form';
|
import type { UseControllerProps } from 'react-hook-form';
|
||||||
import { useController } from 'react-hook-form';
|
import { useController, useWatch } from 'react-hook-form';
|
||||||
import type { AnyModelConfig } from 'services/api/types';
|
import type { AnyModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
const ModelFormatSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
|
const ModelFormatSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
|
||||||
const { field, formState } = useController(props);
|
const { field, formState } = useController(props);
|
||||||
|
const type = useWatch({ control: props.control, name: 'type' });
|
||||||
|
|
||||||
const onChange = useCallback<ComboboxOnChange>(
|
const onChange = useCallback<ComboboxOnChange>(
|
||||||
(v) => {
|
(v) => {
|
||||||
@ -18,17 +19,18 @@ const ModelFormatSelect = <T extends AnyModelConfig>(props: UseControllerProps<T
|
|||||||
);
|
);
|
||||||
|
|
||||||
const options: ComboboxOption[] = useMemo(() => {
|
const options: ComboboxOption[] = useMemo(() => {
|
||||||
if (formState.defaultValues?.type === 'lora') {
|
const modelType = type || formState.defaultValues?.type;
|
||||||
|
if (modelType === 'lora') {
|
||||||
return Object.keys(LORA_MODEL_FORMAT_MAP).map((format) => ({
|
return Object.keys(LORA_MODEL_FORMAT_MAP).map((format) => ({
|
||||||
value: format,
|
value: format,
|
||||||
label: LORA_MODEL_FORMAT_MAP[format],
|
label: LORA_MODEL_FORMAT_MAP[format],
|
||||||
})) as ComboboxOption[];
|
})) as ComboboxOption[];
|
||||||
} else if (formState.defaultValues?.type === 'embedding') {
|
} else if (modelType === 'embedding') {
|
||||||
return [
|
return [
|
||||||
{ value: 'embedding_file', label: 'Embedding File' },
|
{ value: 'embedding_file', label: 'Embedding File' },
|
||||||
{ value: 'embedding_folder', label: 'Embedding Folder' },
|
{ value: 'embedding_folder', label: 'Embedding Folder' },
|
||||||
];
|
];
|
||||||
} else if (formState.defaultValues?.type === 'ip_adapter') {
|
} else if (modelType === 'ip_adapter') {
|
||||||
return [{ value: 'invokeai', label: 'invokeai' }];
|
return [{ value: 'invokeai', label: 'invokeai' }];
|
||||||
} else {
|
} else {
|
||||||
return [
|
return [
|
||||||
@ -36,7 +38,7 @@ const ModelFormatSelect = <T extends AnyModelConfig>(props: UseControllerProps<T
|
|||||||
{ value: 'checkpoint', label: 'Checkpoint' },
|
{ value: 'checkpoint', label: 'Checkpoint' },
|
||||||
];
|
];
|
||||||
}
|
}
|
||||||
}, [formState.defaultValues?.type]);
|
}, [type, formState.defaultValues?.type]);
|
||||||
|
|
||||||
const value = useMemo(() => options.find((o) => o.value === field.value), [options, field.value]);
|
const value = useMemo(() => options.find((o) => o.value === field.value), [options, field.value]);
|
||||||
|
|
||||||
|
@ -72,11 +72,12 @@ type DeleteImportModelsResponse =
|
|||||||
type PruneModelImportsResponse =
|
type PruneModelImportsResponse =
|
||||||
paths['/api/v2/models/import']['patch']['responses']['200']['content']['application/json'];
|
paths['/api/v2/models/import']['patch']['responses']['200']['content']['application/json'];
|
||||||
|
|
||||||
type AddMainModelArg = {
|
type ImportAdvancedModelArg = {
|
||||||
body: MainModelConfig;
|
source: NonNullable<operations['import_model']['requestBody']['content']['application/json']['source']>;
|
||||||
|
config: NonNullable<operations['import_model']['requestBody']['content']['application/json']['config']>;
|
||||||
};
|
};
|
||||||
|
|
||||||
type AddMainModelResponse = paths['/api/v2/models/add']['post']['responses']['201']['content']['application/json'];
|
type ImportAdvancedModelResponse = paths['/api/v2/models/import']['post']['responses']['201']['content']['application/json'];
|
||||||
|
|
||||||
export type ScanFolderResponse =
|
export type ScanFolderResponse =
|
||||||
paths['/api/v2/models/scan_folder']['get']['responses']['200']['content']['application/json'];
|
paths['/api/v2/models/scan_folder']['get']['responses']['200']['content']['application/json'];
|
||||||
@ -198,12 +199,12 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
},
|
},
|
||||||
invalidatesTags: ['Model', 'ModelImports'],
|
invalidatesTags: ['Model', 'ModelImports'],
|
||||||
}),
|
}),
|
||||||
addMainModels: build.mutation<AddMainModelResponse, AddMainModelArg>({
|
importAdvancedModel: build.mutation<ImportAdvancedModelResponse, ImportAdvancedModelArg>({
|
||||||
query: ({ body }) => {
|
query: ({ source, config}) => {
|
||||||
return {
|
return {
|
||||||
url: buildModelsUrl('add'),
|
url: buildModelsUrl('install'),
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
body: body,
|
body: { source, config },
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
invalidatesTags: ['Model', 'ModelImports'],
|
invalidatesTags: ['Model', 'ModelImports'],
|
||||||
@ -344,7 +345,7 @@ export const {
|
|||||||
useDeleteModelsMutation,
|
useDeleteModelsMutation,
|
||||||
useUpdateModelsMutation,
|
useUpdateModelsMutation,
|
||||||
useImportMainModelsMutation,
|
useImportMainModelsMutation,
|
||||||
useAddMainModelsMutation,
|
useImportAdvancedModelMutation,
|
||||||
useConvertMainModelsMutation,
|
useConvertMainModelsMutation,
|
||||||
useMergeMainModelsMutation,
|
useMergeMainModelsMutation,
|
||||||
useSyncModelsMutation,
|
useSyncModelsMutation,
|
||||||
|
Loading…
Reference in New Issue
Block a user