mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
model list, filtering, searching
This commit is contained in:
parent
9068400433
commit
c7d462b222
@ -16,6 +16,7 @@ import { galleryPersistConfig, gallerySlice } from 'features/gallery/store/galle
|
|||||||
import { hrfPersistConfig, hrfSlice } from 'features/hrf/store/hrfSlice';
|
import { hrfPersistConfig, hrfSlice } from 'features/hrf/store/hrfSlice';
|
||||||
import { loraPersistConfig, loraSlice } from 'features/lora/store/loraSlice';
|
import { loraPersistConfig, loraSlice } from 'features/lora/store/loraSlice';
|
||||||
import { modelManagerPersistConfig, modelManagerSlice } from 'features/modelManager/store/modelManagerSlice';
|
import { modelManagerPersistConfig, modelManagerSlice } from 'features/modelManager/store/modelManagerSlice';
|
||||||
|
import { modelManagerV2Slice } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||||
import { nodesPersistConfig, nodesSlice } from 'features/nodes/store/nodesSlice';
|
import { nodesPersistConfig, nodesSlice } from 'features/nodes/store/nodesSlice';
|
||||||
import { workflowPersistConfig, workflowSlice } from 'features/nodes/store/workflowSlice';
|
import { workflowPersistConfig, workflowSlice } from 'features/nodes/store/workflowSlice';
|
||||||
import { generationPersistConfig, generationSlice } from 'features/parameters/store/generationSlice';
|
import { generationPersistConfig, generationSlice } from 'features/parameters/store/generationSlice';
|
||||||
@ -55,6 +56,7 @@ const allReducers = {
|
|||||||
[changeBoardModalSlice.name]: changeBoardModalSlice.reducer,
|
[changeBoardModalSlice.name]: changeBoardModalSlice.reducer,
|
||||||
[loraSlice.name]: loraSlice.reducer,
|
[loraSlice.name]: loraSlice.reducer,
|
||||||
[modelManagerSlice.name]: modelManagerSlice.reducer,
|
[modelManagerSlice.name]: modelManagerSlice.reducer,
|
||||||
|
[modelManagerV2Slice.name]: modelManagerV2Slice.reducer,
|
||||||
[sdxlSlice.name]: sdxlSlice.reducer,
|
[sdxlSlice.name]: sdxlSlice.reducer,
|
||||||
[queueSlice.name]: queueSlice.reducer,
|
[queueSlice.name]: queueSlice.reducer,
|
||||||
[workflowSlice.name]: workflowSlice.reducer,
|
[workflowSlice.name]: workflowSlice.reducer,
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
import { Button, ButtonGroup, Flex, Text } from '@invoke-ai/ui-library';
|
import { Button, ButtonGroup, Flex, Text } from '@invoke-ai/ui-library';
|
||||||
import { memo, useCallback, useState } from 'react';
|
import { memo, useCallback, useState } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { useGetModelImportsQuery } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
import AdvancedAddModels from './AdvancedAddModels';
|
import AdvancedAddModels from './AdvancedAddModels';
|
||||||
import SimpleAddModels from './SimpleAddModels';
|
import SimpleAddModels from './SimpleAddModels';
|
||||||
import { useGetModelImportsQuery } from '../../../../services/api/endpoints/models';
|
|
||||||
|
|
||||||
const AddModels = () => {
|
const AddModels = () => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
@ -3,12 +3,12 @@ import { memo, useState } from 'react';
|
|||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { ALL_BASE_MODELS } from 'services/api/constants';
|
import { ALL_BASE_MODELS } from 'services/api/constants';
|
||||||
import { useGetLoRAModelsQuery, useGetMainModelsQuery } from 'services/api/endpoints/models';
|
import { useGetLoRAModelsQuery, useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||||
|
import type { DiffusersModelConfig, LoRAConfig, MainModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit';
|
import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit';
|
||||||
import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit';
|
import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit';
|
||||||
import LoRAModelEdit from './ModelManagerPanel/LoRAModelEdit';
|
import LoRAModelEdit from './ModelManagerPanel/LoRAModelEdit';
|
||||||
import ModelList from './ModelManagerPanel/ModelList';
|
import ModelList from './ModelManagerPanel/ModelList';
|
||||||
import { DiffusersModelConfig, LoRAConfig, MainModelConfig } from '../../../services/api/types';
|
|
||||||
|
|
||||||
const ModelManagerPanel = () => {
|
const ModelManagerPanel = () => {
|
||||||
const [selectedModelId, setSelectedModelId] = useState<string>();
|
const [selectedModelId, setSelectedModelId] = useState<string>();
|
||||||
|
@ -22,8 +22,9 @@ import type { SubmitHandler } from 'react-hook-form';
|
|||||||
import { useForm } from 'react-hook-form';
|
import { useForm } from 'react-hook-form';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { useGetCheckpointConfigsQuery, useUpdateModelsMutation } from 'services/api/endpoints/models';
|
import { useGetCheckpointConfigsQuery, useUpdateModelsMutation } from 'services/api/endpoints/models';
|
||||||
|
import type { CheckpointModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
import ModelConvert from './ModelConvert';
|
import ModelConvert from './ModelConvert';
|
||||||
import { CheckpointModelConfig } from '../../../../services/api/types';
|
|
||||||
|
|
||||||
type CheckpointModelEditProps = {
|
type CheckpointModelEditProps = {
|
||||||
model: CheckpointModelConfig;
|
model: CheckpointModelConfig;
|
||||||
|
@ -9,8 +9,8 @@ import { memo, useCallback } from 'react';
|
|||||||
import type { SubmitHandler } from 'react-hook-form';
|
import type { SubmitHandler } from 'react-hook-form';
|
||||||
import { useForm } from 'react-hook-form';
|
import { useForm } from 'react-hook-form';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { useUpdateModelsMutation } from 'services/api/endpoints/models';
|
||||||
import type { DiffusersModelConfig } from 'services/api/types';
|
import type { DiffusersModelConfig } from 'services/api/types';
|
||||||
import { useUpdateModelsMutation } from '../../../../services/api/endpoints/models';
|
|
||||||
|
|
||||||
type DiffusersModelEditProps = {
|
type DiffusersModelEditProps = {
|
||||||
model: DiffusersModelConfig;
|
model: DiffusersModelConfig;
|
||||||
|
@ -8,6 +8,7 @@ import { memo, useCallback } from 'react';
|
|||||||
import type { SubmitHandler } from 'react-hook-form';
|
import type { SubmitHandler } from 'react-hook-form';
|
||||||
import { useForm } from 'react-hook-form';
|
import { useForm } from 'react-hook-form';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { useUpdateModelsMutation } from 'services/api/endpoints/models';
|
||||||
import type { LoRAModelConfig } from 'services/api/types';
|
import type { LoRAModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
type LoRAModelEditProps = {
|
type LoRAModelEditProps = {
|
||||||
|
@ -7,9 +7,9 @@ import { useTranslation } from 'react-i18next';
|
|||||||
import { ALL_BASE_MODELS } from 'services/api/constants';
|
import { ALL_BASE_MODELS } from 'services/api/constants';
|
||||||
// import type { LoRAConfig, MainModelConfig } from 'services/api/endpoints/models';
|
// import type { LoRAConfig, MainModelConfig } from 'services/api/endpoints/models';
|
||||||
import { useGetLoRAModelsQuery, useGetMainModelsQuery } from 'services/api/endpoints/models';
|
import { useGetLoRAModelsQuery, useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||||
|
import type { LoRAConfig, MainModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
import ModelListItem from './ModelListItem';
|
import ModelListItem from './ModelListItem';
|
||||||
import { LoRAConfig, MainModelConfig } from '../../../../services/api/types';
|
|
||||||
|
|
||||||
type ModelListProps = {
|
type ModelListProps = {
|
||||||
selectedModelId: string | undefined;
|
selectedModelId: string | undefined;
|
||||||
|
@ -16,7 +16,7 @@ import { memo, useCallback } from 'react';
|
|||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { PiTrashSimpleBold } from 'react-icons/pi';
|
import { PiTrashSimpleBold } from 'react-icons/pi';
|
||||||
import { useDeleteModelsMutation } from 'services/api/endpoints/models';
|
import { useDeleteModelsMutation } from 'services/api/endpoints/models';
|
||||||
import { LoRAConfig, MainModelConfig } from '../../../../services/api/types';
|
import type { LoRAConfig, MainModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
type ModelListItemProps = {
|
type ModelListItemProps = {
|
||||||
model: MainModelConfig | LoRAConfig;
|
model: MainModelConfig | LoRAConfig;
|
||||||
|
@ -0,0 +1,54 @@
|
|||||||
|
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||||
|
import { createSlice } from '@reduxjs/toolkit';
|
||||||
|
import type { PersistConfig, RootState } from 'app/store/store';
|
||||||
|
|
||||||
|
|
||||||
|
type ModelManagerState = {
|
||||||
|
_version: 1;
|
||||||
|
selectedModelKey: string | null;
|
||||||
|
searchTerm: string;
|
||||||
|
filteredModelType: string | null;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const initialModelManagerState: ModelManagerState = {
|
||||||
|
_version: 1,
|
||||||
|
selectedModelKey: null,
|
||||||
|
filteredModelType: null,
|
||||||
|
searchTerm: ""
|
||||||
|
};
|
||||||
|
|
||||||
|
export const modelManagerV2Slice = createSlice({
|
||||||
|
name: 'modelmanagerV2',
|
||||||
|
initialState: initialModelManagerState,
|
||||||
|
reducers: {
|
||||||
|
setSelectedModelKey: (state, action: PayloadAction<string | null>) => {
|
||||||
|
state.selectedModelKey = action.payload;
|
||||||
|
},
|
||||||
|
setSearchTerm: (state, action: PayloadAction<string>) => {
|
||||||
|
state.searchTerm = action.payload;
|
||||||
|
},
|
||||||
|
|
||||||
|
setFilteredModelType: (state, action: PayloadAction<string | null>) => {
|
||||||
|
state.filteredModelType = action.payload;
|
||||||
|
},
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
export const { setSelectedModelKey, setSearchTerm, setFilteredModelType } = modelManagerV2Slice.actions;
|
||||||
|
|
||||||
|
export const selectModelManagerSlice = (state: RootState) => state.modelmanager;
|
||||||
|
|
||||||
|
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||||
|
export const migrateModelManagerState = (state: any): any => {
|
||||||
|
if (!('_version' in state)) {
|
||||||
|
state._version = 1;
|
||||||
|
}
|
||||||
|
return state;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const modelManagerPersistConfig: PersistConfig<ModelManagerState> = {
|
||||||
|
name: modelManagerV2Slice.name,
|
||||||
|
initialState: initialModelManagerState,
|
||||||
|
migrate: migrateModelManagerState,
|
||||||
|
persistDenylist: [],
|
||||||
|
};
|
@ -1,17 +1,8 @@
|
|||||||
import {
|
import { Box, Button, Flex, Heading } from '@invoke-ai/ui-library';
|
||||||
Box,
|
import { SyncModelsIconButton } from 'features/modelManager/components/SyncModels/SyncModelsIconButton';
|
||||||
Button,
|
|
||||||
Flex,
|
import ModelList from './ModelManagerPanel/ModelList';
|
||||||
Heading,
|
import { ModelListNavigation } from './ModelManagerPanel/ModelListNavigation';
|
||||||
IconButton,
|
|
||||||
Input,
|
|
||||||
InputGroup,
|
|
||||||
InputRightElement,
|
|
||||||
Spacer,
|
|
||||||
} from '@invoke-ai/ui-library';
|
|
||||||
import { t } from 'i18next';
|
|
||||||
import { PiXBold } from 'react-icons/pi';
|
|
||||||
import { SyncModelsIconButton } from '../../modelManager/components/SyncModels/SyncModelsIconButton';
|
|
||||||
|
|
||||||
export const ModelManager = () => {
|
export const ModelManager = () => {
|
||||||
return (
|
return (
|
||||||
@ -27,17 +18,8 @@ export const ModelManager = () => {
|
|||||||
</Flex>
|
</Flex>
|
||||||
</Flex>
|
</Flex>
|
||||||
<Box layerStyle="second" p={3} borderRadius="base" w="full" h="full">
|
<Box layerStyle="second" p={3} borderRadius="base" w="full" h="full">
|
||||||
<Flex gap={2} alignItems="center" justifyContent="space-between">
|
<ModelListNavigation />
|
||||||
<Button>All Models</Button>
|
<ModelList />
|
||||||
<Spacer />
|
|
||||||
<InputGroup>
|
|
||||||
<Input placeholder={t('boards.searchBoard')} data-testid="board-search-input" />(
|
|
||||||
<InputRightElement h="full" pe={2}>
|
|
||||||
<IconButton size="sm" variant="link" aria-label={t('boards.clearSearch')} icon={<PiXBold />} />
|
|
||||||
</InputRightElement>
|
|
||||||
)
|
|
||||||
</InputGroup>
|
|
||||||
</Flex>
|
|
||||||
</Box>
|
</Box>
|
||||||
</Box>
|
</Box>
|
||||||
);
|
);
|
||||||
|
@ -0,0 +1,160 @@
|
|||||||
|
import { Flex, Spinner, Text } from '@invoke-ai/ui-library';
|
||||||
|
import type { EntityState } from '@reduxjs/toolkit';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { forEach } from 'lodash-es';
|
||||||
|
import { memo } from 'react';
|
||||||
|
import { ALL_BASE_MODELS } from 'services/api/constants';
|
||||||
|
import {
|
||||||
|
useGetControlNetModelsQuery,
|
||||||
|
useGetIPAdapterModelsQuery,
|
||||||
|
useGetLoRAModelsQuery,
|
||||||
|
useGetMainModelsQuery,
|
||||||
|
useGetT2IAdapterModelsQuery,
|
||||||
|
useGetTextualInversionModelsQuery,
|
||||||
|
useGetVaeModelsQuery,
|
||||||
|
} from 'services/api/endpoints/models';
|
||||||
|
import type { AnyModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
|
import { ModelListWrapper } from './ModelListWrapper';
|
||||||
|
|
||||||
|
const ModelList = () => {
|
||||||
|
const { searchTerm, filteredModelType } = useAppSelector((s) => s.modelmanagerV2);
|
||||||
|
|
||||||
|
const { filteredMainModels, isLoadingMainModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
|
||||||
|
selectFromResult: ({ data, isLoading }) => ({
|
||||||
|
filteredMainModels: modelsFilter(data, searchTerm, filteredModelType),
|
||||||
|
isLoadingMainModels: isLoading,
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
|
||||||
|
const { filteredLoraModels, isLoadingLoraModels } = useGetLoRAModelsQuery(undefined, {
|
||||||
|
selectFromResult: ({ data, isLoading }) => ({
|
||||||
|
filteredLoraModels: modelsFilter(data, searchTerm, filteredModelType),
|
||||||
|
isLoadingLoraModels: isLoading,
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
|
||||||
|
const { filteredTextualInversionModels, isLoadingTextualInversionModels } = useGetTextualInversionModelsQuery(
|
||||||
|
undefined,
|
||||||
|
{
|
||||||
|
selectFromResult: ({ data, isLoading }) => ({
|
||||||
|
filteredTextualInversionModels: modelsFilter(data, searchTerm, filteredModelType),
|
||||||
|
isLoadingTextualInversionModels: isLoading,
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
const { filteredControlnetModels, isLoadingControlnetModels } = useGetControlNetModelsQuery(undefined, {
|
||||||
|
selectFromResult: ({ data, isLoading }) => ({
|
||||||
|
filteredControlnetModels: modelsFilter(data, searchTerm, filteredModelType),
|
||||||
|
isLoadingControlnetModels: isLoading,
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
|
||||||
|
const { filteredT2iAdapterModels, isLoadingT2IAdapterModels } = useGetT2IAdapterModelsQuery(undefined, {
|
||||||
|
selectFromResult: ({ data, isLoading }) => ({
|
||||||
|
filteredT2iAdapterModels: modelsFilter(data, searchTerm, filteredModelType),
|
||||||
|
isLoadingT2IAdapterModels: isLoading,
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
|
||||||
|
const { filteredIpAdapterModels, isLoadingIpAdapterModels } = useGetIPAdapterModelsQuery(undefined, {
|
||||||
|
selectFromResult: ({ data, isLoading }) => ({
|
||||||
|
filteredIpAdapterModels: modelsFilter(data, searchTerm, filteredModelType),
|
||||||
|
isLoadingIpAdapterModels: isLoading,
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
|
||||||
|
const { filteredVaeModels, isLoadingVaeModels } = useGetVaeModelsQuery(undefined, {
|
||||||
|
selectFromResult: ({ data, isLoading }) => ({
|
||||||
|
filteredVaeModels: modelsFilter(data, searchTerm, filteredModelType),
|
||||||
|
isLoadingVaeModels: isLoading,
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex flexDirection="column" p={4}>
|
||||||
|
<Flex flexDirection="column" maxHeight={window.innerHeight - 130} overflow="scroll">
|
||||||
|
{/* Main Model List */}
|
||||||
|
{isLoadingMainModels && <FetchingModelsLoader loadingMessage="Loading Main..." />}
|
||||||
|
{!isLoadingMainModels && filteredMainModels.length > 0 && (
|
||||||
|
<ModelListWrapper title="Main" modelList={filteredMainModels} key="main" />
|
||||||
|
)}
|
||||||
|
{/* LoRAs List */}
|
||||||
|
{isLoadingLoraModels && <FetchingModelsLoader loadingMessage="Loading LoRAs..." />}
|
||||||
|
{!isLoadingLoraModels && filteredLoraModels.length > 0 && (
|
||||||
|
<ModelListWrapper title="LoRAs" modelList={filteredLoraModels} key="loras" />
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* TI List */}
|
||||||
|
{isLoadingTextualInversionModels && <FetchingModelsLoader loadingMessage="Loading Textual Inversions..." />}
|
||||||
|
{!isLoadingTextualInversionModels && filteredTextualInversionModels.length > 0 && (
|
||||||
|
<ModelListWrapper
|
||||||
|
title="Textual Inversions"
|
||||||
|
modelList={filteredTextualInversionModels}
|
||||||
|
key="textual-inversions"
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* VAE List */}
|
||||||
|
{isLoadingVaeModels && <FetchingModelsLoader loadingMessage="Loading VAEs..." />}
|
||||||
|
{!isLoadingVaeModels && filteredVaeModels.length > 0 && (
|
||||||
|
<ModelListWrapper title="VAEs" modelList={filteredVaeModels} key="vae" />
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Controlnet List */}
|
||||||
|
{isLoadingControlnetModels && <FetchingModelsLoader loadingMessage="Loading Controlnets..." />}
|
||||||
|
{!isLoadingControlnetModels && filteredControlnetModels.length > 0 && (
|
||||||
|
<ModelListWrapper title="Controlnets" modelList={filteredControlnetModels} key="controlnets" />
|
||||||
|
)}
|
||||||
|
{/* IP Adapter List */}
|
||||||
|
{isLoadingIpAdapterModels && <FetchingModelsLoader loadingMessage="Loading IP Adapters..." />}
|
||||||
|
{!isLoadingIpAdapterModels && filteredIpAdapterModels.length > 0 && (
|
||||||
|
<ModelListWrapper title="IP Adapters" modelList={filteredIpAdapterModels} key="ip-adapters" />
|
||||||
|
)}
|
||||||
|
{/* T2I Adapters List */}
|
||||||
|
{isLoadingT2IAdapterModels && <FetchingModelsLoader loadingMessage="Loading T2I Adapters..." />}
|
||||||
|
{!isLoadingT2IAdapterModels && filteredT2iAdapterModels.length > 0 && (
|
||||||
|
<ModelListWrapper title="T2I Adapters" modelList={filteredT2iAdapterModels} key="t2i-adapters" />
|
||||||
|
)}
|
||||||
|
</Flex>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(ModelList);
|
||||||
|
|
||||||
|
const modelsFilter = <T extends AnyModelConfig>(
|
||||||
|
data: EntityState<T, string> | undefined,
|
||||||
|
nameFilter: string,
|
||||||
|
filteredModelType: string | null
|
||||||
|
): T[] => {
|
||||||
|
const filteredModels: T[] = [];
|
||||||
|
|
||||||
|
forEach(data?.entities, (model) => {
|
||||||
|
if (!model) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const matchesFilter = model.name.toLowerCase().includes(nameFilter.toLowerCase());
|
||||||
|
const matchesType = filteredModelType ? model.type === filteredModelType : true;
|
||||||
|
|
||||||
|
if (matchesFilter && matchesType) {
|
||||||
|
filteredModels.push(model);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return filteredModels;
|
||||||
|
};
|
||||||
|
|
||||||
|
const FetchingModelsLoader = memo(({ loadingMessage }: { loadingMessage?: string }) => {
|
||||||
|
return (
|
||||||
|
<Flex flexDirection="column" gap={4} borderRadius={4} p={4} bg="base.800">
|
||||||
|
<Flex justifyContent="center" alignItems="center" flexDirection="column" p={4} gap={8}>
|
||||||
|
<Spinner />
|
||||||
|
<Text variant="subtext">{loadingMessage ? loadingMessage : 'Fetching...'}</Text>
|
||||||
|
</Flex>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
FetchingModelsLoader.displayName = 'FetchingModelsLoader';
|
@ -0,0 +1,23 @@
|
|||||||
|
import { Box, Divider, Text } from '@invoke-ai/ui-library';
|
||||||
|
|
||||||
|
export const ModelListHeader = ({ title }: { title: string }) => {
|
||||||
|
return (
|
||||||
|
<Box position="relative" padding="10px 0">
|
||||||
|
<Divider sx={{ backgroundColor: 'base.400' }} />
|
||||||
|
<Box
|
||||||
|
sx={{
|
||||||
|
position: 'absolute',
|
||||||
|
top: '50%',
|
||||||
|
left: 0,
|
||||||
|
transform: 'translate(0, -50%)',
|
||||||
|
backgroundColor: 'base.800',
|
||||||
|
padding: '10px',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<Text variant="subtext" fontSize="sm">
|
||||||
|
{title}
|
||||||
|
</Text>
|
||||||
|
</Box>
|
||||||
|
</Box>
|
||||||
|
);
|
||||||
|
};
|
@ -0,0 +1,115 @@
|
|||||||
|
import {
|
||||||
|
Badge,
|
||||||
|
Button,
|
||||||
|
ConfirmationAlertDialog,
|
||||||
|
Flex,
|
||||||
|
IconButton,
|
||||||
|
Text,
|
||||||
|
Tooltip,
|
||||||
|
useDisclosure,
|
||||||
|
} from '@invoke-ai/ui-library';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { setSelectedModelKey } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||||
|
import { MODEL_TYPE_SHORT_MAP } from 'features/parameters/types/constants';
|
||||||
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
|
import { makeToast } from 'features/system/util/makeToast';
|
||||||
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { PiTrashSimpleBold } from 'react-icons/pi';
|
||||||
|
import { useDeleteModelsMutation } from 'services/api/endpoints/models';
|
||||||
|
import type { AnyModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
|
type ModelListItemProps = {
|
||||||
|
model: AnyModelConfig;
|
||||||
|
};
|
||||||
|
|
||||||
|
const ModelListItem = (props: ModelListItemProps) => {
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||||
|
const [deleteModel] = useDeleteModelsMutation();
|
||||||
|
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||||
|
|
||||||
|
const { model } = props;
|
||||||
|
|
||||||
|
const handleSelectModel = useCallback(() => {
|
||||||
|
dispatch(setSelectedModelKey(model.key));
|
||||||
|
}, [model.key, dispatch]);
|
||||||
|
|
||||||
|
const isSelected = useMemo(() => {
|
||||||
|
return selectedModelKey === model.key;
|
||||||
|
}, [selectedModelKey, model.key]);
|
||||||
|
|
||||||
|
const handleModelDelete = useCallback(() => {
|
||||||
|
deleteModel({ key: model.key })
|
||||||
|
.unwrap()
|
||||||
|
.then((_) => {
|
||||||
|
dispatch(
|
||||||
|
addToast(
|
||||||
|
makeToast({
|
||||||
|
title: `${t('modelManager.modelDeleted')}: ${model.name}`,
|
||||||
|
status: 'success',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
);
|
||||||
|
})
|
||||||
|
.catch((error) => {
|
||||||
|
if (error) {
|
||||||
|
dispatch(
|
||||||
|
addToast(
|
||||||
|
makeToast({
|
||||||
|
title: `${t('modelManager.modelDeleteFailed')}: ${model.name}`,
|
||||||
|
status: 'error',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
dispatch(setSelectedModelKey(null));
|
||||||
|
}, [deleteModel, model, dispatch, t]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex gap={2} alignItems="center" w="full">
|
||||||
|
<Flex
|
||||||
|
as={Button}
|
||||||
|
isChecked={isSelected}
|
||||||
|
variant={isSelected ? 'solid' : 'ghost'}
|
||||||
|
justifyContent="start"
|
||||||
|
p={2}
|
||||||
|
borderRadius="base"
|
||||||
|
w="full"
|
||||||
|
alignItems="center"
|
||||||
|
onClick={handleSelectModel}
|
||||||
|
>
|
||||||
|
<Flex gap={4} alignItems="center">
|
||||||
|
<Badge minWidth={14} p={0.5} fontSize="sm" variant="solid">
|
||||||
|
{MODEL_TYPE_SHORT_MAP[model.base as keyof typeof MODEL_TYPE_SHORT_MAP]}
|
||||||
|
</Badge>
|
||||||
|
<Tooltip label={model.description} placement="bottom">
|
||||||
|
<Text>{model.name}</Text>
|
||||||
|
</Tooltip>
|
||||||
|
</Flex>
|
||||||
|
</Flex>
|
||||||
|
<IconButton
|
||||||
|
onClick={onOpen}
|
||||||
|
icon={<PiTrashSimpleBold />}
|
||||||
|
aria-label={t('modelManager.deleteConfig')}
|
||||||
|
colorScheme="error"
|
||||||
|
/>
|
||||||
|
<ConfirmationAlertDialog
|
||||||
|
isOpen={isOpen}
|
||||||
|
onClose={onClose}
|
||||||
|
title={t('modelManager.deleteModel')}
|
||||||
|
acceptCallback={handleModelDelete}
|
||||||
|
acceptButtonText={t('modelManager.delete')}
|
||||||
|
>
|
||||||
|
<Flex rowGap={4} flexDirection="column">
|
||||||
|
<Text fontWeight="bold">{t('modelManager.deleteMsg1')}</Text>
|
||||||
|
<Text>{t('modelManager.deleteMsg2')}</Text>
|
||||||
|
</Flex>
|
||||||
|
</ConfirmationAlertDialog>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(ModelListItem);
|
@ -0,0 +1,52 @@
|
|||||||
|
import { Flex, IconButton,Input, InputGroup, InputRightElement, Spacer } from '@invoke-ai/ui-library';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { setSearchTerm } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||||
|
import { t } from 'i18next';
|
||||||
|
import type { ChangeEventHandler} from 'react';
|
||||||
|
import { useCallback } from 'react';
|
||||||
|
import { PiXBold } from 'react-icons/pi';
|
||||||
|
|
||||||
|
import { ModelTypeFilter } from './ModelTypeFilter';
|
||||||
|
|
||||||
|
export const ModelListNavigation = () => {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const searchTerm = useAppSelector((s) => s.modelmanagerV2.searchTerm);
|
||||||
|
|
||||||
|
const handleSearch: ChangeEventHandler<HTMLInputElement> = useCallback(
|
||||||
|
(event) => {
|
||||||
|
dispatch(setSearchTerm(event.target.value));
|
||||||
|
},
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
const clearSearch = useCallback(() => {
|
||||||
|
dispatch(setSearchTerm(''));
|
||||||
|
}, [dispatch]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex gap={2} alignItems="center" justifyContent="space-between">
|
||||||
|
<ModelTypeFilter />
|
||||||
|
<Spacer />
|
||||||
|
<InputGroup maxW="400px">
|
||||||
|
<Input
|
||||||
|
placeholder={t('modelManager.search')}
|
||||||
|
value={searchTerm || ''}
|
||||||
|
data-testid="board-search-input"
|
||||||
|
onChange={handleSearch}
|
||||||
|
/>
|
||||||
|
|
||||||
|
{!!searchTerm?.length && (
|
||||||
|
<InputRightElement h="full" pe={2}>
|
||||||
|
<IconButton
|
||||||
|
size="sm"
|
||||||
|
variant="link"
|
||||||
|
aria-label={t('boards.clearSearch')}
|
||||||
|
icon={<PiXBold />}
|
||||||
|
onClick={clearSearch}
|
||||||
|
/>
|
||||||
|
</InputRightElement>
|
||||||
|
)}
|
||||||
|
</InputGroup>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
@ -0,0 +1,25 @@
|
|||||||
|
import { Flex } from '@invoke-ai/ui-library';
|
||||||
|
import type { AnyModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
|
import { ModelListHeader } from './ModelListHeader';
|
||||||
|
import ModelListItem from './ModelListItem';
|
||||||
|
|
||||||
|
type ModelListWrapperProps = {
|
||||||
|
title: string;
|
||||||
|
modelList: AnyModelConfig[];
|
||||||
|
};
|
||||||
|
|
||||||
|
export const ModelListWrapper = (props: ModelListWrapperProps) => {
|
||||||
|
const { title, modelList } = props;
|
||||||
|
return (
|
||||||
|
<Flex flexDirection="column" p="10px 0">
|
||||||
|
<Flex gap={2} flexDir="column">
|
||||||
|
<ModelListHeader title={title} />
|
||||||
|
|
||||||
|
{modelList.map((model) => (
|
||||||
|
<ModelListItem key={model.key} model={model} />
|
||||||
|
))}
|
||||||
|
</Flex>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
@ -0,0 +1,54 @@
|
|||||||
|
import { Button, Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { setFilteredModelType } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||||
|
import { useCallback } from 'react';
|
||||||
|
import { IoFilter } from 'react-icons/io5';
|
||||||
|
|
||||||
|
export const MODEL_TYPE_LABELS: { [key: string]: string } = {
|
||||||
|
main: 'Main',
|
||||||
|
lora: 'LoRA',
|
||||||
|
embedding: 'Textual Inversion',
|
||||||
|
controlnet: 'ControlNet',
|
||||||
|
vae: 'VAE',
|
||||||
|
t2i_adapter: 'T2I Adapter',
|
||||||
|
ip_adapter: 'IP Adapter',
|
||||||
|
clip_vision: 'Clip Vision',
|
||||||
|
onnx: 'Onnx',
|
||||||
|
};
|
||||||
|
|
||||||
|
export const ModelTypeFilter = () => {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const filteredModelType = useAppSelector((s) => s.modelmanagerV2.filteredModelType);
|
||||||
|
|
||||||
|
const selectModelType = useCallback(
|
||||||
|
(option: string) => {
|
||||||
|
dispatch(setFilteredModelType(option));
|
||||||
|
},
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
const clearModelType = useCallback(() => {
|
||||||
|
dispatch(setFilteredModelType(null));
|
||||||
|
}, [dispatch]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Menu>
|
||||||
|
<MenuButton as={Button} leftIcon={<IoFilter />}>
|
||||||
|
{filteredModelType ? MODEL_TYPE_LABELS[filteredModelType] : 'All Models'}
|
||||||
|
</MenuButton>
|
||||||
|
<MenuList>
|
||||||
|
<MenuItem onClick={clearModelType}>All Models</MenuItem>
|
||||||
|
{Object.keys(MODEL_TYPE_LABELS).map((option) => (
|
||||||
|
<MenuItem
|
||||||
|
sx={{
|
||||||
|
backgroundColor: filteredModelType === option ? 'base.700' : 'transparent',
|
||||||
|
}}
|
||||||
|
onClick={selectModelType.bind(null, option)}
|
||||||
|
>
|
||||||
|
{MODEL_TYPE_LABELS[option]}
|
||||||
|
</MenuItem>
|
||||||
|
))}
|
||||||
|
</MenuList>
|
||||||
|
</Menu>
|
||||||
|
);
|
||||||
|
};
|
14
invokeai/frontend/web/src/features/modelManagerV2/types.ts
Normal file
14
invokeai/frontend/web/src/features/modelManagerV2/types.ts
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
import { z } from "zod";
|
||||||
|
|
||||||
|
export const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']);
|
||||||
|
export const zModelType = z.enum([
|
||||||
|
'main',
|
||||||
|
'vae',
|
||||||
|
'lora',
|
||||||
|
'controlnet',
|
||||||
|
'embedding',
|
||||||
|
'ip_adapter',
|
||||||
|
'clip_vision',
|
||||||
|
't2i_adapter',
|
||||||
|
'onnx', // TODO(psyche): Remove this when removed from backend
|
||||||
|
]);
|
@ -1,8 +1,7 @@
|
|||||||
import { Flex, Box } from '@invoke-ai/ui-library';
|
import { Box,Flex } from '@invoke-ai/ui-library';
|
||||||
|
import { ImportModels } from 'features/modelManagerV2/subpanels/ImportModels';
|
||||||
|
import { ModelManager } from 'features/modelManagerV2/subpanels/ModelManager';
|
||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
import { ImportModels } from '../../../modelManagerV2/subpanels/ImportModels';
|
|
||||||
import { ModelManager } from '../../../modelManagerV2/subpanels/ModelManager';
|
|
||||||
|
|
||||||
const ModelManagerTab = () => {
|
const ModelManagerTab = () => {
|
||||||
return (
|
return (
|
||||||
|
Loading…
Reference in New Issue
Block a user