feat(ui): refactor model manager ui

- simplify UI logic in `ModelManagerPanel` components
- fix up the types a bit to make it easier to select models
- remove `openModel` state, just make it a useState since it is very local to model manager
This commit is contained in:
psychedelicious 2023-07-14 19:22:37 +10:00
parent f2af82bf73
commit 1e5ae9d986
11 changed files with 384 additions and 532 deletions

View File

@ -13,7 +13,6 @@ export const systemPersistDenylist: (keyof SystemState)[] = [
'isProcessing',
'totalIterations',
'totalSteps',
'openModel',
'isCancelScheduled',
'progressImage',
'wereModelsReceived',

View File

@ -46,7 +46,6 @@ export interface SystemState {
toastQueue: UseToastOptions[];
searchFolder: string | null;
foundModels: InvokeAI.FoundModel[] | null;
openModel: string | null;
/**
* The current progress image
*/
@ -109,7 +108,6 @@ export const initialSystemState: SystemState = {
toastQueue: [],
searchFolder: null,
foundModels: null,
openModel: null,
progressImage: null,
shouldAntialiasProgressImage: false,
sessionId: null,
@ -164,9 +162,6 @@ export const systemSlice = createSlice({
) => {
state.foundModels = action.payload;
},
setOpenModel: (state, action: PayloadAction<string | null>) => {
state.openModel = action.payload;
},
/**
* A cancel was scheduled
*/
@ -433,7 +428,6 @@ export const {
clearToastQueue,
setSearchFolder,
setFoundModels,
setOpenModel,
cancelScheduled,
scheduledCancelAborted,
cancelTypeChanged,

View File

@ -1,14 +1,6 @@
import {
Tab,
TabList,
TabPanel,
TabPanels,
Tabs,
useColorMode,
} from '@chakra-ui/react';
import { Tab, TabList, TabPanel, TabPanels, Tabs } from '@chakra-ui/react';
import i18n from 'i18n';
import { ReactNode, memo } from 'react';
import { mode } from 'theme/util/mode';
import AddModelsPanel from './subpanels/AddModelsPanel';
import MergeModelsPanel from './subpanels/MergeModelsPanel';
import ModelManagerPanel from './subpanels/ModelManagerPanel';
@ -21,7 +13,7 @@ type ModelManagerTabInfo = {
content: ReactNode;
};
const modelManagerTabs: ModelManagerTabInfo[] = [
const tabs: ModelManagerTabInfo[] = [
{
id: 'modelManager',
label: i18n.t('modelManager.modelManager'),
@ -40,50 +32,25 @@ const modelManagerTabs: ModelManagerTabInfo[] = [
];
const ModelManagerTab = () => {
const { colorMode } = useColorMode();
const renderTabsList = () => {
const modelManagerTabListsToRender: ReactNode[] = [];
modelManagerTabs.forEach((modelManagerTab) => {
modelManagerTabListsToRender.push(
<Tab key={modelManagerTab.id}>{modelManagerTab.label}</Tab>
);
});
return (
<TabList
sx={{
w: '100%',
color: mode('base.900', 'base.400')(colorMode),
flexDirection: 'row',
borderBottomWidth: 2,
borderColor: mode('accent.300', 'accent.600')(colorMode),
}}
>
{modelManagerTabListsToRender}
</TabList>
);
};
const renderTabPanels = () => {
const modelManagerTabPanelsToRender: ReactNode[] = [];
modelManagerTabs.forEach((modelManagerTab) => {
modelManagerTabPanelsToRender.push(
<TabPanel key={modelManagerTab.id}>{modelManagerTab.content}</TabPanel>
);
});
return <TabPanels sx={{ p: 2 }}>{modelManagerTabPanelsToRender}</TabPanels>;
};
return (
<Tabs
isLazy
variant="invokeAI"
sx={{ w: 'full', h: 'full', p: 2, gap: 4, flexDirection: 'column' }}
variant="line"
layerStyle="first"
sx={{ w: 'full', h: 'full', p: 4, gap: 4, borderRadius: 'base' }}
>
{renderTabsList()}
{renderTabPanels()}
<TabList>
{tabs.map((tab) => (
<Tab sx={{ borderTopRadius: 'base' }} key={tab.id}>
{tab.label}
</Tab>
))}
</TabList>
<TabPanels sx={{ p: 4 }}>
{tabs.map((tab) => (
<TabPanel key={tab.id}>{tab.content}</TabPanel>
))}
</TabPanels>
</Tabs>
);
};

View File

@ -1,48 +1,59 @@
import { Flex } from '@chakra-ui/react';
import { RootState } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { Flex, Text } from '@chakra-ui/react';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
import { useState } from 'react';
import {
MainModelConfigEntity,
useGetMainModelsQuery,
} from 'services/api/endpoints/models';
import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit';
import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit';
import ModelList from './ModelManagerPanel/ModelList';
export default function ModelManagerPanel() {
const { data: mainModels } = useGetMainModelsQuery();
const [selectedModelId, setSelectedModelId] = useState<string>();
const { model } = useGetMainModelsQuery(undefined, {
selectFromResult: ({ data }) => ({
model: selectedModelId ? data?.entities[selectedModelId] : undefined,
}),
});
const openModel = useAppSelector(
(state: RootState) => state.system.openModel
);
const renderModelEditTabs = () => {
if (!openModel || !mainModels) return;
const openedModelData = mainModels['entities'][openModel];
if (openedModelData && openedModelData.model_format === 'diffusers') {
return (
<DiffusersModelEdit
modelToEdit={openModel}
retrievedModel={openedModelData}
key={openModel}
/>
);
}
if (openedModelData && openedModelData.model_format === 'checkpoint') {
return (
<CheckpointModelEdit
modelToEdit={openModel}
retrievedModel={openedModelData}
key={openModel}
/>
);
}
};
return (
<Flex width="100%" columnGap={8}>
<ModelList />
{renderModelEditTabs()}
<ModelList
selectedModelId={selectedModelId}
setSelectedModelId={setSelectedModelId}
/>
<ModelEdit model={model} />
</Flex>
);
}
type ModelEditProps = {
model: MainModelConfigEntity | undefined;
};
const ModelEdit = (props: ModelEditProps) => {
const { model } = props;
if (model?.model_format === 'checkpoint') {
return <CheckpointModelEdit model={model} />;
}
if (model?.model_format === 'diffusers') {
return <DiffusersModelEdit model={model} />;
}
return (
<Flex
sx={{
width: '100%',
justifyContent: 'center',
alignItems: 'center',
borderRadius: 'base',
bg: 'base.900',
}}
>
<Text fontWeight={500}>Pick A Model To Edit</Text>
</Flex>
);
};

View File

@ -1,21 +1,20 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { Divider, Flex, Text } from '@chakra-ui/react';
// import { addNewModel } from 'app/socketio/actions';
import { useForm } from '@mantine/form';
import { useTranslation } from 'react-i18next';
import IAIButton from 'common/components/IAIButton';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { makeToast } from 'app/components/Toaster';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import IAIMantineTextInput from 'common/components/IAIMantineInput';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { addToast } from 'features/system/store/systemSlice';
import { useUpdateMainModelsMutation } from 'services/api/endpoints/models';
import { components } from 'services/api/schema';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import {
CheckpointModelConfigEntity,
useUpdateMainModelsMutation,
} from 'services/api/endpoints/models';
import { CheckpointModelConfig } from 'services/api/types';
import ModelConvert from './ModelConvert';
const baseModelSelectData = [
@ -29,36 +28,31 @@ const variantSelectData = [
{ value: 'depth', label: 'Depth' },
];
export type CheckpointModelConfig =
| components['schemas']['StableDiffusion1ModelCheckpointConfig']
| components['schemas']['StableDiffusion2ModelCheckpointConfig'];
type CheckpointModelEditProps = {
modelToEdit: string;
retrievedModel: CheckpointModelConfig;
model: CheckpointModelConfigEntity;
};
export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
const isBusy = useAppSelector(selectIsBusy);
const { modelToEdit, retrievedModel } = props;
const { model } = props;
const [updateMainModel, { error, isLoading }] = useUpdateMainModelsMutation();
const [updateMainModel, { isLoading }] = useUpdateMainModelsMutation();
const dispatch = useAppDispatch();
const { t } = useTranslation();
const checkpointEditForm = useForm<CheckpointModelConfig>({
initialValues: {
model_name: retrievedModel.model_name ? retrievedModel.model_name : '',
base_model: retrievedModel.base_model,
model_name: model.model_name ? model.model_name : '',
base_model: model.base_model,
model_type: 'main',
path: retrievedModel.path ? retrievedModel.path : '',
description: retrievedModel.description ? retrievedModel.description : '',
path: model.path ? model.path : '',
description: model.description ? model.description : '',
model_format: 'checkpoint',
vae: retrievedModel.vae ? retrievedModel.vae : '',
config: retrievedModel.config ? retrievedModel.config : '',
variant: retrievedModel.variant,
vae: model.vae ? model.vae : '',
config: model.config ? model.config : '',
variant: model.variant,
},
validate: {
path: (value) =>
@ -66,50 +60,60 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
},
});
const editModelFormSubmitHandler = (values: CheckpointModelConfig) => {
const responseBody = {
base_model: retrievedModel.base_model,
model_name: retrievedModel.model_name,
body: values,
};
updateMainModel(responseBody)
.unwrap()
.then((payload) => {
checkpointEditForm.setValues(payload as CheckpointModelConfig);
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdated'),
status: 'success',
})
)
);
})
.catch((error) => {
checkpointEditForm.reset();
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdateFailed'),
status: 'error',
})
)
);
});
};
const editModelFormSubmitHandler = useCallback(
(values: CheckpointModelConfig) => {
const responseBody = {
base_model: model.base_model,
model_name: model.model_name,
body: values,
};
updateMainModel(responseBody)
.unwrap()
.then((payload) => {
checkpointEditForm.setValues(payload as CheckpointModelConfig);
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdated'),
status: 'success',
})
)
);
})
.catch((error) => {
checkpointEditForm.reset();
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdateFailed'),
status: 'error',
})
)
);
});
},
[
checkpointEditForm,
dispatch,
model.base_model,
model.model_name,
t,
updateMainModel,
]
);
return modelToEdit ? (
return (
<Flex flexDirection="column" rowGap={4} width="100%">
<Flex justifyContent="space-between" alignItems="center">
<Flex flexDirection="column">
<Text fontSize="lg" fontWeight="bold">
{retrievedModel.model_name}
{model.model_name}
</Text>
<Text fontSize="sm" color="base.400">
{MODEL_TYPE_MAP[retrievedModel.base_model]} Model
{MODEL_TYPE_MAP[model.base_model]} Model
</Text>
</Flex>
<ModelConvert model={retrievedModel} />
<ModelConvert model={model} />
</Flex>
<Divider />
@ -161,17 +165,5 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
</form>
</Flex>
</Flex>
) : (
<Flex
sx={{
width: '100%',
justifyContent: 'center',
alignItems: 'center',
borderRadius: 'base',
bg: 'base.900',
}}
>
<Text fontWeight={500}>Pick A Model To Edit</Text>
</Flex>
);
}

View File

@ -1,29 +1,23 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { Divider, Flex, Text } from '@chakra-ui/react';
// import { addNewModel } from 'app/socketio/actions';
import { useTranslation } from 'react-i18next';
import { useForm } from '@mantine/form';
import { makeToast } from 'app/components/Toaster';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import IAIMantineTextInput from 'common/components/IAIMantineInput';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { addToast } from 'features/system/store/systemSlice';
import { useUpdateMainModelsMutation } from 'services/api/endpoints/models';
import { components } from 'services/api/schema';
export type DiffusersModelConfig =
| components['schemas']['StableDiffusion1ModelDiffusersConfig']
| components['schemas']['StableDiffusion2ModelDiffusersConfig'];
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import {
DiffusersModelConfigEntity,
useUpdateMainModelsMutation,
} from 'services/api/endpoints/models';
import { DiffusersModelConfig } from 'services/api/types';
type DiffusersModelEditProps = {
modelToEdit: string;
retrievedModel: DiffusersModelConfig;
model: DiffusersModelConfigEntity;
};
const baseModelSelectData = [
@ -40,23 +34,23 @@ const variantSelectData = [
export default function DiffusersModelEdit(props: DiffusersModelEditProps) {
const isBusy = useAppSelector(selectIsBusy);
const { retrievedModel, modelToEdit } = props;
const { model } = props;
const [updateMainModel, { isLoading, error }] = useUpdateMainModelsMutation();
const [updateMainModel, { isLoading }] = useUpdateMainModelsMutation();
const dispatch = useAppDispatch();
const { t } = useTranslation();
const diffusersEditForm = useForm<DiffusersModelConfig>({
initialValues: {
model_name: retrievedModel.model_name ? retrievedModel.model_name : '',
base_model: retrievedModel.base_model,
model_name: model.model_name ? model.model_name : '',
base_model: model.base_model,
model_type: 'main',
path: retrievedModel.path ? retrievedModel.path : '',
description: retrievedModel.description ? retrievedModel.description : '',
path: model.path ? model.path : '',
description: model.description ? model.description : '',
model_format: 'diffusers',
vae: retrievedModel.vae ? retrievedModel.vae : '',
variant: retrievedModel.variant,
vae: model.vae ? model.vae : '',
variant: model.variant,
},
validate: {
path: (value) =>
@ -64,46 +58,56 @@ export default function DiffusersModelEdit(props: DiffusersModelEditProps) {
},
});
const editModelFormSubmitHandler = (values: DiffusersModelConfig) => {
const responseBody = {
base_model: retrievedModel.base_model,
model_name: retrievedModel.model_name,
body: values,
};
updateMainModel(responseBody)
.unwrap()
.then((payload) => {
diffusersEditForm.setValues(payload as DiffusersModelConfig);
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdated'),
status: 'success',
})
)
);
})
.catch((error) => {
diffusersEditForm.reset();
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdateFailed'),
status: 'error',
})
)
);
});
};
const editModelFormSubmitHandler = useCallback(
(values: DiffusersModelConfig) => {
const responseBody = {
base_model: model.base_model,
model_name: model.model_name,
body: values,
};
updateMainModel(responseBody)
.unwrap()
.then((payload) => {
diffusersEditForm.setValues(payload as DiffusersModelConfig);
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdated'),
status: 'success',
})
)
);
})
.catch((error) => {
diffusersEditForm.reset();
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdateFailed'),
status: 'error',
})
)
);
});
},
[
diffusersEditForm,
dispatch,
model.base_model,
model.model_name,
t,
updateMainModel,
]
);
return modelToEdit ? (
return (
<Flex flexDirection="column" rowGap={4} width="100%">
<Flex flexDirection="column">
<Text fontSize="lg" fontWeight="bold">
{retrievedModel.model_name}
{model.model_name}
</Text>
<Text fontSize="sm" color="base.400">
{MODEL_TYPE_MAP[retrievedModel.base_model]} Model
{MODEL_TYPE_MAP[model.base_model]} Model
</Text>
</Flex>
<Divider />
@ -146,17 +150,5 @@ export default function DiffusersModelEdit(props: DiffusersModelEditProps) {
</Flex>
</form>
</Flex>
) : (
<Flex
sx={{
width: '100%',
justifyContent: 'center',
alignItems: 'center',
borderRadius: 'base',
bg: 'base.900',
}}
>
<Text fontWeight={'500'}>Pick A Model To Edit</Text>
</Flex>
);
}

View File

@ -1,187 +1,46 @@
import { Box, Flex, Spinner, Text, useColorMode } from '@chakra-ui/react';
import { ButtonGroup, Flex, Text } from '@chakra-ui/react';
import { EntityState } from '@reduxjs/toolkit';
import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput';
import { forEach } from 'lodash-es';
import type { ChangeEvent } from 'react';
import { useCallback, useState } from 'react';
import { useTranslation } from 'react-i18next';
import {
MainModelConfigEntity,
useGetMainModelsQuery,
} from 'services/api/endpoints/models';
import ModelListItem from './ModelListItem';
import { useTranslation } from 'react-i18next';
type ModelListProps = {
selectedModelId: string | undefined;
setSelectedModelId: (name: string | undefined) => void;
};
import type { ChangeEvent, ReactNode } from 'react';
import React, { useMemo, useState, useTransition } from 'react';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
import { mode } from 'theme/util/mode';
function ModelFilterButton({
label,
isActive,
onClick,
}: {
label: string;
isActive: boolean;
onClick: () => void;
}) {
return (
<IAIButton onClick={onClick} isChecked={isActive} size="sm">
{label}
</IAIButton>
);
}
const ModelList = () => {
const { data: mainModels } = useGetMainModelsQuery();
const { colorMode } = useColorMode();
const [renderModelList, setRenderModelList] = React.useState<boolean>(false);
React.useEffect(() => {
const timer = setTimeout(() => {
setRenderModelList(true);
}, 200);
return () => clearTimeout(timer);
}, []);
const [searchText, setSearchText] = useState<string>('');
const [isSelectedFilter, setIsSelectedFilter] = useState<
'all' | 'checkpoint' | 'diffusers'
>('all');
const [_, startTransition] = useTransition();
type ModelFormat = 'all' | 'checkpoint' | 'diffusers';
const ModelList = (props: ModelListProps) => {
const { selectedModelId, setSelectedModelId } = props;
const { t } = useTranslation();
const [nameFilter, setNameFilter] = useState<string>('');
const [modelFormatFilter, setModelFormatFilter] =
useState<ModelFormat>('all');
const handleSearchFilter = (e: ChangeEvent<HTMLInputElement>) => {
startTransition(() => {
setSearchText(e.target.value);
});
};
const { filteredDiffusersModels } = useGetMainModelsQuery(undefined, {
selectFromResult: ({ data }) => ({
filteredDiffusersModels: modelsFilter(data, 'diffusers', nameFilter),
}),
});
const renderModelListItems = useMemo(() => {
const ckptModelListItemsToRender: ReactNode[] = [];
const diffusersModelListItemsToRender: ReactNode[] = [];
const filteredModelListItemsToRender: ReactNode[] = [];
const localFilteredModelListItemsToRender: ReactNode[] = [];
const { filteredCheckpointModels } = useGetMainModelsQuery(undefined, {
selectFromResult: ({ data }) => ({
filteredCheckpointModels: modelsFilter(data, 'checkpoint', nameFilter),
}),
});
if (!mainModels) return;
const modelList = mainModels.entities;
Object.keys(modelList).forEach((model, i) => {
const modelInfo = modelList[model];
// If no model info found for a model, ignore it
if (!modelInfo) return;
if (
modelInfo.model_name.toLowerCase().includes(searchText.toLowerCase())
) {
filteredModelListItemsToRender.push(
<ModelListItem
key={i}
modelKey={model}
name={modelInfo.model_name}
description={modelInfo.description}
/>
);
if (modelInfo?.model_format === isSelectedFilter) {
localFilteredModelListItemsToRender.push(
<ModelListItem
key={i}
modelKey={model}
name={modelInfo.model_name}
description={modelInfo.description}
/>
);
}
}
if (modelInfo?.model_format !== 'diffusers') {
ckptModelListItemsToRender.push(
<ModelListItem
key={i}
modelKey={model}
name={modelInfo.model_name}
description={modelInfo.description}
/>
);
} else {
diffusersModelListItemsToRender.push(
<ModelListItem
key={i}
modelKey={model}
name={modelInfo.model_name}
description={modelInfo.description}
/>
);
}
});
return searchText !== '' ? (
isSelectedFilter === 'all' ? (
<Box marginTop={4}>{filteredModelListItemsToRender}</Box>
) : (
<Box marginTop={4}>{localFilteredModelListItemsToRender}</Box>
)
) : (
<Flex flexDirection="column" rowGap={6}>
{isSelectedFilter === 'all' && (
<>
{diffusersModelListItemsToRender.length > 0 && (
<Box>
<Text
sx={{
fontWeight: '500',
py: 2,
px: 4,
mb: 4,
borderRadius: 'base',
width: 'max-content',
fontSize: 'sm',
bg: mode('base.100', 'base.800')(colorMode),
}}
>
{t('modelManager.diffusersModels')}
</Text>
{diffusersModelListItemsToRender}
</Box>
)}
{ckptModelListItemsToRender.length > 0 && (
<Box>
<Text
sx={{
fontWeight: '500',
py: 2,
px: 4,
my: 4,
mx: 0,
borderRadius: 'base',
width: 'max-content',
fontSize: 'sm',
bg: mode('base.150', 'base.750')(colorMode),
}}
>
{t('modelManager.checkpointModels')}
</Text>
{ckptModelListItemsToRender}
</Box>
)}
</>
)}
{isSelectedFilter === 'diffusers' && (
<Flex flexDirection="column" marginTop={4}>
{diffusersModelListItemsToRender}
</Flex>
)}
{isSelectedFilter === 'checkpoint' && (
<Flex flexDirection="column" marginTop={4}>
{ckptModelListItemsToRender}
</Flex>
)}
</Flex>
);
}, [mainModels, searchText, t, isSelectedFilter, colorMode]);
const handleSearchFilter = useCallback((e: ChangeEvent<HTMLInputElement>) => {
setNameFilter(e.target.value);
}, []);
return (
<Flex flexDirection="column" rowGap={4} width="50%" minWidth="50%">
@ -189,7 +48,6 @@ const ModelList = () => {
onChange={handleSearchFilter}
label={t('modelManager.search')}
/>
<Flex
flexDirection="column"
gap={4}
@ -197,34 +55,58 @@ const ModelList = () => {
overflow="scroll"
paddingInlineEnd={4}
>
<Flex columnGap={2}>
<ModelFilterButton
label={t('modelManager.allModels')}
onClick={() => setIsSelectedFilter('all')}
isActive={isSelectedFilter === 'all'}
/>
<ModelFilterButton
label={t('modelManager.diffusersModels')}
onClick={() => setIsSelectedFilter('diffusers')}
isActive={isSelectedFilter === 'diffusers'}
/>
<ModelFilterButton
label={t('modelManager.checkpointModels')}
onClick={() => setIsSelectedFilter('checkpoint')}
isActive={isSelectedFilter === 'checkpoint'}
/>
</Flex>
{renderModelList ? (
renderModelListItems
) : (
<Flex
width="100%"
minHeight={96}
justifyContent="center"
alignItems="center"
<ButtonGroup isAttached>
<IAIButton
onClick={() => setModelFormatFilter('all')}
isChecked={modelFormatFilter === 'all'}
size="sm"
>
<Spinner />
{t('modelManager.allModels')}
</IAIButton>
<IAIButton
size="sm"
onClick={() => setModelFormatFilter('diffusers')}
isChecked={modelFormatFilter === 'diffusers'}
>
{t('modelManager.diffusersModels')}
</IAIButton>
<IAIButton
size="sm"
onClick={() => setModelFormatFilter('checkpoint')}
isChecked={modelFormatFilter === 'checkpoint'}
>
{t('modelManager.checkpointModels')}
</IAIButton>
</ButtonGroup>
{['all', 'diffusers'].includes(modelFormatFilter) && (
<Flex sx={{ gap: 2, flexDir: 'column' }}>
<Text variant="subtext" size="sm">
Diffusers
</Text>
{filteredDiffusersModels.map((model) => (
<ModelListItem
key={model.id}
model={model}
isSelected={selectedModelId === model.id}
setSelectedModelId={setSelectedModelId}
/>
))}
</Flex>
)}
{['all', 'checkpoint'].includes(modelFormatFilter) && (
<Flex sx={{ gap: 2, flexDir: 'column' }}>
<Text variant="subtext" size="sm">
Checkpoint
</Text>
{filteredCheckpointModels.map((model) => (
<ModelListItem
key={model.id}
model={model}
isSelected={selectedModelId === model.id}
setSelectedModelId={setSelectedModelId}
/>
))}
</Flex>
)}
</Flex>
@ -233,3 +115,27 @@ const ModelList = () => {
};
export default ModelList;
const modelsFilter = (
data: EntityState<MainModelConfigEntity> | undefined,
model_format: ModelFormat,
nameFilter: string
) => {
const filteredModels: MainModelConfigEntity[] = [];
forEach(data?.entities, (model) => {
if (!model) {
return;
}
const matchesFilter = model.model_name
.toLowerCase()
.includes(nameFilter.toLowerCase());
const matchesFormat = model.model_format === model_format;
if (matchesFilter && matchesFormat) {
filteredModels.push(model);
}
});
return filteredModels;
};

View File

@ -1,115 +1,96 @@
import { DeleteIcon, EditIcon } from '@chakra-ui/icons';
import {
Box,
Flex,
Spacer,
Text,
Tooltip,
useColorMode,
} from '@chakra-ui/react';
// import { deleteModel, requestModelChange } from 'app/socketio/actions';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { DeleteIcon } from '@chakra-ui/icons';
import { Box, Flex, Spacer, Text, Tooltip } from '@chakra-ui/react';
import { useAppSelector } from 'app/store/storeHooks';
import IAIAlertDialog from 'common/components/IAIAlertDialog';
import IAIButton from 'common/components/IAIButton';
import IAIIconButton from 'common/components/IAIIconButton';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { setOpenModel } from 'features/system/store/systemSlice';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useDeleteMainModelsMutation } from 'services/api/endpoints/models';
import { BaseModelType } from 'services/api/types';
import { mode } from 'theme/util/mode';
import { FaEdit } from 'react-icons/fa';
import {
MainModelConfigEntity,
useDeleteMainModelsMutation,
} from 'services/api/endpoints/models';
type ModelListItemProps = {
modelKey: string;
name: string;
description: string | undefined;
model: MainModelConfigEntity;
isSelected: boolean;
setSelectedModelId: (v: string | undefined) => void;
};
export default function ModelListItem(props: ModelListItemProps) {
const isBusy = useAppSelector(selectIsBusy);
const { colorMode } = useColorMode();
const openModel = useAppSelector(
(state: RootState) => state.system.openModel
);
const { t } = useTranslation();
const [deleteMainModel] = useDeleteMainModelsMutation();
const { t } = useTranslation();
const { model, isSelected, setSelectedModelId } = props;
const dispatch = useAppDispatch();
const handleSelectModel = useCallback(() => {
setSelectedModelId(model.id);
}, [model.id, setSelectedModelId]);
const { modelKey, name, description } = props;
const openModelHandler = () => {
dispatch(setOpenModel(modelKey));
};
const handleModelDelete = () => {
const [base_model, _, model_name] = modelKey.split('/');
deleteMainModel({
base_model: base_model as BaseModelType,
model_name: model_name,
});
dispatch(setOpenModel(null));
};
const handleModelDelete = useCallback(() => {
deleteMainModel(model);
setSelectedModelId(undefined);
}, [deleteMainModel, model, setSelectedModelId]);
return (
<Flex
alignItems="center"
p={2}
borderRadius="base"
sx={
modelKey === openModel
? {
bg: mode('accent.200', 'accent.600')(colorMode),
_hover: {
bg: mode('accent.200', 'accent.600')(colorMode),
},
}
: {
_hover: {
bg: mode('base.100', 'base.800')(colorMode),
},
}
}
>
<Box onClick={openModelHandler} cursor="pointer">
<Tooltip label={description} hasArrow placement="bottom">
<Text fontWeight="600">{name}</Text>
</Tooltip>
</Box>
<Spacer onClick={openModelHandler} cursor="pointer" />
<Flex gap={2} alignItems="center">
<Flex sx={{ gap: 2, alignItems: 'center', w: 'full' }}>
<Flex
as={IAIButton}
isChecked={isSelected}
sx={{
p: 2,
borderRadius: 'base',
w: 'full',
alignItems: 'center',
bg: isSelected ? 'accent.200' : 'base.100',
_hover: {
bg: isSelected ? 'accent.250' : 'base.150',
},
_dark: {
bg: isSelected ? 'accent.600' : 'base.850',
_hover: {
bg: isSelected ? 'accent.550' : 'base.800',
},
},
}}
onClick={handleSelectModel}
>
<Box cursor="pointer">
<Tooltip label={model.description} hasArrow placement="bottom">
<Text fontWeight="600">{model.model_name}</Text>
</Tooltip>
</Box>
<Spacer onClick={handleSelectModel} cursor="pointer" />
<IAIIconButton
icon={<EditIcon />}
icon={<FaEdit />}
size="sm"
onClick={openModelHandler}
onClick={handleSelectModel}
aria-label={t('accessibility.modifyConfig')}
isDisabled={isBusy}
variant="link"
/>
<IAIAlertDialog
title={t('modelManager.deleteModel')}
acceptCallback={handleModelDelete}
acceptButtonText={t('modelManager.delete')}
triggerComponent={
<IAIIconButton
icon={<DeleteIcon />}
size="sm"
aria-label={t('modelManager.deleteConfig')}
isDisabled={isBusy}
colorScheme="error"
/>
}
>
<Flex rowGap={4} flexDirection="column">
<p style={{ fontWeight: 'bold' }}>{t('modelManager.deleteMsg1')}</p>
<p>{t('modelManager.deleteMsg2')}</p>
</Flex>
</IAIAlertDialog>
</Flex>
<IAIAlertDialog
title={t('modelManager.deleteModel')}
acceptCallback={handleModelDelete}
acceptButtonText={t('modelManager.delete')}
triggerComponent={
<IAIIconButton
icon={<DeleteIcon />}
aria-label={t('modelManager.deleteConfig')}
isDisabled={isBusy}
colorScheme="error"
/>
}
>
<Flex rowGap={4} flexDirection="column">
<p style={{ fontWeight: 'bold' }}>{t('modelManager.deleteMsg1')}</p>
<p>{t('modelManager.deleteMsg2')}</p>
</Flex>
</IAIAlertDialog>
</Flex>
);
}

View File

@ -3,7 +3,9 @@ import { cloneDeep } from 'lodash-es';
import {
AnyModelConfig,
BaseModelType,
CheckpointModelConfig,
ControlNetModelConfig,
DiffusersModelConfig,
LoRAModelConfig,
MainModelConfig,
MergeModelConfig,
@ -14,7 +16,13 @@ import {
import { ApiFullTagDescription, LIST_TAG, api } from '..';
import { paths } from '../schema';
export type MainModelConfigEntity = MainModelConfig & { id: string };
export type DiffusersModelConfigEntity = DiffusersModelConfig & { id: string };
export type CheckpointModelConfigEntity = CheckpointModelConfig & {
id: string;
};
export type MainModelConfigEntity =
| DiffusersModelConfigEntity
| CheckpointModelConfigEntity;
export type LoRAModelConfigEntity = LoRAModelConfig & { id: string };

View File

@ -42,11 +42,13 @@ export type ControlNetModelConfig =
components['schemas']['ControlNetModelConfig'];
export type TextualInversionModelConfig =
components['schemas']['TextualInversionModelConfig'];
export type MainModelConfig =
| components['schemas']['StableDiffusion1ModelCheckpointConfig']
export type DiffusersModelConfig =
| components['schemas']['StableDiffusion1ModelDiffusersConfig']
| components['schemas']['StableDiffusion2ModelCheckpointConfig']
| components['schemas']['StableDiffusion2ModelDiffusersConfig'];
export type CheckpointModelConfig =
| components['schemas']['StableDiffusion1ModelCheckpointConfig']
| components['schemas']['StableDiffusion2ModelCheckpointConfig'];
export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig;
export type AnyModelConfig =
| LoRAModelConfig
| VaeModelConfig

View File

@ -2,7 +2,7 @@ import { defineStyle, defineStyleConfig } from '@chakra-ui/react';
import { mode } from '@chakra-ui/theme-tools';
const subtext = defineStyle((props) => ({
color: mode('colors.base.500', 'colors.base.400')(props),
color: mode('base.500', 'base.400')(props),
}));
export const textTheme = defineStyleConfig({