refactored and fixed issues with advanced import form

This commit is contained in:
Jennifer Player 2024-02-23 16:16:42 -05:00 committed by Brandon Rising
parent 170d9bca98
commit 7f56e84a8d
5 changed files with 231 additions and 338 deletions

View File

@ -1,56 +1,238 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox, Flex, FormLabel,Text } from '@invoke-ai/ui-library';
import { Button, Flex, FormControl, FormErrorMessage, FormLabel, Input, Text, Textarea } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import { useCallback, useMemo, useState } from 'react';
import BaseModelSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/BaseModelSelect';
import BooleanSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/BooleanSelect';
import ModelFormatSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/ModelFormatSelect';
import ModelTypeSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/ModelTypeSelect';
import ModelVariantSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/ModelVariantSelect';
import PredictionTypeSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/PredictionTypeSelect';
import RepoVariantSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/RepoVariantSelect';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { useCallback, useEffect } from 'react';
import type { SubmitHandler} from 'react-hook-form';
import {useForm } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import { useImportAdvancedModelMutation } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
import { z } from 'zod';
import { AdvancedImportCheckpoint } from './AdvancedImportCheckpoint';
import { AdvancedImportDiffusers } from './AdvancedImportDiffusers';
export const zManualAddMode = z.enum(['diffusers', 'checkpoint']);
export type ManualAddMode = z.infer<typeof zManualAddMode>;
export const isManualAddMode = (v: unknown): v is ManualAddMode => zManualAddMode.safeParse(v).success;
export const AdvancedImport = () => {
const [advancedAddMode, setAdvancedAddMode] = useState<ManualAddMode>('diffusers');
const dispatch = useAppDispatch();
const [importAdvancedModel] = useImportAdvancedModelMutation();
const { t } = useTranslation();
const handleChange: ComboboxOnChange = useCallback((v) => {
if (!isManualAddMode(v?.value)) {
return;
}
setAdvancedAddMode(v.value);
}, []);
const options: ComboboxOption[] = useMemo(
() => [
{ label: t('modelManager.diffusersModels'), value: 'diffusers' },
{ label: t('modelManager.checkpointOrSafetensors'), value: 'checkpoint' },
],
[t]
const {
register,
handleSubmit,
control,
formState: { errors },
setValue,
resetField,
reset,
watch,
} = useForm<AnyModelConfig>({
defaultValues: {
name: '',
base: 'sd-1',
type: 'main',
path: '',
description: '',
format: 'diffusers',
vae: '',
variant: 'normal',
},
mode: 'onChange',
});
const onSubmit = useCallback<SubmitHandler<AnyModelConfig>>(
(values) => {
const cleanValues = Object.fromEntries(
Object.entries(values).filter(([value]) => value !== null && value !== undefined)
);
importAdvancedModel({
source: {
path: cleanValues.path as string,
type: 'local',
},
config: cleanValues,
})
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: t('modelManager.modelAdded', {
modelName: values.name,
}),
status: 'success',
})
)
);
reset();
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: t('toast.modelAddFailed'),
status: 'error',
})
)
);
}
});
},
[dispatch, reset, t, importAdvancedModel]
);
const value = useMemo(() => options.find((o) => o.value === advancedAddMode), [options, advancedAddMode]);
const watchedModelType = watch('type');
const watchedModelFormat = watch('format');
useEffect(() => {
if (watchedModelType === 'main') {
setValue('format', 'diffusers');
setValue('repo_variant', '');
setValue('variant', 'normal');
}
if (watchedModelType === 'lora') {
setValue('format', 'lycoris');
} else if (watchedModelType === 'embedding') {
setValue('format', 'embedding_file');
} else if (watchedModelType === 'ip_adapter') {
setValue('format', 'invokeai');
} else {
setValue('format', 'diffusers');
}
resetField('upcast_attention');
resetField('ztsnr_training');
resetField('vae');
resetField('config');
resetField('prediction_type');
resetField('image_encoder_model_id');
}, [watchedModelType, resetField, setValue]);
return (
<ScrollableContent>
<Flex flexDirection="column" gap={4} width="100%" pb={10}>
<Flex alignItems="flex-end" gap="4">
<Flex direction="column" gap="3" width="full">
<FormLabel>{t('modelManager.modelType')}</FormLabel>
<Combobox value={value} options={options} onChange={handleChange} />
<form onSubmit={handleSubmit(onSubmit)}>
<Flex flexDirection="column" gap={4} width="100%" pb={10}>
<Flex alignItems="flex-end" gap="4">
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>Model Type</FormLabel>
<ModelTypeSelect<AnyModelConfig> control={control} name="type" />
</FormControl>
<Text px="2" fontSize="xs" textAlign="center">
{t('modelManager.advancedImportInfo')}
</Text>
</Flex>
<Text px="2" fontSize="xs" textAlign="center">
{t('modelManager.advancedImportInfo')}
</Text>
</Flex>
<Flex p={4} borderRadius={4} bg="base.850" height="100%">
{advancedAddMode === 'diffusers' && <AdvancedImportDiffusers />}
{advancedAddMode === 'checkpoint' && <AdvancedImportCheckpoint />}
<Flex p={4} borderRadius={4} bg="base.850" height="100%" direction="column" gap="3">
<FormControl isInvalid={Boolean(errors.name)}>
<Flex direction="column" width="full">
<FormLabel>{t('modelManager.name')}</FormLabel>
<Input
{...register('name', {
validate: (value) => value.trim().length >= 3 || 'Must be at least 3 characters',
})}
/>
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
</Flex>
</FormControl>
<Flex>
<FormControl>
<Flex direction="column" width="full">
<FormLabel>{t('modelManager.description')}</FormLabel>
<Textarea size="sm" {...register('description')} />
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
</Flex>
</FormControl>
</Flex>
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>Base Model</FormLabel>
<BaseModelSelect<AnyModelConfig> control={control} name="base" />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>Format</FormLabel>
<ModelFormatSelect<AnyModelConfig> control={control} name="format" />
</FormControl>
</Flex>
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1} isInvalid={Boolean(errors.path)}>
<FormLabel>Path</FormLabel>
<Input
{...register('path', {
validate: (value) => value.trim().length > 0 || 'Must provide a path',
})}
/>
{errors.path?.message && <FormErrorMessage>{errors.path?.message}</FormErrorMessage>}
</FormControl>
</Flex>
{watchedModelType === 'main' && (
<>
<Flex gap={4}>
{watchedModelFormat === 'diffusers' && (
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>Repo Variant</FormLabel>
<RepoVariantSelect<AnyModelConfig> control={control} name="repo_variant" />
</FormControl>
)}
{watchedModelFormat === 'checkpoint' && (
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>Config Path</FormLabel>
<Input {...register('config')} />
</FormControl>
)}
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>Variant</FormLabel>
<ModelVariantSelect<AnyModelConfig> control={control} name="variant" />
</FormControl>
</Flex>
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>Prediction Type</FormLabel>
<PredictionTypeSelect<AnyModelConfig> control={control} name="prediction_type" />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>Upcast Attention</FormLabel>
<BooleanSelect<AnyModelConfig> control={control} name="upcast_attention" />
</FormControl>
</Flex>
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>ZTSNR Training</FormLabel>
<BooleanSelect<AnyModelConfig> control={control} name="ztsnr_training" />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>VAE Path</FormLabel>
<Input {...register('vae')} />
</FormControl>
</Flex>
</>
)}
{watchedModelType === 'ip_adapter' && (
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>Image Encoder Model ID</FormLabel>
<Input {...register('image_encoder_model_id')} />
</FormControl>
</Flex>
)}
<Button mt={2} type="submit">
{t('modelManager.addModel')}
</Button>
</Flex>
</Flex>
</Flex>
</form>
</ScrollableContent>
);
};

View File

@ -1,160 +0,0 @@
import {
Button,
Checkbox,
Flex,
FormControl,
FormErrorMessage,
FormLabel,
Input,
Textarea,
} from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import type { CSSProperties } from 'react';
import { useCallback, useState } from 'react';
import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import { useAddMainModelsMutation } from 'services/api/endpoints/models';
import type { CheckpointModelConfig, DiffusersModelConfig } from 'services/api/types';
import BaseModelSelect from './BaseModelSelect';
import CheckpointConfigsSelect from './CheckpointConfigsSelect';
import ModelVariantSelect from './ModelVariantSelect';
export const AdvancedImportCheckpoint = () => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const [addMainModel] = useAddMainModelsMutation();
const [useCustomConfig, setUseCustomConfig] = useState<boolean>(false);
const {
register,
handleSubmit,
control,
formState: { errors },
reset,
} = useForm<CheckpointModelConfig>({
defaultValues: {
name: '',
base: 'sd-1',
type: 'main',
path: '',
description: '',
format: 'checkpoint',
vae: '',
variant: 'normal',
config: '',
},
mode: 'onChange',
});
const onSubmit = useCallback<SubmitHandler<DiffusersModelConfig>>(
(values) => {
addMainModel({
body: values,
})
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: t('modelManager.modelAdded', {
modelName: values.name,
}),
status: 'success',
})
)
);
reset();
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: t('toast.modelAddFailed'),
status: 'error',
})
)
);
}
});
},
[addMainModel, dispatch, reset, t]
);
const handleChangeUseCustomConfig = useCallback(() => setUseCustomConfig((prev) => !prev), []);
return (
<form onSubmit={handleSubmit(onSubmit)} style={formStyles}>
<Flex flexDirection="column" gap={2}>
<FormControl isInvalid={Boolean(errors.name)}>
<Flex direction="column" width="full">
<FormLabel>{t('modelManager.name')}</FormLabel>
<Input
{...register('name', {
validate: (value) => value.trim().length > 3 || 'Must be at least 3 characters',
})}
/>
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
</Flex>
</FormControl>
<Flex>
<FormControl>
<Flex direction="column" width="full">
<FormLabel>{t('modelManager.description')}</FormLabel>
<Textarea size="sm" {...register('description')} />
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
</Flex>
</FormControl>
</Flex>
<Flex gap="2">
<BaseModelSelect<CheckpointModelConfig> control={control} name="base" />
<ModelVariantSelect<CheckpointModelConfig> control={control} name="variant" />
</Flex>
<FormControl isInvalid={Boolean(errors.path)}>
<Flex direction="column" width="full">
<FormLabel>{t('modelManager.modelLocation')}</FormLabel>
<Input
{...register('path', {
validate: (value) => value.trim().length > 0 || 'Must provide a path',
})}
/>
{errors.path?.message && <FormErrorMessage>{errors.path?.message}</FormErrorMessage>}
</Flex>
</FormControl>
<FormControl>
<Flex direction="column" width="full">
<FormLabel>{t('modelManager.vaeLocation')}</FormLabel>
<Input {...register('vae')} />
</Flex>
</FormControl>
<Flex flexDirection="column" width="100%" gap={2}>
{!useCustomConfig ? (
<CheckpointConfigsSelect control={control} name="config" />
) : (
<FormControl isRequired>
<FormLabel>{t('modelManager.config')}</FormLabel>
<Input {...register('config')} />
</FormControl>
)}
<FormControl>
<FormLabel>{t('modelManager.useCustomConfig')}</FormLabel>
<Checkbox isChecked={useCustomConfig} onChange={handleChangeUseCustomConfig} />
</FormControl>
<Button mt={2} type="submit">
{t('modelManager.addModel')}
</Button>
</Flex>
</Flex>
</form>
);
};
const formStyles: CSSProperties = {
width: '100%',
};

View File

@ -1,132 +0,0 @@
import { Button, Flex, FormControl, FormErrorMessage, FormLabel, Input, Textarea } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import type { CSSProperties } from 'react';
import {useCallback } from 'react';
import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import { useAddMainModelsMutation } from 'services/api/endpoints/models';
import type { DiffusersModelConfig } from 'services/api/types';
import BaseModelSelect from './BaseModelSelect';
import ModelVariantSelect from './ModelVariantSelect';
export const AdvancedImportDiffusers = () => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const [addMainModel] = useAddMainModelsMutation();
const {
register,
handleSubmit,
control,
formState: { errors },
reset,
} = useForm<DiffusersModelConfig>({
defaultValues: {
name: '',
base: 'sd-1',
type: 'main',
path: '',
description: '',
format: 'diffusers',
vae: '',
variant: 'normal',
},
mode: 'onChange',
});
const onSubmit = useCallback<SubmitHandler<DiffusersModelConfig>>(
(values) => {
addMainModel({
body: values,
})
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: t('modelManager.modelAdded', {
modelName: values.name,
}),
status: 'success',
})
)
);
reset();
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: t('toast.modelAddFailed'),
status: 'error',
})
)
);
}
});
},
[addMainModel, dispatch, reset, t]
);
return (
<form onSubmit={handleSubmit(onSubmit)} style={formStyles}>
<Flex flexDirection="column" gap={2}>
<FormControl isInvalid={Boolean(errors.name)}>
<Flex direction="column" width="full">
<FormLabel>{t('modelManager.name')}</FormLabel>
<Input
{...register('name', {
validate: (value) => value.trim().length > 3 || 'Must be at least 3 characters',
})}
/>
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
</Flex>
</FormControl>
<Flex>
<FormControl>
<Flex direction="column" width="full">
<FormLabel>{t('modelManager.description')}</FormLabel>
<Textarea size="sm" {...register('description')} />
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
</Flex>
</FormControl>
</Flex>
<Flex gap="2">
<BaseModelSelect<DiffusersModelConfig> control={control} name="base" />
<ModelVariantSelect<DiffusersModelConfig> control={control} name="variant" />
</Flex>
<FormControl isInvalid={Boolean(errors.path)}>
<Flex direction="column" width="full">
<FormLabel>{t('modelManager.modelLocation')}</FormLabel>
<Input
{...register('path', {
validate: (value) => value.trim().length > 0 || 'Must provide a path',
})}
/>
{errors.path?.message && <FormErrorMessage>{errors.path?.message}</FormErrorMessage>}
</Flex>
</FormControl>
<FormControl>
<Flex direction="column" width="full">
<FormLabel>{t('modelManager.vaeLocation')}</FormLabel>
<Input {...register('vae')} />
</Flex>
</FormControl>
<Button mt={2} type="submit">
{t('modelManager.addModel')}
</Button>
</Flex>
</form>
);
};
const formStyles: CSSProperties = {
width: '100%',
};

View File

@ -4,11 +4,12 @@ import { typedMemo } from 'common/util/typedMemo';
import { LORA_MODEL_FORMAT_MAP } from 'features/parameters/types/constants';
import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form';
import { useController, useWatch } from 'react-hook-form';
import type { AnyModelConfig } from 'services/api/types';
const ModelFormatSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
const { field, formState } = useController(props);
const type = useWatch({ control: props.control, name: 'type' });
const onChange = useCallback<ComboboxOnChange>(
(v) => {
@ -18,17 +19,18 @@ const ModelFormatSelect = <T extends AnyModelConfig>(props: UseControllerProps<T
);
const options: ComboboxOption[] = useMemo(() => {
if (formState.defaultValues?.type === 'lora') {
const modelType = type || formState.defaultValues?.type;
if (modelType === 'lora') {
return Object.keys(LORA_MODEL_FORMAT_MAP).map((format) => ({
value: format,
label: LORA_MODEL_FORMAT_MAP[format],
})) as ComboboxOption[];
} else if (formState.defaultValues?.type === 'embedding') {
} else if (modelType === 'embedding') {
return [
{ value: 'embedding_file', label: 'Embedding File' },
{ value: 'embedding_folder', label: 'Embedding Folder' },
];
} else if (formState.defaultValues?.type === 'ip_adapter') {
} else if (modelType === 'ip_adapter') {
return [{ value: 'invokeai', label: 'invokeai' }];
} else {
return [
@ -36,7 +38,7 @@ const ModelFormatSelect = <T extends AnyModelConfig>(props: UseControllerProps<T
{ value: 'checkpoint', label: 'Checkpoint' },
];
}
}, [formState.defaultValues?.type]);
}, [type, formState.defaultValues?.type]);
const value = useMemo(() => options.find((o) => o.value === field.value), [options, field.value]);

View File

@ -72,11 +72,12 @@ type DeleteImportModelsResponse =
type PruneModelImportsResponse =
paths['/api/v2/models/import']['patch']['responses']['200']['content']['application/json'];
type AddMainModelArg = {
body: MainModelConfig;
type ImportAdvancedModelArg = {
source: NonNullable<operations['import_model']['requestBody']['content']['application/json']['source']>;
config: NonNullable<operations['import_model']['requestBody']['content']['application/json']['config']>;
};
type AddMainModelResponse = paths['/api/v2/models/add']['post']['responses']['201']['content']['application/json'];
type ImportAdvancedModelResponse = paths['/api/v2/models/import']['post']['responses']['201']['content']['application/json'];
export type ScanFolderResponse =
paths['/api/v2/models/scan_folder']['get']['responses']['200']['content']['application/json'];
@ -198,12 +199,12 @@ export const modelsApi = api.injectEndpoints({
},
invalidatesTags: ['Model', 'ModelImports'],
}),
addMainModels: build.mutation<AddMainModelResponse, AddMainModelArg>({
query: ({ body }) => {
importAdvancedModel: build.mutation<ImportAdvancedModelResponse, ImportAdvancedModelArg>({
query: ({ source, config}) => {
return {
url: buildModelsUrl('add'),
url: buildModelsUrl('install'),
method: 'POST',
body: body,
body: { source, config },
};
},
invalidatesTags: ['Model', 'ModelImports'],
@ -344,7 +345,7 @@ export const {
useDeleteModelsMutation,
useUpdateModelsMutation,
useImportMainModelsMutation,
useAddMainModelsMutation,
useImportAdvancedModelMutation,
useConvertMainModelsMutation,
useMergeMainModelsMutation,
useSyncModelsMutation,