model list, filtering, searching

This commit is contained in:
Mary Hipp 2024-02-20 13:03:28 -05:00 committed by psychedelicious
parent 9068400433
commit c7d462b222
18 changed files with 517 additions and 35 deletions

View File

@ -16,6 +16,7 @@ import { galleryPersistConfig, gallerySlice } from 'features/gallery/store/galle
import { hrfPersistConfig, hrfSlice } from 'features/hrf/store/hrfSlice';
import { loraPersistConfig, loraSlice } from 'features/lora/store/loraSlice';
import { modelManagerPersistConfig, modelManagerSlice } from 'features/modelManager/store/modelManagerSlice';
import { modelManagerV2Slice } from 'features/modelManagerV2/store/modelManagerV2Slice';
import { nodesPersistConfig, nodesSlice } from 'features/nodes/store/nodesSlice';
import { workflowPersistConfig, workflowSlice } from 'features/nodes/store/workflowSlice';
import { generationPersistConfig, generationSlice } from 'features/parameters/store/generationSlice';
@ -55,6 +56,7 @@ const allReducers = {
[changeBoardModalSlice.name]: changeBoardModalSlice.reducer,
[loraSlice.name]: loraSlice.reducer,
[modelManagerSlice.name]: modelManagerSlice.reducer,
[modelManagerV2Slice.name]: modelManagerV2Slice.reducer,
[sdxlSlice.name]: sdxlSlice.reducer,
[queueSlice.name]: queueSlice.reducer,
[workflowSlice.name]: workflowSlice.reducer,

View File

@ -1,10 +1,10 @@
import { Button, ButtonGroup, Flex, Text } from '@invoke-ai/ui-library';
import { memo, useCallback, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetModelImportsQuery } from 'services/api/endpoints/models';
import AdvancedAddModels from './AdvancedAddModels';
import SimpleAddModels from './SimpleAddModels';
import { useGetModelImportsQuery } from '../../../../services/api/endpoints/models';
const AddModels = () => {
const { t } = useTranslation();

View File

@ -3,12 +3,12 @@ import { memo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { ALL_BASE_MODELS } from 'services/api/constants';
import { useGetLoRAModelsQuery, useGetMainModelsQuery } from 'services/api/endpoints/models';
import type { DiffusersModelConfig, LoRAConfig, MainModelConfig } from 'services/api/types';
import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit';
import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit';
import LoRAModelEdit from './ModelManagerPanel/LoRAModelEdit';
import ModelList from './ModelManagerPanel/ModelList';
import { DiffusersModelConfig, LoRAConfig, MainModelConfig } from '../../../services/api/types';
const ModelManagerPanel = () => {
const [selectedModelId, setSelectedModelId] = useState<string>();

View File

@ -22,8 +22,9 @@ import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import { useGetCheckpointConfigsQuery, useUpdateModelsMutation } from 'services/api/endpoints/models';
import type { CheckpointModelConfig } from 'services/api/types';
import ModelConvert from './ModelConvert';
import { CheckpointModelConfig } from '../../../../services/api/types';
type CheckpointModelEditProps = {
model: CheckpointModelConfig;

View File

@ -9,8 +9,8 @@ import { memo, useCallback } from 'react';
import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import { useUpdateModelsMutation } from 'services/api/endpoints/models';
import type { DiffusersModelConfig } from 'services/api/types';
import { useUpdateModelsMutation } from '../../../../services/api/endpoints/models';
type DiffusersModelEditProps = {
model: DiffusersModelConfig;

View File

@ -8,6 +8,7 @@ import { memo, useCallback } from 'react';
import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import { useUpdateModelsMutation } from 'services/api/endpoints/models';
import type { LoRAModelConfig } from 'services/api/types';
type LoRAModelEditProps = {

View File

@ -7,9 +7,9 @@ import { useTranslation } from 'react-i18next';
import { ALL_BASE_MODELS } from 'services/api/constants';
// import type { LoRAConfig, MainModelConfig } from 'services/api/endpoints/models';
import { useGetLoRAModelsQuery, useGetMainModelsQuery } from 'services/api/endpoints/models';
import type { LoRAConfig, MainModelConfig } from 'services/api/types';
import ModelListItem from './ModelListItem';
import { LoRAConfig, MainModelConfig } from '../../../../services/api/types';
type ModelListProps = {
selectedModelId: string | undefined;

View File

@ -16,7 +16,7 @@ import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiTrashSimpleBold } from 'react-icons/pi';
import { useDeleteModelsMutation } from 'services/api/endpoints/models';
import { LoRAConfig, MainModelConfig } from '../../../../services/api/types';
import type { LoRAConfig, MainModelConfig } from 'services/api/types';
type ModelListItemProps = {
model: MainModelConfig | LoRAConfig;

View File

@ -0,0 +1,54 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
type ModelManagerState = {
_version: 1;
selectedModelKey: string | null;
searchTerm: string;
filteredModelType: string | null;
};
export const initialModelManagerState: ModelManagerState = {
_version: 1,
selectedModelKey: null,
filteredModelType: null,
searchTerm: ""
};
export const modelManagerV2Slice = createSlice({
name: 'modelmanagerV2',
initialState: initialModelManagerState,
reducers: {
setSelectedModelKey: (state, action: PayloadAction<string | null>) => {
state.selectedModelKey = action.payload;
},
setSearchTerm: (state, action: PayloadAction<string>) => {
state.searchTerm = action.payload;
},
setFilteredModelType: (state, action: PayloadAction<string | null>) => {
state.filteredModelType = action.payload;
},
},
});
export const { setSelectedModelKey, setSearchTerm, setFilteredModelType } = modelManagerV2Slice.actions;
export const selectModelManagerSlice = (state: RootState) => state.modelmanager;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
export const migrateModelManagerState = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
}
return state;
};
export const modelManagerPersistConfig: PersistConfig<ModelManagerState> = {
name: modelManagerV2Slice.name,
initialState: initialModelManagerState,
migrate: migrateModelManagerState,
persistDenylist: [],
};

View File

@ -1,17 +1,8 @@
import {
Box,
Button,
Flex,
Heading,
IconButton,
Input,
InputGroup,
InputRightElement,
Spacer,
} from '@invoke-ai/ui-library';
import { t } from 'i18next';
import { PiXBold } from 'react-icons/pi';
import { SyncModelsIconButton } from '../../modelManager/components/SyncModels/SyncModelsIconButton';
import { Box, Button, Flex, Heading } from '@invoke-ai/ui-library';
import { SyncModelsIconButton } from 'features/modelManager/components/SyncModels/SyncModelsIconButton';
import ModelList from './ModelManagerPanel/ModelList';
import { ModelListNavigation } from './ModelManagerPanel/ModelListNavigation';
export const ModelManager = () => {
return (
@ -27,17 +18,8 @@ export const ModelManager = () => {
</Flex>
</Flex>
<Box layerStyle="second" p={3} borderRadius="base" w="full" h="full">
<Flex gap={2} alignItems="center" justifyContent="space-between">
<Button>All Models</Button>
<Spacer />
<InputGroup>
<Input placeholder={t('boards.searchBoard')} data-testid="board-search-input" />(
<InputRightElement h="full" pe={2}>
<IconButton size="sm" variant="link" aria-label={t('boards.clearSearch')} icon={<PiXBold />} />
</InputRightElement>
)
</InputGroup>
</Flex>
<ModelListNavigation />
<ModelList />
</Box>
</Box>
);

View File

@ -0,0 +1,160 @@
import { Flex, Spinner, Text } from '@invoke-ai/ui-library';
import type { EntityState } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { forEach } from 'lodash-es';
import { memo } from 'react';
import { ALL_BASE_MODELS } from 'services/api/constants';
import {
useGetControlNetModelsQuery,
useGetIPAdapterModelsQuery,
useGetLoRAModelsQuery,
useGetMainModelsQuery,
useGetT2IAdapterModelsQuery,
useGetTextualInversionModelsQuery,
useGetVaeModelsQuery,
} from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
import { ModelListWrapper } from './ModelListWrapper';
const ModelList = () => {
const { searchTerm, filteredModelType } = useAppSelector((s) => s.modelmanagerV2);
const { filteredMainModels, isLoadingMainModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
selectFromResult: ({ data, isLoading }) => ({
filteredMainModels: modelsFilter(data, searchTerm, filteredModelType),
isLoadingMainModels: isLoading,
}),
});
const { filteredLoraModels, isLoadingLoraModels } = useGetLoRAModelsQuery(undefined, {
selectFromResult: ({ data, isLoading }) => ({
filteredLoraModels: modelsFilter(data, searchTerm, filteredModelType),
isLoadingLoraModels: isLoading,
}),
});
const { filteredTextualInversionModels, isLoadingTextualInversionModels } = useGetTextualInversionModelsQuery(
undefined,
{
selectFromResult: ({ data, isLoading }) => ({
filteredTextualInversionModels: modelsFilter(data, searchTerm, filteredModelType),
isLoadingTextualInversionModels: isLoading,
}),
}
);
const { filteredControlnetModels, isLoadingControlnetModels } = useGetControlNetModelsQuery(undefined, {
selectFromResult: ({ data, isLoading }) => ({
filteredControlnetModels: modelsFilter(data, searchTerm, filteredModelType),
isLoadingControlnetModels: isLoading,
}),
});
const { filteredT2iAdapterModels, isLoadingT2IAdapterModels } = useGetT2IAdapterModelsQuery(undefined, {
selectFromResult: ({ data, isLoading }) => ({
filteredT2iAdapterModels: modelsFilter(data, searchTerm, filteredModelType),
isLoadingT2IAdapterModels: isLoading,
}),
});
const { filteredIpAdapterModels, isLoadingIpAdapterModels } = useGetIPAdapterModelsQuery(undefined, {
selectFromResult: ({ data, isLoading }) => ({
filteredIpAdapterModels: modelsFilter(data, searchTerm, filteredModelType),
isLoadingIpAdapterModels: isLoading,
}),
});
const { filteredVaeModels, isLoadingVaeModels } = useGetVaeModelsQuery(undefined, {
selectFromResult: ({ data, isLoading }) => ({
filteredVaeModels: modelsFilter(data, searchTerm, filteredModelType),
isLoadingVaeModels: isLoading,
}),
});
return (
<Flex flexDirection="column" p={4}>
<Flex flexDirection="column" maxHeight={window.innerHeight - 130} overflow="scroll">
{/* Main Model List */}
{isLoadingMainModels && <FetchingModelsLoader loadingMessage="Loading Main..." />}
{!isLoadingMainModels && filteredMainModels.length > 0 && (
<ModelListWrapper title="Main" modelList={filteredMainModels} key="main" />
)}
{/* LoRAs List */}
{isLoadingLoraModels && <FetchingModelsLoader loadingMessage="Loading LoRAs..." />}
{!isLoadingLoraModels && filteredLoraModels.length > 0 && (
<ModelListWrapper title="LoRAs" modelList={filteredLoraModels} key="loras" />
)}
{/* TI List */}
{isLoadingTextualInversionModels && <FetchingModelsLoader loadingMessage="Loading Textual Inversions..." />}
{!isLoadingTextualInversionModels && filteredTextualInversionModels.length > 0 && (
<ModelListWrapper
title="Textual Inversions"
modelList={filteredTextualInversionModels}
key="textual-inversions"
/>
)}
{/* VAE List */}
{isLoadingVaeModels && <FetchingModelsLoader loadingMessage="Loading VAEs..." />}
{!isLoadingVaeModels && filteredVaeModels.length > 0 && (
<ModelListWrapper title="VAEs" modelList={filteredVaeModels} key="vae" />
)}
{/* Controlnet List */}
{isLoadingControlnetModels && <FetchingModelsLoader loadingMessage="Loading Controlnets..." />}
{!isLoadingControlnetModels && filteredControlnetModels.length > 0 && (
<ModelListWrapper title="Controlnets" modelList={filteredControlnetModels} key="controlnets" />
)}
{/* IP Adapter List */}
{isLoadingIpAdapterModels && <FetchingModelsLoader loadingMessage="Loading IP Adapters..." />}
{!isLoadingIpAdapterModels && filteredIpAdapterModels.length > 0 && (
<ModelListWrapper title="IP Adapters" modelList={filteredIpAdapterModels} key="ip-adapters" />
)}
{/* T2I Adapters List */}
{isLoadingT2IAdapterModels && <FetchingModelsLoader loadingMessage="Loading T2I Adapters..." />}
{!isLoadingT2IAdapterModels && filteredT2iAdapterModels.length > 0 && (
<ModelListWrapper title="T2I Adapters" modelList={filteredT2iAdapterModels} key="t2i-adapters" />
)}
</Flex>
</Flex>
);
};
export default memo(ModelList);
const modelsFilter = <T extends AnyModelConfig>(
data: EntityState<T, string> | undefined,
nameFilter: string,
filteredModelType: string | null
): T[] => {
const filteredModels: T[] = [];
forEach(data?.entities, (model) => {
if (!model) {
return;
}
const matchesFilter = model.name.toLowerCase().includes(nameFilter.toLowerCase());
const matchesType = filteredModelType ? model.type === filteredModelType : true;
if (matchesFilter && matchesType) {
filteredModels.push(model);
}
});
return filteredModels;
};
const FetchingModelsLoader = memo(({ loadingMessage }: { loadingMessage?: string }) => {
return (
<Flex flexDirection="column" gap={4} borderRadius={4} p={4} bg="base.800">
<Flex justifyContent="center" alignItems="center" flexDirection="column" p={4} gap={8}>
<Spinner />
<Text variant="subtext">{loadingMessage ? loadingMessage : 'Fetching...'}</Text>
</Flex>
</Flex>
);
});
FetchingModelsLoader.displayName = 'FetchingModelsLoader';

View File

@ -0,0 +1,23 @@
import { Box, Divider, Text } from '@invoke-ai/ui-library';
export const ModelListHeader = ({ title }: { title: string }) => {
return (
<Box position="relative" padding="10px 0">
<Divider sx={{ backgroundColor: 'base.400' }} />
<Box
sx={{
position: 'absolute',
top: '50%',
left: 0,
transform: 'translate(0, -50%)',
backgroundColor: 'base.800',
padding: '10px',
}}
>
<Text variant="subtext" fontSize="sm">
{title}
</Text>
</Box>
</Box>
);
};

View File

@ -0,0 +1,115 @@
import {
Badge,
Button,
ConfirmationAlertDialog,
Flex,
IconButton,
Text,
Tooltip,
useDisclosure,
} from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { setSelectedModelKey } from 'features/modelManagerV2/store/modelManagerV2Slice';
import { MODEL_TYPE_SHORT_MAP } from 'features/parameters/types/constants';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiTrashSimpleBold } from 'react-icons/pi';
import { useDeleteModelsMutation } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
type ModelListItemProps = {
model: AnyModelConfig;
};
const ModelListItem = (props: ModelListItemProps) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
const [deleteModel] = useDeleteModelsMutation();
const { isOpen, onOpen, onClose } = useDisclosure();
const { model } = props;
const handleSelectModel = useCallback(() => {
dispatch(setSelectedModelKey(model.key));
}, [model.key, dispatch]);
const isSelected = useMemo(() => {
return selectedModelKey === model.key;
}, [selectedModelKey, model.key]);
const handleModelDelete = useCallback(() => {
deleteModel({ key: model.key })
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: `${t('modelManager.modelDeleted')}: ${model.name}`,
status: 'success',
})
)
);
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: `${t('modelManager.modelDeleteFailed')}: ${model.name}`,
status: 'error',
})
)
);
}
});
dispatch(setSelectedModelKey(null));
}, [deleteModel, model, dispatch, t]);
return (
<Flex gap={2} alignItems="center" w="full">
<Flex
as={Button}
isChecked={isSelected}
variant={isSelected ? 'solid' : 'ghost'}
justifyContent="start"
p={2}
borderRadius="base"
w="full"
alignItems="center"
onClick={handleSelectModel}
>
<Flex gap={4} alignItems="center">
<Badge minWidth={14} p={0.5} fontSize="sm" variant="solid">
{MODEL_TYPE_SHORT_MAP[model.base as keyof typeof MODEL_TYPE_SHORT_MAP]}
</Badge>
<Tooltip label={model.description} placement="bottom">
<Text>{model.name}</Text>
</Tooltip>
</Flex>
</Flex>
<IconButton
onClick={onOpen}
icon={<PiTrashSimpleBold />}
aria-label={t('modelManager.deleteConfig')}
colorScheme="error"
/>
<ConfirmationAlertDialog
isOpen={isOpen}
onClose={onClose}
title={t('modelManager.deleteModel')}
acceptCallback={handleModelDelete}
acceptButtonText={t('modelManager.delete')}
>
<Flex rowGap={4} flexDirection="column">
<Text fontWeight="bold">{t('modelManager.deleteMsg1')}</Text>
<Text>{t('modelManager.deleteMsg2')}</Text>
</Flex>
</ConfirmationAlertDialog>
</Flex>
);
};
export default memo(ModelListItem);

View File

@ -0,0 +1,52 @@
import { Flex, IconButton,Input, InputGroup, InputRightElement, Spacer } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { setSearchTerm } from 'features/modelManagerV2/store/modelManagerV2Slice';
import { t } from 'i18next';
import type { ChangeEventHandler} from 'react';
import { useCallback } from 'react';
import { PiXBold } from 'react-icons/pi';
import { ModelTypeFilter } from './ModelTypeFilter';
export const ModelListNavigation = () => {
const dispatch = useAppDispatch();
const searchTerm = useAppSelector((s) => s.modelmanagerV2.searchTerm);
const handleSearch: ChangeEventHandler<HTMLInputElement> = useCallback(
(event) => {
dispatch(setSearchTerm(event.target.value));
},
[dispatch]
);
const clearSearch = useCallback(() => {
dispatch(setSearchTerm(''));
}, [dispatch]);
return (
<Flex gap={2} alignItems="center" justifyContent="space-between">
<ModelTypeFilter />
<Spacer />
<InputGroup maxW="400px">
<Input
placeholder={t('modelManager.search')}
value={searchTerm || ''}
data-testid="board-search-input"
onChange={handleSearch}
/>
{!!searchTerm?.length && (
<InputRightElement h="full" pe={2}>
<IconButton
size="sm"
variant="link"
aria-label={t('boards.clearSearch')}
icon={<PiXBold />}
onClick={clearSearch}
/>
</InputRightElement>
)}
</InputGroup>
</Flex>
);
};

View File

@ -0,0 +1,25 @@
import { Flex } from '@invoke-ai/ui-library';
import type { AnyModelConfig } from 'services/api/types';
import { ModelListHeader } from './ModelListHeader';
import ModelListItem from './ModelListItem';
type ModelListWrapperProps = {
title: string;
modelList: AnyModelConfig[];
};
export const ModelListWrapper = (props: ModelListWrapperProps) => {
const { title, modelList } = props;
return (
<Flex flexDirection="column" p="10px 0">
<Flex gap={2} flexDir="column">
<ModelListHeader title={title} />
{modelList.map((model) => (
<ModelListItem key={model.key} model={model} />
))}
</Flex>
</Flex>
);
};

View File

@ -0,0 +1,54 @@
import { Button, Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { setFilteredModelType } from 'features/modelManagerV2/store/modelManagerV2Slice';
import { useCallback } from 'react';
import { IoFilter } from 'react-icons/io5';
export const MODEL_TYPE_LABELS: { [key: string]: string } = {
main: 'Main',
lora: 'LoRA',
embedding: 'Textual Inversion',
controlnet: 'ControlNet',
vae: 'VAE',
t2i_adapter: 'T2I Adapter',
ip_adapter: 'IP Adapter',
clip_vision: 'Clip Vision',
onnx: 'Onnx',
};
export const ModelTypeFilter = () => {
const dispatch = useAppDispatch();
const filteredModelType = useAppSelector((s) => s.modelmanagerV2.filteredModelType);
const selectModelType = useCallback(
(option: string) => {
dispatch(setFilteredModelType(option));
},
[dispatch]
);
const clearModelType = useCallback(() => {
dispatch(setFilteredModelType(null));
}, [dispatch]);
return (
<Menu>
<MenuButton as={Button} leftIcon={<IoFilter />}>
{filteredModelType ? MODEL_TYPE_LABELS[filteredModelType] : 'All Models'}
</MenuButton>
<MenuList>
<MenuItem onClick={clearModelType}>All Models</MenuItem>
{Object.keys(MODEL_TYPE_LABELS).map((option) => (
<MenuItem
sx={{
backgroundColor: filteredModelType === option ? 'base.700' : 'transparent',
}}
onClick={selectModelType.bind(null, option)}
>
{MODEL_TYPE_LABELS[option]}
</MenuItem>
))}
</MenuList>
</Menu>
);
};

View File

@ -0,0 +1,14 @@
import { z } from "zod";
export const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']);
export const zModelType = z.enum([
'main',
'vae',
'lora',
'controlnet',
'embedding',
'ip_adapter',
'clip_vision',
't2i_adapter',
'onnx', // TODO(psyche): Remove this when removed from backend
]);

View File

@ -1,8 +1,7 @@
import { Flex, Box } from '@invoke-ai/ui-library';
import { Box,Flex } from '@invoke-ai/ui-library';
import { ImportModels } from 'features/modelManagerV2/subpanels/ImportModels';
import { ModelManager } from 'features/modelManagerV2/subpanels/ModelManager';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { ImportModels } from '../../../modelManagerV2/subpanels/ImportModels';
import { ModelManager } from '../../../modelManagerV2/subpanels/ModelManager';
const ModelManagerTab = () => {
return (