tidy(ui): split install model into helper hook

This was duplicated like 7 times or so
This commit is contained in:
psychedelicious 2024-05-21 19:55:59 +10:00
parent a66b3497e0
commit f2b9684de8
7 changed files with 90 additions and 167 deletions

View File

@ -0,0 +1,48 @@
import { toast } from 'features/toast/toast';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useInstallModelMutation } from 'services/api/endpoints/models';
type InstallModelArg = {
source: string;
inplace?: boolean;
onSuccess?: () => void;
onError?: (error: unknown) => void;
};
export const useInstallModel = () => {
const { t } = useTranslation();
const [_installModel, request] = useInstallModelMutation();
const installModel = useCallback(
({ source, inplace, onSuccess, onError }: InstallModelArg) => {
_installModel({ source, inplace })
.unwrap()
.then((_) => {
if (onSuccess) {
onSuccess();
}
toast({
id: 'MODEL_INSTALL_QUEUED',
title: t('toast.modelAddedSimple'),
status: 'success',
});
})
.catch((error) => {
if (onError) {
onError(error);
}
if (error) {
toast({
id: 'MODEL_INSTALL_QUEUE_FAILED',
title: `${error.data.detail} `,
status: 'error',
});
}
});
},
[_installModel, t]
);
return [installModel, request] as const;
};

View File

@ -1,9 +1,9 @@
import { Button, Flex, FormControl, FormErrorMessage, FormHelperText, FormLabel, Input } from '@invoke-ai/ui-library'; import { Button, Flex, FormControl, FormErrorMessage, FormHelperText, FormLabel, Input } from '@invoke-ai/ui-library';
import { toast, ToastID } from 'features/toast/toast'; import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel';
import type { ChangeEventHandler } from 'react'; import type { ChangeEventHandler } from 'react';
import { useCallback, useState } from 'react'; import { useCallback, useState } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useInstallModelMutation, useLazyGetHuggingFaceModelsQuery } from 'services/api/endpoints/models'; import { useLazyGetHuggingFaceModelsQuery } from 'services/api/endpoints/models';
import { HuggingFaceResults } from './HuggingFaceResults'; import { HuggingFaceResults } from './HuggingFaceResults';
@ -14,41 +14,17 @@ export const HuggingFaceForm = () => {
const { t } = useTranslation(); const { t } = useTranslation();
const [_getHuggingFaceModels, { isLoading, data }] = useLazyGetHuggingFaceModelsQuery(); const [_getHuggingFaceModels, { isLoading, data }] = useLazyGetHuggingFaceModelsQuery();
const [installModel] = useInstallModelMutation(); const [installModel] = useInstallModel();
const handleInstallModel = useCallback(
(source: string) => {
installModel({ source })
.unwrap()
.then((_) => {
toast({
id: ToastID.MODEL_INSTALL_QUEUED,
title: t('toast.modelAddedSimple'),
status: 'success',
});
})
.catch((error) => {
if (error) {
toast({
id: ToastID.MODEL_INSTALL_QUEUE_FAILED,
title: `${error.data.detail} `,
status: 'error',
});
}
});
},
[installModel, t]
);
const getModels = useCallback(async () => { const getModels = useCallback(async () => {
_getHuggingFaceModels(huggingFaceRepo) _getHuggingFaceModels(huggingFaceRepo)
.unwrap() .unwrap()
.then((response) => { .then((response) => {
if (response.is_diffusers) { if (response.is_diffusers) {
handleInstallModel(huggingFaceRepo); installModel({ source: huggingFaceRepo });
setDisplayResults(false); setDisplayResults(false);
} else if (response.urls?.length === 1 && response.urls[0]) { } else if (response.urls?.length === 1 && response.urls[0]) {
handleInstallModel(response.urls[0]); installModel({ source: response.urls[0] });
setDisplayResults(false); setDisplayResults(false);
} else { } else {
setDisplayResults(true); setDisplayResults(true);
@ -57,7 +33,7 @@ export const HuggingFaceForm = () => {
.catch((error) => { .catch((error) => {
setErrorMessage(error.data.detail || ''); setErrorMessage(error.data.detail || '');
}); });
}, [_getHuggingFaceModels, handleInstallModel, huggingFaceRepo]); }, [_getHuggingFaceModels, installModel, huggingFaceRepo]);
const handleSetHuggingFaceRepo: ChangeEventHandler<HTMLInputElement> = useCallback((e) => { const handleSetHuggingFaceRepo: ChangeEventHandler<HTMLInputElement> = useCallback((e) => {
setHuggingFaceRepo(e.target.value); setHuggingFaceRepo(e.target.value);

View File

@ -1,9 +1,8 @@
import { Flex, IconButton, Text } from '@invoke-ai/ui-library'; import { Flex, IconButton, Text } from '@invoke-ai/ui-library';
import { toast, ToastID } from 'features/toast/toast'; import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel';
import { useCallback } from 'react'; import { useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { PiPlusBold } from 'react-icons/pi'; import { PiPlusBold } from 'react-icons/pi';
import { useInstallModelMutation } from 'services/api/endpoints/models';
type Props = { type Props = {
result: string; result: string;
@ -11,28 +10,11 @@ type Props = {
export const HuggingFaceResultItem = ({ result }: Props) => { export const HuggingFaceResultItem = ({ result }: Props) => {
const { t } = useTranslation(); const { t } = useTranslation();
const [installModel] = useInstallModelMutation(); const [installModel] = useInstallModel();
const handleInstall = useCallback(() => { const onClick = useCallback(() => {
installModel({ source: result }) installModel({ source: result });
.unwrap() }, [installModel, result]);
.then((_) => {
toast({
id: ToastID.MODEL_INSTALL_QUEUED,
title: t('toast.modelAddedSimple'),
status: 'success',
});
})
.catch((error) => {
if (error) {
toast({
id: ToastID.MODEL_INSTALL_QUEUE_FAILED,
title: `${error.data.detail} `,
status: 'error',
});
}
});
}, [installModel, result, t]);
return ( return (
<Flex alignItems="center" justifyContent="space-between" w="100%" gap={3}> <Flex alignItems="center" justifyContent="space-between" w="100%" gap={3}>
@ -42,7 +24,7 @@ export const HuggingFaceResultItem = ({ result }: Props) => {
{result} {result}
</Text> </Text>
</Flex> </Flex>
<IconButton aria-label={t('modelManager.install')} icon={<PiPlusBold />} onClick={handleInstall} size="sm" /> <IconButton aria-label={t('modelManager.install')} icon={<PiPlusBold />} onClick={onClick} size="sm" />
</Flex> </Flex>
); );
}; };

View File

@ -9,12 +9,11 @@ import {
InputRightElement, InputRightElement,
} from '@invoke-ai/ui-library'; } from '@invoke-ai/ui-library';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent'; import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import { toast, ToastID } from 'features/toast/toast'; import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel';
import type { ChangeEventHandler } from 'react'; import type { ChangeEventHandler } from 'react';
import { useCallback, useMemo, useState } from 'react'; import { useCallback, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { PiXBold } from 'react-icons/pi'; import { PiXBold } from 'react-icons/pi';
import { useInstallModelMutation } from 'services/api/endpoints/models';
import { HuggingFaceResultItem } from './HuggingFaceResultItem'; import { HuggingFaceResultItem } from './HuggingFaceResultItem';
@ -26,7 +25,7 @@ export const HuggingFaceResults = ({ results }: HuggingFaceResultsProps) => {
const { t } = useTranslation(); const { t } = useTranslation();
const [searchTerm, setSearchTerm] = useState(''); const [searchTerm, setSearchTerm] = useState('');
const [installModel] = useInstallModelMutation(); const [installModel] = useInstallModel();
const filteredResults = useMemo(() => { const filteredResults = useMemo(() => {
return results.filter((result) => { return results.filter((result) => {
@ -43,28 +42,11 @@ export const HuggingFaceResults = ({ results }: HuggingFaceResultsProps) => {
setSearchTerm(''); setSearchTerm('');
}, []); }, []);
const handleAddAll = useCallback(() => { const onClickAddAll = useCallback(() => {
for (const result of filteredResults) { for (const result of filteredResults) {
installModel({ source: result }) installModel({ source: result });
.unwrap()
.then((_) => {
toast({
id: ToastID.MODEL_INSTALL_QUEUED,
title: t('toast.modelAddedSimple'),
status: 'success',
});
})
.catch((error) => {
if (error) {
toast({
id: ToastID.MODEL_INSTALL_QUEUE_FAILED,
title: `${error.data.detail} `,
status: 'error',
});
}
});
} }
}, [filteredResults, installModel, t]); }, [filteredResults, installModel]);
return ( return (
<> <>
@ -73,7 +55,7 @@ export const HuggingFaceResults = ({ results }: HuggingFaceResultsProps) => {
<Flex justifyContent="space-between" alignItems="center"> <Flex justifyContent="space-between" alignItems="center">
<Heading size="sm">{t('modelManager.availableModels')}</Heading> <Heading size="sm">{t('modelManager.availableModels')}</Heading>
<Flex alignItems="center" gap={3}> <Flex alignItems="center" gap={3}>
<Button size="sm" onClick={handleAddAll} isDisabled={results.length === 0} flexShrink={0}> <Button size="sm" onClick={onClickAddAll} isDisabled={results.length === 0} flexShrink={0}>
{t('modelManager.installAll')} {t('modelManager.installAll')}
</Button> </Button>
<InputGroup w={64} size="xs"> <InputGroup w={64} size="xs">

View File

@ -1,10 +1,9 @@
import { Button, Checkbox, Flex, FormControl, FormHelperText, FormLabel, Input } from '@invoke-ai/ui-library'; import { Button, Checkbox, Flex, FormControl, FormHelperText, FormLabel, Input } from '@invoke-ai/ui-library';
import { toast, ToastID } from 'features/toast/toast'; import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel';
import { t } from 'i18next'; import { t } from 'i18next';
import { useCallback } from 'react'; import { useCallback } from 'react';
import type { SubmitHandler } from 'react-hook-form'; import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form'; import { useForm } from 'react-hook-form';
import { useInstallModelMutation } from 'services/api/endpoints/models';
type SimpleImportModelConfig = { type SimpleImportModelConfig = {
location: string; location: string;
@ -12,7 +11,7 @@ type SimpleImportModelConfig = {
}; };
export const InstallModelForm = () => { export const InstallModelForm = () => {
const [installModel, { isLoading }] = useInstallModelMutation(); const [installModel, { isLoading }] = useInstallModel();
const { register, handleSubmit, formState, reset } = useForm<SimpleImportModelConfig>({ const { register, handleSubmit, formState, reset } = useForm<SimpleImportModelConfig>({
defaultValues: { defaultValues: {
@ -22,34 +21,22 @@ export const InstallModelForm = () => {
mode: 'onChange', mode: 'onChange',
}); });
const resetForm = useCallback(() => reset(undefined, { keepValues: true }), [reset]);
const onSubmit = useCallback<SubmitHandler<SimpleImportModelConfig>>( const onSubmit = useCallback<SubmitHandler<SimpleImportModelConfig>>(
(values) => { (values) => {
if (!values?.location) { if (!values?.location) {
return; return;
} }
installModel({ source: values.location, inplace: values.inplace }) installModel({
.unwrap() source: values.location,
.then((_) => { inplace: values.inplace,
toast({ onSuccess: resetForm,
id: ToastID.MODEL_INSTALL_QUEUED, onError: resetForm,
title: t('toast.modelAddedSimple'), });
status: 'success',
});
reset(undefined, { keepValues: true });
})
.catch((error) => {
reset(undefined, { keepValues: true });
if (error) {
toast({
id: ToastID.MODEL_INSTALL_QUEUE_FAILED,
title: `${error.data.detail} `,
status: 'error',
});
}
});
}, },
[reset, installModel] [installModel, resetForm]
); );
return ( return (

View File

@ -12,12 +12,12 @@ import {
InputRightElement, InputRightElement,
} from '@invoke-ai/ui-library'; } from '@invoke-ai/ui-library';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent'; import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import { toast, ToastID } from 'features/toast/toast'; import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel';
import type { ChangeEvent, ChangeEventHandler } from 'react'; import type { ChangeEvent, ChangeEventHandler } from 'react';
import { useCallback, useMemo, useState } from 'react'; import { useCallback, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { PiXBold } from 'react-icons/pi'; import { PiXBold } from 'react-icons/pi';
import { type ScanFolderResponse, useInstallModelMutation } from 'services/api/endpoints/models'; import type { ScanFolderResponse } from 'services/api/endpoints/models';
import { ScanModelResultItem } from './ScanFolderResultItem'; import { ScanModelResultItem } from './ScanFolderResultItem';
@ -29,7 +29,7 @@ export const ScanModelsResults = ({ results }: ScanModelResultsProps) => {
const { t } = useTranslation(); const { t } = useTranslation();
const [searchTerm, setSearchTerm] = useState(''); const [searchTerm, setSearchTerm] = useState('');
const [inplace, setInplace] = useState(true); const [inplace, setInplace] = useState(true);
const [installModel] = useInstallModelMutation(); const [installModel] = useInstallModel();
const filteredResults = useMemo(() => { const filteredResults = useMemo(() => {
return results.filter((result) => { return results.filter((result) => {
@ -55,49 +55,15 @@ export const ScanModelsResults = ({ results }: ScanModelResultsProps) => {
if (result.is_installed) { if (result.is_installed) {
continue; continue;
} }
installModel({ source: result.path, inplace }) installModel({ source: result.path, inplace });
.unwrap()
.then((_) => {
toast({
id: ToastID.MODEL_INSTALL_QUEUED,
title: t('toast.modelAddedSimple'),
status: 'success',
});
})
.catch((error) => {
if (error) {
toast({
id: ToastID.MODEL_INSTALL_QUEUE_FAILED,
title: `${error.data.detail} `,
status: 'error',
});
}
});
} }
}, [filteredResults, installModel, inplace, t]); }, [filteredResults, installModel, inplace]);
const handleInstallOne = useCallback( const handleInstallOne = useCallback(
(source: string) => { (source: string) => {
installModel({ source, inplace }) installModel({ source, inplace });
.unwrap()
.then((_) => {
toast({
id: ToastID.MODEL_INSTALL_QUEUED,
title: t('toast.modelAddedSimple'),
status: 'success',
});
})
.catch((error) => {
if (error) {
toast({
id: ToastID.MODEL_INSTALL_QUEUE_FAILED,
title: `${error.data.detail} `,
status: 'error',
});
}
});
}, },
[installModel, inplace, t] [installModel, inplace]
); );
return ( return (

View File

@ -1,11 +1,10 @@
import { Badge, Box, Flex, IconButton, Text } from '@invoke-ai/ui-library'; import { Badge, Box, Flex, IconButton, Text } from '@invoke-ai/ui-library';
import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel';
import ModelBaseBadge from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge'; import ModelBaseBadge from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge';
import { toast, ToastID } from 'features/toast/toast';
import { useCallback, useMemo } from 'react'; import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { PiPlusBold } from 'react-icons/pi'; import { PiPlusBold } from 'react-icons/pi';
import type { GetStarterModelsResponse } from 'services/api/endpoints/models'; import type { GetStarterModelsResponse } from 'services/api/endpoints/models';
import { useInstallModelMutation } from 'services/api/endpoints/models';
type Props = { type Props = {
result: GetStarterModelsResponse[number]; result: GetStarterModelsResponse[number];
@ -19,30 +18,13 @@ export const StarterModelsResultItem = ({ result }: Props) => {
} }
return _allSources; return _allSources;
}, [result]); }, [result]);
const [installModel] = useInstallModelMutation(); const [installModel] = useInstallModel();
const handleQuickAdd = useCallback(() => { const onClick = useCallback(() => {
for (const source of allSources) { for (const source of allSources) {
installModel({ source }) installModel({ source });
.unwrap()
.then((_) => {
toast({
id: ToastID.MODEL_INSTALL_QUEUED,
title: t('toast.modelAddedSimple'),
status: 'success',
});
})
.catch((error) => {
if (error) {
toast({
id: ToastID.MODEL_INSTALL_QUEUE_FAILED,
title: `${error.data.detail} `,
status: 'error',
});
}
});
} }
}, [allSources, installModel, t]); }, [allSources, installModel]);
return ( return (
<Flex alignItems="center" justifyContent="space-between" w="100%" gap={3}> <Flex alignItems="center" justifyContent="space-between" w="100%" gap={3}>
@ -58,7 +40,7 @@ export const StarterModelsResultItem = ({ result }: Props) => {
{result.is_installed ? ( {result.is_installed ? (
<Badge>{t('common.installed')}</Badge> <Badge>{t('common.installed')}</Badge>
) : ( ) : (
<IconButton aria-label={t('modelManager.install')} icon={<PiPlusBold />} onClick={handleQuickAdd} size="sm" /> <IconButton aria-label={t('modelManager.install')} icon={<PiPlusBold />} onClick={onClick} size="sm" />
)} )}
</Box> </Box>
</Flex> </Flex>