added advanced import forms, not fully working yet

This commit is contained in:
Jennifer Player 2024-02-22 18:56:51 -05:00 committed by Brandon Rising
parent ccdb89534a
commit 2db252af31
14 changed files with 522 additions and 50 deletions

View File

@ -696,6 +696,7 @@
"addNewModel": "Add New Model", "addNewModel": "Add New Model",
"addSelected": "Add Selected", "addSelected": "Add Selected",
"advanced": "Advanced", "advanced": "Advanced",
"advancedImportInfo": "The advanced tab allows for manual configuration of core model settings. Only use this tab if you are confident that you know the correct model type and configuration for the selected model.",
"allModels": "All Models", "allModels": "All Models",
"alpha": "Alpha", "alpha": "Alpha",
"availableModels": "Available Models", "availableModels": "Available Models",

View File

@ -0,0 +1,56 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox, Flex, FormLabel,Text } from '@invoke-ai/ui-library';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import { useCallback, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { z } from 'zod';
import { AdvancedImportCheckpoint } from './AdvancedImportCheckpoint';
import { AdvancedImportDiffusers } from './AdvancedImportDiffusers';
export const zManualAddMode = z.enum(['diffusers', 'checkpoint']);
export type ManualAddMode = z.infer<typeof zManualAddMode>;
export const isManualAddMode = (v: unknown): v is ManualAddMode => zManualAddMode.safeParse(v).success;
export const AdvancedImport = () => {
const [advancedAddMode, setAdvancedAddMode] = useState<ManualAddMode>('diffusers');
const { t } = useTranslation();
const handleChange: ComboboxOnChange = useCallback((v) => {
if (!isManualAddMode(v?.value)) {
return;
}
setAdvancedAddMode(v.value);
}, []);
const options: ComboboxOption[] = useMemo(
() => [
{ label: t('modelManager.diffusersModels'), value: 'diffusers' },
{ label: t('modelManager.checkpointOrSafetensors'), value: 'checkpoint' },
],
[t]
);
const value = useMemo(() => options.find((o) => o.value === advancedAddMode), [options, advancedAddMode]);
return (
<ScrollableContent>
<Flex flexDirection="column" gap={4} width="100%" pb={10}>
<Flex alignItems="flex-end" gap="4">
<Flex direction="column" gap="3" width="full">
<FormLabel>{t('modelManager.modelType')}</FormLabel>
<Combobox value={value} options={options} onChange={handleChange} />
</Flex>
<Text px="2" fontSize="xs" textAlign="center">
{t('modelManager.advancedImportInfo')}
</Text>
</Flex>
<Flex p={4} borderRadius={4} bg="base.850" height="100%">
{advancedAddMode === 'diffusers' && <AdvancedImportDiffusers />}
{advancedAddMode === 'checkpoint' && <AdvancedImportCheckpoint />}
</Flex>
</Flex>
</ScrollableContent>
);
};

View File

@ -0,0 +1,160 @@
import {
Button,
Checkbox,
Flex,
FormControl,
FormErrorMessage,
FormLabel,
Input,
Textarea,
} from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import type { CSSProperties } from 'react';
import { useCallback, useState } from 'react';
import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import { useAddMainModelsMutation } from 'services/api/endpoints/models';
import type { CheckpointModelConfig, DiffusersModelConfig } from 'services/api/types';
import BaseModelSelect from './BaseModelSelect';
import CheckpointConfigsSelect from './CheckpointConfigsSelect';
import ModelVariantSelect from './ModelVariantSelect';
export const AdvancedImportCheckpoint = () => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const [addMainModel] = useAddMainModelsMutation();
const [useCustomConfig, setUseCustomConfig] = useState<boolean>(false);
const {
register,
handleSubmit,
control,
formState: { errors },
reset,
} = useForm<CheckpointModelConfig>({
defaultValues: {
name: '',
base: 'sd-1',
type: 'main',
path: '',
description: '',
format: 'checkpoint',
vae: '',
variant: 'normal',
config: '',
},
mode: 'onChange',
});
const onSubmit = useCallback<SubmitHandler<DiffusersModelConfig>>(
(values) => {
addMainModel({
body: values,
})
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: t('modelManager.modelAdded', {
modelName: values.name,
}),
status: 'success',
})
)
);
reset();
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: t('toast.modelAddFailed'),
status: 'error',
})
)
);
}
});
},
[addMainModel, dispatch, reset, t]
);
const handleChangeUseCustomConfig = useCallback(() => setUseCustomConfig((prev) => !prev), []);
return (
<form onSubmit={handleSubmit(onSubmit)} style={formStyles}>
<Flex flexDirection="column" gap={2}>
<FormControl isInvalid={Boolean(errors.name)}>
<Flex direction="column" width="full">
<FormLabel>{t('modelManager.name')}</FormLabel>
<Input
{...register('name', {
validate: (value) => value.trim().length > 3 || 'Must be at least 3 characters',
})}
/>
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
</Flex>
</FormControl>
<Flex>
<FormControl>
<Flex direction="column" width="full">
<FormLabel>{t('modelManager.description')}</FormLabel>
<Textarea size="sm" {...register('description')} />
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
</Flex>
</FormControl>
</Flex>
<Flex gap="2">
<BaseModelSelect<CheckpointModelConfig> control={control} name="base" />
<ModelVariantSelect<CheckpointModelConfig> control={control} name="variant" />
</Flex>
<FormControl isInvalid={Boolean(errors.path)}>
<Flex direction="column" width="full">
<FormLabel>{t('modelManager.modelLocation')}</FormLabel>
<Input
{...register('path', {
validate: (value) => value.trim().length > 0 || 'Must provide a path',
})}
/>
{errors.path?.message && <FormErrorMessage>{errors.path?.message}</FormErrorMessage>}
</Flex>
</FormControl>
<FormControl>
<Flex direction="column" width="full">
<FormLabel>{t('modelManager.vaeLocation')}</FormLabel>
<Input {...register('vae')} />
</Flex>
</FormControl>
<Flex flexDirection="column" width="100%" gap={2}>
{!useCustomConfig ? (
<CheckpointConfigsSelect control={control} name="config" />
) : (
<FormControl isRequired>
<FormLabel>{t('modelManager.config')}</FormLabel>
<Input {...register('config')} />
</FormControl>
)}
<FormControl>
<FormLabel>{t('modelManager.useCustomConfig')}</FormLabel>
<Checkbox isChecked={useCustomConfig} onChange={handleChangeUseCustomConfig} />
</FormControl>
<Button mt={2} type="submit">
{t('modelManager.addModel')}
</Button>
</Flex>
</Flex>
</form>
);
};
const formStyles: CSSProperties = {
width: '100%',
};

View File

@ -0,0 +1,132 @@
import { Button, Flex, FormControl, FormErrorMessage, FormLabel, Input, Textarea } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import type { CSSProperties } from 'react';
import {useCallback } from 'react';
import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import { useAddMainModelsMutation } from 'services/api/endpoints/models';
import type { DiffusersModelConfig } from 'services/api/types';
import BaseModelSelect from './BaseModelSelect';
import ModelVariantSelect from './ModelVariantSelect';
export const AdvancedImportDiffusers = () => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const [addMainModel] = useAddMainModelsMutation();
const {
register,
handleSubmit,
control,
formState: { errors },
reset,
} = useForm<DiffusersModelConfig>({
defaultValues: {
name: '',
base: 'sd-1',
type: 'main',
path: '',
description: '',
format: 'diffusers',
vae: '',
variant: 'normal',
},
mode: 'onChange',
});
const onSubmit = useCallback<SubmitHandler<DiffusersModelConfig>>(
(values) => {
addMainModel({
body: values,
})
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: t('modelManager.modelAdded', {
modelName: values.name,
}),
status: 'success',
})
)
);
reset();
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: t('toast.modelAddFailed'),
status: 'error',
})
)
);
}
});
},
[addMainModel, dispatch, reset, t]
);
return (
<form onSubmit={handleSubmit(onSubmit)} style={formStyles}>
<Flex flexDirection="column" gap={2}>
<FormControl isInvalid={Boolean(errors.name)}>
<Flex direction="column" width="full">
<FormLabel>{t('modelManager.name')}</FormLabel>
<Input
{...register('name', {
validate: (value) => value.trim().length > 3 || 'Must be at least 3 characters',
})}
/>
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
</Flex>
</FormControl>
<Flex>
<FormControl>
<Flex direction="column" width="full">
<FormLabel>{t('modelManager.description')}</FormLabel>
<Textarea size="sm" {...register('description')} />
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
</Flex>
</FormControl>
</Flex>
<Flex gap="2">
<BaseModelSelect<DiffusersModelConfig> control={control} name="base" />
<ModelVariantSelect<DiffusersModelConfig> control={control} name="variant" />
</Flex>
<FormControl isInvalid={Boolean(errors.path)}>
<Flex direction="column" width="full">
<FormLabel>{t('modelManager.modelLocation')}</FormLabel>
<Input
{...register('path', {
validate: (value) => value.trim().length > 0 || 'Must provide a path',
})}
/>
{errors.path?.message && <FormErrorMessage>{errors.path?.message}</FormErrorMessage>}
</Flex>
</FormControl>
<FormControl>
<Flex direction="column" width="full">
<FormLabel>{t('modelManager.vaeLocation')}</FormLabel>
<Input {...register('vae')} />
</Flex>
</FormControl>
<Button mt={2} type="submit">
{t('modelManager.addModel')}
</Button>
</Flex>
</form>
);
};
const formStyles: CSSProperties = {
width: '100%',
};

View File

@ -0,0 +1,38 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { typedMemo } from 'common/util/typedMemo';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import type { AnyModelConfig } from 'services/api/types';
const options: ComboboxOption[] = [
{ value: 'sd-1', label: MODEL_TYPE_MAP['sd-1'] },
{ value: 'sd-2', label: MODEL_TYPE_MAP['sd-2'] },
{ value: 'sdxl', label: MODEL_TYPE_MAP['sdxl'] },
{ value: 'sdxl-refiner', label: MODEL_TYPE_MAP['sdxl-refiner'] },
];
const BaseModelSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
const { t } = useTranslation();
const { field } = useController(props);
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
const onChange = useCallback<ComboboxOnChange>(
(v) => {
field.onChange(v?.value);
},
[field]
);
return (
<FormControl>
<Flex direction="column" width="full">
<FormLabel>{t('modelManager.baseModel')}</FormLabel>
<Combobox value={value} options={options} onChange={onChange} />
</Flex>
</FormControl>
);
};
export default typedMemo(BaseModelSelect);

View File

@ -0,0 +1,32 @@
import type { ChakraProps, ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { memo, useCallback, useMemo } from 'react';
import { useController, type UseControllerProps } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import { useGetCheckpointConfigsQuery } from 'services/api/endpoints/models';
import type { CheckpointModelConfig } from 'services/api/types';
const sx: ChakraProps['sx'] = { w: 'full' };
const CheckpointConfigsSelect = (props: UseControllerProps<CheckpointModelConfig>) => {
const { data } = useGetCheckpointConfigsQuery();
const { t } = useTranslation();
const options = useMemo<ComboboxOption[]>(() => (data ? data.map((i) => ({ label: i, value: i })) : []), [data]);
const { field } = useController(props);
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value, options]);
const onChange = useCallback<ComboboxOnChange>(
(v) => {
field.onChange(v?.value);
},
[field]
);
return (
<FormControl>
<FormLabel>{t('modelManager.configFile')}</FormLabel>
<Combobox placeholder="Select A Config File" value={value} options={options} onChange={onChange} sx={sx} />
</FormControl>
);
};
export default memo(CheckpointConfigsSelect);

View File

@ -57,6 +57,10 @@ export const ImportQueueModel = (props: ModelListItemProps) => {
return `${bytes.toFixed(2)} ${units[i]}`; return `${bytes.toFixed(2)} ${units[i]}`;
}; };
const modelName = useMemo(() => {
return model.source.repo_id || model.source.url || model.source.path.substring(model.source.path.lastIndexOf('/') + 1);
}, [model.source]);
const progressValue = useMemo(() => { const progressValue = useMemo(() => {
return (model.bytes / model.total_bytes) * 100; return (model.bytes / model.total_bytes) * 100;
}, [model.bytes, model.total_bytes]); }, [model.bytes, model.total_bytes]);
@ -71,7 +75,7 @@ export const ImportQueueModel = (props: ModelListItemProps) => {
return ( return (
<Flex gap="2" w="full" alignItems="center" textAlign="center"> <Flex gap="2" w="full" alignItems="center" textAlign="center">
<Text w="20%" whiteSpace="nowrap" overflow="hidden" textOverflow="ellipsis"> <Text w="20%" whiteSpace="nowrap" overflow="hidden" textOverflow="ellipsis">
{model.source.repo_id} {modelName}
</Text> </Text>
<Progress <Progress
value={progressValue} value={progressValue}
@ -83,8 +87,7 @@ export const ImportQueueModel = (props: ModelListItemProps) => {
<Text minW="20%" fontSize="xs" w="20%"> <Text minW="20%" fontSize="xs" w="20%">
{progressString} {progressString}
</Text> </Text>
<Text w="15%">{model.status[0].toUpperCase() + <Text w="15%">{model.status[0].toUpperCase() + model.status.slice(1)}</Text>
model.status.slice(1)}</Text>
<Box w="10%"> <Box w="10%">
{(model.status === 'downloading' || model.status === 'waiting') && ( {(model.status === 'downloading' || model.status === 'waiting') && (
<IconButton <IconButton

View File

@ -0,0 +1,36 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { typedMemo } from 'common/util/typedMemo';
import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import type { CheckpointModelConfig, DiffusersModelConfig } from 'services/api/types';
const options: ComboboxOption[] = [
{ value: 'normal', label: 'Normal' },
{ value: 'inpaint', label: 'Inpaint' },
{ value: 'depth', label: 'Depth' },
];
const ModelVariantSelect = <T extends CheckpointModelConfig | DiffusersModelConfig>(props: UseControllerProps<T>) => {
const { t } = useTranslation();
const { field } = useController(props);
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
const onChange = useCallback<ComboboxOnChange>(
(v) => {
field.onChange(v?.value);
},
[field]
);
return (
<FormControl>
<Flex direction="column" width="full">
<FormLabel>{t('modelManager.variant')}</FormLabel>
<Combobox value={value} options={options} onChange={onChange} />
</Flex>
</FormControl>
);
};
export default typedMemo(ModelVariantSelect);

View File

@ -1,5 +1,4 @@
import { ScanModelsForm } from './ScanModelsForm';
export const ScanModels = () => { export const ScanModels = () => {
return <ScanModelsForm />; return null;
}; };

View File

@ -1,9 +1,10 @@
import { Flex, FormControl, FormLabel, Input, Button, FormErrorMessage, Divider } from '@invoke-ai/ui-library'; import { Button,Flex, FormControl, FormErrorMessage, FormLabel, Input } from '@invoke-ai/ui-library';
import { ChangeEventHandler, useCallback, useState } from 'react'; import type { ChangeEventHandler} from 'react';
import { useLazyScanModelsQuery } from '../../../../../services/api/endpoints/models'; import { useCallback, useState } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useLazyScanModelsQuery } from 'services/api/endpoints/models';
import { ScanModelsResults } from './ScanModelsResults'; import { ScanModelsResults } from './ScanModelsResults';
import ScrollableContent from '../../../../../common/components/OverlayScrollbars/ScrollableContent';
export const ScanModelsForm = () => { export const ScanModelsForm = () => {
const [scanPath, setScanPath] = useState(''); const [scanPath, setScanPath] = useState('');

View File

@ -1,18 +1,18 @@
import { import {
Text, Divider,
Flex, Flex,
Heading, Heading,
IconButton, IconButton,
Input, Input,
InputGroup, InputGroup,
InputRightElement, InputRightElement,
Divider, Text,
Box,
} from '@invoke-ai/ui-library'; } from '@invoke-ai/ui-library';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import { t } from 'i18next'; import { t } from 'i18next';
import { ChangeEventHandler, useCallback, useMemo, useState } from 'react'; import type { ChangeEventHandler} from 'react';
import { useCallback, useState } from 'react';
import { PiXBold } from 'react-icons/pi'; import { PiXBold } from 'react-icons/pi';
import ScrollableContent from '../../../../../common/components/OverlayScrollbars/ScrollableContent';
export const ScanModelsResults = ({ results }: { results: string[] }) => { export const ScanModelsResults = ({ results }: { results: string[] }) => {
const [searchTerm, setSearchTerm] = useState(''); const [searchTerm, setSearchTerm] = useState('');

View File

@ -1,10 +1,12 @@
import { Flex, FormControl, FormLabel, Input, Button } from '@invoke-ai/ui-library'; import { Button,Flex, FormControl, FormLabel, Input } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { t } from 'i18next'; import { t } from 'i18next';
import { useForm } from '@mantine/form'; import { useCallback } from 'react';
import { useAppDispatch } from '../../../../app/store/storeHooks'; import type { SubmitHandler} from 'react-hook-form';
import { useImportMainModelsMutation } from '../../../../services/api/endpoints/models'; import { useForm } from 'react-hook-form';
import { addToast } from '../../../system/store/systemSlice'; import { useImportMainModelsMutation } from 'services/api/endpoints/models';
import { makeToast } from '../../../system/util/makeToast';
type SimpleImportModelConfig = { type SimpleImportModelConfig = {
location: string; location: string;
@ -15,13 +17,19 @@ export const SimpleImport = () => {
const [importMainModel, { isLoading }] = useImportMainModelsMutation(); const [importMainModel, { isLoading }] = useImportMainModelsMutation();
const addModelForm = useForm({ const { register, handleSubmit, formState, reset } = useForm<SimpleImportModelConfig>({
initialValues: { defaultValues: {
location: '', location: '',
}, },
mode: 'onChange',
}); });
const handleAddModelSubmit = (values: SimpleImportModelConfig) => { const onSubmit = useCallback<SubmitHandler<SimpleImportModelConfig>>(
(values) => {
if (!values?.location) {
return;
}
importMainModel({ source: values.location, config: undefined }) importMainModel({ source: values.location, config: undefined })
.unwrap() .unwrap()
.then((_) => { .then((_) => {
@ -33,9 +41,10 @@ export const SimpleImport = () => {
}) })
) )
); );
addModelForm.reset(); reset();
}) })
.catch((error) => { .catch((error) => {
reset();
if (error) { if (error) {
dispatch( dispatch(
addToast( addToast(
@ -47,18 +56,20 @@ export const SimpleImport = () => {
); );
} }
}); });
}; },
[dispatch, reset, importMainModel]
);
return ( return (
<form onSubmit={addModelForm.onSubmit((v) => handleAddModelSubmit(v))}> <form onSubmit={handleSubmit(onSubmit)}>
<Flex gap={2} alignItems="flex-end" justifyContent="space-between"> <Flex gap={2} alignItems="flex-end" justifyContent="space-between">
<FormControl> <FormControl>
<Flex direction="column" w="full"> <Flex direction="column" w="full">
<FormLabel>{t('modelManager.modelLocation')}</FormLabel> <FormLabel>{t('modelManager.modelLocation')}</FormLabel>
<Input {...addModelForm.getInputProps('location')} /> <Input {...register('location')} />
</Flex> </Flex>
</FormControl> </FormControl>
<Button isDisabled={!addModelForm.values.location} isLoading={isLoading} type="submit"> <Button onClick={handleSubmit(onSubmit)} isDisabled={!formState.isDirty} isLoading={isLoading} type="submit">
{t('modelManager.addModel')} {t('modelManager.addModel')}
</Button> </Button>
</Flex> </Flex>

View File

@ -1,8 +1,9 @@
import { Box, Divider, Heading, Tab, Flex, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/ui-library'; import { Box, Flex, Heading, Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/ui-library';
import { AdvancedImport } from './AddModelPanel/AdvancedImport';
import { ImportQueue } from './AddModelPanel/ImportQueue'; import { ImportQueue } from './AddModelPanel/ImportQueue';
import { SimpleImport } from './AddModelPanel/SimpleImport';
import { ScanModels } from './AddModelPanel/ScanModels/ScanModels'; import { ScanModels } from './AddModelPanel/ScanModels/ScanModels';
import { SimpleImport } from './AddModelPanel/SimpleImport';
export const ImportModels = () => { export const ImportModels = () => {
return ( return (
@ -21,7 +22,9 @@ export const ImportModels = () => {
<TabPanel> <TabPanel>
<SimpleImport /> <SimpleImport />
</TabPanel> </TabPanel>
<TabPanel>Advanced Import Placeholder</TabPanel> <TabPanel height="100%">
<AdvancedImport />
</TabPanel>
<TabPanel height="100%"> <TabPanel height="100%">
<ScanModels /> <ScanModels />
</TabPanel> </TabPanel>

View File

@ -206,7 +206,7 @@ export const modelsApi = api.injectEndpoints({
body: body, body: body,
}; };
}, },
invalidatesTags: ['Model'], invalidatesTags: ['Model', 'ModelImports'],
}), }),
deleteModels: build.mutation<DeleteMainModelResponse, DeleteMainModelArg>({ deleteModels: build.mutation<DeleteMainModelResponse, DeleteMainModelArg>({
query: ({ key }) => { query: ({ key }) => {