mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): add starter models tab to MM
Lists all starter models with an install button if the model is not yet installed.
This commit is contained in:
parent
aa689e5384
commit
bd3e8cbdfb
@ -688,6 +688,7 @@
|
|||||||
"settings": "Settings",
|
"settings": "Settings",
|
||||||
"simpleModelPlaceholder": "URL or path to a local file or diffusers folder",
|
"simpleModelPlaceholder": "URL or path to a local file or diffusers folder",
|
||||||
"source": "Source",
|
"source": "Source",
|
||||||
|
"starterModels": "Starter Models",
|
||||||
"syncModels": "Sync Models",
|
"syncModels": "Sync Models",
|
||||||
"triggerPhrases": "Trigger Phrases",
|
"triggerPhrases": "Trigger Phrases",
|
||||||
"loraTriggerPhrases": "LoRA Trigger Phrases",
|
"loraTriggerPhrases": "LoRA Trigger Phrases",
|
||||||
|
@ -0,0 +1,75 @@
|
|||||||
|
import { Badge, Box, Flex, IconButton, Text } from '@invoke-ai/ui-library';
|
||||||
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
|
import ModelBaseBadge from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge';
|
||||||
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
|
import { makeToast } from 'features/system/util/makeToast';
|
||||||
|
import { useCallback, useMemo } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { PiPlusBold } from 'react-icons/pi';
|
||||||
|
import type { GetStarterModelsResponse } from 'services/api/endpoints/models';
|
||||||
|
import { useInstallModelMutation } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
result: GetStarterModelsResponse[number];
|
||||||
|
};
|
||||||
|
export const StarterModelsResultItem = ({ result }: Props) => {
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const allSources = useMemo(() => {
|
||||||
|
const _allSources = [result.source];
|
||||||
|
if (result.dependencies) {
|
||||||
|
_allSources.push(...result.dependencies);
|
||||||
|
}
|
||||||
|
return _allSources;
|
||||||
|
}, [result]);
|
||||||
|
const [installModel] = useInstallModelMutation();
|
||||||
|
|
||||||
|
const handleQuickAdd = useCallback(() => {
|
||||||
|
for (const source of allSources) {
|
||||||
|
installModel({ source })
|
||||||
|
.unwrap()
|
||||||
|
.then((_) => {
|
||||||
|
dispatch(
|
||||||
|
addToast(
|
||||||
|
makeToast({
|
||||||
|
title: t('toast.modelAddedSimple'),
|
||||||
|
status: 'success',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
);
|
||||||
|
})
|
||||||
|
.catch((error) => {
|
||||||
|
if (error) {
|
||||||
|
dispatch(
|
||||||
|
addToast(
|
||||||
|
makeToast({
|
||||||
|
title: `${error.data.detail} `,
|
||||||
|
status: 'error',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}, [allSources, installModel, dispatch, t]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex alignItems="center" justifyContent="space-between" w="100%" gap={3}>
|
||||||
|
<Flex fontSize="sm" flexDir="column">
|
||||||
|
<Flex gap={3}>
|
||||||
|
<Badge h="min-content">{result.type.replace('_', ' ')}</Badge>
|
||||||
|
<ModelBaseBadge base={result.base} />
|
||||||
|
<Text fontWeight="semibold">{result.name}</Text>
|
||||||
|
</Flex>
|
||||||
|
<Text variant="subtext">{result.description}</Text>
|
||||||
|
</Flex>
|
||||||
|
<Box>
|
||||||
|
{result.is_installed ? (
|
||||||
|
<Badge>{t('common.installed')}</Badge>
|
||||||
|
) : (
|
||||||
|
<IconButton aria-label={t('modelManager.install')} icon={<PiPlusBold />} onClick={handleQuickAdd} size="sm" />
|
||||||
|
)}
|
||||||
|
</Box>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
@ -0,0 +1,16 @@
|
|||||||
|
import { Flex } from '@invoke-ai/ui-library';
|
||||||
|
import { FetchingModelsLoader } from 'features/modelManagerV2/subpanels/ModelManagerPanel/FetchingModelsLoader';
|
||||||
|
import { useGetStarterModelsQuery } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
|
import { StarterModelsResults } from './StarterModelsResults';
|
||||||
|
|
||||||
|
export const StarterModelsForm = () => {
|
||||||
|
const { isLoading, data } = useGetStarterModelsQuery();
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex flexDir="column" height="100%" gap={3}>
|
||||||
|
{isLoading && <FetchingModelsLoader loadingMessage="Loading Embeddings..." />}
|
||||||
|
{data && <StarterModelsResults results={data} />}
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
@ -0,0 +1,72 @@
|
|||||||
|
import { Flex, IconButton, Input, InputGroup, InputRightElement } from '@invoke-ai/ui-library';
|
||||||
|
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||||
|
import type { ChangeEventHandler } from 'react';
|
||||||
|
import { useCallback, useMemo, useState } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { PiXBold } from 'react-icons/pi';
|
||||||
|
import type { GetStarterModelsResponse } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
|
import { StarterModelsResultItem } from './StartModelsResultItem';
|
||||||
|
|
||||||
|
type StarterModelsResultsProps = {
|
||||||
|
results: NonNullable<GetStarterModelsResponse>;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const StarterModelsResults = ({ results }: StarterModelsResultsProps) => {
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const [searchTerm, setSearchTerm] = useState('');
|
||||||
|
|
||||||
|
const filteredResults = useMemo(() => {
|
||||||
|
return results.filter((result) => {
|
||||||
|
const name = result.name.toLowerCase();
|
||||||
|
const type = result.type.toLowerCase();
|
||||||
|
return name.includes(searchTerm.toLowerCase()) || type.includes(searchTerm.toLowerCase());
|
||||||
|
});
|
||||||
|
}, [results, searchTerm]);
|
||||||
|
|
||||||
|
const handleSearch: ChangeEventHandler<HTMLInputElement> = useCallback((e) => {
|
||||||
|
setSearchTerm(e.target.value.trim());
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const clearSearch = useCallback(() => {
|
||||||
|
setSearchTerm('');
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex flexDir="column" gap={3} height="100%">
|
||||||
|
<Flex justifyContent="flex-end" alignItems="center">
|
||||||
|
<InputGroup w={64} size="xs">
|
||||||
|
<Input
|
||||||
|
placeholder={t('modelManager.search')}
|
||||||
|
value={searchTerm}
|
||||||
|
data-testid="board-search-input"
|
||||||
|
onChange={handleSearch}
|
||||||
|
size="xs"
|
||||||
|
/>
|
||||||
|
|
||||||
|
{searchTerm && (
|
||||||
|
<InputRightElement h="full" pe={2}>
|
||||||
|
<IconButton
|
||||||
|
size="sm"
|
||||||
|
variant="link"
|
||||||
|
aria-label={t('boards.clearSearch')}
|
||||||
|
icon={<PiXBold />}
|
||||||
|
onClick={clearSearch}
|
||||||
|
flexShrink={0}
|
||||||
|
/>
|
||||||
|
</InputRightElement>
|
||||||
|
)}
|
||||||
|
</InputGroup>
|
||||||
|
</Flex>
|
||||||
|
<Flex height="100%" layerStyle="third" borderRadius="base" p={3}>
|
||||||
|
<ScrollableContent>
|
||||||
|
<Flex flexDir="column" gap={3}>
|
||||||
|
{filteredResults.map((result) => (
|
||||||
|
<StarterModelsResultItem key={result.source} result={result} />
|
||||||
|
))}
|
||||||
|
</Flex>
|
||||||
|
</ScrollableContent>
|
||||||
|
</Flex>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
@ -1,5 +1,8 @@
|
|||||||
import { Box, Flex, Heading, Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/ui-library';
|
import { Box, Flex, Heading, Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/ui-library';
|
||||||
|
import { StarterModelsForm } from 'features/modelManagerV2/subpanels/AddModelPanel/StarterModels/StarterModelsForm';
|
||||||
|
import { useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { useMainModels } from 'services/api/hooks/modelsByType';
|
||||||
|
|
||||||
import { HuggingFaceForm } from './AddModelPanel/HuggingFaceFolder/HuggingFaceForm';
|
import { HuggingFaceForm } from './AddModelPanel/HuggingFaceFolder/HuggingFaceForm';
|
||||||
import { InstallModelForm } from './AddModelPanel/InstallModelForm';
|
import { InstallModelForm } from './AddModelPanel/InstallModelForm';
|
||||||
@ -8,14 +11,23 @@ import { ScanModelsForm } from './AddModelPanel/ScanFolder/ScanFolderForm';
|
|||||||
|
|
||||||
export const InstallModels = () => {
|
export const InstallModels = () => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
const [mainModels, { data }] = useMainModels();
|
||||||
|
const defaultIndex = useMemo(() => {
|
||||||
|
if (data && mainModels.length) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
return 3;
|
||||||
|
}, [data, mainModels.length]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex layerStyle="first" borderRadius="base" w="full" h="full" flexDir="column" gap={4}>
|
<Flex layerStyle="first" borderRadius="base" w="full" h="full" flexDir="column" gap={4}>
|
||||||
<Heading fontSize="xl">{t('modelManager.addModel')}</Heading>
|
<Heading fontSize="xl">{t('modelManager.addModel')}</Heading>
|
||||||
<Tabs variant="collapse" height="50%" display="flex" flexDir="column">
|
<Tabs variant="collapse" height="50%" display="flex" flexDir="column" defaultIndex={defaultIndex}>
|
||||||
<TabList>
|
<TabList>
|
||||||
<Tab>{t('modelManager.urlOrLocalPath')}</Tab>
|
<Tab>{t('modelManager.urlOrLocalPath')}</Tab>
|
||||||
<Tab>{t('modelManager.huggingFace')}</Tab>
|
<Tab>{t('modelManager.huggingFace')}</Tab>
|
||||||
<Tab>{t('modelManager.scanFolder')}</Tab>
|
<Tab>{t('modelManager.scanFolder')}</Tab>
|
||||||
|
<Tab>{t('modelManager.starterModels')}</Tab>
|
||||||
</TabList>
|
</TabList>
|
||||||
<TabPanels p={3} height="100%">
|
<TabPanels p={3} height="100%">
|
||||||
<TabPanel>
|
<TabPanel>
|
||||||
@ -27,6 +39,9 @@ export const InstallModels = () => {
|
|||||||
<TabPanel height="100%">
|
<TabPanel height="100%">
|
||||||
<ScanModelsForm />
|
<ScanModelsForm />
|
||||||
</TabPanel>
|
</TabPanel>
|
||||||
|
<TabPanel height="100%">
|
||||||
|
<StarterModelsForm />
|
||||||
|
</TabPanel>
|
||||||
</TabPanels>
|
</TabPanels>
|
||||||
</Tabs>
|
</Tabs>
|
||||||
<Box layerStyle="second" borderRadius="base" h="50%">
|
<Box layerStyle="second" borderRadius="base" h="50%">
|
||||||
|
@ -27,6 +27,9 @@ type GetModelConfigsResponse = NonNullable<
|
|||||||
paths['/api/v2/models/']['get']['responses']['200']['content']['application/json']
|
paths['/api/v2/models/']['get']['responses']['200']['content']['application/json']
|
||||||
>;
|
>;
|
||||||
|
|
||||||
|
export type GetStarterModelsResponse =
|
||||||
|
paths['/api/v2/models/starter_models']['get']['responses']['200']['content']['application/json'];
|
||||||
|
|
||||||
type DeleteModelArg = {
|
type DeleteModelArg = {
|
||||||
key: string;
|
key: string;
|
||||||
};
|
};
|
||||||
@ -259,6 +262,9 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
});
|
});
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
|
getStarterModels: build.query<GetStarterModelsResponse, void>({
|
||||||
|
query: () => buildModelsUrl('starter_models'),
|
||||||
|
}),
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -277,4 +283,5 @@ export const {
|
|||||||
useListModelInstallsQuery,
|
useListModelInstallsQuery,
|
||||||
useCancelModelInstallMutation,
|
useCancelModelInstallMutation,
|
||||||
usePruneCompletedModelInstallsMutation,
|
usePruneCompletedModelInstallsMutation,
|
||||||
|
useGetStarterModelsQuery,
|
||||||
} = modelsApi;
|
} = modelsApi;
|
||||||
|
Loading…
Reference in New Issue
Block a user