mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix types for ImportQueue, add QuickAdd for scan models
This commit is contained in:
parent
5496699d6c
commit
53f0090197
@ -4,10 +4,9 @@ import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { t } from 'i18next';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { RiSparklingFill } from 'react-icons/ri';
|
||||
import { useGetModelImportsQuery, usePruneModelImportsMutation } from 'services/api/endpoints/models';
|
||||
|
||||
import { ImportQueueModel } from './ImportQueueModel';
|
||||
import { ImportQueueItem } from './ImportQueueItem';
|
||||
|
||||
export const ImportQueue = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
@ -51,20 +50,15 @@ export const ImportQueue = () => {
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" p={3} h="full">
|
||||
<Flex justifyContent="space-between">
|
||||
<Flex justifyContent="space-between" alignItems="center">
|
||||
<Text>{t('modelManager.importQueue')}</Text>
|
||||
<Button
|
||||
isDisabled={!pruneAvailable}
|
||||
onClick={pruneQueue}
|
||||
tooltip={t('modelManager.pruneTooltip')}
|
||||
rightIcon={<RiSparklingFill />}
|
||||
>
|
||||
<Button size="sm" isDisabled={!pruneAvailable} onClick={pruneQueue} tooltip={t('modelManager.pruneTooltip')}>
|
||||
{t('modelManager.prune')}
|
||||
</Button>
|
||||
</Flex>
|
||||
<Box mt={3} layerStyle="first" p={3} borderRadius="base" w="full" h="full">
|
||||
<Flex direction="column" gap="2">
|
||||
{data?.map((model) => <ImportQueueModel key={model.id} model={model} />)}
|
||||
<Flex flexDir="column-reverse" gap="2">
|
||||
{data?.map((model) => <ImportQueueItem key={model.id} model={model} />)}
|
||||
</Flex>
|
||||
</Box>
|
||||
</Flex>
|
@ -0,0 +1,28 @@
|
||||
import { Badge, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { ModelInstallStatus } from '../../../../../services/api/types';
|
||||
|
||||
const STATUSES = {
|
||||
waiting: { colorScheme: 'cyan', translationKey: 'queue.pending' },
|
||||
downloading: { colorScheme: 'yellow', translationKey: 'queue.in_progress' },
|
||||
running: { colorScheme: 'yellow', translationKey: 'queue.in_progress' },
|
||||
completed: { colorScheme: 'green', translationKey: 'queue.completed' },
|
||||
error: { colorScheme: 'red', translationKey: 'queue.failed' },
|
||||
cancelled: { colorScheme: 'orange', translationKey: 'queue.canceled' },
|
||||
};
|
||||
|
||||
const ImportQueueBadge = ({ status, detail }: { status?: ModelInstallStatus; detail?: string }) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
if (!status) {
|
||||
return <></>;
|
||||
}
|
||||
|
||||
return (
|
||||
<Tooltip label={detail}>
|
||||
<Badge colorScheme={STATUSES[status].colorScheme}>{t(STATUSES[status].translationKey)}</Badge>
|
||||
</Tooltip>
|
||||
);
|
||||
};
|
||||
export default memo(ImportQueueBadge);
|
@ -0,0 +1,136 @@
|
||||
import { Box, Flex, IconButton, Progress, Tag, Text, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { t } from 'i18next';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { PiXBold } from 'react-icons/pi';
|
||||
import { useDeleteModelImportMutation } from 'services/api/endpoints/models';
|
||||
import type { ModelInstallJob, HFModelSource, LocalModelSource, URLModelSource } from 'services/api/types';
|
||||
import ImportQueueBadge from './ImportQueueBadge';
|
||||
|
||||
type ModelListItemProps = {
|
||||
model: ModelInstallJob;
|
||||
};
|
||||
|
||||
const formatBytes = (bytes: number) => {
|
||||
const units = ['b', 'kb', 'mb', 'gb', 'tb'];
|
||||
|
||||
let i = 0;
|
||||
|
||||
for (i; bytes >= 1024 && i < 4; i++) {
|
||||
bytes /= 1024;
|
||||
}
|
||||
|
||||
return `${bytes.toFixed(2)} ${units[i]}`;
|
||||
};
|
||||
|
||||
export const ImportQueueItem = (props: ModelListItemProps) => {
|
||||
const { model } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const [deleteImportModel] = useDeleteModelImportMutation();
|
||||
|
||||
const source = useMemo(() => {
|
||||
if (model.source.type === 'hf') {
|
||||
return model.source as HFModelSource;
|
||||
} else if (model.source.type === 'local') {
|
||||
return model.source as LocalModelSource;
|
||||
} else if (model.source.type === 'url') {
|
||||
return model.source as URLModelSource;
|
||||
} else {
|
||||
return model.source as LocalModelSource;
|
||||
}
|
||||
}, [model.source]);
|
||||
|
||||
const handleDeleteModelImport = useCallback(() => {
|
||||
deleteImportModel(model.id)
|
||||
.unwrap()
|
||||
.then((_) => {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('toast.modelImportCanceled'),
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
})
|
||||
.catch((error) => {
|
||||
if (error) {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: `${error.data.detail} `,
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
});
|
||||
}, [deleteImportModel, model, dispatch]);
|
||||
|
||||
const modelName = useMemo(() => {
|
||||
switch (source.type) {
|
||||
case 'hf':
|
||||
return source.repo_id;
|
||||
case 'url':
|
||||
return source.url;
|
||||
case 'local':
|
||||
return source.path.substring(source.path.lastIndexOf('/') + 1);
|
||||
default:
|
||||
return '';
|
||||
}
|
||||
}, [source]);
|
||||
|
||||
const progressValue = useMemo(() => {
|
||||
if (model.bytes === undefined || model.total_bytes === undefined) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
return (model.bytes / model.total_bytes) * 100;
|
||||
}, [model.bytes, model.total_bytes]);
|
||||
|
||||
const progressString = useMemo(() => {
|
||||
if (model.status !== 'downloading' || model.bytes === undefined || model.total_bytes === undefined) {
|
||||
return '';
|
||||
}
|
||||
return `${formatBytes(model.bytes)} / ${formatBytes(model.total_bytes)}`;
|
||||
}, [model.bytes, model.total_bytes, model.status]);
|
||||
|
||||
return (
|
||||
<Flex gap="2" w="full" alignItems="center" textAlign="center">
|
||||
<Tooltip label={modelName}>
|
||||
<Text w="30%" whiteSpace="nowrap" overflow="hidden" textOverflow="ellipsis">
|
||||
{modelName}
|
||||
</Text>
|
||||
</Tooltip>
|
||||
<Flex flexDir="column" w="50%">
|
||||
<Tooltip label={progressString}>
|
||||
<Progress
|
||||
value={progressValue}
|
||||
isIndeterminate={progressValue === undefined}
|
||||
aria-label={t('accessibility.invokeProgressBar')}
|
||||
h={2}
|
||||
/>
|
||||
</Tooltip>
|
||||
</Flex>
|
||||
<Box w="15%">
|
||||
<ImportQueueBadge status={model.status} />
|
||||
</Box>
|
||||
|
||||
<Box w="10%">
|
||||
{(model.status === 'downloading' || model.status === 'waiting') && (
|
||||
<IconButton
|
||||
isRound={true}
|
||||
size="xs"
|
||||
tooltip={t('modelManager.cancel')}
|
||||
aria-label={t('modelManager.cancel')}
|
||||
icon={<PiXBold />}
|
||||
onClick={handleDeleteModelImport}
|
||||
/>
|
||||
)}
|
||||
</Box>
|
||||
</Flex>
|
||||
);
|
||||
};
|
@ -1,105 +0,0 @@
|
||||
import { Box, Flex, IconButton, Progress, Text } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { t } from 'i18next';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { PiXBold } from 'react-icons/pi';
|
||||
import { useDeleteModelImportMutation } from 'services/api/endpoints/models';
|
||||
import type { ImportModelConfig } from 'services/api/types';
|
||||
|
||||
type ModelListItemProps = {
|
||||
model: ImportModelConfig;
|
||||
};
|
||||
|
||||
export const ImportQueueModel = (props: ModelListItemProps) => {
|
||||
const { model } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const [deleteImportModel] = useDeleteModelImportMutation();
|
||||
|
||||
const handleDeleteModelImport = useCallback(() => {
|
||||
deleteImportModel({ key: model.id })
|
||||
.unwrap()
|
||||
.then((_) => {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('toast.modelImportCanceled'),
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
})
|
||||
.catch((error) => {
|
||||
if (error) {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: `${error.data.detail} `,
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
});
|
||||
}, [deleteImportModel, model, dispatch]);
|
||||
|
||||
const formatBytes = (bytes: number) => {
|
||||
const units = ['b', 'kb', 'mb', 'gb', 'tb'];
|
||||
|
||||
let i = 0;
|
||||
|
||||
for (i; bytes >= 1024 && i < 4; i++) {
|
||||
bytes /= 1024;
|
||||
}
|
||||
|
||||
return `${bytes.toFixed(2)} ${units[i]}`;
|
||||
};
|
||||
|
||||
const modelName = useMemo(() => {
|
||||
return model.source.repo_id || model.source.url || model.source.path.substring(model.source.path.lastIndexOf('/') + 1);
|
||||
}, [model.source]);
|
||||
|
||||
const progressValue = useMemo(() => {
|
||||
return (model.bytes / model.total_bytes) * 100;
|
||||
}, [model.bytes, model.total_bytes]);
|
||||
|
||||
const progressString = useMemo(() => {
|
||||
if (model.status !== 'downloading') {
|
||||
return '--';
|
||||
}
|
||||
return `${formatBytes(model.bytes)} / ${formatBytes(model.total_bytes)}`;
|
||||
}, [model.bytes, model.total_bytes, model.status]);
|
||||
|
||||
return (
|
||||
<Flex gap="2" w="full" alignItems="center" textAlign="center">
|
||||
<Text w="20%" whiteSpace="nowrap" overflow="hidden" textOverflow="ellipsis">
|
||||
{modelName}
|
||||
</Text>
|
||||
<Progress
|
||||
value={progressValue}
|
||||
isIndeterminate={progressValue === undefined}
|
||||
aria-label={t('accessibility.invokeProgressBar')}
|
||||
h={2}
|
||||
w="50%"
|
||||
/>
|
||||
<Text minW="20%" fontSize="xs" w="20%">
|
||||
{progressString}
|
||||
</Text>
|
||||
<Text w="15%">{model.status[0].toUpperCase() + model.status.slice(1)}</Text>
|
||||
<Box w="10%">
|
||||
{(model.status === 'downloading' || model.status === 'waiting') && (
|
||||
<IconButton
|
||||
isRound={true}
|
||||
size="xs"
|
||||
tooltip={t('modelManager.cancel')}
|
||||
aria-label={t('modelManager.cancel')}
|
||||
icon={<PiXBold />}
|
||||
onClick={handleDeleteModelImport}
|
||||
/>
|
||||
)}
|
||||
</Box>
|
||||
</Flex>
|
||||
);
|
||||
};
|
@ -0,0 +1,55 @@
|
||||
import { Flex, Text, Box, Button, IconButton, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { IoAdd } from 'react-icons/io5';
|
||||
import { useAppDispatch } from '../../../../../app/store/storeHooks';
|
||||
import { useImportMainModelsMutation } from '../../../../../services/api/endpoints/models';
|
||||
import { addToast } from '../../../../system/store/systemSlice';
|
||||
import { makeToast } from '../../../../system/util/makeToast';
|
||||
|
||||
export const ScanModelResultItem = ({ result }: { result: string }) => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const [importMainModel, { isLoading }] = useImportMainModelsMutation();
|
||||
|
||||
const handleQuickAdd = () => {
|
||||
importMainModel({ source: result, config: undefined })
|
||||
.unwrap()
|
||||
.then((_) => {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('toast.modelAddedSimple'),
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
})
|
||||
.catch((error) => {
|
||||
if (error) {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: `${error.data.detail} `,
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
return (
|
||||
<Flex justifyContent={'space-between'}>
|
||||
<Flex fontSize="sm" flexDir="column">
|
||||
<Text fontWeight="semibold">{result.split('\\').slice(-1)[0]}</Text>
|
||||
<Text variant="subtext">{result}</Text>
|
||||
</Flex>
|
||||
<Box>
|
||||
<Tooltip label={t('modelManager.quickAdd')}>
|
||||
<IconButton aria-label={t('modelManager.quickAdd')} icon={<IoAdd />} onClick={handleQuickAdd} />
|
||||
</Tooltip>
|
||||
</Box>
|
||||
</Flex>
|
||||
);
|
||||
};
|
@ -1,18 +1,10 @@
|
||||
import {
|
||||
Divider,
|
||||
Flex,
|
||||
Heading,
|
||||
IconButton,
|
||||
Input,
|
||||
InputGroup,
|
||||
InputRightElement,
|
||||
Text,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { Divider, Flex, Heading, IconButton, Input, InputGroup, InputRightElement, Text } from '@invoke-ai/ui-library';
|
||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||
import { t } from 'i18next';
|
||||
import type { ChangeEventHandler} from 'react';
|
||||
import type { ChangeEventHandler } from 'react';
|
||||
import { useCallback, useState } from 'react';
|
||||
import { PiXBold } from 'react-icons/pi';
|
||||
import { ScanModelResultItem } from './ScanModelResultItem';
|
||||
|
||||
export const ScanModelsResults = ({ results }: { results: string[] }) => {
|
||||
const [searchTerm, setSearchTerm] = useState('');
|
||||
@ -67,12 +59,11 @@ export const ScanModelsResults = ({ results }: { results: string[] }) => {
|
||||
</Flex>
|
||||
<Flex height="100%" layerStyle="third" borderRadius="base" p={4} mt={4} mb={4}>
|
||||
<ScrollableContent>
|
||||
{filteredResults.map((result) => (
|
||||
<Flex key={result} fontSize="sm" flexDir="column">
|
||||
<Text fontWeight="semibold">{result.split('\\').slice(-1)[0]}</Text>
|
||||
<Text variant="subtext">{result}</Text>
|
||||
</Flex>
|
||||
))}
|
||||
<Flex flexDir="column" gap={3}>
|
||||
{filteredResults.map((result) => (
|
||||
<ScanModelResultItem key={result} result={result} />
|
||||
))}
|
||||
</Flex>
|
||||
</ScrollableContent>
|
||||
</Flex>
|
||||
</Flex>
|
||||
|
@ -1,9 +1,10 @@
|
||||
import { Box, Flex, Heading, Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/ui-library';
|
||||
|
||||
import { AdvancedImport } from './AddModelPanel/AdvancedImport';
|
||||
import { ImportQueue } from './AddModelPanel/ImportQueue';
|
||||
import { ImportQueue } from './AddModelPanel/ImportQueue/ImportQueue';
|
||||
import { ScanModels } from './AddModelPanel/ScanModels/ScanModels';
|
||||
import { SimpleImport } from './AddModelPanel/SimpleImport';
|
||||
import { ScanModelsForm } from './AddModelPanel/ScanModels/ScanModelsForm';
|
||||
|
||||
export const ImportModels = () => {
|
||||
return (
|
||||
@ -23,10 +24,10 @@ export const ImportModels = () => {
|
||||
<SimpleImport />
|
||||
</TabPanel>
|
||||
<TabPanel height="100%">
|
||||
<AdvancedImport />
|
||||
<AdvancedImport />
|
||||
</TabPanel>
|
||||
<TabPanel height="100%">
|
||||
<ScanModels />
|
||||
<ScanModelsForm />
|
||||
</TabPanel>
|
||||
</TabPanels>
|
||||
</Tabs>
|
||||
|
@ -127,25 +127,25 @@ export const vaeModelsAdapterSelectors = vaeModelsAdapter.getSelectors(undefined
|
||||
|
||||
const buildProvidesTags =
|
||||
<TEntity extends AnyModelConfig>(tagType: (typeof tagTypes)[number]) =>
|
||||
(result: EntityState<TEntity, string> | undefined) => {
|
||||
const tags: ApiTagDescription[] = [{ type: tagType, id: LIST_TAG }, 'Model'];
|
||||
if (result) {
|
||||
tags.push(
|
||||
...result.ids.map((id) => ({
|
||||
type: tagType,
|
||||
id,
|
||||
}))
|
||||
);
|
||||
}
|
||||
(result: EntityState<TEntity, string> | undefined) => {
|
||||
const tags: ApiTagDescription[] = [{ type: tagType, id: LIST_TAG }, 'Model'];
|
||||
if (result) {
|
||||
tags.push(
|
||||
...result.ids.map((id) => ({
|
||||
type: tagType,
|
||||
id,
|
||||
}))
|
||||
);
|
||||
}
|
||||
|
||||
return tags;
|
||||
};
|
||||
return tags;
|
||||
};
|
||||
|
||||
const buildTransformResponse =
|
||||
<T extends AnyModelConfig>(adapter: EntityAdapter<T, string>) =>
|
||||
(response: { models: T[] }) => {
|
||||
return adapter.setAll(adapter.getInitialState(), response.models);
|
||||
};
|
||||
(response: { models: T[] }) => {
|
||||
return adapter.setAll(adapter.getInitialState(), response.models);
|
||||
};
|
||||
|
||||
/**
|
||||
* Builds an endpoint URL for the models router
|
||||
@ -305,10 +305,10 @@ export const modelsApi = api.injectEndpoints({
|
||||
},
|
||||
providesTags: ['ModelImports'],
|
||||
}),
|
||||
deleteModelImport: build.mutation<DeleteImportModelsResponse, DeleteMainModelArg>({
|
||||
query: ({ key }) => {
|
||||
deleteModelImport: build.mutation<DeleteImportModelsResponse, number>({
|
||||
query: (id) => {
|
||||
return {
|
||||
url: buildModelsUrl(`import/${key}`),
|
||||
url: buildModelsUrl(`import/${id}`),
|
||||
method: 'DELETE',
|
||||
};
|
||||
},
|
||||
|
@ -117,6 +117,13 @@ export const isRefinerMainModelModelConfig = (config: AnyModelConfig): config is
|
||||
|
||||
export type MergeModelConfig = S['Body_merge'];
|
||||
export type ImportModelConfig = S['Body_import_model'];
|
||||
export type ModelInstallJob = S['ModelInstallJob']
|
||||
export type ModelInstallStatus = S["InstallStatus"]
|
||||
|
||||
export type HFModelSource = S['HFModelSource'];
|
||||
export type CivitaiModelSource = S['CivitaiModelSource'];
|
||||
export type LocalModelSource = S['LocalModelSource'];
|
||||
export type URLModelSource = S['URLModelSource'];
|
||||
|
||||
// Graphs
|
||||
export type Graph = S['Graph'];
|
||||
|
Loading…
Reference in New Issue
Block a user