fix types for ImportQueue, add QuickAdd for scan models

This commit is contained in:
Mary Hipp 2024-02-23 10:31:27 -05:00 committed by psychedelicious
parent b3beaefa04
commit 7785e8ff79
9 changed files with 261 additions and 154 deletions

View File

@ -4,10 +4,9 @@ import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast'; import { makeToast } from 'features/system/util/makeToast';
import { t } from 'i18next'; import { t } from 'i18next';
import { useCallback, useMemo } from 'react'; import { useCallback, useMemo } from 'react';
import { RiSparklingFill } from 'react-icons/ri';
import { useGetModelImportsQuery, usePruneModelImportsMutation } from 'services/api/endpoints/models'; import { useGetModelImportsQuery, usePruneModelImportsMutation } from 'services/api/endpoints/models';
import { ImportQueueModel } from './ImportQueueModel'; import { ImportQueueItem } from './ImportQueueItem';
export const ImportQueue = () => { export const ImportQueue = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
@ -51,20 +50,15 @@ export const ImportQueue = () => {
return ( return (
<Flex flexDir="column" p={3} h="full"> <Flex flexDir="column" p={3} h="full">
<Flex justifyContent="space-between"> <Flex justifyContent="space-between" alignItems="center">
<Text>{t('modelManager.importQueue')}</Text> <Text>{t('modelManager.importQueue')}</Text>
<Button <Button size="sm" isDisabled={!pruneAvailable} onClick={pruneQueue} tooltip={t('modelManager.pruneTooltip')}>
isDisabled={!pruneAvailable}
onClick={pruneQueue}
tooltip={t('modelManager.pruneTooltip')}
rightIcon={<RiSparklingFill />}
>
{t('modelManager.prune')} {t('modelManager.prune')}
</Button> </Button>
</Flex> </Flex>
<Box mt={3} layerStyle="first" p={3} borderRadius="base" w="full" h="full"> <Box mt={3} layerStyle="first" p={3} borderRadius="base" w="full" h="full">
<Flex direction="column" gap="2"> <Flex flexDir="column-reverse" gap="2">
{data?.map((model) => <ImportQueueModel key={model.id} model={model} />)} {data?.map((model) => <ImportQueueItem key={model.id} model={model} />)}
</Flex> </Flex>
</Box> </Box>
</Flex> </Flex>

View File

@ -0,0 +1,28 @@
import { Badge, Tooltip } from '@invoke-ai/ui-library';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { ModelInstallStatus } from '../../../../../services/api/types';
const STATUSES = {
waiting: { colorScheme: 'cyan', translationKey: 'queue.pending' },
downloading: { colorScheme: 'yellow', translationKey: 'queue.in_progress' },
running: { colorScheme: 'yellow', translationKey: 'queue.in_progress' },
completed: { colorScheme: 'green', translationKey: 'queue.completed' },
error: { colorScheme: 'red', translationKey: 'queue.failed' },
cancelled: { colorScheme: 'orange', translationKey: 'queue.canceled' },
};
const ImportQueueBadge = ({ status, detail }: { status?: ModelInstallStatus; detail?: string }) => {
const { t } = useTranslation();
if (!status) {
return <></>;
}
return (
<Tooltip label={detail}>
<Badge colorScheme={STATUSES[status].colorScheme}>{t(STATUSES[status].translationKey)}</Badge>
</Tooltip>
);
};
export default memo(ImportQueueBadge);

View File

@ -0,0 +1,136 @@
import { Box, Flex, IconButton, Progress, Tag, Text, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { t } from 'i18next';
import { useCallback, useMemo } from 'react';
import { PiXBold } from 'react-icons/pi';
import { useDeleteModelImportMutation } from 'services/api/endpoints/models';
import type { ModelInstallJob, HFModelSource, LocalModelSource, URLModelSource } from 'services/api/types';
import ImportQueueBadge from './ImportQueueBadge';
type ModelListItemProps = {
model: ModelInstallJob;
};
const formatBytes = (bytes: number) => {
const units = ['b', 'kb', 'mb', 'gb', 'tb'];
let i = 0;
for (i; bytes >= 1024 && i < 4; i++) {
bytes /= 1024;
}
return `${bytes.toFixed(2)} ${units[i]}`;
};
export const ImportQueueItem = (props: ModelListItemProps) => {
const { model } = props;
const dispatch = useAppDispatch();
const [deleteImportModel] = useDeleteModelImportMutation();
const source = useMemo(() => {
if (model.source.type === 'hf') {
return model.source as HFModelSource;
} else if (model.source.type === 'local') {
return model.source as LocalModelSource;
} else if (model.source.type === 'url') {
return model.source as URLModelSource;
} else {
return model.source as LocalModelSource;
}
}, [model.source]);
const handleDeleteModelImport = useCallback(() => {
deleteImportModel(model.id)
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: t('toast.modelImportCanceled'),
status: 'success',
})
)
);
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: `${error.data.detail} `,
status: 'error',
})
)
);
}
});
}, [deleteImportModel, model, dispatch]);
const modelName = useMemo(() => {
switch (source.type) {
case 'hf':
return source.repo_id;
case 'url':
return source.url;
case 'local':
return source.path.substring(source.path.lastIndexOf('/') + 1);
default:
return '';
}
}, [source]);
const progressValue = useMemo(() => {
if (model.bytes === undefined || model.total_bytes === undefined) {
return 0;
}
return (model.bytes / model.total_bytes) * 100;
}, [model.bytes, model.total_bytes]);
const progressString = useMemo(() => {
if (model.status !== 'downloading' || model.bytes === undefined || model.total_bytes === undefined) {
return '';
}
return `${formatBytes(model.bytes)} / ${formatBytes(model.total_bytes)}`;
}, [model.bytes, model.total_bytes, model.status]);
return (
<Flex gap="2" w="full" alignItems="center" textAlign="center">
<Tooltip label={modelName}>
<Text w="30%" whiteSpace="nowrap" overflow="hidden" textOverflow="ellipsis">
{modelName}
</Text>
</Tooltip>
<Flex flexDir="column" w="50%">
<Tooltip label={progressString}>
<Progress
value={progressValue}
isIndeterminate={progressValue === undefined}
aria-label={t('accessibility.invokeProgressBar')}
h={2}
/>
</Tooltip>
</Flex>
<Box w="15%">
<ImportQueueBadge status={model.status} />
</Box>
<Box w="10%">
{(model.status === 'downloading' || model.status === 'waiting') && (
<IconButton
isRound={true}
size="xs"
tooltip={t('modelManager.cancel')}
aria-label={t('modelManager.cancel')}
icon={<PiXBold />}
onClick={handleDeleteModelImport}
/>
)}
</Box>
</Flex>
);
};

View File

@ -1,105 +0,0 @@
import { Box, Flex, IconButton, Progress, Text } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { t } from 'i18next';
import { useCallback, useMemo } from 'react';
import { PiXBold } from 'react-icons/pi';
import { useDeleteModelImportMutation } from 'services/api/endpoints/models';
import type { ImportModelConfig } from 'services/api/types';
type ModelListItemProps = {
model: ImportModelConfig;
};
export const ImportQueueModel = (props: ModelListItemProps) => {
const { model } = props;
const dispatch = useAppDispatch();
const [deleteImportModel] = useDeleteModelImportMutation();
const handleDeleteModelImport = useCallback(() => {
deleteImportModel({ key: model.id })
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: t('toast.modelImportCanceled'),
status: 'success',
})
)
);
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: `${error.data.detail} `,
status: 'error',
})
)
);
}
});
}, [deleteImportModel, model, dispatch]);
const formatBytes = (bytes: number) => {
const units = ['b', 'kb', 'mb', 'gb', 'tb'];
let i = 0;
for (i; bytes >= 1024 && i < 4; i++) {
bytes /= 1024;
}
return `${bytes.toFixed(2)} ${units[i]}`;
};
const modelName = useMemo(() => {
return model.source.repo_id || model.source.url || model.source.path.substring(model.source.path.lastIndexOf('/') + 1);
}, [model.source]);
const progressValue = useMemo(() => {
return (model.bytes / model.total_bytes) * 100;
}, [model.bytes, model.total_bytes]);
const progressString = useMemo(() => {
if (model.status !== 'downloading') {
return '--';
}
return `${formatBytes(model.bytes)} / ${formatBytes(model.total_bytes)}`;
}, [model.bytes, model.total_bytes, model.status]);
return (
<Flex gap="2" w="full" alignItems="center" textAlign="center">
<Text w="20%" whiteSpace="nowrap" overflow="hidden" textOverflow="ellipsis">
{modelName}
</Text>
<Progress
value={progressValue}
isIndeterminate={progressValue === undefined}
aria-label={t('accessibility.invokeProgressBar')}
h={2}
w="50%"
/>
<Text minW="20%" fontSize="xs" w="20%">
{progressString}
</Text>
<Text w="15%">{model.status[0].toUpperCase() + model.status.slice(1)}</Text>
<Box w="10%">
{(model.status === 'downloading' || model.status === 'waiting') && (
<IconButton
isRound={true}
size="xs"
tooltip={t('modelManager.cancel')}
aria-label={t('modelManager.cancel')}
icon={<PiXBold />}
onClick={handleDeleteModelImport}
/>
)}
</Box>
</Flex>
);
};

View File

@ -0,0 +1,55 @@
import { Flex, Text, Box, Button, IconButton, Tooltip } from '@invoke-ai/ui-library';
import { useTranslation } from 'react-i18next';
import { IoAdd } from 'react-icons/io5';
import { useAppDispatch } from '../../../../../app/store/storeHooks';
import { useImportMainModelsMutation } from '../../../../../services/api/endpoints/models';
import { addToast } from '../../../../system/store/systemSlice';
import { makeToast } from '../../../../system/util/makeToast';
export const ScanModelResultItem = ({ result }: { result: string }) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const [importMainModel, { isLoading }] = useImportMainModelsMutation();
const handleQuickAdd = () => {
importMainModel({ source: result, config: undefined })
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: t('toast.modelAddedSimple'),
status: 'success',
})
)
);
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: `${error.data.detail} `,
status: 'error',
})
)
);
}
});
};
return (
<Flex justifyContent={'space-between'}>
<Flex fontSize="sm" flexDir="column">
<Text fontWeight="semibold">{result.split('\\').slice(-1)[0]}</Text>
<Text variant="subtext">{result}</Text>
</Flex>
<Box>
<Tooltip label={t('modelManager.quickAdd')}>
<IconButton aria-label={t('modelManager.quickAdd')} icon={<IoAdd />} onClick={handleQuickAdd} />
</Tooltip>
</Box>
</Flex>
);
};

View File

@ -1,18 +1,10 @@
import { import { Divider, Flex, Heading, IconButton, Input, InputGroup, InputRightElement, Text } from '@invoke-ai/ui-library';
Divider,
Flex,
Heading,
IconButton,
Input,
InputGroup,
InputRightElement,
Text,
} from '@invoke-ai/ui-library';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent'; import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import { t } from 'i18next'; import { t } from 'i18next';
import type { ChangeEventHandler} from 'react'; import type { ChangeEventHandler } from 'react';
import { useCallback, useState } from 'react'; import { useCallback, useState } from 'react';
import { PiXBold } from 'react-icons/pi'; import { PiXBold } from 'react-icons/pi';
import { ScanModelResultItem } from './ScanModelResultItem';
export const ScanModelsResults = ({ results }: { results: string[] }) => { export const ScanModelsResults = ({ results }: { results: string[] }) => {
const [searchTerm, setSearchTerm] = useState(''); const [searchTerm, setSearchTerm] = useState('');
@ -67,12 +59,11 @@ export const ScanModelsResults = ({ results }: { results: string[] }) => {
</Flex> </Flex>
<Flex height="100%" layerStyle="third" borderRadius="base" p={4} mt={4} mb={4}> <Flex height="100%" layerStyle="third" borderRadius="base" p={4} mt={4} mb={4}>
<ScrollableContent> <ScrollableContent>
{filteredResults.map((result) => ( <Flex flexDir="column" gap={3}>
<Flex key={result} fontSize="sm" flexDir="column"> {filteredResults.map((result) => (
<Text fontWeight="semibold">{result.split('\\').slice(-1)[0]}</Text> <ScanModelResultItem key={result} result={result} />
<Text variant="subtext">{result}</Text> ))}
</Flex> </Flex>
))}
</ScrollableContent> </ScrollableContent>
</Flex> </Flex>
</Flex> </Flex>

View File

@ -1,9 +1,10 @@
import { Box, Flex, Heading, Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/ui-library'; import { Box, Flex, Heading, Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/ui-library';
import { AdvancedImport } from './AddModelPanel/AdvancedImport'; import { AdvancedImport } from './AddModelPanel/AdvancedImport';
import { ImportQueue } from './AddModelPanel/ImportQueue'; import { ImportQueue } from './AddModelPanel/ImportQueue/ImportQueue';
import { ScanModels } from './AddModelPanel/ScanModels/ScanModels'; import { ScanModels } from './AddModelPanel/ScanModels/ScanModels';
import { SimpleImport } from './AddModelPanel/SimpleImport'; import { SimpleImport } from './AddModelPanel/SimpleImport';
import { ScanModelsForm } from './AddModelPanel/ScanModels/ScanModelsForm';
export const ImportModels = () => { export const ImportModels = () => {
return ( return (
@ -23,10 +24,10 @@ export const ImportModels = () => {
<SimpleImport /> <SimpleImport />
</TabPanel> </TabPanel>
<TabPanel height="100%"> <TabPanel height="100%">
<AdvancedImport /> <AdvancedImport />
</TabPanel> </TabPanel>
<TabPanel height="100%"> <TabPanel height="100%">
<ScanModels /> <ScanModelsForm />
</TabPanel> </TabPanel>
</TabPanels> </TabPanels>
</Tabs> </Tabs>

View File

@ -127,25 +127,25 @@ export const vaeModelsAdapterSelectors = vaeModelsAdapter.getSelectors(undefined
const buildProvidesTags = const buildProvidesTags =
<TEntity extends AnyModelConfig>(tagType: (typeof tagTypes)[number]) => <TEntity extends AnyModelConfig>(tagType: (typeof tagTypes)[number]) =>
(result: EntityState<TEntity, string> | undefined) => { (result: EntityState<TEntity, string> | undefined) => {
const tags: ApiTagDescription[] = [{ type: tagType, id: LIST_TAG }, 'Model']; const tags: ApiTagDescription[] = [{ type: tagType, id: LIST_TAG }, 'Model'];
if (result) { if (result) {
tags.push( tags.push(
...result.ids.map((id) => ({ ...result.ids.map((id) => ({
type: tagType, type: tagType,
id, id,
})) }))
); );
} }
return tags; return tags;
}; };
const buildTransformResponse = const buildTransformResponse =
<T extends AnyModelConfig>(adapter: EntityAdapter<T, string>) => <T extends AnyModelConfig>(adapter: EntityAdapter<T, string>) =>
(response: { models: T[] }) => { (response: { models: T[] }) => {
return adapter.setAll(adapter.getInitialState(), response.models); return adapter.setAll(adapter.getInitialState(), response.models);
}; };
/** /**
* Builds an endpoint URL for the models router * Builds an endpoint URL for the models router
@ -305,10 +305,10 @@ export const modelsApi = api.injectEndpoints({
}, },
providesTags: ['ModelImports'], providesTags: ['ModelImports'],
}), }),
deleteModelImport: build.mutation<DeleteImportModelsResponse, DeleteMainModelArg>({ deleteModelImport: build.mutation<DeleteImportModelsResponse, number>({
query: ({ key }) => { query: (id) => {
return { return {
url: buildModelsUrl(`import/${key}`), url: buildModelsUrl(`import/${id}`),
method: 'DELETE', method: 'DELETE',
}; };
}, },

View File

@ -117,6 +117,13 @@ export const isRefinerMainModelModelConfig = (config: AnyModelConfig): config is
export type MergeModelConfig = S['Body_merge']; export type MergeModelConfig = S['Body_merge'];
export type ImportModelConfig = S['Body_import_model']; export type ImportModelConfig = S['Body_import_model'];
export type ModelInstallJob = S['ModelInstallJob']
export type ModelInstallStatus = S["InstallStatus"]
export type HFModelSource = S['HFModelSource'];
export type CivitaiModelSource = S['CivitaiModelSource'];
export type LocalModelSource = S['LocalModelSource'];
export type URLModelSource = S['URLModelSource'];
// Graphs // Graphs
export type Graph = S['Graph']; export type Graph = S['Graph'];