add scan model endpoint, break add model into tabs

This commit is contained in:
Mary Hipp 2024-02-22 10:12:27 -05:00 committed by Brandon Rising
parent 0030606d99
commit a9b1f4b8c6
6 changed files with 134 additions and 97 deletions

View File

@ -1,4 +1,4 @@
import { Box, Button,Flex, Text } from '@invoke-ai/ui-library'; import { Box, Button, Flex, Text } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast'; import { makeToast } from 'features/system/util/makeToast';
@ -7,7 +7,7 @@ import { useCallback, useMemo } from 'react';
import { RiSparklingFill } from 'react-icons/ri'; import { RiSparklingFill } from 'react-icons/ri';
import { useGetModelImportsQuery, usePruneModelImportsMutation } from 'services/api/endpoints/models'; import { useGetModelImportsQuery, usePruneModelImportsMutation } from 'services/api/endpoints/models';
import { ImportQueueModel } from './ImportQueueModel'; import { ImportQueueModel } from '../ImportQueueModel';
export const ImportQueue = () => { export const ImportQueue = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
@ -50,7 +50,7 @@ export const ImportQueue = () => {
}, [data]); }, [data]);
return ( return (
<> <Flex flexDir="column" p={3} h="full">
<Flex justifyContent="space-between"> <Flex justifyContent="space-between">
<Text>{t('modelManager.importQueue')}</Text> <Text>{t('modelManager.importQueue')}</Text>
<Button <Button
@ -67,6 +67,6 @@ export const ImportQueue = () => {
{data?.map((model) => <ImportQueueModel key={model.id} model={model} />)} {data?.map((model) => <ImportQueueModel key={model.id} model={model} />)}
</Flex> </Flex>
</Box> </Box>
</> </Flex>
); );
}; };

View File

@ -0,0 +1,67 @@
import { Flex, FormControl, FormLabel, Input, Button } from '@invoke-ai/ui-library';
import { t } from 'i18next';
import { useForm } from '@mantine/form';
import { useAppDispatch } from '../../../../app/store/storeHooks';
import { useImportMainModelsMutation } from '../../../../services/api/endpoints/models';
import { addToast } from '../../../system/store/systemSlice';
import { makeToast } from '../../../system/util/makeToast';
type SimpleImportModelConfig = {
location: string;
};
export const SimpleImport = () => {
const dispatch = useAppDispatch();
const [importMainModel, { isLoading }] = useImportMainModelsMutation();
const addModelForm = useForm({
initialValues: {
location: '',
},
});
const handleAddModelSubmit = (values: SimpleImportModelConfig) => {
importMainModel({ source: values.location, config: undefined })
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: t('toast.modelAddedSimple'),
status: 'success',
})
)
);
addModelForm.reset();
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: `${error.data.detail} `,
status: 'error',
})
)
);
}
});
};
return (
<form onSubmit={addModelForm.onSubmit((v) => handleAddModelSubmit(v))}>
<Flex gap={2} alignItems="flex-end" justifyContent="space-between">
<FormControl>
<Flex direction="column" w="full">
<FormLabel>{t('modelManager.modelLocation')}</FormLabel>
<Input {...addModelForm.getInputProps('location')} />
</Flex>
</FormControl>
<Button isDisabled={!addModelForm.values.location} isLoading={isLoading} type="submit">
{t('modelManager.addModel')}
</Button>
</Flex>
</form>
);
};

View File

@ -1,80 +1,30 @@
import { Box, Button, Divider,Flex, FormControl, FormLabel, Heading, Input } from '@invoke-ai/ui-library'; import { Box, Divider, Heading, Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/ui-library';
import { useForm } from '@mantine/form';
import { useAppDispatch } from 'app/store/storeHooks';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { t } from 'i18next';
import type { CSSProperties } from 'react';
import { useImportMainModelsMutation } from 'services/api/endpoints/models';
import { ImportQueue } from './ImportQueue'; import { ImportQueue } from './AddModelPanel/ImportQueue';
import { SimpleImport } from './AddModelPanel/SimpleImport';
const formStyles: CSSProperties = {
width: '100%',
};
type ExtendedImportModelConfig = {
location: string;
};
export const ImportModels = () => { export const ImportModels = () => {
const dispatch = useAppDispatch();
const [importMainModel, { isLoading }] = useImportMainModelsMutation();
const addModelForm = useForm({
initialValues: {
location: '',
},
});
const handleAddModelSubmit = (values: ExtendedImportModelConfig) => {
importMainModel({ source: values.location, config: undefined })
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: t('toast.modelAddedSimple'),
status: 'success',
})
)
);
addModelForm.reset();
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: `${error.data.detail} `,
status: 'error',
})
)
);
}
});
};
return ( return (
<Box layerStyle="first" p={3} borderRadius="base" w="full" h="full"> <Box layerStyle="first" p={3} borderRadius="base" w="full" h="full">
<Box w="full" p={4}> <Box w="full" p={4}>
<Heading fontSize="xl">Add Model</Heading> <Heading fontSize="xl">Add Model</Heading>
</Box> </Box>
<Box layerStyle="second" p={3} borderRadius="base" w="full" h="full"> <Box layerStyle="second" borderRadius="base" w="full" h="100vh">
<form onSubmit={addModelForm.onSubmit((v) => handleAddModelSubmit(v))} style={formStyles}> <Tabs variant="collapse">
<Flex gap={2} alignItems="flex-end" justifyContent="space-between"> <TabList>
<FormControl> <Tab>Simple</Tab>
<Flex direction="column" w="full"> <Tab>Advanced</Tab>
<FormLabel>{t('modelManager.modelLocation')}</FormLabel> <Tab>Scan</Tab>
<Input {...addModelForm.getInputProps('location')} /> </TabList>
</Flex> <TabPanels p={3}>
</FormControl> <TabPanel>
<Button isDisabled={!addModelForm.values.location} isLoading={isLoading} type="submit"> <SimpleImport />
{t('modelManager.addModel')} </TabPanel>
</Button> <TabPanel>Advanced Import Placeholder</TabPanel>
</Flex> <TabPanel>Scan Models Placeholder</TabPanel>
</form> </TabPanels>
</Tabs>
<Divider mt="5" mb="3" /> <Divider mt="5" mb="3" />
<ImportQueue /> <ImportQueue />
</Box> </Box>

View File

@ -25,10 +25,10 @@ type UpdateModelArg = {
}; };
type UpdateModelResponse = paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json']; type UpdateModelResponse = paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json'];
type GetModelConfigResponse = paths['/api/v2/models/i/{key}']['get']['responses']['200']['content']['application/json'];
type GetModelMetadataResponse = type GetModelMetadataResponse =
paths['/api/v2/models/meta/i/{key}']['get']['responses']['200']['content']['application/json']; paths['/api/v2/models/i/{key}/meta']['get']['responses']['200']['content']['application/json'];
type GetModelResponse = paths['/api/v2/models/i/{key}']['get']['responses']['200']['content']['application/json'];
type ListModelsArg = NonNullable<paths['/api/v2/models/']['get']['parameters']['query']>; type ListModelsArg = NonNullable<paths['/api/v2/models/']['get']['parameters']['query']>;
@ -78,16 +78,13 @@ type AddMainModelArg = {
type AddMainModelResponse = paths['/api/v2/models/add']['post']['responses']['201']['content']['application/json']; type AddMainModelResponse = paths['/api/v2/models/add']['post']['responses']['201']['content']['application/json'];
type SyncModelsResponse = paths['/api/v2/models/sync']['patch']['responses']['204']['content'] export type ScanFolderResponse =
paths['/api/v2/models/scan_folder']['get']['responses']['200']['content']['application/json'];
export type SearchFolderResponse = type ScanFolderArg = operations['scan_for_models']['parameters']['query'];
paths['/api/v2/models/search']['get']['responses']['200']['content']['application/json'];
type CheckpointConfigsResponse = type CheckpointConfigsResponse =
paths['/api/v2/models/ckpt_confs']['get']['responses']['200']['content']['application/json']; paths['/api/v2/models/ckpt_confs']['get']['responses']['200']['content']['application/json'];
type SearchFolderArg = operations['search_for_models']['parameters']['query'];
export const mainModelsAdapter = createEntityAdapter<MainModelConfig, string>({ export const mainModelsAdapter = createEntityAdapter<MainModelConfig, string>({
selectId: (entity) => entity.key, selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name), sortComparer: (a, b) => a.name.localeCompare(b.name),
@ -131,7 +128,6 @@ const buildProvidesTags =
<TEntity extends AnyModelConfig>(tagType: (typeof tagTypes)[number]) => <TEntity extends AnyModelConfig>(tagType: (typeof tagTypes)[number]) =>
(result: EntityState<TEntity, string> | undefined) => { (result: EntityState<TEntity, string> | undefined) => {
const tags: ApiTagDescription[] = [{ type: tagType, id: LIST_TAG }, 'Model']; const tags: ApiTagDescription[] = [{ type: tagType, id: LIST_TAG }, 'Model'];
if (result) { if (result) {
tags.push( tags.push(
...result.ids.map((id) => ({ ...result.ids.map((id) => ({
@ -175,15 +171,9 @@ export const modelsApi = api.injectEndpoints({
providesTags: buildProvidesTags<MainModelConfig>('MainModel'), providesTags: buildProvidesTags<MainModelConfig>('MainModel'),
transformResponse: buildTransformResponse<MainModelConfig>(mainModelsAdapter), transformResponse: buildTransformResponse<MainModelConfig>(mainModelsAdapter),
}), }),
getModel: build.query<GetModelResponse, string>({
query: (key) => {
return buildModelsUrl(`i/${key}`);
},
providesTags: ['Model'],
}),
getModelMetadata: build.query<GetModelMetadataResponse, string>({ getModelMetadata: build.query<GetModelMetadataResponse, string>({
query: (key) => { query: (key) => {
return buildModelsUrl(`meta/i/${key}`); return buildModelsUrl(`i/${key}/meta`);
}, },
providesTags: ['Model'], providesTags: ['Model'],
}), }),
@ -247,7 +237,7 @@ export const modelsApi = api.injectEndpoints({
}, },
invalidatesTags: ['Model'], invalidatesTags: ['Model'],
}), }),
getModelConfig: build.query<AnyModelConfig, string>({ getModelConfig: build.query<GetModelConfigResponse, string>({
query: (key) => buildModelsUrl(`i/${key}`), query: (key) => buildModelsUrl(`i/${key}`),
providesTags: (result) => { providesTags: (result) => {
const tags: ApiTagDescription[] = ['Model']; const tags: ApiTagDescription[] = ['Model'];
@ -259,7 +249,7 @@ export const modelsApi = api.injectEndpoints({
return tags; return tags;
}, },
}), }),
syncModels: build.mutation<SyncModelsResponse, void>({ syncModels: build.mutation<void, void>({
query: () => { query: () => {
return { return {
url: buildModelsUrl('sync'), url: buildModelsUrl('sync'),
@ -298,16 +288,16 @@ export const modelsApi = api.injectEndpoints({
providesTags: buildProvidesTags<TextualInversionModelConfig>('TextualInversionModel'), providesTags: buildProvidesTags<TextualInversionModelConfig>('TextualInversionModel'),
transformResponse: buildTransformResponse<TextualInversionModelConfig>(textualInversionModelsAdapter), transformResponse: buildTransformResponse<TextualInversionModelConfig>(textualInversionModelsAdapter),
}), }),
getModelsInFolder: build.query<SearchFolderResponse, SearchFolderArg>({ scanModels: build.query<ScanFolderResponse, ScanFolderArg>({
query: (arg) => { query: (arg) => {
const folderQueryStr = queryString.stringify(arg, {}); const folderQueryStr = arg ? queryString.stringify(arg, {}) : '';
return { return {
url: buildModelsUrl(`search?${folderQueryStr}`), url: buildModelsUrl(`scan_folder?${folderQueryStr}`),
}; };
}, },
}), }),
getModelImports: build.query<ListImportModelsResponse, void>({ getModelImports: build.query<ListImportModelsResponse, void>({
query: (arg) => { query: () => {
return { return {
url: buildModelsUrl(`import`), url: buildModelsUrl(`import`),
}; };
@ -358,11 +348,10 @@ export const {
useConvertMainModelsMutation, useConvertMainModelsMutation,
useMergeMainModelsMutation, useMergeMainModelsMutation,
useSyncModelsMutation, useSyncModelsMutation,
useGetModelsInFolderQuery, useScanModelsQuery,
useGetCheckpointConfigsQuery, useGetCheckpointConfigsQuery,
useGetModelImportsQuery, useGetModelImportsQuery,
useGetModelMetadataQuery, useGetModelMetadataQuery,
useDeleteModelImportMutation, useDeleteModelImportMutation,
usePruneModelImportsMutation, usePruneModelImportsMutation,
useGetModelQuery,
} = modelsApi; } = modelsApi;

View File

@ -60,6 +60,10 @@ export type paths = {
*/ */
get: operations["list_tags"]; get: operations["list_tags"];
}; };
"/api/v2/models/scan_folder": {
/** Scan For Models */
get: operations["scan_for_models"];
};
"/api/v2/models/tags/search": { "/api/v2/models/tags/search": {
/** /**
* Search By Metadata Tags * Search By Metadata Tags
@ -11361,6 +11365,33 @@ export type operations = {
}; };
}; };
}; };
/** Scan For Models */
scan_for_models: {
parameters: {
query?: {
/** @description Directory path to search for models */
scan_path?: string;
};
};
responses: {
/** @description Directory scanned successfully */
200: {
content: {
"application/json": string[];
};
};
/** @description Invalid directory path */
400: {
content: never;
};
/** @description Validation Error */
422: {
content: {
"application/json": components["schemas"]["HTTPValidationError"];
};
};
};
};
/** /**
* Search By Metadata Tags * Search By Metadata Tags
* @description Get a list of models. * @description Get a list of models.