fix types for ImportQueue, add QuickAdd for scan models

This commit is contained in:
Mary Hipp 2024-02-23 10:31:27 -05:00 committed by psychedelicious
parent b3beaefa04
commit 7785e8ff79
9 changed files with 261 additions and 154 deletions

View File

@ -4,10 +4,9 @@ import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { t } from 'i18next';
import { useCallback, useMemo } from 'react';
import { RiSparklingFill } from 'react-icons/ri';
import { useGetModelImportsQuery, usePruneModelImportsMutation } from 'services/api/endpoints/models';
import { ImportQueueModel } from './ImportQueueModel';
import { ImportQueueItem } from './ImportQueueItem';
export const ImportQueue = () => {
const dispatch = useAppDispatch();
@ -51,20 +50,15 @@ export const ImportQueue = () => {
return (
<Flex flexDir="column" p={3} h="full">
<Flex justifyContent="space-between">
<Flex justifyContent="space-between" alignItems="center">
<Text>{t('modelManager.importQueue')}</Text>
<Button
isDisabled={!pruneAvailable}
onClick={pruneQueue}
tooltip={t('modelManager.pruneTooltip')}
rightIcon={<RiSparklingFill />}
>
<Button size="sm" isDisabled={!pruneAvailable} onClick={pruneQueue} tooltip={t('modelManager.pruneTooltip')}>
{t('modelManager.prune')}
</Button>
</Flex>
<Box mt={3} layerStyle="first" p={3} borderRadius="base" w="full" h="full">
<Flex direction="column" gap="2">
{data?.map((model) => <ImportQueueModel key={model.id} model={model} />)}
<Flex flexDir="column-reverse" gap="2">
{data?.map((model) => <ImportQueueItem key={model.id} model={model} />)}
</Flex>
</Box>
</Flex>

View File

@ -0,0 +1,28 @@
import { Badge, Tooltip } from '@invoke-ai/ui-library';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { ModelInstallStatus } from '../../../../../services/api/types';
const STATUSES = {
waiting: { colorScheme: 'cyan', translationKey: 'queue.pending' },
downloading: { colorScheme: 'yellow', translationKey: 'queue.in_progress' },
running: { colorScheme: 'yellow', translationKey: 'queue.in_progress' },
completed: { colorScheme: 'green', translationKey: 'queue.completed' },
error: { colorScheme: 'red', translationKey: 'queue.failed' },
cancelled: { colorScheme: 'orange', translationKey: 'queue.canceled' },
};
const ImportQueueBadge = ({ status, detail }: { status?: ModelInstallStatus; detail?: string }) => {
const { t } = useTranslation();
if (!status) {
return <></>;
}
return (
<Tooltip label={detail}>
<Badge colorScheme={STATUSES[status].colorScheme}>{t(STATUSES[status].translationKey)}</Badge>
</Tooltip>
);
};
export default memo(ImportQueueBadge);

View File

@ -0,0 +1,136 @@
import { Box, Flex, IconButton, Progress, Tag, Text, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { t } from 'i18next';
import { useCallback, useMemo } from 'react';
import { PiXBold } from 'react-icons/pi';
import { useDeleteModelImportMutation } from 'services/api/endpoints/models';
import type { ModelInstallJob, HFModelSource, LocalModelSource, URLModelSource } from 'services/api/types';
import ImportQueueBadge from './ImportQueueBadge';
type ModelListItemProps = {
model: ModelInstallJob;
};
const formatBytes = (bytes: number) => {
const units = ['b', 'kb', 'mb', 'gb', 'tb'];
let i = 0;
for (i; bytes >= 1024 && i < 4; i++) {
bytes /= 1024;
}
return `${bytes.toFixed(2)} ${units[i]}`;
};
export const ImportQueueItem = (props: ModelListItemProps) => {
const { model } = props;
const dispatch = useAppDispatch();
const [deleteImportModel] = useDeleteModelImportMutation();
const source = useMemo(() => {
if (model.source.type === 'hf') {
return model.source as HFModelSource;
} else if (model.source.type === 'local') {
return model.source as LocalModelSource;
} else if (model.source.type === 'url') {
return model.source as URLModelSource;
} else {
return model.source as LocalModelSource;
}
}, [model.source]);
const handleDeleteModelImport = useCallback(() => {
deleteImportModel(model.id)
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: t('toast.modelImportCanceled'),
status: 'success',
})
)
);
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: `${error.data.detail} `,
status: 'error',
})
)
);
}
});
}, [deleteImportModel, model, dispatch]);
const modelName = useMemo(() => {
switch (source.type) {
case 'hf':
return source.repo_id;
case 'url':
return source.url;
case 'local':
return source.path.substring(source.path.lastIndexOf('/') + 1);
default:
return '';
}
}, [source]);
const progressValue = useMemo(() => {
if (model.bytes === undefined || model.total_bytes === undefined) {
return 0;
}
return (model.bytes / model.total_bytes) * 100;
}, [model.bytes, model.total_bytes]);
const progressString = useMemo(() => {
if (model.status !== 'downloading' || model.bytes === undefined || model.total_bytes === undefined) {
return '';
}
return `${formatBytes(model.bytes)} / ${formatBytes(model.total_bytes)}`;
}, [model.bytes, model.total_bytes, model.status]);
return (
<Flex gap="2" w="full" alignItems="center" textAlign="center">
<Tooltip label={modelName}>
<Text w="30%" whiteSpace="nowrap" overflow="hidden" textOverflow="ellipsis">
{modelName}
</Text>
</Tooltip>
<Flex flexDir="column" w="50%">
<Tooltip label={progressString}>
<Progress
value={progressValue}
isIndeterminate={progressValue === undefined}
aria-label={t('accessibility.invokeProgressBar')}
h={2}
/>
</Tooltip>
</Flex>
<Box w="15%">
<ImportQueueBadge status={model.status} />
</Box>
<Box w="10%">
{(model.status === 'downloading' || model.status === 'waiting') && (
<IconButton
isRound={true}
size="xs"
tooltip={t('modelManager.cancel')}
aria-label={t('modelManager.cancel')}
icon={<PiXBold />}
onClick={handleDeleteModelImport}
/>
)}
</Box>
</Flex>
);
};

View File

@ -1,105 +0,0 @@
import { Box, Flex, IconButton, Progress, Text } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { t } from 'i18next';
import { useCallback, useMemo } from 'react';
import { PiXBold } from 'react-icons/pi';
import { useDeleteModelImportMutation } from 'services/api/endpoints/models';
import type { ImportModelConfig } from 'services/api/types';
type ModelListItemProps = {
model: ImportModelConfig;
};
export const ImportQueueModel = (props: ModelListItemProps) => {
const { model } = props;
const dispatch = useAppDispatch();
const [deleteImportModel] = useDeleteModelImportMutation();
const handleDeleteModelImport = useCallback(() => {
deleteImportModel({ key: model.id })
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: t('toast.modelImportCanceled'),
status: 'success',
})
)
);
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: `${error.data.detail} `,
status: 'error',
})
)
);
}
});
}, [deleteImportModel, model, dispatch]);
const formatBytes = (bytes: number) => {
const units = ['b', 'kb', 'mb', 'gb', 'tb'];
let i = 0;
for (i; bytes >= 1024 && i < 4; i++) {
bytes /= 1024;
}
return `${bytes.toFixed(2)} ${units[i]}`;
};
const modelName = useMemo(() => {
return model.source.repo_id || model.source.url || model.source.path.substring(model.source.path.lastIndexOf('/') + 1);
}, [model.source]);
const progressValue = useMemo(() => {
return (model.bytes / model.total_bytes) * 100;
}, [model.bytes, model.total_bytes]);
const progressString = useMemo(() => {
if (model.status !== 'downloading') {
return '--';
}
return `${formatBytes(model.bytes)} / ${formatBytes(model.total_bytes)}`;
}, [model.bytes, model.total_bytes, model.status]);
return (
<Flex gap="2" w="full" alignItems="center" textAlign="center">
<Text w="20%" whiteSpace="nowrap" overflow="hidden" textOverflow="ellipsis">
{modelName}
</Text>
<Progress
value={progressValue}
isIndeterminate={progressValue === undefined}
aria-label={t('accessibility.invokeProgressBar')}
h={2}
w="50%"
/>
<Text minW="20%" fontSize="xs" w="20%">
{progressString}
</Text>
<Text w="15%">{model.status[0].toUpperCase() + model.status.slice(1)}</Text>
<Box w="10%">
{(model.status === 'downloading' || model.status === 'waiting') && (
<IconButton
isRound={true}
size="xs"
tooltip={t('modelManager.cancel')}
aria-label={t('modelManager.cancel')}
icon={<PiXBold />}
onClick={handleDeleteModelImport}
/>
)}
</Box>
</Flex>
);
};

View File

@ -0,0 +1,55 @@
import { Flex, Text, Box, Button, IconButton, Tooltip } from '@invoke-ai/ui-library';
import { useTranslation } from 'react-i18next';
import { IoAdd } from 'react-icons/io5';
import { useAppDispatch } from '../../../../../app/store/storeHooks';
import { useImportMainModelsMutation } from '../../../../../services/api/endpoints/models';
import { addToast } from '../../../../system/store/systemSlice';
import { makeToast } from '../../../../system/util/makeToast';
export const ScanModelResultItem = ({ result }: { result: string }) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const [importMainModel, { isLoading }] = useImportMainModelsMutation();
const handleQuickAdd = () => {
importMainModel({ source: result, config: undefined })
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: t('toast.modelAddedSimple'),
status: 'success',
})
)
);
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: `${error.data.detail} `,
status: 'error',
})
)
);
}
});
};
return (
<Flex justifyContent={'space-between'}>
<Flex fontSize="sm" flexDir="column">
<Text fontWeight="semibold">{result.split('\\').slice(-1)[0]}</Text>
<Text variant="subtext">{result}</Text>
</Flex>
<Box>
<Tooltip label={t('modelManager.quickAdd')}>
<IconButton aria-label={t('modelManager.quickAdd')} icon={<IoAdd />} onClick={handleQuickAdd} />
</Tooltip>
</Box>
</Flex>
);
};

View File

@ -1,18 +1,10 @@
import {
Divider,
Flex,
Heading,
IconButton,
Input,
InputGroup,
InputRightElement,
Text,
} from '@invoke-ai/ui-library';
import { Divider, Flex, Heading, IconButton, Input, InputGroup, InputRightElement, Text } from '@invoke-ai/ui-library';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import { t } from 'i18next';
import type { ChangeEventHandler} from 'react';
import type { ChangeEventHandler } from 'react';
import { useCallback, useState } from 'react';
import { PiXBold } from 'react-icons/pi';
import { ScanModelResultItem } from './ScanModelResultItem';
export const ScanModelsResults = ({ results }: { results: string[] }) => {
const [searchTerm, setSearchTerm] = useState('');
@ -67,12 +59,11 @@ export const ScanModelsResults = ({ results }: { results: string[] }) => {
</Flex>
<Flex height="100%" layerStyle="third" borderRadius="base" p={4} mt={4} mb={4}>
<ScrollableContent>
{filteredResults.map((result) => (
<Flex key={result} fontSize="sm" flexDir="column">
<Text fontWeight="semibold">{result.split('\\').slice(-1)[0]}</Text>
<Text variant="subtext">{result}</Text>
</Flex>
))}
<Flex flexDir="column" gap={3}>
{filteredResults.map((result) => (
<ScanModelResultItem key={result} result={result} />
))}
</Flex>
</ScrollableContent>
</Flex>
</Flex>

View File

@ -1,9 +1,10 @@
import { Box, Flex, Heading, Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/ui-library';
import { AdvancedImport } from './AddModelPanel/AdvancedImport';
import { ImportQueue } from './AddModelPanel/ImportQueue';
import { ImportQueue } from './AddModelPanel/ImportQueue/ImportQueue';
import { ScanModels } from './AddModelPanel/ScanModels/ScanModels';
import { SimpleImport } from './AddModelPanel/SimpleImport';
import { ScanModelsForm } from './AddModelPanel/ScanModels/ScanModelsForm';
export const ImportModels = () => {
return (
@ -23,10 +24,10 @@ export const ImportModels = () => {
<SimpleImport />
</TabPanel>
<TabPanel height="100%">
<AdvancedImport />
<AdvancedImport />
</TabPanel>
<TabPanel height="100%">
<ScanModels />
<ScanModelsForm />
</TabPanel>
</TabPanels>
</Tabs>

View File

@ -127,25 +127,25 @@ export const vaeModelsAdapterSelectors = vaeModelsAdapter.getSelectors(undefined
const buildProvidesTags =
<TEntity extends AnyModelConfig>(tagType: (typeof tagTypes)[number]) =>
(result: EntityState<TEntity, string> | undefined) => {
const tags: ApiTagDescription[] = [{ type: tagType, id: LIST_TAG }, 'Model'];
if (result) {
tags.push(
...result.ids.map((id) => ({
type: tagType,
id,
}))
);
}
(result: EntityState<TEntity, string> | undefined) => {
const tags: ApiTagDescription[] = [{ type: tagType, id: LIST_TAG }, 'Model'];
if (result) {
tags.push(
...result.ids.map((id) => ({
type: tagType,
id,
}))
);
}
return tags;
};
return tags;
};
const buildTransformResponse =
<T extends AnyModelConfig>(adapter: EntityAdapter<T, string>) =>
(response: { models: T[] }) => {
return adapter.setAll(adapter.getInitialState(), response.models);
};
(response: { models: T[] }) => {
return adapter.setAll(adapter.getInitialState(), response.models);
};
/**
* Builds an endpoint URL for the models router
@ -305,10 +305,10 @@ export const modelsApi = api.injectEndpoints({
},
providesTags: ['ModelImports'],
}),
deleteModelImport: build.mutation<DeleteImportModelsResponse, DeleteMainModelArg>({
query: ({ key }) => {
deleteModelImport: build.mutation<DeleteImportModelsResponse, number>({
query: (id) => {
return {
url: buildModelsUrl(`import/${key}`),
url: buildModelsUrl(`import/${id}`),
method: 'DELETE',
};
},

View File

@ -117,6 +117,13 @@ export const isRefinerMainModelModelConfig = (config: AnyModelConfig): config is
export type MergeModelConfig = S['Body_merge'];
export type ImportModelConfig = S['Body_import_model'];
export type ModelInstallJob = S['ModelInstallJob']
export type ModelInstallStatus = S["InstallStatus"]
export type HFModelSource = S['HFModelSource'];
export type CivitaiModelSource = S['CivitaiModelSource'];
export type LocalModelSource = S['LocalModelSource'];
export type URLModelSource = S['URLModelSource'];
// Graphs
export type Graph = S['Graph'];