feat(ui): refactor model manager ui

- simplify UI logic in `ModelManagerPanel` components
- fix up the types a bit to make it easier to select models
- remove `openModel` state, just make it a useState since it is very local to model manager
This commit is contained in:
psychedelicious 2023-07-14 19:22:37 +10:00
parent f2af82bf73
commit 1e5ae9d986
11 changed files with 384 additions and 532 deletions

View File

@ -13,7 +13,6 @@ export const systemPersistDenylist: (keyof SystemState)[] = [
'isProcessing', 'isProcessing',
'totalIterations', 'totalIterations',
'totalSteps', 'totalSteps',
'openModel',
'isCancelScheduled', 'isCancelScheduled',
'progressImage', 'progressImage',
'wereModelsReceived', 'wereModelsReceived',

View File

@ -46,7 +46,6 @@ export interface SystemState {
toastQueue: UseToastOptions[]; toastQueue: UseToastOptions[];
searchFolder: string | null; searchFolder: string | null;
foundModels: InvokeAI.FoundModel[] | null; foundModels: InvokeAI.FoundModel[] | null;
openModel: string | null;
/** /**
* The current progress image * The current progress image
*/ */
@ -109,7 +108,6 @@ export const initialSystemState: SystemState = {
toastQueue: [], toastQueue: [],
searchFolder: null, searchFolder: null,
foundModels: null, foundModels: null,
openModel: null,
progressImage: null, progressImage: null,
shouldAntialiasProgressImage: false, shouldAntialiasProgressImage: false,
sessionId: null, sessionId: null,
@ -164,9 +162,6 @@ export const systemSlice = createSlice({
) => { ) => {
state.foundModels = action.payload; state.foundModels = action.payload;
}, },
setOpenModel: (state, action: PayloadAction<string | null>) => {
state.openModel = action.payload;
},
/** /**
* A cancel was scheduled * A cancel was scheduled
*/ */
@ -433,7 +428,6 @@ export const {
clearToastQueue, clearToastQueue,
setSearchFolder, setSearchFolder,
setFoundModels, setFoundModels,
setOpenModel,
cancelScheduled, cancelScheduled,
scheduledCancelAborted, scheduledCancelAborted,
cancelTypeChanged, cancelTypeChanged,

View File

@ -1,14 +1,6 @@
import { import { Tab, TabList, TabPanel, TabPanels, Tabs } from '@chakra-ui/react';
Tab,
TabList,
TabPanel,
TabPanels,
Tabs,
useColorMode,
} from '@chakra-ui/react';
import i18n from 'i18n'; import i18n from 'i18n';
import { ReactNode, memo } from 'react'; import { ReactNode, memo } from 'react';
import { mode } from 'theme/util/mode';
import AddModelsPanel from './subpanels/AddModelsPanel'; import AddModelsPanel from './subpanels/AddModelsPanel';
import MergeModelsPanel from './subpanels/MergeModelsPanel'; import MergeModelsPanel from './subpanels/MergeModelsPanel';
import ModelManagerPanel from './subpanels/ModelManagerPanel'; import ModelManagerPanel from './subpanels/ModelManagerPanel';
@ -21,7 +13,7 @@ type ModelManagerTabInfo = {
content: ReactNode; content: ReactNode;
}; };
const modelManagerTabs: ModelManagerTabInfo[] = [ const tabs: ModelManagerTabInfo[] = [
{ {
id: 'modelManager', id: 'modelManager',
label: i18n.t('modelManager.modelManager'), label: i18n.t('modelManager.modelManager'),
@ -40,50 +32,25 @@ const modelManagerTabs: ModelManagerTabInfo[] = [
]; ];
const ModelManagerTab = () => { const ModelManagerTab = () => {
const { colorMode } = useColorMode();
const renderTabsList = () => {
const modelManagerTabListsToRender: ReactNode[] = [];
modelManagerTabs.forEach((modelManagerTab) => {
modelManagerTabListsToRender.push(
<Tab key={modelManagerTab.id}>{modelManagerTab.label}</Tab>
);
});
return (
<TabList
sx={{
w: '100%',
color: mode('base.900', 'base.400')(colorMode),
flexDirection: 'row',
borderBottomWidth: 2,
borderColor: mode('accent.300', 'accent.600')(colorMode),
}}
>
{modelManagerTabListsToRender}
</TabList>
);
};
const renderTabPanels = () => {
const modelManagerTabPanelsToRender: ReactNode[] = [];
modelManagerTabs.forEach((modelManagerTab) => {
modelManagerTabPanelsToRender.push(
<TabPanel key={modelManagerTab.id}>{modelManagerTab.content}</TabPanel>
);
});
return <TabPanels sx={{ p: 2 }}>{modelManagerTabPanelsToRender}</TabPanels>;
};
return ( return (
<Tabs <Tabs
isLazy isLazy
variant="invokeAI" variant="line"
sx={{ w: 'full', h: 'full', p: 2, gap: 4, flexDirection: 'column' }} layerStyle="first"
sx={{ w: 'full', h: 'full', p: 4, gap: 4, borderRadius: 'base' }}
> >
{renderTabsList()} <TabList>
{renderTabPanels()} {tabs.map((tab) => (
<Tab sx={{ borderTopRadius: 'base' }} key={tab.id}>
{tab.label}
</Tab>
))}
</TabList>
<TabPanels sx={{ p: 4 }}>
{tabs.map((tab) => (
<TabPanel key={tab.id}>{tab.content}</TabPanel>
))}
</TabPanels>
</Tabs> </Tabs>
); );
}; };

View File

@ -1,48 +1,59 @@
import { Flex } from '@chakra-ui/react'; import { Flex, Text } from '@chakra-ui/react';
import { RootState } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { useGetMainModelsQuery } from 'services/api/endpoints/models'; import { useState } from 'react';
import {
MainModelConfigEntity,
useGetMainModelsQuery,
} from 'services/api/endpoints/models';
import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit'; import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit';
import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit'; import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit';
import ModelList from './ModelManagerPanel/ModelList'; import ModelList from './ModelManagerPanel/ModelList';
export default function ModelManagerPanel() { export default function ModelManagerPanel() {
const { data: mainModels } = useGetMainModelsQuery(); const [selectedModelId, setSelectedModelId] = useState<string>();
const { model } = useGetMainModelsQuery(undefined, {
selectFromResult: ({ data }) => ({
model: selectedModelId ? data?.entities[selectedModelId] : undefined,
}),
});
const openModel = useAppSelector(
(state: RootState) => state.system.openModel
);
const renderModelEditTabs = () => {
if (!openModel || !mainModels) return;
const openedModelData = mainModels['entities'][openModel];
if (openedModelData && openedModelData.model_format === 'diffusers') {
return (
<DiffusersModelEdit
modelToEdit={openModel}
retrievedModel={openedModelData}
key={openModel}
/>
);
}
if (openedModelData && openedModelData.model_format === 'checkpoint') {
return (
<CheckpointModelEdit
modelToEdit={openModel}
retrievedModel={openedModelData}
key={openModel}
/>
);
}
};
return ( return (
<Flex width="100%" columnGap={8}> <Flex width="100%" columnGap={8}>
<ModelList /> <ModelList
{renderModelEditTabs()} selectedModelId={selectedModelId}
setSelectedModelId={setSelectedModelId}
/>
<ModelEdit model={model} />
</Flex> </Flex>
); );
} }
type ModelEditProps = {
model: MainModelConfigEntity | undefined;
};
const ModelEdit = (props: ModelEditProps) => {
const { model } = props;
if (model?.model_format === 'checkpoint') {
return <CheckpointModelEdit model={model} />;
}
if (model?.model_format === 'diffusers') {
return <DiffusersModelEdit model={model} />;
}
return (
<Flex
sx={{
width: '100%',
justifyContent: 'center',
alignItems: 'center',
borderRadius: 'base',
bg: 'base.900',
}}
>
<Text fontWeight={500}>Pick A Model To Edit</Text>
</Flex>
);
};

View File

@ -1,21 +1,20 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { Divider, Flex, Text } from '@chakra-ui/react'; import { Divider, Flex, Text } from '@chakra-ui/react';
// import { addNewModel } from 'app/socketio/actions';
import { useForm } from '@mantine/form'; import { useForm } from '@mantine/form';
import { useTranslation } from 'react-i18next';
import IAIButton from 'common/components/IAIButton';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { makeToast } from 'app/components/Toaster'; import { makeToast } from 'app/components/Toaster';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import IAIMantineTextInput from 'common/components/IAIMantineInput'; import IAIMantineTextInput from 'common/components/IAIMantineInput';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { selectIsBusy } from 'features/system/store/systemSelectors'; import { selectIsBusy } from 'features/system/store/systemSelectors';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { useUpdateMainModelsMutation } from 'services/api/endpoints/models'; import { useCallback } from 'react';
import { components } from 'services/api/schema'; import { useTranslation } from 'react-i18next';
import {
CheckpointModelConfigEntity,
useUpdateMainModelsMutation,
} from 'services/api/endpoints/models';
import { CheckpointModelConfig } from 'services/api/types';
import ModelConvert from './ModelConvert'; import ModelConvert from './ModelConvert';
const baseModelSelectData = [ const baseModelSelectData = [
@ -29,36 +28,31 @@ const variantSelectData = [
{ value: 'depth', label: 'Depth' }, { value: 'depth', label: 'Depth' },
]; ];
export type CheckpointModelConfig =
| components['schemas']['StableDiffusion1ModelCheckpointConfig']
| components['schemas']['StableDiffusion2ModelCheckpointConfig'];
type CheckpointModelEditProps = { type CheckpointModelEditProps = {
modelToEdit: string; model: CheckpointModelConfigEntity;
retrievedModel: CheckpointModelConfig;
}; };
export default function CheckpointModelEdit(props: CheckpointModelEditProps) { export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
const isBusy = useAppSelector(selectIsBusy); const isBusy = useAppSelector(selectIsBusy);
const { modelToEdit, retrievedModel } = props; const { model } = props;
const [updateMainModel, { error, isLoading }] = useUpdateMainModelsMutation(); const [updateMainModel, { isLoading }] = useUpdateMainModelsMutation();
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const checkpointEditForm = useForm<CheckpointModelConfig>({ const checkpointEditForm = useForm<CheckpointModelConfig>({
initialValues: { initialValues: {
model_name: retrievedModel.model_name ? retrievedModel.model_name : '', model_name: model.model_name ? model.model_name : '',
base_model: retrievedModel.base_model, base_model: model.base_model,
model_type: 'main', model_type: 'main',
path: retrievedModel.path ? retrievedModel.path : '', path: model.path ? model.path : '',
description: retrievedModel.description ? retrievedModel.description : '', description: model.description ? model.description : '',
model_format: 'checkpoint', model_format: 'checkpoint',
vae: retrievedModel.vae ? retrievedModel.vae : '', vae: model.vae ? model.vae : '',
config: retrievedModel.config ? retrievedModel.config : '', config: model.config ? model.config : '',
variant: retrievedModel.variant, variant: model.variant,
}, },
validate: { validate: {
path: (value) => path: (value) =>
@ -66,10 +60,11 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
}, },
}); });
const editModelFormSubmitHandler = (values: CheckpointModelConfig) => { const editModelFormSubmitHandler = useCallback(
(values: CheckpointModelConfig) => {
const responseBody = { const responseBody = {
base_model: retrievedModel.base_model, base_model: model.base_model,
model_name: retrievedModel.model_name, model_name: model.model_name,
body: values, body: values,
}; };
updateMainModel(responseBody) updateMainModel(responseBody)
@ -96,20 +91,29 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
) )
); );
}); });
}; },
[
checkpointEditForm,
dispatch,
model.base_model,
model.model_name,
t,
updateMainModel,
]
);
return modelToEdit ? ( return (
<Flex flexDirection="column" rowGap={4} width="100%"> <Flex flexDirection="column" rowGap={4} width="100%">
<Flex justifyContent="space-between" alignItems="center"> <Flex justifyContent="space-between" alignItems="center">
<Flex flexDirection="column"> <Flex flexDirection="column">
<Text fontSize="lg" fontWeight="bold"> <Text fontSize="lg" fontWeight="bold">
{retrievedModel.model_name} {model.model_name}
</Text> </Text>
<Text fontSize="sm" color="base.400"> <Text fontSize="sm" color="base.400">
{MODEL_TYPE_MAP[retrievedModel.base_model]} Model {MODEL_TYPE_MAP[model.base_model]} Model
</Text> </Text>
</Flex> </Flex>
<ModelConvert model={retrievedModel} /> <ModelConvert model={model} />
</Flex> </Flex>
<Divider /> <Divider />
@ -161,17 +165,5 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
</form> </form>
</Flex> </Flex>
</Flex> </Flex>
) : (
<Flex
sx={{
width: '100%',
justifyContent: 'center',
alignItems: 'center',
borderRadius: 'base',
bg: 'base.900',
}}
>
<Text fontWeight={500}>Pick A Model To Edit</Text>
</Flex>
); );
} }

View File

@ -1,29 +1,23 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { Divider, Flex, Text } from '@chakra-ui/react'; import { Divider, Flex, Text } from '@chakra-ui/react';
// import { addNewModel } from 'app/socketio/actions';
import { useTranslation } from 'react-i18next';
import { useForm } from '@mantine/form'; import { useForm } from '@mantine/form';
import { makeToast } from 'app/components/Toaster'; import { makeToast } from 'app/components/Toaster';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton'; import IAIButton from 'common/components/IAIButton';
import IAIMantineTextInput from 'common/components/IAIMantineInput'; import IAIMantineTextInput from 'common/components/IAIMantineInput';
import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { selectIsBusy } from 'features/system/store/systemSelectors'; import { selectIsBusy } from 'features/system/store/systemSelectors';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { useUpdateMainModelsMutation } from 'services/api/endpoints/models'; import { useCallback } from 'react';
import { components } from 'services/api/schema'; import { useTranslation } from 'react-i18next';
import {
export type DiffusersModelConfig = DiffusersModelConfigEntity,
| components['schemas']['StableDiffusion1ModelDiffusersConfig'] useUpdateMainModelsMutation,
| components['schemas']['StableDiffusion2ModelDiffusersConfig']; } from 'services/api/endpoints/models';
import { DiffusersModelConfig } from 'services/api/types';
type DiffusersModelEditProps = { type DiffusersModelEditProps = {
modelToEdit: string; model: DiffusersModelConfigEntity;
retrievedModel: DiffusersModelConfig;
}; };
const baseModelSelectData = [ const baseModelSelectData = [
@ -40,23 +34,23 @@ const variantSelectData = [
export default function DiffusersModelEdit(props: DiffusersModelEditProps) { export default function DiffusersModelEdit(props: DiffusersModelEditProps) {
const isBusy = useAppSelector(selectIsBusy); const isBusy = useAppSelector(selectIsBusy);
const { retrievedModel, modelToEdit } = props; const { model } = props;
const [updateMainModel, { isLoading, error }] = useUpdateMainModelsMutation(); const [updateMainModel, { isLoading }] = useUpdateMainModelsMutation();
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const diffusersEditForm = useForm<DiffusersModelConfig>({ const diffusersEditForm = useForm<DiffusersModelConfig>({
initialValues: { initialValues: {
model_name: retrievedModel.model_name ? retrievedModel.model_name : '', model_name: model.model_name ? model.model_name : '',
base_model: retrievedModel.base_model, base_model: model.base_model,
model_type: 'main', model_type: 'main',
path: retrievedModel.path ? retrievedModel.path : '', path: model.path ? model.path : '',
description: retrievedModel.description ? retrievedModel.description : '', description: model.description ? model.description : '',
model_format: 'diffusers', model_format: 'diffusers',
vae: retrievedModel.vae ? retrievedModel.vae : '', vae: model.vae ? model.vae : '',
variant: retrievedModel.variant, variant: model.variant,
}, },
validate: { validate: {
path: (value) => path: (value) =>
@ -64,10 +58,11 @@ export default function DiffusersModelEdit(props: DiffusersModelEditProps) {
}, },
}); });
const editModelFormSubmitHandler = (values: DiffusersModelConfig) => { const editModelFormSubmitHandler = useCallback(
(values: DiffusersModelConfig) => {
const responseBody = { const responseBody = {
base_model: retrievedModel.base_model, base_model: model.base_model,
model_name: retrievedModel.model_name, model_name: model.model_name,
body: values, body: values,
}; };
updateMainModel(responseBody) updateMainModel(responseBody)
@ -94,16 +89,25 @@ export default function DiffusersModelEdit(props: DiffusersModelEditProps) {
) )
); );
}); });
}; },
[
diffusersEditForm,
dispatch,
model.base_model,
model.model_name,
t,
updateMainModel,
]
);
return modelToEdit ? ( return (
<Flex flexDirection="column" rowGap={4} width="100%"> <Flex flexDirection="column" rowGap={4} width="100%">
<Flex flexDirection="column"> <Flex flexDirection="column">
<Text fontSize="lg" fontWeight="bold"> <Text fontSize="lg" fontWeight="bold">
{retrievedModel.model_name} {model.model_name}
</Text> </Text>
<Text fontSize="sm" color="base.400"> <Text fontSize="sm" color="base.400">
{MODEL_TYPE_MAP[retrievedModel.base_model]} Model {MODEL_TYPE_MAP[model.base_model]} Model
</Text> </Text>
</Flex> </Flex>
<Divider /> <Divider />
@ -146,17 +150,5 @@ export default function DiffusersModelEdit(props: DiffusersModelEditProps) {
</Flex> </Flex>
</form> </form>
</Flex> </Flex>
) : (
<Flex
sx={{
width: '100%',
justifyContent: 'center',
alignItems: 'center',
borderRadius: 'base',
bg: 'base.900',
}}
>
<Text fontWeight={'500'}>Pick A Model To Edit</Text>
</Flex>
); );
} }

View File

@ -1,187 +1,46 @@
import { Box, Flex, Spinner, Text, useColorMode } from '@chakra-ui/react'; import { ButtonGroup, Flex, Text } from '@chakra-ui/react';
import { EntityState } from '@reduxjs/toolkit';
import IAIButton from 'common/components/IAIButton'; import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput'; import IAIInput from 'common/components/IAIInput';
import { forEach } from 'lodash-es';
import type { ChangeEvent } from 'react';
import { useCallback, useState } from 'react';
import { useTranslation } from 'react-i18next';
import {
MainModelConfigEntity,
useGetMainModelsQuery,
} from 'services/api/endpoints/models';
import ModelListItem from './ModelListItem'; import ModelListItem from './ModelListItem';
import { useTranslation } from 'react-i18next'; type ModelListProps = {
selectedModelId: string | undefined;
import type { ChangeEvent, ReactNode } from 'react'; setSelectedModelId: (name: string | undefined) => void;
import React, { useMemo, useState, useTransition } from 'react';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
import { mode } from 'theme/util/mode';
function ModelFilterButton({
label,
isActive,
onClick,
}: {
label: string;
isActive: boolean;
onClick: () => void;
}) {
return (
<IAIButton onClick={onClick} isChecked={isActive} size="sm">
{label}
</IAIButton>
);
}
const ModelList = () => {
const { data: mainModels } = useGetMainModelsQuery();
const { colorMode } = useColorMode();
const [renderModelList, setRenderModelList] = React.useState<boolean>(false);
React.useEffect(() => {
const timer = setTimeout(() => {
setRenderModelList(true);
}, 200);
return () => clearTimeout(timer);
}, []);
const [searchText, setSearchText] = useState<string>('');
const [isSelectedFilter, setIsSelectedFilter] = useState<
'all' | 'checkpoint' | 'diffusers'
>('all');
const [_, startTransition] = useTransition();
const { t } = useTranslation();
const handleSearchFilter = (e: ChangeEvent<HTMLInputElement>) => {
startTransition(() => {
setSearchText(e.target.value);
});
}; };
const renderModelListItems = useMemo(() => { type ModelFormat = 'all' | 'checkpoint' | 'diffusers';
const ckptModelListItemsToRender: ReactNode[] = [];
const diffusersModelListItemsToRender: ReactNode[] = [];
const filteredModelListItemsToRender: ReactNode[] = [];
const localFilteredModelListItemsToRender: ReactNode[] = [];
if (!mainModels) return; const ModelList = (props: ModelListProps) => {
const { selectedModelId, setSelectedModelId } = props;
const { t } = useTranslation();
const [nameFilter, setNameFilter] = useState<string>('');
const [modelFormatFilter, setModelFormatFilter] =
useState<ModelFormat>('all');
const modelList = mainModels.entities; const { filteredDiffusersModels } = useGetMainModelsQuery(undefined, {
selectFromResult: ({ data }) => ({
Object.keys(modelList).forEach((model, i) => { filteredDiffusersModels: modelsFilter(data, 'diffusers', nameFilter),
const modelInfo = modelList[model]; }),
// If no model info found for a model, ignore it
if (!modelInfo) return;
if (
modelInfo.model_name.toLowerCase().includes(searchText.toLowerCase())
) {
filteredModelListItemsToRender.push(
<ModelListItem
key={i}
modelKey={model}
name={modelInfo.model_name}
description={modelInfo.description}
/>
);
if (modelInfo?.model_format === isSelectedFilter) {
localFilteredModelListItemsToRender.push(
<ModelListItem
key={i}
modelKey={model}
name={modelInfo.model_name}
description={modelInfo.description}
/>
);
}
}
if (modelInfo?.model_format !== 'diffusers') {
ckptModelListItemsToRender.push(
<ModelListItem
key={i}
modelKey={model}
name={modelInfo.model_name}
description={modelInfo.description}
/>
);
} else {
diffusersModelListItemsToRender.push(
<ModelListItem
key={i}
modelKey={model}
name={modelInfo.model_name}
description={modelInfo.description}
/>
);
}
}); });
return searchText !== '' ? ( const { filteredCheckpointModels } = useGetMainModelsQuery(undefined, {
isSelectedFilter === 'all' ? ( selectFromResult: ({ data }) => ({
<Box marginTop={4}>{filteredModelListItemsToRender}</Box> filteredCheckpointModels: modelsFilter(data, 'checkpoint', nameFilter),
) : ( }),
<Box marginTop={4}>{localFilteredModelListItemsToRender}</Box> });
)
) : (
<Flex flexDirection="column" rowGap={6}>
{isSelectedFilter === 'all' && (
<>
{diffusersModelListItemsToRender.length > 0 && (
<Box>
<Text
sx={{
fontWeight: '500',
py: 2,
px: 4,
mb: 4,
borderRadius: 'base',
width: 'max-content',
fontSize: 'sm',
bg: mode('base.100', 'base.800')(colorMode),
}}
>
{t('modelManager.diffusersModels')}
</Text>
{diffusersModelListItemsToRender}
</Box>
)}
{ckptModelListItemsToRender.length > 0 && ( const handleSearchFilter = useCallback((e: ChangeEvent<HTMLInputElement>) => {
<Box> setNameFilter(e.target.value);
<Text }, []);
sx={{
fontWeight: '500',
py: 2,
px: 4,
my: 4,
mx: 0,
borderRadius: 'base',
width: 'max-content',
fontSize: 'sm',
bg: mode('base.150', 'base.750')(colorMode),
}}
>
{t('modelManager.checkpointModels')}
</Text>
{ckptModelListItemsToRender}
</Box>
)}
</>
)}
{isSelectedFilter === 'diffusers' && (
<Flex flexDirection="column" marginTop={4}>
{diffusersModelListItemsToRender}
</Flex>
)}
{isSelectedFilter === 'checkpoint' && (
<Flex flexDirection="column" marginTop={4}>
{ckptModelListItemsToRender}
</Flex>
)}
</Flex>
);
}, [mainModels, searchText, t, isSelectedFilter, colorMode]);
return ( return (
<Flex flexDirection="column" rowGap={4} width="50%" minWidth="50%"> <Flex flexDirection="column" rowGap={4} width="50%" minWidth="50%">
@ -189,7 +48,6 @@ const ModelList = () => {
onChange={handleSearchFilter} onChange={handleSearchFilter}
label={t('modelManager.search')} label={t('modelManager.search')}
/> />
<Flex <Flex
flexDirection="column" flexDirection="column"
gap={4} gap={4}
@ -197,34 +55,58 @@ const ModelList = () => {
overflow="scroll" overflow="scroll"
paddingInlineEnd={4} paddingInlineEnd={4}
> >
<Flex columnGap={2}> <ButtonGroup isAttached>
<ModelFilterButton <IAIButton
label={t('modelManager.allModels')} onClick={() => setModelFormatFilter('all')}
onClick={() => setIsSelectedFilter('all')} isChecked={modelFormatFilter === 'all'}
isActive={isSelectedFilter === 'all'} size="sm"
/>
<ModelFilterButton
label={t('modelManager.diffusersModels')}
onClick={() => setIsSelectedFilter('diffusers')}
isActive={isSelectedFilter === 'diffusers'}
/>
<ModelFilterButton
label={t('modelManager.checkpointModels')}
onClick={() => setIsSelectedFilter('checkpoint')}
isActive={isSelectedFilter === 'checkpoint'}
/>
</Flex>
{renderModelList ? (
renderModelListItems
) : (
<Flex
width="100%"
minHeight={96}
justifyContent="center"
alignItems="center"
> >
<Spinner /> {t('modelManager.allModels')}
</IAIButton>
<IAIButton
size="sm"
onClick={() => setModelFormatFilter('diffusers')}
isChecked={modelFormatFilter === 'diffusers'}
>
{t('modelManager.diffusersModels')}
</IAIButton>
<IAIButton
size="sm"
onClick={() => setModelFormatFilter('checkpoint')}
isChecked={modelFormatFilter === 'checkpoint'}
>
{t('modelManager.checkpointModels')}
</IAIButton>
</ButtonGroup>
{['all', 'diffusers'].includes(modelFormatFilter) && (
<Flex sx={{ gap: 2, flexDir: 'column' }}>
<Text variant="subtext" size="sm">
Diffusers
</Text>
{filteredDiffusersModels.map((model) => (
<ModelListItem
key={model.id}
model={model}
isSelected={selectedModelId === model.id}
setSelectedModelId={setSelectedModelId}
/>
))}
</Flex>
)}
{['all', 'checkpoint'].includes(modelFormatFilter) && (
<Flex sx={{ gap: 2, flexDir: 'column' }}>
<Text variant="subtext" size="sm">
Checkpoint
</Text>
{filteredCheckpointModels.map((model) => (
<ModelListItem
key={model.id}
model={model}
isSelected={selectedModelId === model.id}
setSelectedModelId={setSelectedModelId}
/>
))}
</Flex> </Flex>
)} )}
</Flex> </Flex>
@ -233,3 +115,27 @@ const ModelList = () => {
}; };
export default ModelList; export default ModelList;
const modelsFilter = (
data: EntityState<MainModelConfigEntity> | undefined,
model_format: ModelFormat,
nameFilter: string
) => {
const filteredModels: MainModelConfigEntity[] = [];
forEach(data?.entities, (model) => {
if (!model) {
return;
}
const matchesFilter = model.model_name
.toLowerCase()
.includes(nameFilter.toLowerCase());
const matchesFormat = model.model_format === model_format;
if (matchesFilter && matchesFormat) {
filteredModels.push(model);
}
});
return filteredModels;
};

View File

@ -1,95 +1,78 @@
import { DeleteIcon, EditIcon } from '@chakra-ui/icons'; import { DeleteIcon } from '@chakra-ui/icons';
import { import { Box, Flex, Spacer, Text, Tooltip } from '@chakra-ui/react';
Box, import { useAppSelector } from 'app/store/storeHooks';
Flex,
Spacer,
Text,
Tooltip,
useColorMode,
} from '@chakra-ui/react';
// import { deleteModel, requestModelChange } from 'app/socketio/actions';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIAlertDialog from 'common/components/IAIAlertDialog'; import IAIAlertDialog from 'common/components/IAIAlertDialog';
import IAIButton from 'common/components/IAIButton';
import IAIIconButton from 'common/components/IAIIconButton'; import IAIIconButton from 'common/components/IAIIconButton';
import { selectIsBusy } from 'features/system/store/systemSelectors'; import { selectIsBusy } from 'features/system/store/systemSelectors';
import { setOpenModel } from 'features/system/store/systemSlice'; import { useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useDeleteMainModelsMutation } from 'services/api/endpoints/models'; import { FaEdit } from 'react-icons/fa';
import { BaseModelType } from 'services/api/types'; import {
import { mode } from 'theme/util/mode'; MainModelConfigEntity,
useDeleteMainModelsMutation,
} from 'services/api/endpoints/models';
type ModelListItemProps = { type ModelListItemProps = {
modelKey: string; model: MainModelConfigEntity;
name: string; isSelected: boolean;
description: string | undefined; setSelectedModelId: (v: string | undefined) => void;
}; };
export default function ModelListItem(props: ModelListItemProps) { export default function ModelListItem(props: ModelListItemProps) {
const isBusy = useAppSelector(selectIsBusy); const isBusy = useAppSelector(selectIsBusy);
const { t } = useTranslation();
const { colorMode } = useColorMode();
const openModel = useAppSelector(
(state: RootState) => state.system.openModel
);
const [deleteMainModel] = useDeleteMainModelsMutation(); const [deleteMainModel] = useDeleteMainModelsMutation();
const { t } = useTranslation(); const { model, isSelected, setSelectedModelId } = props;
const dispatch = useAppDispatch(); const handleSelectModel = useCallback(() => {
setSelectedModelId(model.id);
}, [model.id, setSelectedModelId]);
const { modelKey, name, description } = props; const handleModelDelete = useCallback(() => {
deleteMainModel(model);
const openModelHandler = () => { setSelectedModelId(undefined);
dispatch(setOpenModel(modelKey)); }, [deleteMainModel, model, setSelectedModelId]);
};
const handleModelDelete = () => {
const [base_model, _, model_name] = modelKey.split('/');
deleteMainModel({
base_model: base_model as BaseModelType,
model_name: model_name,
});
dispatch(setOpenModel(null));
};
return ( return (
<Flex sx={{ gap: 2, alignItems: 'center', w: 'full' }}>
<Flex <Flex
alignItems="center" as={IAIButton}
p={2} isChecked={isSelected}
borderRadius="base" sx={{
sx={ p: 2,
modelKey === openModel borderRadius: 'base',
? { w: 'full',
bg: mode('accent.200', 'accent.600')(colorMode), alignItems: 'center',
bg: isSelected ? 'accent.200' : 'base.100',
_hover: { _hover: {
bg: mode('accent.200', 'accent.600')(colorMode), bg: isSelected ? 'accent.250' : 'base.150',
}, },
} _dark: {
: { bg: isSelected ? 'accent.600' : 'base.850',
_hover: { _hover: {
bg: mode('base.100', 'base.800')(colorMode), bg: isSelected ? 'accent.550' : 'base.800',
}, },
} },
} }}
onClick={handleSelectModel}
> >
<Box onClick={openModelHandler} cursor="pointer"> <Box cursor="pointer">
<Tooltip label={description} hasArrow placement="bottom"> <Tooltip label={model.description} hasArrow placement="bottom">
<Text fontWeight="600">{name}</Text> <Text fontWeight="600">{model.model_name}</Text>
</Tooltip> </Tooltip>
</Box> </Box>
<Spacer onClick={openModelHandler} cursor="pointer" /> <Spacer onClick={handleSelectModel} cursor="pointer" />
<Flex gap={2} alignItems="center">
<IAIIconButton <IAIIconButton
icon={<EditIcon />} icon={<FaEdit />}
size="sm" size="sm"
onClick={openModelHandler} onClick={handleSelectModel}
aria-label={t('accessibility.modifyConfig')} aria-label={t('accessibility.modifyConfig')}
isDisabled={isBusy} isDisabled={isBusy}
variant="link"
/> />
</Flex>
<IAIAlertDialog <IAIAlertDialog
title={t('modelManager.deleteModel')} title={t('modelManager.deleteModel')}
acceptCallback={handleModelDelete} acceptCallback={handleModelDelete}
@ -97,7 +80,6 @@ export default function ModelListItem(props: ModelListItemProps) {
triggerComponent={ triggerComponent={
<IAIIconButton <IAIIconButton
icon={<DeleteIcon />} icon={<DeleteIcon />}
size="sm"
aria-label={t('modelManager.deleteConfig')} aria-label={t('modelManager.deleteConfig')}
isDisabled={isBusy} isDisabled={isBusy}
colorScheme="error" colorScheme="error"
@ -110,6 +92,5 @@ export default function ModelListItem(props: ModelListItemProps) {
</Flex> </Flex>
</IAIAlertDialog> </IAIAlertDialog>
</Flex> </Flex>
</Flex>
); );
} }

View File

@ -3,7 +3,9 @@ import { cloneDeep } from 'lodash-es';
import { import {
AnyModelConfig, AnyModelConfig,
BaseModelType, BaseModelType,
CheckpointModelConfig,
ControlNetModelConfig, ControlNetModelConfig,
DiffusersModelConfig,
LoRAModelConfig, LoRAModelConfig,
MainModelConfig, MainModelConfig,
MergeModelConfig, MergeModelConfig,
@ -14,7 +16,13 @@ import {
import { ApiFullTagDescription, LIST_TAG, api } from '..'; import { ApiFullTagDescription, LIST_TAG, api } from '..';
import { paths } from '../schema'; import { paths } from '../schema';
export type MainModelConfigEntity = MainModelConfig & { id: string }; export type DiffusersModelConfigEntity = DiffusersModelConfig & { id: string };
export type CheckpointModelConfigEntity = CheckpointModelConfig & {
id: string;
};
export type MainModelConfigEntity =
| DiffusersModelConfigEntity
| CheckpointModelConfigEntity;
export type LoRAModelConfigEntity = LoRAModelConfig & { id: string }; export type LoRAModelConfigEntity = LoRAModelConfig & { id: string };

View File

@ -42,11 +42,13 @@ export type ControlNetModelConfig =
components['schemas']['ControlNetModelConfig']; components['schemas']['ControlNetModelConfig'];
export type TextualInversionModelConfig = export type TextualInversionModelConfig =
components['schemas']['TextualInversionModelConfig']; components['schemas']['TextualInversionModelConfig'];
export type MainModelConfig = export type DiffusersModelConfig =
| components['schemas']['StableDiffusion1ModelCheckpointConfig']
| components['schemas']['StableDiffusion1ModelDiffusersConfig'] | components['schemas']['StableDiffusion1ModelDiffusersConfig']
| components['schemas']['StableDiffusion2ModelCheckpointConfig']
| components['schemas']['StableDiffusion2ModelDiffusersConfig']; | components['schemas']['StableDiffusion2ModelDiffusersConfig'];
export type CheckpointModelConfig =
| components['schemas']['StableDiffusion1ModelCheckpointConfig']
| components['schemas']['StableDiffusion2ModelCheckpointConfig'];
export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig;
export type AnyModelConfig = export type AnyModelConfig =
| LoRAModelConfig | LoRAModelConfig
| VaeModelConfig | VaeModelConfig

View File

@ -2,7 +2,7 @@ import { defineStyle, defineStyleConfig } from '@chakra-ui/react';
import { mode } from '@chakra-ui/theme-tools'; import { mode } from '@chakra-ui/theme-tools';
const subtext = defineStyle((props) => ({ const subtext = defineStyle((props) => ({
color: mode('colors.base.500', 'colors.base.400')(props), color: mode('base.500', 'base.400')(props),
})); }));
export const textTheme = defineStyleConfig({ export const textTheme = defineStyleConfig({