mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
model list, filtering, searching
This commit is contained in:
parent
358cac9674
commit
3a8d5dc349
@ -16,6 +16,7 @@ import { galleryPersistConfig, gallerySlice } from 'features/gallery/store/galle
|
||||
import { hrfPersistConfig, hrfSlice } from 'features/hrf/store/hrfSlice';
|
||||
import { loraPersistConfig, loraSlice } from 'features/lora/store/loraSlice';
|
||||
import { modelManagerPersistConfig, modelManagerSlice } from 'features/modelManager/store/modelManagerSlice';
|
||||
import { modelManagerV2Slice } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||
import { nodesPersistConfig, nodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { workflowPersistConfig, workflowSlice } from 'features/nodes/store/workflowSlice';
|
||||
import { generationPersistConfig, generationSlice } from 'features/parameters/store/generationSlice';
|
||||
@ -55,6 +56,7 @@ const allReducers = {
|
||||
[changeBoardModalSlice.name]: changeBoardModalSlice.reducer,
|
||||
[loraSlice.name]: loraSlice.reducer,
|
||||
[modelManagerSlice.name]: modelManagerSlice.reducer,
|
||||
[modelManagerV2Slice.name]: modelManagerV2Slice.reducer,
|
||||
[sdxlSlice.name]: sdxlSlice.reducer,
|
||||
[queueSlice.name]: queueSlice.reducer,
|
||||
[workflowSlice.name]: workflowSlice.reducer,
|
||||
|
@ -1,10 +1,10 @@
|
||||
import { Button, ButtonGroup, Flex, Text } from '@invoke-ai/ui-library';
|
||||
import { memo, useCallback, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetModelImportsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
import AdvancedAddModels from './AdvancedAddModels';
|
||||
import SimpleAddModels from './SimpleAddModels';
|
||||
import { useGetModelImportsQuery } from '../../../../services/api/endpoints/models';
|
||||
|
||||
const AddModels = () => {
|
||||
const { t } = useTranslation();
|
||||
|
@ -3,12 +3,12 @@ import { memo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { ALL_BASE_MODELS } from 'services/api/constants';
|
||||
import { useGetLoRAModelsQuery, useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||
import type { DiffusersModelConfig, LoRAConfig, MainModelConfig } from 'services/api/types';
|
||||
|
||||
import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit';
|
||||
import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit';
|
||||
import LoRAModelEdit from './ModelManagerPanel/LoRAModelEdit';
|
||||
import ModelList from './ModelManagerPanel/ModelList';
|
||||
import { DiffusersModelConfig, LoRAConfig, MainModelConfig } from '../../../services/api/types';
|
||||
|
||||
const ModelManagerPanel = () => {
|
||||
const [selectedModelId, setSelectedModelId] = useState<string>();
|
||||
|
@ -22,8 +22,9 @@ import type { SubmitHandler } from 'react-hook-form';
|
||||
import { useForm } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetCheckpointConfigsQuery, useUpdateModelsMutation } from 'services/api/endpoints/models';
|
||||
import type { CheckpointModelConfig } from 'services/api/types';
|
||||
|
||||
import ModelConvert from './ModelConvert';
|
||||
import { CheckpointModelConfig } from '../../../../services/api/types';
|
||||
|
||||
type CheckpointModelEditProps = {
|
||||
model: CheckpointModelConfig;
|
||||
|
@ -9,8 +9,8 @@ import { memo, useCallback } from 'react';
|
||||
import type { SubmitHandler } from 'react-hook-form';
|
||||
import { useForm } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useUpdateModelsMutation } from 'services/api/endpoints/models';
|
||||
import type { DiffusersModelConfig } from 'services/api/types';
|
||||
import { useUpdateModelsMutation } from '../../../../services/api/endpoints/models';
|
||||
|
||||
type DiffusersModelEditProps = {
|
||||
model: DiffusersModelConfig;
|
||||
|
@ -8,6 +8,7 @@ import { memo, useCallback } from 'react';
|
||||
import type { SubmitHandler } from 'react-hook-form';
|
||||
import { useForm } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useUpdateModelsMutation } from 'services/api/endpoints/models';
|
||||
import type { LoRAModelConfig } from 'services/api/types';
|
||||
|
||||
type LoRAModelEditProps = {
|
||||
|
@ -7,9 +7,9 @@ import { useTranslation } from 'react-i18next';
|
||||
import { ALL_BASE_MODELS } from 'services/api/constants';
|
||||
// import type { LoRAConfig, MainModelConfig } from 'services/api/endpoints/models';
|
||||
import { useGetLoRAModelsQuery, useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||
import type { LoRAConfig, MainModelConfig } from 'services/api/types';
|
||||
|
||||
import ModelListItem from './ModelListItem';
|
||||
import { LoRAConfig, MainModelConfig } from '../../../../services/api/types';
|
||||
|
||||
type ModelListProps = {
|
||||
selectedModelId: string | undefined;
|
||||
|
@ -16,7 +16,7 @@ import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiTrashSimpleBold } from 'react-icons/pi';
|
||||
import { useDeleteModelsMutation } from 'services/api/endpoints/models';
|
||||
import { LoRAConfig, MainModelConfig } from '../../../../services/api/types';
|
||||
import type { LoRAConfig, MainModelConfig } from 'services/api/types';
|
||||
|
||||
type ModelListItemProps = {
|
||||
model: MainModelConfig | LoRAConfig;
|
||||
|
@ -0,0 +1,54 @@
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
|
||||
|
||||
type ModelManagerState = {
|
||||
_version: 1;
|
||||
selectedModelKey: string | null;
|
||||
searchTerm: string;
|
||||
filteredModelType: string | null;
|
||||
};
|
||||
|
||||
export const initialModelManagerState: ModelManagerState = {
|
||||
_version: 1,
|
||||
selectedModelKey: null,
|
||||
filteredModelType: null,
|
||||
searchTerm: ""
|
||||
};
|
||||
|
||||
export const modelManagerV2Slice = createSlice({
|
||||
name: 'modelmanagerV2',
|
||||
initialState: initialModelManagerState,
|
||||
reducers: {
|
||||
setSelectedModelKey: (state, action: PayloadAction<string | null>) => {
|
||||
state.selectedModelKey = action.payload;
|
||||
},
|
||||
setSearchTerm: (state, action: PayloadAction<string>) => {
|
||||
state.searchTerm = action.payload;
|
||||
},
|
||||
|
||||
setFilteredModelType: (state, action: PayloadAction<string | null>) => {
|
||||
state.filteredModelType = action.payload;
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
export const { setSelectedModelKey, setSearchTerm, setFilteredModelType } = modelManagerV2Slice.actions;
|
||||
|
||||
export const selectModelManagerSlice = (state: RootState) => state.modelmanager;
|
||||
|
||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||
export const migrateModelManagerState = (state: any): any => {
|
||||
if (!('_version' in state)) {
|
||||
state._version = 1;
|
||||
}
|
||||
return state;
|
||||
};
|
||||
|
||||
export const modelManagerPersistConfig: PersistConfig<ModelManagerState> = {
|
||||
name: modelManagerV2Slice.name,
|
||||
initialState: initialModelManagerState,
|
||||
migrate: migrateModelManagerState,
|
||||
persistDenylist: [],
|
||||
};
|
@ -1,17 +1,8 @@
|
||||
import {
|
||||
Box,
|
||||
Button,
|
||||
Flex,
|
||||
Heading,
|
||||
IconButton,
|
||||
Input,
|
||||
InputGroup,
|
||||
InputRightElement,
|
||||
Spacer,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { t } from 'i18next';
|
||||
import { PiXBold } from 'react-icons/pi';
|
||||
import { SyncModelsIconButton } from '../../modelManager/components/SyncModels/SyncModelsIconButton';
|
||||
import { Box, Button, Flex, Heading } from '@invoke-ai/ui-library';
|
||||
import { SyncModelsIconButton } from 'features/modelManager/components/SyncModels/SyncModelsIconButton';
|
||||
|
||||
import ModelList from './ModelManagerPanel/ModelList';
|
||||
import { ModelListNavigation } from './ModelManagerPanel/ModelListNavigation';
|
||||
|
||||
export const ModelManager = () => {
|
||||
return (
|
||||
@ -27,17 +18,8 @@ export const ModelManager = () => {
|
||||
</Flex>
|
||||
</Flex>
|
||||
<Box layerStyle="second" p={3} borderRadius="base" w="full" h="full">
|
||||
<Flex gap={2} alignItems="center" justifyContent="space-between">
|
||||
<Button>All Models</Button>
|
||||
<Spacer />
|
||||
<InputGroup>
|
||||
<Input placeholder={t('boards.searchBoard')} data-testid="board-search-input" />(
|
||||
<InputRightElement h="full" pe={2}>
|
||||
<IconButton size="sm" variant="link" aria-label={t('boards.clearSearch')} icon={<PiXBold />} />
|
||||
</InputRightElement>
|
||||
)
|
||||
</InputGroup>
|
||||
</Flex>
|
||||
<ModelListNavigation />
|
||||
<ModelList />
|
||||
</Box>
|
||||
</Box>
|
||||
);
|
||||
|
@ -0,0 +1,160 @@
|
||||
import { Flex, Spinner, Text } from '@invoke-ai/ui-library';
|
||||
import type { EntityState } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { memo } from 'react';
|
||||
import { ALL_BASE_MODELS } from 'services/api/constants';
|
||||
import {
|
||||
useGetControlNetModelsQuery,
|
||||
useGetIPAdapterModelsQuery,
|
||||
useGetLoRAModelsQuery,
|
||||
useGetMainModelsQuery,
|
||||
useGetT2IAdapterModelsQuery,
|
||||
useGetTextualInversionModelsQuery,
|
||||
useGetVaeModelsQuery,
|
||||
} from 'services/api/endpoints/models';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
import { ModelListWrapper } from './ModelListWrapper';
|
||||
|
||||
const ModelList = () => {
|
||||
const { searchTerm, filteredModelType } = useAppSelector((s) => s.modelmanagerV2);
|
||||
|
||||
const { filteredMainModels, isLoadingMainModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
|
||||
selectFromResult: ({ data, isLoading }) => ({
|
||||
filteredMainModels: modelsFilter(data, searchTerm, filteredModelType),
|
||||
isLoadingMainModels: isLoading,
|
||||
}),
|
||||
});
|
||||
|
||||
const { filteredLoraModels, isLoadingLoraModels } = useGetLoRAModelsQuery(undefined, {
|
||||
selectFromResult: ({ data, isLoading }) => ({
|
||||
filteredLoraModels: modelsFilter(data, searchTerm, filteredModelType),
|
||||
isLoadingLoraModels: isLoading,
|
||||
}),
|
||||
});
|
||||
|
||||
const { filteredTextualInversionModels, isLoadingTextualInversionModels } = useGetTextualInversionModelsQuery(
|
||||
undefined,
|
||||
{
|
||||
selectFromResult: ({ data, isLoading }) => ({
|
||||
filteredTextualInversionModels: modelsFilter(data, searchTerm, filteredModelType),
|
||||
isLoadingTextualInversionModels: isLoading,
|
||||
}),
|
||||
}
|
||||
);
|
||||
|
||||
const { filteredControlnetModels, isLoadingControlnetModels } = useGetControlNetModelsQuery(undefined, {
|
||||
selectFromResult: ({ data, isLoading }) => ({
|
||||
filteredControlnetModels: modelsFilter(data, searchTerm, filteredModelType),
|
||||
isLoadingControlnetModels: isLoading,
|
||||
}),
|
||||
});
|
||||
|
||||
const { filteredT2iAdapterModels, isLoadingT2IAdapterModels } = useGetT2IAdapterModelsQuery(undefined, {
|
||||
selectFromResult: ({ data, isLoading }) => ({
|
||||
filteredT2iAdapterModels: modelsFilter(data, searchTerm, filteredModelType),
|
||||
isLoadingT2IAdapterModels: isLoading,
|
||||
}),
|
||||
});
|
||||
|
||||
const { filteredIpAdapterModels, isLoadingIpAdapterModels } = useGetIPAdapterModelsQuery(undefined, {
|
||||
selectFromResult: ({ data, isLoading }) => ({
|
||||
filteredIpAdapterModels: modelsFilter(data, searchTerm, filteredModelType),
|
||||
isLoadingIpAdapterModels: isLoading,
|
||||
}),
|
||||
});
|
||||
|
||||
const { filteredVaeModels, isLoadingVaeModels } = useGetVaeModelsQuery(undefined, {
|
||||
selectFromResult: ({ data, isLoading }) => ({
|
||||
filteredVaeModels: modelsFilter(data, searchTerm, filteredModelType),
|
||||
isLoadingVaeModels: isLoading,
|
||||
}),
|
||||
});
|
||||
|
||||
return (
|
||||
<Flex flexDirection="column" p={4}>
|
||||
<Flex flexDirection="column" maxHeight={window.innerHeight - 130} overflow="scroll">
|
||||
{/* Main Model List */}
|
||||
{isLoadingMainModels && <FetchingModelsLoader loadingMessage="Loading Main..." />}
|
||||
{!isLoadingMainModels && filteredMainModels.length > 0 && (
|
||||
<ModelListWrapper title="Main" modelList={filteredMainModels} key="main" />
|
||||
)}
|
||||
{/* LoRAs List */}
|
||||
{isLoadingLoraModels && <FetchingModelsLoader loadingMessage="Loading LoRAs..." />}
|
||||
{!isLoadingLoraModels && filteredLoraModels.length > 0 && (
|
||||
<ModelListWrapper title="LoRAs" modelList={filteredLoraModels} key="loras" />
|
||||
)}
|
||||
|
||||
{/* TI List */}
|
||||
{isLoadingTextualInversionModels && <FetchingModelsLoader loadingMessage="Loading Textual Inversions..." />}
|
||||
{!isLoadingTextualInversionModels && filteredTextualInversionModels.length > 0 && (
|
||||
<ModelListWrapper
|
||||
title="Textual Inversions"
|
||||
modelList={filteredTextualInversionModels}
|
||||
key="textual-inversions"
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* VAE List */}
|
||||
{isLoadingVaeModels && <FetchingModelsLoader loadingMessage="Loading VAEs..." />}
|
||||
{!isLoadingVaeModels && filteredVaeModels.length > 0 && (
|
||||
<ModelListWrapper title="VAEs" modelList={filteredVaeModels} key="vae" />
|
||||
)}
|
||||
|
||||
{/* Controlnet List */}
|
||||
{isLoadingControlnetModels && <FetchingModelsLoader loadingMessage="Loading Controlnets..." />}
|
||||
{!isLoadingControlnetModels && filteredControlnetModels.length > 0 && (
|
||||
<ModelListWrapper title="Controlnets" modelList={filteredControlnetModels} key="controlnets" />
|
||||
)}
|
||||
{/* IP Adapter List */}
|
||||
{isLoadingIpAdapterModels && <FetchingModelsLoader loadingMessage="Loading IP Adapters..." />}
|
||||
{!isLoadingIpAdapterModels && filteredIpAdapterModels.length > 0 && (
|
||||
<ModelListWrapper title="IP Adapters" modelList={filteredIpAdapterModels} key="ip-adapters" />
|
||||
)}
|
||||
{/* T2I Adapters List */}
|
||||
{isLoadingT2IAdapterModels && <FetchingModelsLoader loadingMessage="Loading T2I Adapters..." />}
|
||||
{!isLoadingT2IAdapterModels && filteredT2iAdapterModels.length > 0 && (
|
||||
<ModelListWrapper title="T2I Adapters" modelList={filteredT2iAdapterModels} key="t2i-adapters" />
|
||||
)}
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ModelList);
|
||||
|
||||
const modelsFilter = <T extends AnyModelConfig>(
|
||||
data: EntityState<T, string> | undefined,
|
||||
nameFilter: string,
|
||||
filteredModelType: string | null
|
||||
): T[] => {
|
||||
const filteredModels: T[] = [];
|
||||
|
||||
forEach(data?.entities, (model) => {
|
||||
if (!model) {
|
||||
return;
|
||||
}
|
||||
|
||||
const matchesFilter = model.name.toLowerCase().includes(nameFilter.toLowerCase());
|
||||
const matchesType = filteredModelType ? model.type === filteredModelType : true;
|
||||
|
||||
if (matchesFilter && matchesType) {
|
||||
filteredModels.push(model);
|
||||
}
|
||||
});
|
||||
return filteredModels;
|
||||
};
|
||||
|
||||
const FetchingModelsLoader = memo(({ loadingMessage }: { loadingMessage?: string }) => {
|
||||
return (
|
||||
<Flex flexDirection="column" gap={4} borderRadius={4} p={4} bg="base.800">
|
||||
<Flex justifyContent="center" alignItems="center" flexDirection="column" p={4} gap={8}>
|
||||
<Spinner />
|
||||
<Text variant="subtext">{loadingMessage ? loadingMessage : 'Fetching...'}</Text>
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
FetchingModelsLoader.displayName = 'FetchingModelsLoader';
|
@ -0,0 +1,23 @@
|
||||
import { Box, Divider, Text } from '@invoke-ai/ui-library';
|
||||
|
||||
export const ModelListHeader = ({ title }: { title: string }) => {
|
||||
return (
|
||||
<Box position="relative" padding="10px 0">
|
||||
<Divider sx={{ backgroundColor: 'base.400' }} />
|
||||
<Box
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
top: '50%',
|
||||
left: 0,
|
||||
transform: 'translate(0, -50%)',
|
||||
backgroundColor: 'base.800',
|
||||
padding: '10px',
|
||||
}}
|
||||
>
|
||||
<Text variant="subtext" fontSize="sm">
|
||||
{title}
|
||||
</Text>
|
||||
</Box>
|
||||
</Box>
|
||||
);
|
||||
};
|
@ -0,0 +1,115 @@
|
||||
import {
|
||||
Badge,
|
||||
Button,
|
||||
ConfirmationAlertDialog,
|
||||
Flex,
|
||||
IconButton,
|
||||
Text,
|
||||
Tooltip,
|
||||
useDisclosure,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { setSelectedModelKey } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||
import { MODEL_TYPE_SHORT_MAP } from 'features/parameters/types/constants';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiTrashSimpleBold } from 'react-icons/pi';
|
||||
import { useDeleteModelsMutation } from 'services/api/endpoints/models';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
type ModelListItemProps = {
|
||||
model: AnyModelConfig;
|
||||
};
|
||||
|
||||
const ModelListItem = (props: ModelListItemProps) => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
const [deleteModel] = useDeleteModelsMutation();
|
||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||
|
||||
const { model } = props;
|
||||
|
||||
const handleSelectModel = useCallback(() => {
|
||||
dispatch(setSelectedModelKey(model.key));
|
||||
}, [model.key, dispatch]);
|
||||
|
||||
const isSelected = useMemo(() => {
|
||||
return selectedModelKey === model.key;
|
||||
}, [selectedModelKey, model.key]);
|
||||
|
||||
const handleModelDelete = useCallback(() => {
|
||||
deleteModel({ key: model.key })
|
||||
.unwrap()
|
||||
.then((_) => {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: `${t('modelManager.modelDeleted')}: ${model.name}`,
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
})
|
||||
.catch((error) => {
|
||||
if (error) {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: `${t('modelManager.modelDeleteFailed')}: ${model.name}`,
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
});
|
||||
dispatch(setSelectedModelKey(null));
|
||||
}, [deleteModel, model, dispatch, t]);
|
||||
|
||||
return (
|
||||
<Flex gap={2} alignItems="center" w="full">
|
||||
<Flex
|
||||
as={Button}
|
||||
isChecked={isSelected}
|
||||
variant={isSelected ? 'solid' : 'ghost'}
|
||||
justifyContent="start"
|
||||
p={2}
|
||||
borderRadius="base"
|
||||
w="full"
|
||||
alignItems="center"
|
||||
onClick={handleSelectModel}
|
||||
>
|
||||
<Flex gap={4} alignItems="center">
|
||||
<Badge minWidth={14} p={0.5} fontSize="sm" variant="solid">
|
||||
{MODEL_TYPE_SHORT_MAP[model.base as keyof typeof MODEL_TYPE_SHORT_MAP]}
|
||||
</Badge>
|
||||
<Tooltip label={model.description} placement="bottom">
|
||||
<Text>{model.name}</Text>
|
||||
</Tooltip>
|
||||
</Flex>
|
||||
</Flex>
|
||||
<IconButton
|
||||
onClick={onOpen}
|
||||
icon={<PiTrashSimpleBold />}
|
||||
aria-label={t('modelManager.deleteConfig')}
|
||||
colorScheme="error"
|
||||
/>
|
||||
<ConfirmationAlertDialog
|
||||
isOpen={isOpen}
|
||||
onClose={onClose}
|
||||
title={t('modelManager.deleteModel')}
|
||||
acceptCallback={handleModelDelete}
|
||||
acceptButtonText={t('modelManager.delete')}
|
||||
>
|
||||
<Flex rowGap={4} flexDirection="column">
|
||||
<Text fontWeight="bold">{t('modelManager.deleteMsg1')}</Text>
|
||||
<Text>{t('modelManager.deleteMsg2')}</Text>
|
||||
</Flex>
|
||||
</ConfirmationAlertDialog>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ModelListItem);
|
@ -0,0 +1,52 @@
|
||||
import { Flex, IconButton,Input, InputGroup, InputRightElement, Spacer } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { setSearchTerm } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||
import { t } from 'i18next';
|
||||
import type { ChangeEventHandler} from 'react';
|
||||
import { useCallback } from 'react';
|
||||
import { PiXBold } from 'react-icons/pi';
|
||||
|
||||
import { ModelTypeFilter } from './ModelTypeFilter';
|
||||
|
||||
export const ModelListNavigation = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const searchTerm = useAppSelector((s) => s.modelmanagerV2.searchTerm);
|
||||
|
||||
const handleSearch: ChangeEventHandler<HTMLInputElement> = useCallback(
|
||||
(event) => {
|
||||
dispatch(setSearchTerm(event.target.value));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const clearSearch = useCallback(() => {
|
||||
dispatch(setSearchTerm(''));
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<Flex gap={2} alignItems="center" justifyContent="space-between">
|
||||
<ModelTypeFilter />
|
||||
<Spacer />
|
||||
<InputGroup maxW="400px">
|
||||
<Input
|
||||
placeholder={t('modelManager.search')}
|
||||
value={searchTerm || ''}
|
||||
data-testid="board-search-input"
|
||||
onChange={handleSearch}
|
||||
/>
|
||||
|
||||
{!!searchTerm?.length && (
|
||||
<InputRightElement h="full" pe={2}>
|
||||
<IconButton
|
||||
size="sm"
|
||||
variant="link"
|
||||
aria-label={t('boards.clearSearch')}
|
||||
icon={<PiXBold />}
|
||||
onClick={clearSearch}
|
||||
/>
|
||||
</InputRightElement>
|
||||
)}
|
||||
</InputGroup>
|
||||
</Flex>
|
||||
);
|
||||
};
|
@ -0,0 +1,25 @@
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
import { ModelListHeader } from './ModelListHeader';
|
||||
import ModelListItem from './ModelListItem';
|
||||
|
||||
type ModelListWrapperProps = {
|
||||
title: string;
|
||||
modelList: AnyModelConfig[];
|
||||
};
|
||||
|
||||
export const ModelListWrapper = (props: ModelListWrapperProps) => {
|
||||
const { title, modelList } = props;
|
||||
return (
|
||||
<Flex flexDirection="column" p="10px 0">
|
||||
<Flex gap={2} flexDir="column">
|
||||
<ModelListHeader title={title} />
|
||||
|
||||
{modelList.map((model) => (
|
||||
<ModelListItem key={model.key} model={model} />
|
||||
))}
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
};
|
@ -0,0 +1,54 @@
|
||||
import { Button, Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { setFilteredModelType } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||
import { useCallback } from 'react';
|
||||
import { IoFilter } from 'react-icons/io5';
|
||||
|
||||
export const MODEL_TYPE_LABELS: { [key: string]: string } = {
|
||||
main: 'Main',
|
||||
lora: 'LoRA',
|
||||
embedding: 'Textual Inversion',
|
||||
controlnet: 'ControlNet',
|
||||
vae: 'VAE',
|
||||
t2i_adapter: 'T2I Adapter',
|
||||
ip_adapter: 'IP Adapter',
|
||||
clip_vision: 'Clip Vision',
|
||||
onnx: 'Onnx',
|
||||
};
|
||||
|
||||
export const ModelTypeFilter = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const filteredModelType = useAppSelector((s) => s.modelmanagerV2.filteredModelType);
|
||||
|
||||
const selectModelType = useCallback(
|
||||
(option: string) => {
|
||||
dispatch(setFilteredModelType(option));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const clearModelType = useCallback(() => {
|
||||
dispatch(setFilteredModelType(null));
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<Menu>
|
||||
<MenuButton as={Button} leftIcon={<IoFilter />}>
|
||||
{filteredModelType ? MODEL_TYPE_LABELS[filteredModelType] : 'All Models'}
|
||||
</MenuButton>
|
||||
<MenuList>
|
||||
<MenuItem onClick={clearModelType}>All Models</MenuItem>
|
||||
{Object.keys(MODEL_TYPE_LABELS).map((option) => (
|
||||
<MenuItem
|
||||
sx={{
|
||||
backgroundColor: filteredModelType === option ? 'base.700' : 'transparent',
|
||||
}}
|
||||
onClick={selectModelType.bind(null, option)}
|
||||
>
|
||||
{MODEL_TYPE_LABELS[option]}
|
||||
</MenuItem>
|
||||
))}
|
||||
</MenuList>
|
||||
</Menu>
|
||||
);
|
||||
};
|
14
invokeai/frontend/web/src/features/modelManagerV2/types.ts
Normal file
14
invokeai/frontend/web/src/features/modelManagerV2/types.ts
Normal file
@ -0,0 +1,14 @@
|
||||
import { z } from "zod";
|
||||
|
||||
export const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']);
|
||||
export const zModelType = z.enum([
|
||||
'main',
|
||||
'vae',
|
||||
'lora',
|
||||
'controlnet',
|
||||
'embedding',
|
||||
'ip_adapter',
|
||||
'clip_vision',
|
||||
't2i_adapter',
|
||||
'onnx', // TODO(psyche): Remove this when removed from backend
|
||||
]);
|
@ -1,8 +1,7 @@
|
||||
import { Flex, Box } from '@invoke-ai/ui-library';
|
||||
import { Box,Flex } from '@invoke-ai/ui-library';
|
||||
import { ImportModels } from 'features/modelManagerV2/subpanels/ImportModels';
|
||||
import { ModelManager } from 'features/modelManagerV2/subpanels/ModelManager';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { ImportModels } from '../../../modelManagerV2/subpanels/ImportModels';
|
||||
import { ModelManager } from '../../../modelManagerV2/subpanels/ModelManager';
|
||||
|
||||
const ModelManagerTab = () => {
|
||||
return (
|
||||
|
Loading…
Reference in New Issue
Block a user