Only install starter models if not already installed

This commit is contained in:
Brandon Rising 2024-08-26 12:54:28 -04:00 committed by Brandon
parent bbf934d980
commit cf633e4ef2
3 changed files with 28 additions and 7 deletions

View File

@ -5,11 +5,13 @@ import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { PiPlusBold } from 'react-icons/pi'; import { PiPlusBold } from 'react-icons/pi';
import type { GetStarterModelsResponse } from 'services/api/endpoints/models'; import type { GetStarterModelsResponse } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
type Props = { type Props = {
result: GetStarterModelsResponse[number]; result: GetStarterModelsResponse[number];
modelList: AnyModelConfig[];
}; };
export const StarterModelsResultItem = memo(({ result }: Props) => { export const StarterModelsResultItem = memo(({ result, modelList }: Props) => {
const { t } = useTranslation(); const { t } = useTranslation();
const allSources = useMemo(() => { const allSources = useMemo(() => {
const _allSources = [ const _allSources = [
@ -38,9 +40,12 @@ export const StarterModelsResultItem = memo(({ result }: Props) => {
const onClick = useCallback(() => { const onClick = useCallback(() => {
for (const { config, source } of allSources) { for (const { config, source } of allSources) {
if (modelList.some((mc) => config.base === mc.base && config.name === mc.name && config.type === mc.type)) {
continue;
}
installModel({ config, source }); installModel({ config, source });
} }
}, [allSources, installModel]); }, [modelList, allSources, installModel]);
return ( return (
<Flex alignItems="center" justifyContent="space-between" w="100%" gap={3}> <Flex alignItems="center" justifyContent="space-between" w="100%" gap={3}>

View File

@ -1,17 +1,31 @@
import { Flex } from '@invoke-ai/ui-library'; import { Flex } from '@invoke-ai/ui-library';
import { EMPTY_ARRAY } from 'app/store/constants';
import { FetchingModelsLoader } from 'features/modelManagerV2/subpanels/ModelManagerPanel/FetchingModelsLoader'; import { FetchingModelsLoader } from 'features/modelManagerV2/subpanels/ModelManagerPanel/FetchingModelsLoader';
import { memo } from 'react'; import { memo, useMemo } from 'react';
import { useGetStarterModelsQuery } from 'services/api/endpoints/models'; import {
modelConfigsAdapterSelectors,
useGetModelConfigsQuery,
useGetStarterModelsQuery,
} from 'services/api/endpoints/models';
import { StarterModelsResults } from './StarterModelsResults'; import { StarterModelsResults } from './StarterModelsResults';
export const StarterModelsForm = memo(() => { export const StarterModelsForm = memo(() => {
const { isLoading, data } = useGetStarterModelsQuery(); const { isLoading, data } = useGetStarterModelsQuery();
const { data: modelListRes } = useGetModelConfigsQuery();
const modelList = useMemo(() => {
if (!modelListRes) {
return EMPTY_ARRAY;
}
return modelConfigsAdapterSelectors.selectAll(modelListRes);
}, [modelListRes]);
return ( return (
<Flex flexDir="column" height="100%" gap={3}> <Flex flexDir="column" height="100%" gap={3}>
{isLoading && <FetchingModelsLoader loadingMessage="Loading Embeddings..." />} {isLoading && <FetchingModelsLoader loadingMessage="Loading Embeddings..." />}
{data && <StarterModelsResults results={data} />} {data && <StarterModelsResults results={data} modelList={modelList} />}
</Flex> </Flex>
); );
}); });

View File

@ -5,14 +5,16 @@ import { memo, useCallback, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { PiXBold } from 'react-icons/pi'; import { PiXBold } from 'react-icons/pi';
import type { GetStarterModelsResponse } from 'services/api/endpoints/models'; import type { GetStarterModelsResponse } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
import { StarterModelsResultItem } from './StartModelsResultItem'; import { StarterModelsResultItem } from './StartModelsResultItem';
type StarterModelsResultsProps = { type StarterModelsResultsProps = {
results: NonNullable<GetStarterModelsResponse>; results: NonNullable<GetStarterModelsResponse>;
modelList: AnyModelConfig[];
}; };
export const StarterModelsResults = memo(({ results }: StarterModelsResultsProps) => { export const StarterModelsResults = memo(({ results, modelList }: StarterModelsResultsProps) => {
const { t } = useTranslation(); const { t } = useTranslation();
const [searchTerm, setSearchTerm] = useState(''); const [searchTerm, setSearchTerm] = useState('');
@ -72,7 +74,7 @@ export const StarterModelsResults = memo(({ results }: StarterModelsResultsProps
<ScrollableContent> <ScrollableContent>
<Flex flexDir="column" gap={3}> <Flex flexDir="column" gap={3}>
{filteredResults.map((result) => ( {filteredResults.map((result) => (
<StarterModelsResultItem key={result.source} result={result} /> <StarterModelsResultItem key={result.source} result={result} modelList={modelList} />
))} ))}
</Flex> </Flex>
</ScrollableContent> </ScrollableContent>