feat(ui): add starter models tab to MM

Lists all starter models with an install button if the model is not yet installed.
This commit is contained in:
psychedelicious 2024-03-19 16:02:19 +11:00
parent aa689e5384
commit bd3e8cbdfb
6 changed files with 187 additions and 1 deletions

View File

@ -688,6 +688,7 @@
"settings": "Settings",
"simpleModelPlaceholder": "URL or path to a local file or diffusers folder",
"source": "Source",
"starterModels": "Starter Models",
"syncModels": "Sync Models",
"triggerPhrases": "Trigger Phrases",
"loraTriggerPhrases": "LoRA Trigger Phrases",

View File

@ -0,0 +1,75 @@
import { Badge, Box, Flex, IconButton, Text } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import ModelBaseBadge from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiPlusBold } from 'react-icons/pi';
import type { GetStarterModelsResponse } from 'services/api/endpoints/models';
import { useInstallModelMutation } from 'services/api/endpoints/models';
type Props = {
result: GetStarterModelsResponse[number];
};
export const StarterModelsResultItem = ({ result }: Props) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const allSources = useMemo(() => {
const _allSources = [result.source];
if (result.dependencies) {
_allSources.push(...result.dependencies);
}
return _allSources;
}, [result]);
const [installModel] = useInstallModelMutation();
const handleQuickAdd = useCallback(() => {
for (const source of allSources) {
installModel({ source })
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: t('toast.modelAddedSimple'),
status: 'success',
})
)
);
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: `${error.data.detail} `,
status: 'error',
})
)
);
}
});
}
}, [allSources, installModel, dispatch, t]);
return (
<Flex alignItems="center" justifyContent="space-between" w="100%" gap={3}>
<Flex fontSize="sm" flexDir="column">
<Flex gap={3}>
<Badge h="min-content">{result.type.replace('_', ' ')}</Badge>
<ModelBaseBadge base={result.base} />
<Text fontWeight="semibold">{result.name}</Text>
</Flex>
<Text variant="subtext">{result.description}</Text>
</Flex>
<Box>
{result.is_installed ? (
<Badge>{t('common.installed')}</Badge>
) : (
<IconButton aria-label={t('modelManager.install')} icon={<PiPlusBold />} onClick={handleQuickAdd} size="sm" />
)}
</Box>
</Flex>
);
};

View File

@ -0,0 +1,16 @@
import { Flex } from '@invoke-ai/ui-library';
import { FetchingModelsLoader } from 'features/modelManagerV2/subpanels/ModelManagerPanel/FetchingModelsLoader';
import { useGetStarterModelsQuery } from 'services/api/endpoints/models';
import { StarterModelsResults } from './StarterModelsResults';
export const StarterModelsForm = () => {
const { isLoading, data } = useGetStarterModelsQuery();
return (
<Flex flexDir="column" height="100%" gap={3}>
{isLoading && <FetchingModelsLoader loadingMessage="Loading Embeddings..." />}
{data && <StarterModelsResults results={data} />}
</Flex>
);
};

View File

@ -0,0 +1,72 @@
import { Flex, IconButton, Input, InputGroup, InputRightElement } from '@invoke-ai/ui-library';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import type { ChangeEventHandler } from 'react';
import { useCallback, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { PiXBold } from 'react-icons/pi';
import type { GetStarterModelsResponse } from 'services/api/endpoints/models';
import { StarterModelsResultItem } from './StartModelsResultItem';
type StarterModelsResultsProps = {
results: NonNullable<GetStarterModelsResponse>;
};
export const StarterModelsResults = ({ results }: StarterModelsResultsProps) => {
const { t } = useTranslation();
const [searchTerm, setSearchTerm] = useState('');
const filteredResults = useMemo(() => {
return results.filter((result) => {
const name = result.name.toLowerCase();
const type = result.type.toLowerCase();
return name.includes(searchTerm.toLowerCase()) || type.includes(searchTerm.toLowerCase());
});
}, [results, searchTerm]);
const handleSearch: ChangeEventHandler<HTMLInputElement> = useCallback((e) => {
setSearchTerm(e.target.value.trim());
}, []);
const clearSearch = useCallback(() => {
setSearchTerm('');
}, []);
return (
<Flex flexDir="column" gap={3} height="100%">
<Flex justifyContent="flex-end" alignItems="center">
<InputGroup w={64} size="xs">
<Input
placeholder={t('modelManager.search')}
value={searchTerm}
data-testid="board-search-input"
onChange={handleSearch}
size="xs"
/>
{searchTerm && (
<InputRightElement h="full" pe={2}>
<IconButton
size="sm"
variant="link"
aria-label={t('boards.clearSearch')}
icon={<PiXBold />}
onClick={clearSearch}
flexShrink={0}
/>
</InputRightElement>
)}
</InputGroup>
</Flex>
<Flex height="100%" layerStyle="third" borderRadius="base" p={3}>
<ScrollableContent>
<Flex flexDir="column" gap={3}>
{filteredResults.map((result) => (
<StarterModelsResultItem key={result.source} result={result} />
))}
</Flex>
</ScrollableContent>
</Flex>
</Flex>
);
};

View File

@ -1,5 +1,8 @@
import { Box, Flex, Heading, Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/ui-library';
import { StarterModelsForm } from 'features/modelManagerV2/subpanels/AddModelPanel/StarterModels/StarterModelsForm';
import { useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useMainModels } from 'services/api/hooks/modelsByType';
import { HuggingFaceForm } from './AddModelPanel/HuggingFaceFolder/HuggingFaceForm';
import { InstallModelForm } from './AddModelPanel/InstallModelForm';
@ -8,14 +11,23 @@ import { ScanModelsForm } from './AddModelPanel/ScanFolder/ScanFolderForm';
export const InstallModels = () => {
const { t } = useTranslation();
const [mainModels, { data }] = useMainModels();
const defaultIndex = useMemo(() => {
if (data && mainModels.length) {
return 0;
}
return 3;
}, [data, mainModels.length]);
return (
<Flex layerStyle="first" borderRadius="base" w="full" h="full" flexDir="column" gap={4}>
<Heading fontSize="xl">{t('modelManager.addModel')}</Heading>
<Tabs variant="collapse" height="50%" display="flex" flexDir="column">
<Tabs variant="collapse" height="50%" display="flex" flexDir="column" defaultIndex={defaultIndex}>
<TabList>
<Tab>{t('modelManager.urlOrLocalPath')}</Tab>
<Tab>{t('modelManager.huggingFace')}</Tab>
<Tab>{t('modelManager.scanFolder')}</Tab>
<Tab>{t('modelManager.starterModels')}</Tab>
</TabList>
<TabPanels p={3} height="100%">
<TabPanel>
@ -27,6 +39,9 @@ export const InstallModels = () => {
<TabPanel height="100%">
<ScanModelsForm />
</TabPanel>
<TabPanel height="100%">
<StarterModelsForm />
</TabPanel>
</TabPanels>
</Tabs>
<Box layerStyle="second" borderRadius="base" h="50%">

View File

@ -27,6 +27,9 @@ type GetModelConfigsResponse = NonNullable<
paths['/api/v2/models/']['get']['responses']['200']['content']['application/json']
>;
export type GetStarterModelsResponse =
paths['/api/v2/models/starter_models']['get']['responses']['200']['content']['application/json'];
type DeleteModelArg = {
key: string;
};
@ -259,6 +262,9 @@ export const modelsApi = api.injectEndpoints({
});
},
}),
getStarterModels: build.query<GetStarterModelsResponse, void>({
query: () => buildModelsUrl('starter_models'),
}),
}),
});
@ -277,4 +283,5 @@ export const {
useListModelInstallsQuery,
useCancelModelInstallMutation,
usePruneCompletedModelInstallsMutation,
useGetStarterModelsQuery,
} = modelsApi;