mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): refactor model manager ui
- simplify UI logic in `ModelManagerPanel` components - fix up the types a bit to make it easier to select models - remove `openModel` state, just make it a useState since it is very local to model manager
This commit is contained in:
parent
f2af82bf73
commit
1e5ae9d986
@ -13,7 +13,6 @@ export const systemPersistDenylist: (keyof SystemState)[] = [
|
||||
'isProcessing',
|
||||
'totalIterations',
|
||||
'totalSteps',
|
||||
'openModel',
|
||||
'isCancelScheduled',
|
||||
'progressImage',
|
||||
'wereModelsReceived',
|
||||
|
@ -46,7 +46,6 @@ export interface SystemState {
|
||||
toastQueue: UseToastOptions[];
|
||||
searchFolder: string | null;
|
||||
foundModels: InvokeAI.FoundModel[] | null;
|
||||
openModel: string | null;
|
||||
/**
|
||||
* The current progress image
|
||||
*/
|
||||
@ -109,7 +108,6 @@ export const initialSystemState: SystemState = {
|
||||
toastQueue: [],
|
||||
searchFolder: null,
|
||||
foundModels: null,
|
||||
openModel: null,
|
||||
progressImage: null,
|
||||
shouldAntialiasProgressImage: false,
|
||||
sessionId: null,
|
||||
@ -164,9 +162,6 @@ export const systemSlice = createSlice({
|
||||
) => {
|
||||
state.foundModels = action.payload;
|
||||
},
|
||||
setOpenModel: (state, action: PayloadAction<string | null>) => {
|
||||
state.openModel = action.payload;
|
||||
},
|
||||
/**
|
||||
* A cancel was scheduled
|
||||
*/
|
||||
@ -433,7 +428,6 @@ export const {
|
||||
clearToastQueue,
|
||||
setSearchFolder,
|
||||
setFoundModels,
|
||||
setOpenModel,
|
||||
cancelScheduled,
|
||||
scheduledCancelAborted,
|
||||
cancelTypeChanged,
|
||||
|
@ -1,14 +1,6 @@
|
||||
import {
|
||||
Tab,
|
||||
TabList,
|
||||
TabPanel,
|
||||
TabPanels,
|
||||
Tabs,
|
||||
useColorMode,
|
||||
} from '@chakra-ui/react';
|
||||
import { Tab, TabList, TabPanel, TabPanels, Tabs } from '@chakra-ui/react';
|
||||
import i18n from 'i18n';
|
||||
import { ReactNode, memo } from 'react';
|
||||
import { mode } from 'theme/util/mode';
|
||||
import AddModelsPanel from './subpanels/AddModelsPanel';
|
||||
import MergeModelsPanel from './subpanels/MergeModelsPanel';
|
||||
import ModelManagerPanel from './subpanels/ModelManagerPanel';
|
||||
@ -21,7 +13,7 @@ type ModelManagerTabInfo = {
|
||||
content: ReactNode;
|
||||
};
|
||||
|
||||
const modelManagerTabs: ModelManagerTabInfo[] = [
|
||||
const tabs: ModelManagerTabInfo[] = [
|
||||
{
|
||||
id: 'modelManager',
|
||||
label: i18n.t('modelManager.modelManager'),
|
||||
@ -40,50 +32,25 @@ const modelManagerTabs: ModelManagerTabInfo[] = [
|
||||
];
|
||||
|
||||
const ModelManagerTab = () => {
|
||||
const { colorMode } = useColorMode();
|
||||
|
||||
const renderTabsList = () => {
|
||||
const modelManagerTabListsToRender: ReactNode[] = [];
|
||||
|
||||
modelManagerTabs.forEach((modelManagerTab) => {
|
||||
modelManagerTabListsToRender.push(
|
||||
<Tab key={modelManagerTab.id}>{modelManagerTab.label}</Tab>
|
||||
);
|
||||
});
|
||||
|
||||
return (
|
||||
<TabList
|
||||
sx={{
|
||||
w: '100%',
|
||||
color: mode('base.900', 'base.400')(colorMode),
|
||||
flexDirection: 'row',
|
||||
borderBottomWidth: 2,
|
||||
borderColor: mode('accent.300', 'accent.600')(colorMode),
|
||||
}}
|
||||
>
|
||||
{modelManagerTabListsToRender}
|
||||
</TabList>
|
||||
);
|
||||
};
|
||||
|
||||
const renderTabPanels = () => {
|
||||
const modelManagerTabPanelsToRender: ReactNode[] = [];
|
||||
modelManagerTabs.forEach((modelManagerTab) => {
|
||||
modelManagerTabPanelsToRender.push(
|
||||
<TabPanel key={modelManagerTab.id}>{modelManagerTab.content}</TabPanel>
|
||||
);
|
||||
});
|
||||
|
||||
return <TabPanels sx={{ p: 2 }}>{modelManagerTabPanelsToRender}</TabPanels>;
|
||||
};
|
||||
return (
|
||||
<Tabs
|
||||
isLazy
|
||||
variant="invokeAI"
|
||||
sx={{ w: 'full', h: 'full', p: 2, gap: 4, flexDirection: 'column' }}
|
||||
variant="line"
|
||||
layerStyle="first"
|
||||
sx={{ w: 'full', h: 'full', p: 4, gap: 4, borderRadius: 'base' }}
|
||||
>
|
||||
{renderTabsList()}
|
||||
{renderTabPanels()}
|
||||
<TabList>
|
||||
{tabs.map((tab) => (
|
||||
<Tab sx={{ borderTopRadius: 'base' }} key={tab.id}>
|
||||
{tab.label}
|
||||
</Tab>
|
||||
))}
|
||||
</TabList>
|
||||
<TabPanels sx={{ p: 4 }}>
|
||||
{tabs.map((tab) => (
|
||||
<TabPanel key={tab.id}>{tab.content}</TabPanel>
|
||||
))}
|
||||
</TabPanels>
|
||||
</Tabs>
|
||||
);
|
||||
};
|
||||
|
@ -1,48 +1,59 @@
|
||||
import { Flex } from '@chakra-ui/react';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { Flex, Text } from '@chakra-ui/react';
|
||||
|
||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||
import { useState } from 'react';
|
||||
import {
|
||||
MainModelConfigEntity,
|
||||
useGetMainModelsQuery,
|
||||
} from 'services/api/endpoints/models';
|
||||
import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit';
|
||||
import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit';
|
||||
import ModelList from './ModelManagerPanel/ModelList';
|
||||
|
||||
export default function ModelManagerPanel() {
|
||||
const { data: mainModels } = useGetMainModelsQuery();
|
||||
const [selectedModelId, setSelectedModelId] = useState<string>();
|
||||
const { model } = useGetMainModelsQuery(undefined, {
|
||||
selectFromResult: ({ data }) => ({
|
||||
model: selectedModelId ? data?.entities[selectedModelId] : undefined,
|
||||
}),
|
||||
});
|
||||
|
||||
const openModel = useAppSelector(
|
||||
(state: RootState) => state.system.openModel
|
||||
);
|
||||
|
||||
const renderModelEditTabs = () => {
|
||||
if (!openModel || !mainModels) return;
|
||||
|
||||
const openedModelData = mainModels['entities'][openModel];
|
||||
|
||||
if (openedModelData && openedModelData.model_format === 'diffusers') {
|
||||
return (
|
||||
<DiffusersModelEdit
|
||||
modelToEdit={openModel}
|
||||
retrievedModel={openedModelData}
|
||||
key={openModel}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (openedModelData && openedModelData.model_format === 'checkpoint') {
|
||||
return (
|
||||
<CheckpointModelEdit
|
||||
modelToEdit={openModel}
|
||||
retrievedModel={openedModelData}
|
||||
key={openModel}
|
||||
/>
|
||||
);
|
||||
}
|
||||
};
|
||||
return (
|
||||
<Flex width="100%" columnGap={8}>
|
||||
<ModelList />
|
||||
{renderModelEditTabs()}
|
||||
<ModelList
|
||||
selectedModelId={selectedModelId}
|
||||
setSelectedModelId={setSelectedModelId}
|
||||
/>
|
||||
<ModelEdit model={model} />
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
||||
type ModelEditProps = {
|
||||
model: MainModelConfigEntity | undefined;
|
||||
};
|
||||
|
||||
const ModelEdit = (props: ModelEditProps) => {
|
||||
const { model } = props;
|
||||
|
||||
if (model?.model_format === 'checkpoint') {
|
||||
return <CheckpointModelEdit model={model} />;
|
||||
}
|
||||
|
||||
if (model?.model_format === 'diffusers') {
|
||||
return <DiffusersModelEdit model={model} />;
|
||||
}
|
||||
|
||||
return (
|
||||
<Flex
|
||||
sx={{
|
||||
width: '100%',
|
||||
justifyContent: 'center',
|
||||
alignItems: 'center',
|
||||
borderRadius: 'base',
|
||||
bg: 'base.900',
|
||||
}}
|
||||
>
|
||||
<Text fontWeight={500}>Pick A Model To Edit</Text>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
@ -1,21 +1,20 @@
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
|
||||
import { Divider, Flex, Text } from '@chakra-ui/react';
|
||||
|
||||
// import { addNewModel } from 'app/socketio/actions';
|
||||
import { useForm } from '@mantine/form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
|
||||
import { makeToast } from 'app/components/Toaster';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import IAIMantineTextInput from 'common/components/IAIMantineInput';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { useUpdateMainModelsMutation } from 'services/api/endpoints/models';
|
||||
import { components } from 'services/api/schema';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import {
|
||||
CheckpointModelConfigEntity,
|
||||
useUpdateMainModelsMutation,
|
||||
} from 'services/api/endpoints/models';
|
||||
import { CheckpointModelConfig } from 'services/api/types';
|
||||
import ModelConvert from './ModelConvert';
|
||||
|
||||
const baseModelSelectData = [
|
||||
@ -29,36 +28,31 @@ const variantSelectData = [
|
||||
{ value: 'depth', label: 'Depth' },
|
||||
];
|
||||
|
||||
export type CheckpointModelConfig =
|
||||
| components['schemas']['StableDiffusion1ModelCheckpointConfig']
|
||||
| components['schemas']['StableDiffusion2ModelCheckpointConfig'];
|
||||
|
||||
type CheckpointModelEditProps = {
|
||||
modelToEdit: string;
|
||||
retrievedModel: CheckpointModelConfig;
|
||||
model: CheckpointModelConfigEntity;
|
||||
};
|
||||
|
||||
export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
|
||||
const isBusy = useAppSelector(selectIsBusy);
|
||||
|
||||
const { modelToEdit, retrievedModel } = props;
|
||||
const { model } = props;
|
||||
|
||||
const [updateMainModel, { error, isLoading }] = useUpdateMainModelsMutation();
|
||||
const [updateMainModel, { isLoading }] = useUpdateMainModelsMutation();
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const checkpointEditForm = useForm<CheckpointModelConfig>({
|
||||
initialValues: {
|
||||
model_name: retrievedModel.model_name ? retrievedModel.model_name : '',
|
||||
base_model: retrievedModel.base_model,
|
||||
model_name: model.model_name ? model.model_name : '',
|
||||
base_model: model.base_model,
|
||||
model_type: 'main',
|
||||
path: retrievedModel.path ? retrievedModel.path : '',
|
||||
description: retrievedModel.description ? retrievedModel.description : '',
|
||||
path: model.path ? model.path : '',
|
||||
description: model.description ? model.description : '',
|
||||
model_format: 'checkpoint',
|
||||
vae: retrievedModel.vae ? retrievedModel.vae : '',
|
||||
config: retrievedModel.config ? retrievedModel.config : '',
|
||||
variant: retrievedModel.variant,
|
||||
vae: model.vae ? model.vae : '',
|
||||
config: model.config ? model.config : '',
|
||||
variant: model.variant,
|
||||
},
|
||||
validate: {
|
||||
path: (value) =>
|
||||
@ -66,50 +60,60 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
|
||||
},
|
||||
});
|
||||
|
||||
const editModelFormSubmitHandler = (values: CheckpointModelConfig) => {
|
||||
const responseBody = {
|
||||
base_model: retrievedModel.base_model,
|
||||
model_name: retrievedModel.model_name,
|
||||
body: values,
|
||||
};
|
||||
updateMainModel(responseBody)
|
||||
.unwrap()
|
||||
.then((payload) => {
|
||||
checkpointEditForm.setValues(payload as CheckpointModelConfig);
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('modelManager.modelUpdated'),
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
})
|
||||
.catch((error) => {
|
||||
checkpointEditForm.reset();
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('modelManager.modelUpdateFailed'),
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
});
|
||||
};
|
||||
const editModelFormSubmitHandler = useCallback(
|
||||
(values: CheckpointModelConfig) => {
|
||||
const responseBody = {
|
||||
base_model: model.base_model,
|
||||
model_name: model.model_name,
|
||||
body: values,
|
||||
};
|
||||
updateMainModel(responseBody)
|
||||
.unwrap()
|
||||
.then((payload) => {
|
||||
checkpointEditForm.setValues(payload as CheckpointModelConfig);
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('modelManager.modelUpdated'),
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
})
|
||||
.catch((error) => {
|
||||
checkpointEditForm.reset();
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('modelManager.modelUpdateFailed'),
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
});
|
||||
},
|
||||
[
|
||||
checkpointEditForm,
|
||||
dispatch,
|
||||
model.base_model,
|
||||
model.model_name,
|
||||
t,
|
||||
updateMainModel,
|
||||
]
|
||||
);
|
||||
|
||||
return modelToEdit ? (
|
||||
return (
|
||||
<Flex flexDirection="column" rowGap={4} width="100%">
|
||||
<Flex justifyContent="space-between" alignItems="center">
|
||||
<Flex flexDirection="column">
|
||||
<Text fontSize="lg" fontWeight="bold">
|
||||
{retrievedModel.model_name}
|
||||
{model.model_name}
|
||||
</Text>
|
||||
<Text fontSize="sm" color="base.400">
|
||||
{MODEL_TYPE_MAP[retrievedModel.base_model]} Model
|
||||
{MODEL_TYPE_MAP[model.base_model]} Model
|
||||
</Text>
|
||||
</Flex>
|
||||
<ModelConvert model={retrievedModel} />
|
||||
<ModelConvert model={model} />
|
||||
</Flex>
|
||||
<Divider />
|
||||
|
||||
@ -161,17 +165,5 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
|
||||
</form>
|
||||
</Flex>
|
||||
</Flex>
|
||||
) : (
|
||||
<Flex
|
||||
sx={{
|
||||
width: '100%',
|
||||
justifyContent: 'center',
|
||||
alignItems: 'center',
|
||||
borderRadius: 'base',
|
||||
bg: 'base.900',
|
||||
}}
|
||||
>
|
||||
<Text fontWeight={500}>Pick A Model To Edit</Text>
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
@ -1,29 +1,23 @@
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
|
||||
import { Divider, Flex, Text } from '@chakra-ui/react';
|
||||
|
||||
// import { addNewModel } from 'app/socketio/actions';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import { useForm } from '@mantine/form';
|
||||
import { makeToast } from 'app/components/Toaster';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import IAIMantineTextInput from 'common/components/IAIMantineInput';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
|
||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { useUpdateMainModelsMutation } from 'services/api/endpoints/models';
|
||||
import { components } from 'services/api/schema';
|
||||
|
||||
export type DiffusersModelConfig =
|
||||
| components['schemas']['StableDiffusion1ModelDiffusersConfig']
|
||||
| components['schemas']['StableDiffusion2ModelDiffusersConfig'];
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import {
|
||||
DiffusersModelConfigEntity,
|
||||
useUpdateMainModelsMutation,
|
||||
} from 'services/api/endpoints/models';
|
||||
import { DiffusersModelConfig } from 'services/api/types';
|
||||
|
||||
type DiffusersModelEditProps = {
|
||||
modelToEdit: string;
|
||||
retrievedModel: DiffusersModelConfig;
|
||||
model: DiffusersModelConfigEntity;
|
||||
};
|
||||
|
||||
const baseModelSelectData = [
|
||||
@ -40,23 +34,23 @@ const variantSelectData = [
|
||||
export default function DiffusersModelEdit(props: DiffusersModelEditProps) {
|
||||
const isBusy = useAppSelector(selectIsBusy);
|
||||
|
||||
const { retrievedModel, modelToEdit } = props;
|
||||
const { model } = props;
|
||||
|
||||
const [updateMainModel, { isLoading, error }] = useUpdateMainModelsMutation();
|
||||
const [updateMainModel, { isLoading }] = useUpdateMainModelsMutation();
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const diffusersEditForm = useForm<DiffusersModelConfig>({
|
||||
initialValues: {
|
||||
model_name: retrievedModel.model_name ? retrievedModel.model_name : '',
|
||||
base_model: retrievedModel.base_model,
|
||||
model_name: model.model_name ? model.model_name : '',
|
||||
base_model: model.base_model,
|
||||
model_type: 'main',
|
||||
path: retrievedModel.path ? retrievedModel.path : '',
|
||||
description: retrievedModel.description ? retrievedModel.description : '',
|
||||
path: model.path ? model.path : '',
|
||||
description: model.description ? model.description : '',
|
||||
model_format: 'diffusers',
|
||||
vae: retrievedModel.vae ? retrievedModel.vae : '',
|
||||
variant: retrievedModel.variant,
|
||||
vae: model.vae ? model.vae : '',
|
||||
variant: model.variant,
|
||||
},
|
||||
validate: {
|
||||
path: (value) =>
|
||||
@ -64,46 +58,56 @@ export default function DiffusersModelEdit(props: DiffusersModelEditProps) {
|
||||
},
|
||||
});
|
||||
|
||||
const editModelFormSubmitHandler = (values: DiffusersModelConfig) => {
|
||||
const responseBody = {
|
||||
base_model: retrievedModel.base_model,
|
||||
model_name: retrievedModel.model_name,
|
||||
body: values,
|
||||
};
|
||||
updateMainModel(responseBody)
|
||||
.unwrap()
|
||||
.then((payload) => {
|
||||
diffusersEditForm.setValues(payload as DiffusersModelConfig);
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('modelManager.modelUpdated'),
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
})
|
||||
.catch((error) => {
|
||||
diffusersEditForm.reset();
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('modelManager.modelUpdateFailed'),
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
});
|
||||
};
|
||||
const editModelFormSubmitHandler = useCallback(
|
||||
(values: DiffusersModelConfig) => {
|
||||
const responseBody = {
|
||||
base_model: model.base_model,
|
||||
model_name: model.model_name,
|
||||
body: values,
|
||||
};
|
||||
updateMainModel(responseBody)
|
||||
.unwrap()
|
||||
.then((payload) => {
|
||||
diffusersEditForm.setValues(payload as DiffusersModelConfig);
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('modelManager.modelUpdated'),
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
})
|
||||
.catch((error) => {
|
||||
diffusersEditForm.reset();
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('modelManager.modelUpdateFailed'),
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
});
|
||||
},
|
||||
[
|
||||
diffusersEditForm,
|
||||
dispatch,
|
||||
model.base_model,
|
||||
model.model_name,
|
||||
t,
|
||||
updateMainModel,
|
||||
]
|
||||
);
|
||||
|
||||
return modelToEdit ? (
|
||||
return (
|
||||
<Flex flexDirection="column" rowGap={4} width="100%">
|
||||
<Flex flexDirection="column">
|
||||
<Text fontSize="lg" fontWeight="bold">
|
||||
{retrievedModel.model_name}
|
||||
{model.model_name}
|
||||
</Text>
|
||||
<Text fontSize="sm" color="base.400">
|
||||
{MODEL_TYPE_MAP[retrievedModel.base_model]} Model
|
||||
{MODEL_TYPE_MAP[model.base_model]} Model
|
||||
</Text>
|
||||
</Flex>
|
||||
<Divider />
|
||||
@ -146,17 +150,5 @@ export default function DiffusersModelEdit(props: DiffusersModelEditProps) {
|
||||
</Flex>
|
||||
</form>
|
||||
</Flex>
|
||||
) : (
|
||||
<Flex
|
||||
sx={{
|
||||
width: '100%',
|
||||
justifyContent: 'center',
|
||||
alignItems: 'center',
|
||||
borderRadius: 'base',
|
||||
bg: 'base.900',
|
||||
}}
|
||||
>
|
||||
<Text fontWeight={'500'}>Pick A Model To Edit</Text>
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
@ -1,187 +1,46 @@
|
||||
import { Box, Flex, Spinner, Text, useColorMode } from '@chakra-ui/react';
|
||||
import { ButtonGroup, Flex, Text } from '@chakra-ui/react';
|
||||
import { EntityState } from '@reduxjs/toolkit';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import IAIInput from 'common/components/IAIInput';
|
||||
|
||||
import { forEach } from 'lodash-es';
|
||||
import type { ChangeEvent } from 'react';
|
||||
import { useCallback, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import {
|
||||
MainModelConfigEntity,
|
||||
useGetMainModelsQuery,
|
||||
} from 'services/api/endpoints/models';
|
||||
import ModelListItem from './ModelListItem';
|
||||
|
||||
import { useTranslation } from 'react-i18next';
|
||||
type ModelListProps = {
|
||||
selectedModelId: string | undefined;
|
||||
setSelectedModelId: (name: string | undefined) => void;
|
||||
};
|
||||
|
||||
import type { ChangeEvent, ReactNode } from 'react';
|
||||
import React, { useMemo, useState, useTransition } from 'react';
|
||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||
import { mode } from 'theme/util/mode';
|
||||
|
||||
function ModelFilterButton({
|
||||
label,
|
||||
isActive,
|
||||
onClick,
|
||||
}: {
|
||||
label: string;
|
||||
isActive: boolean;
|
||||
onClick: () => void;
|
||||
}) {
|
||||
return (
|
||||
<IAIButton onClick={onClick} isChecked={isActive} size="sm">
|
||||
{label}
|
||||
</IAIButton>
|
||||
);
|
||||
}
|
||||
|
||||
const ModelList = () => {
|
||||
const { data: mainModels } = useGetMainModelsQuery();
|
||||
const { colorMode } = useColorMode();
|
||||
|
||||
const [renderModelList, setRenderModelList] = React.useState<boolean>(false);
|
||||
|
||||
React.useEffect(() => {
|
||||
const timer = setTimeout(() => {
|
||||
setRenderModelList(true);
|
||||
}, 200);
|
||||
|
||||
return () => clearTimeout(timer);
|
||||
}, []);
|
||||
|
||||
const [searchText, setSearchText] = useState<string>('');
|
||||
const [isSelectedFilter, setIsSelectedFilter] = useState<
|
||||
'all' | 'checkpoint' | 'diffusers'
|
||||
>('all');
|
||||
const [_, startTransition] = useTransition();
|
||||
type ModelFormat = 'all' | 'checkpoint' | 'diffusers';
|
||||
|
||||
const ModelList = (props: ModelListProps) => {
|
||||
const { selectedModelId, setSelectedModelId } = props;
|
||||
const { t } = useTranslation();
|
||||
const [nameFilter, setNameFilter] = useState<string>('');
|
||||
const [modelFormatFilter, setModelFormatFilter] =
|
||||
useState<ModelFormat>('all');
|
||||
|
||||
const handleSearchFilter = (e: ChangeEvent<HTMLInputElement>) => {
|
||||
startTransition(() => {
|
||||
setSearchText(e.target.value);
|
||||
});
|
||||
};
|
||||
const { filteredDiffusersModels } = useGetMainModelsQuery(undefined, {
|
||||
selectFromResult: ({ data }) => ({
|
||||
filteredDiffusersModels: modelsFilter(data, 'diffusers', nameFilter),
|
||||
}),
|
||||
});
|
||||
|
||||
const renderModelListItems = useMemo(() => {
|
||||
const ckptModelListItemsToRender: ReactNode[] = [];
|
||||
const diffusersModelListItemsToRender: ReactNode[] = [];
|
||||
const filteredModelListItemsToRender: ReactNode[] = [];
|
||||
const localFilteredModelListItemsToRender: ReactNode[] = [];
|
||||
const { filteredCheckpointModels } = useGetMainModelsQuery(undefined, {
|
||||
selectFromResult: ({ data }) => ({
|
||||
filteredCheckpointModels: modelsFilter(data, 'checkpoint', nameFilter),
|
||||
}),
|
||||
});
|
||||
|
||||
if (!mainModels) return;
|
||||
|
||||
const modelList = mainModels.entities;
|
||||
|
||||
Object.keys(modelList).forEach((model, i) => {
|
||||
const modelInfo = modelList[model];
|
||||
|
||||
// If no model info found for a model, ignore it
|
||||
if (!modelInfo) return;
|
||||
|
||||
if (
|
||||
modelInfo.model_name.toLowerCase().includes(searchText.toLowerCase())
|
||||
) {
|
||||
filteredModelListItemsToRender.push(
|
||||
<ModelListItem
|
||||
key={i}
|
||||
modelKey={model}
|
||||
name={modelInfo.model_name}
|
||||
description={modelInfo.description}
|
||||
/>
|
||||
);
|
||||
if (modelInfo?.model_format === isSelectedFilter) {
|
||||
localFilteredModelListItemsToRender.push(
|
||||
<ModelListItem
|
||||
key={i}
|
||||
modelKey={model}
|
||||
name={modelInfo.model_name}
|
||||
description={modelInfo.description}
|
||||
/>
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if (modelInfo?.model_format !== 'diffusers') {
|
||||
ckptModelListItemsToRender.push(
|
||||
<ModelListItem
|
||||
key={i}
|
||||
modelKey={model}
|
||||
name={modelInfo.model_name}
|
||||
description={modelInfo.description}
|
||||
/>
|
||||
);
|
||||
} else {
|
||||
diffusersModelListItemsToRender.push(
|
||||
<ModelListItem
|
||||
key={i}
|
||||
modelKey={model}
|
||||
name={modelInfo.model_name}
|
||||
description={modelInfo.description}
|
||||
/>
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
return searchText !== '' ? (
|
||||
isSelectedFilter === 'all' ? (
|
||||
<Box marginTop={4}>{filteredModelListItemsToRender}</Box>
|
||||
) : (
|
||||
<Box marginTop={4}>{localFilteredModelListItemsToRender}</Box>
|
||||
)
|
||||
) : (
|
||||
<Flex flexDirection="column" rowGap={6}>
|
||||
{isSelectedFilter === 'all' && (
|
||||
<>
|
||||
{diffusersModelListItemsToRender.length > 0 && (
|
||||
<Box>
|
||||
<Text
|
||||
sx={{
|
||||
fontWeight: '500',
|
||||
py: 2,
|
||||
px: 4,
|
||||
mb: 4,
|
||||
borderRadius: 'base',
|
||||
width: 'max-content',
|
||||
fontSize: 'sm',
|
||||
bg: mode('base.100', 'base.800')(colorMode),
|
||||
}}
|
||||
>
|
||||
{t('modelManager.diffusersModels')}
|
||||
</Text>
|
||||
{diffusersModelListItemsToRender}
|
||||
</Box>
|
||||
)}
|
||||
|
||||
{ckptModelListItemsToRender.length > 0 && (
|
||||
<Box>
|
||||
<Text
|
||||
sx={{
|
||||
fontWeight: '500',
|
||||
py: 2,
|
||||
px: 4,
|
||||
my: 4,
|
||||
mx: 0,
|
||||
borderRadius: 'base',
|
||||
width: 'max-content',
|
||||
fontSize: 'sm',
|
||||
bg: mode('base.150', 'base.750')(colorMode),
|
||||
}}
|
||||
>
|
||||
{t('modelManager.checkpointModels')}
|
||||
</Text>
|
||||
{ckptModelListItemsToRender}
|
||||
</Box>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
|
||||
{isSelectedFilter === 'diffusers' && (
|
||||
<Flex flexDirection="column" marginTop={4}>
|
||||
{diffusersModelListItemsToRender}
|
||||
</Flex>
|
||||
)}
|
||||
|
||||
{isSelectedFilter === 'checkpoint' && (
|
||||
<Flex flexDirection="column" marginTop={4}>
|
||||
{ckptModelListItemsToRender}
|
||||
</Flex>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
}, [mainModels, searchText, t, isSelectedFilter, colorMode]);
|
||||
const handleSearchFilter = useCallback((e: ChangeEvent<HTMLInputElement>) => {
|
||||
setNameFilter(e.target.value);
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<Flex flexDirection="column" rowGap={4} width="50%" minWidth="50%">
|
||||
@ -189,7 +48,6 @@ const ModelList = () => {
|
||||
onChange={handleSearchFilter}
|
||||
label={t('modelManager.search')}
|
||||
/>
|
||||
|
||||
<Flex
|
||||
flexDirection="column"
|
||||
gap={4}
|
||||
@ -197,34 +55,58 @@ const ModelList = () => {
|
||||
overflow="scroll"
|
||||
paddingInlineEnd={4}
|
||||
>
|
||||
<Flex columnGap={2}>
|
||||
<ModelFilterButton
|
||||
label={t('modelManager.allModels')}
|
||||
onClick={() => setIsSelectedFilter('all')}
|
||||
isActive={isSelectedFilter === 'all'}
|
||||
/>
|
||||
<ModelFilterButton
|
||||
label={t('modelManager.diffusersModels')}
|
||||
onClick={() => setIsSelectedFilter('diffusers')}
|
||||
isActive={isSelectedFilter === 'diffusers'}
|
||||
/>
|
||||
<ModelFilterButton
|
||||
label={t('modelManager.checkpointModels')}
|
||||
onClick={() => setIsSelectedFilter('checkpoint')}
|
||||
isActive={isSelectedFilter === 'checkpoint'}
|
||||
/>
|
||||
</Flex>
|
||||
|
||||
{renderModelList ? (
|
||||
renderModelListItems
|
||||
) : (
|
||||
<Flex
|
||||
width="100%"
|
||||
minHeight={96}
|
||||
justifyContent="center"
|
||||
alignItems="center"
|
||||
<ButtonGroup isAttached>
|
||||
<IAIButton
|
||||
onClick={() => setModelFormatFilter('all')}
|
||||
isChecked={modelFormatFilter === 'all'}
|
||||
size="sm"
|
||||
>
|
||||
<Spinner />
|
||||
{t('modelManager.allModels')}
|
||||
</IAIButton>
|
||||
<IAIButton
|
||||
size="sm"
|
||||
onClick={() => setModelFormatFilter('diffusers')}
|
||||
isChecked={modelFormatFilter === 'diffusers'}
|
||||
>
|
||||
{t('modelManager.diffusersModels')}
|
||||
</IAIButton>
|
||||
<IAIButton
|
||||
size="sm"
|
||||
onClick={() => setModelFormatFilter('checkpoint')}
|
||||
isChecked={modelFormatFilter === 'checkpoint'}
|
||||
>
|
||||
{t('modelManager.checkpointModels')}
|
||||
</IAIButton>
|
||||
</ButtonGroup>
|
||||
|
||||
{['all', 'diffusers'].includes(modelFormatFilter) && (
|
||||
<Flex sx={{ gap: 2, flexDir: 'column' }}>
|
||||
<Text variant="subtext" size="sm">
|
||||
Diffusers
|
||||
</Text>
|
||||
{filteredDiffusersModels.map((model) => (
|
||||
<ModelListItem
|
||||
key={model.id}
|
||||
model={model}
|
||||
isSelected={selectedModelId === model.id}
|
||||
setSelectedModelId={setSelectedModelId}
|
||||
/>
|
||||
))}
|
||||
</Flex>
|
||||
)}
|
||||
{['all', 'checkpoint'].includes(modelFormatFilter) && (
|
||||
<Flex sx={{ gap: 2, flexDir: 'column' }}>
|
||||
<Text variant="subtext" size="sm">
|
||||
Checkpoint
|
||||
</Text>
|
||||
{filteredCheckpointModels.map((model) => (
|
||||
<ModelListItem
|
||||
key={model.id}
|
||||
model={model}
|
||||
isSelected={selectedModelId === model.id}
|
||||
setSelectedModelId={setSelectedModelId}
|
||||
/>
|
||||
))}
|
||||
</Flex>
|
||||
)}
|
||||
</Flex>
|
||||
@ -233,3 +115,27 @@ const ModelList = () => {
|
||||
};
|
||||
|
||||
export default ModelList;
|
||||
|
||||
const modelsFilter = (
|
||||
data: EntityState<MainModelConfigEntity> | undefined,
|
||||
model_format: ModelFormat,
|
||||
nameFilter: string
|
||||
) => {
|
||||
const filteredModels: MainModelConfigEntity[] = [];
|
||||
forEach(data?.entities, (model) => {
|
||||
if (!model) {
|
||||
return;
|
||||
}
|
||||
|
||||
const matchesFilter = model.model_name
|
||||
.toLowerCase()
|
||||
.includes(nameFilter.toLowerCase());
|
||||
|
||||
const matchesFormat = model.model_format === model_format;
|
||||
|
||||
if (matchesFilter && matchesFormat) {
|
||||
filteredModels.push(model);
|
||||
}
|
||||
});
|
||||
return filteredModels;
|
||||
};
|
||||
|
@ -1,115 +1,96 @@
|
||||
import { DeleteIcon, EditIcon } from '@chakra-ui/icons';
|
||||
import {
|
||||
Box,
|
||||
Flex,
|
||||
Spacer,
|
||||
Text,
|
||||
Tooltip,
|
||||
useColorMode,
|
||||
} from '@chakra-ui/react';
|
||||
|
||||
// import { deleteModel, requestModelChange } from 'app/socketio/actions';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { DeleteIcon } from '@chakra-ui/icons';
|
||||
import { Box, Flex, Spacer, Text, Tooltip } from '@chakra-ui/react';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIAlertDialog from 'common/components/IAIAlertDialog';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||
import { setOpenModel } from 'features/system/store/systemSlice';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useDeleteMainModelsMutation } from 'services/api/endpoints/models';
|
||||
import { BaseModelType } from 'services/api/types';
|
||||
import { mode } from 'theme/util/mode';
|
||||
import { FaEdit } from 'react-icons/fa';
|
||||
import {
|
||||
MainModelConfigEntity,
|
||||
useDeleteMainModelsMutation,
|
||||
} from 'services/api/endpoints/models';
|
||||
|
||||
type ModelListItemProps = {
|
||||
modelKey: string;
|
||||
name: string;
|
||||
description: string | undefined;
|
||||
model: MainModelConfigEntity;
|
||||
isSelected: boolean;
|
||||
setSelectedModelId: (v: string | undefined) => void;
|
||||
};
|
||||
|
||||
export default function ModelListItem(props: ModelListItemProps) {
|
||||
const isBusy = useAppSelector(selectIsBusy);
|
||||
|
||||
const { colorMode } = useColorMode();
|
||||
|
||||
const openModel = useAppSelector(
|
||||
(state: RootState) => state.system.openModel
|
||||
);
|
||||
|
||||
const { t } = useTranslation();
|
||||
const [deleteMainModel] = useDeleteMainModelsMutation();
|
||||
|
||||
const { t } = useTranslation();
|
||||
const { model, isSelected, setSelectedModelId } = props;
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const handleSelectModel = useCallback(() => {
|
||||
setSelectedModelId(model.id);
|
||||
}, [model.id, setSelectedModelId]);
|
||||
|
||||
const { modelKey, name, description } = props;
|
||||
|
||||
const openModelHandler = () => {
|
||||
dispatch(setOpenModel(modelKey));
|
||||
};
|
||||
|
||||
const handleModelDelete = () => {
|
||||
const [base_model, _, model_name] = modelKey.split('/');
|
||||
deleteMainModel({
|
||||
base_model: base_model as BaseModelType,
|
||||
model_name: model_name,
|
||||
});
|
||||
dispatch(setOpenModel(null));
|
||||
};
|
||||
const handleModelDelete = useCallback(() => {
|
||||
deleteMainModel(model);
|
||||
setSelectedModelId(undefined);
|
||||
}, [deleteMainModel, model, setSelectedModelId]);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
alignItems="center"
|
||||
p={2}
|
||||
borderRadius="base"
|
||||
sx={
|
||||
modelKey === openModel
|
||||
? {
|
||||
bg: mode('accent.200', 'accent.600')(colorMode),
|
||||
_hover: {
|
||||
bg: mode('accent.200', 'accent.600')(colorMode),
|
||||
},
|
||||
}
|
||||
: {
|
||||
_hover: {
|
||||
bg: mode('base.100', 'base.800')(colorMode),
|
||||
},
|
||||
}
|
||||
}
|
||||
>
|
||||
<Box onClick={openModelHandler} cursor="pointer">
|
||||
<Tooltip label={description} hasArrow placement="bottom">
|
||||
<Text fontWeight="600">{name}</Text>
|
||||
</Tooltip>
|
||||
</Box>
|
||||
<Spacer onClick={openModelHandler} cursor="pointer" />
|
||||
<Flex gap={2} alignItems="center">
|
||||
<Flex sx={{ gap: 2, alignItems: 'center', w: 'full' }}>
|
||||
<Flex
|
||||
as={IAIButton}
|
||||
isChecked={isSelected}
|
||||
sx={{
|
||||
p: 2,
|
||||
borderRadius: 'base',
|
||||
w: 'full',
|
||||
alignItems: 'center',
|
||||
bg: isSelected ? 'accent.200' : 'base.100',
|
||||
_hover: {
|
||||
bg: isSelected ? 'accent.250' : 'base.150',
|
||||
},
|
||||
_dark: {
|
||||
bg: isSelected ? 'accent.600' : 'base.850',
|
||||
_hover: {
|
||||
bg: isSelected ? 'accent.550' : 'base.800',
|
||||
},
|
||||
},
|
||||
}}
|
||||
onClick={handleSelectModel}
|
||||
>
|
||||
<Box cursor="pointer">
|
||||
<Tooltip label={model.description} hasArrow placement="bottom">
|
||||
<Text fontWeight="600">{model.model_name}</Text>
|
||||
</Tooltip>
|
||||
</Box>
|
||||
<Spacer onClick={handleSelectModel} cursor="pointer" />
|
||||
<IAIIconButton
|
||||
icon={<EditIcon />}
|
||||
icon={<FaEdit />}
|
||||
size="sm"
|
||||
onClick={openModelHandler}
|
||||
onClick={handleSelectModel}
|
||||
aria-label={t('accessibility.modifyConfig')}
|
||||
isDisabled={isBusy}
|
||||
variant="link"
|
||||
/>
|
||||
<IAIAlertDialog
|
||||
title={t('modelManager.deleteModel')}
|
||||
acceptCallback={handleModelDelete}
|
||||
acceptButtonText={t('modelManager.delete')}
|
||||
triggerComponent={
|
||||
<IAIIconButton
|
||||
icon={<DeleteIcon />}
|
||||
size="sm"
|
||||
aria-label={t('modelManager.deleteConfig')}
|
||||
isDisabled={isBusy}
|
||||
colorScheme="error"
|
||||
/>
|
||||
}
|
||||
>
|
||||
<Flex rowGap={4} flexDirection="column">
|
||||
<p style={{ fontWeight: 'bold' }}>{t('modelManager.deleteMsg1')}</p>
|
||||
<p>{t('modelManager.deleteMsg2')}</p>
|
||||
</Flex>
|
||||
</IAIAlertDialog>
|
||||
</Flex>
|
||||
<IAIAlertDialog
|
||||
title={t('modelManager.deleteModel')}
|
||||
acceptCallback={handleModelDelete}
|
||||
acceptButtonText={t('modelManager.delete')}
|
||||
triggerComponent={
|
||||
<IAIIconButton
|
||||
icon={<DeleteIcon />}
|
||||
aria-label={t('modelManager.deleteConfig')}
|
||||
isDisabled={isBusy}
|
||||
colorScheme="error"
|
||||
/>
|
||||
}
|
||||
>
|
||||
<Flex rowGap={4} flexDirection="column">
|
||||
<p style={{ fontWeight: 'bold' }}>{t('modelManager.deleteMsg1')}</p>
|
||||
<p>{t('modelManager.deleteMsg2')}</p>
|
||||
</Flex>
|
||||
</IAIAlertDialog>
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
@ -3,7 +3,9 @@ import { cloneDeep } from 'lodash-es';
|
||||
import {
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
CheckpointModelConfig,
|
||||
ControlNetModelConfig,
|
||||
DiffusersModelConfig,
|
||||
LoRAModelConfig,
|
||||
MainModelConfig,
|
||||
MergeModelConfig,
|
||||
@ -14,7 +16,13 @@ import {
|
||||
import { ApiFullTagDescription, LIST_TAG, api } from '..';
|
||||
import { paths } from '../schema';
|
||||
|
||||
export type MainModelConfigEntity = MainModelConfig & { id: string };
|
||||
export type DiffusersModelConfigEntity = DiffusersModelConfig & { id: string };
|
||||
export type CheckpointModelConfigEntity = CheckpointModelConfig & {
|
||||
id: string;
|
||||
};
|
||||
export type MainModelConfigEntity =
|
||||
| DiffusersModelConfigEntity
|
||||
| CheckpointModelConfigEntity;
|
||||
|
||||
export type LoRAModelConfigEntity = LoRAModelConfig & { id: string };
|
||||
|
||||
|
@ -42,11 +42,13 @@ export type ControlNetModelConfig =
|
||||
components['schemas']['ControlNetModelConfig'];
|
||||
export type TextualInversionModelConfig =
|
||||
components['schemas']['TextualInversionModelConfig'];
|
||||
export type MainModelConfig =
|
||||
| components['schemas']['StableDiffusion1ModelCheckpointConfig']
|
||||
export type DiffusersModelConfig =
|
||||
| components['schemas']['StableDiffusion1ModelDiffusersConfig']
|
||||
| components['schemas']['StableDiffusion2ModelCheckpointConfig']
|
||||
| components['schemas']['StableDiffusion2ModelDiffusersConfig'];
|
||||
export type CheckpointModelConfig =
|
||||
| components['schemas']['StableDiffusion1ModelCheckpointConfig']
|
||||
| components['schemas']['StableDiffusion2ModelCheckpointConfig'];
|
||||
export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig;
|
||||
export type AnyModelConfig =
|
||||
| LoRAModelConfig
|
||||
| VaeModelConfig
|
||||
|
@ -2,7 +2,7 @@ import { defineStyle, defineStyleConfig } from '@chakra-ui/react';
|
||||
import { mode } from '@chakra-ui/theme-tools';
|
||||
|
||||
const subtext = defineStyle((props) => ({
|
||||
color: mode('colors.base.500', 'colors.base.400')(props),
|
||||
color: mode('base.500', 'base.400')(props),
|
||||
}));
|
||||
|
||||
export const textTheme = defineStyleConfig({
|
||||
|
Loading…
Reference in New Issue
Block a user