From 1e5ae9d986b11c15c2225ea3f55a4f2853e4be2f Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 14 Jul 2023 19:22:37 +1000 Subject: [PATCH] feat(ui): refactor model manager ui - simplify UI logic in `ModelManagerPanel` components - fix up the types a bit to make it easier to select models - remove `openModel` state, just make it a useState since it is very local to model manager --- .../system/store/systemPersistDenylist.ts | 1 - .../src/features/system/store/systemSlice.ts | 6 - .../tabs/ModelManager/ModelManagerTab.tsx | 67 +--- .../subpanels/ModelManagerPanel.tsx | 83 +++-- .../ModelManagerPanel/CheckpointModelEdit.tsx | 138 ++++---- .../ModelManagerPanel/DiffusersModelEdit.tsx | 130 ++++---- .../subpanels/ModelManagerPanel/ModelList.tsx | 310 ++++++------------ .../ModelManagerPanel/ModelListItem.tsx | 161 ++++----- .../web/src/services/api/endpoints/models.ts | 10 +- .../frontend/web/src/services/api/types.d.ts | 8 +- .../frontend/web/src/theme/components/text.ts | 2 +- 11 files changed, 384 insertions(+), 532 deletions(-) diff --git a/invokeai/frontend/web/src/features/system/store/systemPersistDenylist.ts b/invokeai/frontend/web/src/features/system/store/systemPersistDenylist.ts index c2481c29df..bba279c4bc 100644 --- a/invokeai/frontend/web/src/features/system/store/systemPersistDenylist.ts +++ b/invokeai/frontend/web/src/features/system/store/systemPersistDenylist.ts @@ -13,7 +13,6 @@ export const systemPersistDenylist: (keyof SystemState)[] = [ 'isProcessing', 'totalIterations', 'totalSteps', - 'openModel', 'isCancelScheduled', 'progressImage', 'wereModelsReceived', diff --git a/invokeai/frontend/web/src/features/system/store/systemSlice.ts b/invokeai/frontend/web/src/features/system/store/systemSlice.ts index 01c1344263..4d723378ba 100644 --- a/invokeai/frontend/web/src/features/system/store/systemSlice.ts +++ b/invokeai/frontend/web/src/features/system/store/systemSlice.ts @@ -46,7 +46,6 @@ export interface SystemState { toastQueue: UseToastOptions[]; searchFolder: string | null; foundModels: InvokeAI.FoundModel[] | null; - openModel: string | null; /** * The current progress image */ @@ -109,7 +108,6 @@ export const initialSystemState: SystemState = { toastQueue: [], searchFolder: null, foundModels: null, - openModel: null, progressImage: null, shouldAntialiasProgressImage: false, sessionId: null, @@ -164,9 +162,6 @@ export const systemSlice = createSlice({ ) => { state.foundModels = action.payload; }, - setOpenModel: (state, action: PayloadAction) => { - state.openModel = action.payload; - }, /** * A cancel was scheduled */ @@ -433,7 +428,6 @@ export const { clearToastQueue, setSearchFolder, setFoundModels, - setOpenModel, cancelScheduled, scheduledCancelAborted, cancelTypeChanged, diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/ModelManagerTab.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/ModelManagerTab.tsx index 7a56a3f651..70de375774 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/ModelManagerTab.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/ModelManagerTab.tsx @@ -1,14 +1,6 @@ -import { - Tab, - TabList, - TabPanel, - TabPanels, - Tabs, - useColorMode, -} from '@chakra-ui/react'; +import { Tab, TabList, TabPanel, TabPanels, Tabs } from '@chakra-ui/react'; import i18n from 'i18n'; import { ReactNode, memo } from 'react'; -import { mode } from 'theme/util/mode'; import AddModelsPanel from './subpanels/AddModelsPanel'; import MergeModelsPanel from './subpanels/MergeModelsPanel'; import ModelManagerPanel from './subpanels/ModelManagerPanel'; @@ -21,7 +13,7 @@ type ModelManagerTabInfo = { content: ReactNode; }; -const modelManagerTabs: ModelManagerTabInfo[] = [ +const tabs: ModelManagerTabInfo[] = [ { id: 'modelManager', label: i18n.t('modelManager.modelManager'), @@ -40,50 +32,25 @@ const modelManagerTabs: ModelManagerTabInfo[] = [ ]; const ModelManagerTab = () => { - const { colorMode } = useColorMode(); - - const renderTabsList = () => { - const modelManagerTabListsToRender: ReactNode[] = []; - - modelManagerTabs.forEach((modelManagerTab) => { - modelManagerTabListsToRender.push( - {modelManagerTab.label} - ); - }); - - return ( - - {modelManagerTabListsToRender} - - ); - }; - - const renderTabPanels = () => { - const modelManagerTabPanelsToRender: ReactNode[] = []; - modelManagerTabs.forEach((modelManagerTab) => { - modelManagerTabPanelsToRender.push( - {modelManagerTab.content} - ); - }); - - return {modelManagerTabPanelsToRender}; - }; return ( - {renderTabsList()} - {renderTabPanels()} + + {tabs.map((tab) => ( + + {tab.label} + + ))} + + + {tabs.map((tab) => ( + {tab.content} + ))} + ); }; diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel.tsx index ddf0874b21..f681b79437 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel.tsx @@ -1,48 +1,59 @@ -import { Flex } from '@chakra-ui/react'; -import { RootState } from 'app/store/store'; -import { useAppSelector } from 'app/store/storeHooks'; +import { Flex, Text } from '@chakra-ui/react'; -import { useGetMainModelsQuery } from 'services/api/endpoints/models'; +import { useState } from 'react'; +import { + MainModelConfigEntity, + useGetMainModelsQuery, +} from 'services/api/endpoints/models'; import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit'; import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit'; import ModelList from './ModelManagerPanel/ModelList'; export default function ModelManagerPanel() { - const { data: mainModels } = useGetMainModelsQuery(); + const [selectedModelId, setSelectedModelId] = useState(); + const { model } = useGetMainModelsQuery(undefined, { + selectFromResult: ({ data }) => ({ + model: selectedModelId ? data?.entities[selectedModelId] : undefined, + }), + }); - const openModel = useAppSelector( - (state: RootState) => state.system.openModel - ); - - const renderModelEditTabs = () => { - if (!openModel || !mainModels) return; - - const openedModelData = mainModels['entities'][openModel]; - - if (openedModelData && openedModelData.model_format === 'diffusers') { - return ( - - ); - } - - if (openedModelData && openedModelData.model_format === 'checkpoint') { - return ( - - ); - } - }; return ( - - {renderModelEditTabs()} + + ); } + +type ModelEditProps = { + model: MainModelConfigEntity | undefined; +}; + +const ModelEdit = (props: ModelEditProps) => { + const { model } = props; + + if (model?.model_format === 'checkpoint') { + return ; + } + + if (model?.model_format === 'diffusers') { + return ; + } + + return ( + + Pick A Model To Edit + + ); +}; diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx index 91d668f1e2..2cdaea904f 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx @@ -1,21 +1,20 @@ -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; - import { Divider, Flex, Text } from '@chakra-ui/react'; - -// import { addNewModel } from 'app/socketio/actions'; import { useForm } from '@mantine/form'; -import { useTranslation } from 'react-i18next'; - -import IAIButton from 'common/components/IAIButton'; -import IAIMantineSelect from 'common/components/IAIMantineSelect'; - import { makeToast } from 'app/components/Toaster'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import IAIButton from 'common/components/IAIButton'; import IAIMantineTextInput from 'common/components/IAIMantineInput'; +import IAIMantineSelect from 'common/components/IAIMantineSelect'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { selectIsBusy } from 'features/system/store/systemSelectors'; import { addToast } from 'features/system/store/systemSlice'; -import { useUpdateMainModelsMutation } from 'services/api/endpoints/models'; -import { components } from 'services/api/schema'; +import { useCallback } from 'react'; +import { useTranslation } from 'react-i18next'; +import { + CheckpointModelConfigEntity, + useUpdateMainModelsMutation, +} from 'services/api/endpoints/models'; +import { CheckpointModelConfig } from 'services/api/types'; import ModelConvert from './ModelConvert'; const baseModelSelectData = [ @@ -29,36 +28,31 @@ const variantSelectData = [ { value: 'depth', label: 'Depth' }, ]; -export type CheckpointModelConfig = - | components['schemas']['StableDiffusion1ModelCheckpointConfig'] - | components['schemas']['StableDiffusion2ModelCheckpointConfig']; - type CheckpointModelEditProps = { - modelToEdit: string; - retrievedModel: CheckpointModelConfig; + model: CheckpointModelConfigEntity; }; export default function CheckpointModelEdit(props: CheckpointModelEditProps) { const isBusy = useAppSelector(selectIsBusy); - const { modelToEdit, retrievedModel } = props; + const { model } = props; - const [updateMainModel, { error, isLoading }] = useUpdateMainModelsMutation(); + const [updateMainModel, { isLoading }] = useUpdateMainModelsMutation(); const dispatch = useAppDispatch(); const { t } = useTranslation(); const checkpointEditForm = useForm({ initialValues: { - model_name: retrievedModel.model_name ? retrievedModel.model_name : '', - base_model: retrievedModel.base_model, + model_name: model.model_name ? model.model_name : '', + base_model: model.base_model, model_type: 'main', - path: retrievedModel.path ? retrievedModel.path : '', - description: retrievedModel.description ? retrievedModel.description : '', + path: model.path ? model.path : '', + description: model.description ? model.description : '', model_format: 'checkpoint', - vae: retrievedModel.vae ? retrievedModel.vae : '', - config: retrievedModel.config ? retrievedModel.config : '', - variant: retrievedModel.variant, + vae: model.vae ? model.vae : '', + config: model.config ? model.config : '', + variant: model.variant, }, validate: { path: (value) => @@ -66,50 +60,60 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) { }, }); - const editModelFormSubmitHandler = (values: CheckpointModelConfig) => { - const responseBody = { - base_model: retrievedModel.base_model, - model_name: retrievedModel.model_name, - body: values, - }; - updateMainModel(responseBody) - .unwrap() - .then((payload) => { - checkpointEditForm.setValues(payload as CheckpointModelConfig); - dispatch( - addToast( - makeToast({ - title: t('modelManager.modelUpdated'), - status: 'success', - }) - ) - ); - }) - .catch((error) => { - checkpointEditForm.reset(); - dispatch( - addToast( - makeToast({ - title: t('modelManager.modelUpdateFailed'), - status: 'error', - }) - ) - ); - }); - }; + const editModelFormSubmitHandler = useCallback( + (values: CheckpointModelConfig) => { + const responseBody = { + base_model: model.base_model, + model_name: model.model_name, + body: values, + }; + updateMainModel(responseBody) + .unwrap() + .then((payload) => { + checkpointEditForm.setValues(payload as CheckpointModelConfig); + dispatch( + addToast( + makeToast({ + title: t('modelManager.modelUpdated'), + status: 'success', + }) + ) + ); + }) + .catch((error) => { + checkpointEditForm.reset(); + dispatch( + addToast( + makeToast({ + title: t('modelManager.modelUpdateFailed'), + status: 'error', + }) + ) + ); + }); + }, + [ + checkpointEditForm, + dispatch, + model.base_model, + model.model_name, + t, + updateMainModel, + ] + ); - return modelToEdit ? ( + return ( - {retrievedModel.model_name} + {model.model_name} - {MODEL_TYPE_MAP[retrievedModel.base_model]} Model + {MODEL_TYPE_MAP[model.base_model]} Model - + @@ -161,17 +165,5 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) { - ) : ( - - Pick A Model To Edit - ); } diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx index 65793c9d2c..5c1667b331 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx @@ -1,29 +1,23 @@ -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; - import { Divider, Flex, Text } from '@chakra-ui/react'; - -// import { addNewModel } from 'app/socketio/actions'; -import { useTranslation } from 'react-i18next'; - import { useForm } from '@mantine/form'; import { makeToast } from 'app/components/Toaster'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import IAIButton from 'common/components/IAIButton'; import IAIMantineTextInput from 'common/components/IAIMantineInput'; import IAIMantineSelect from 'common/components/IAIMantineSelect'; - import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { selectIsBusy } from 'features/system/store/systemSelectors'; import { addToast } from 'features/system/store/systemSlice'; -import { useUpdateMainModelsMutation } from 'services/api/endpoints/models'; -import { components } from 'services/api/schema'; - -export type DiffusersModelConfig = - | components['schemas']['StableDiffusion1ModelDiffusersConfig'] - | components['schemas']['StableDiffusion2ModelDiffusersConfig']; +import { useCallback } from 'react'; +import { useTranslation } from 'react-i18next'; +import { + DiffusersModelConfigEntity, + useUpdateMainModelsMutation, +} from 'services/api/endpoints/models'; +import { DiffusersModelConfig } from 'services/api/types'; type DiffusersModelEditProps = { - modelToEdit: string; - retrievedModel: DiffusersModelConfig; + model: DiffusersModelConfigEntity; }; const baseModelSelectData = [ @@ -40,23 +34,23 @@ const variantSelectData = [ export default function DiffusersModelEdit(props: DiffusersModelEditProps) { const isBusy = useAppSelector(selectIsBusy); - const { retrievedModel, modelToEdit } = props; + const { model } = props; - const [updateMainModel, { isLoading, error }] = useUpdateMainModelsMutation(); + const [updateMainModel, { isLoading }] = useUpdateMainModelsMutation(); const dispatch = useAppDispatch(); const { t } = useTranslation(); const diffusersEditForm = useForm({ initialValues: { - model_name: retrievedModel.model_name ? retrievedModel.model_name : '', - base_model: retrievedModel.base_model, + model_name: model.model_name ? model.model_name : '', + base_model: model.base_model, model_type: 'main', - path: retrievedModel.path ? retrievedModel.path : '', - description: retrievedModel.description ? retrievedModel.description : '', + path: model.path ? model.path : '', + description: model.description ? model.description : '', model_format: 'diffusers', - vae: retrievedModel.vae ? retrievedModel.vae : '', - variant: retrievedModel.variant, + vae: model.vae ? model.vae : '', + variant: model.variant, }, validate: { path: (value) => @@ -64,46 +58,56 @@ export default function DiffusersModelEdit(props: DiffusersModelEditProps) { }, }); - const editModelFormSubmitHandler = (values: DiffusersModelConfig) => { - const responseBody = { - base_model: retrievedModel.base_model, - model_name: retrievedModel.model_name, - body: values, - }; - updateMainModel(responseBody) - .unwrap() - .then((payload) => { - diffusersEditForm.setValues(payload as DiffusersModelConfig); - dispatch( - addToast( - makeToast({ - title: t('modelManager.modelUpdated'), - status: 'success', - }) - ) - ); - }) - .catch((error) => { - diffusersEditForm.reset(); - dispatch( - addToast( - makeToast({ - title: t('modelManager.modelUpdateFailed'), - status: 'error', - }) - ) - ); - }); - }; + const editModelFormSubmitHandler = useCallback( + (values: DiffusersModelConfig) => { + const responseBody = { + base_model: model.base_model, + model_name: model.model_name, + body: values, + }; + updateMainModel(responseBody) + .unwrap() + .then((payload) => { + diffusersEditForm.setValues(payload as DiffusersModelConfig); + dispatch( + addToast( + makeToast({ + title: t('modelManager.modelUpdated'), + status: 'success', + }) + ) + ); + }) + .catch((error) => { + diffusersEditForm.reset(); + dispatch( + addToast( + makeToast({ + title: t('modelManager.modelUpdateFailed'), + status: 'error', + }) + ) + ); + }); + }, + [ + diffusersEditForm, + dispatch, + model.base_model, + model.model_name, + t, + updateMainModel, + ] + ); - return modelToEdit ? ( + return ( - {retrievedModel.model_name} + {model.model_name} - {MODEL_TYPE_MAP[retrievedModel.base_model]} Model + {MODEL_TYPE_MAP[model.base_model]} Model @@ -146,17 +150,5 @@ export default function DiffusersModelEdit(props: DiffusersModelEditProps) { - ) : ( - - Pick A Model To Edit - ); } diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx index 655b8c4de3..b0e44f7615 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx @@ -1,187 +1,46 @@ -import { Box, Flex, Spinner, Text, useColorMode } from '@chakra-ui/react'; +import { ButtonGroup, Flex, Text } from '@chakra-ui/react'; +import { EntityState } from '@reduxjs/toolkit'; import IAIButton from 'common/components/IAIButton'; import IAIInput from 'common/components/IAIInput'; - +import { forEach } from 'lodash-es'; +import type { ChangeEvent } from 'react'; +import { useCallback, useState } from 'react'; +import { useTranslation } from 'react-i18next'; +import { + MainModelConfigEntity, + useGetMainModelsQuery, +} from 'services/api/endpoints/models'; import ModelListItem from './ModelListItem'; -import { useTranslation } from 'react-i18next'; +type ModelListProps = { + selectedModelId: string | undefined; + setSelectedModelId: (name: string | undefined) => void; +}; -import type { ChangeEvent, ReactNode } from 'react'; -import React, { useMemo, useState, useTransition } from 'react'; -import { useGetMainModelsQuery } from 'services/api/endpoints/models'; -import { mode } from 'theme/util/mode'; - -function ModelFilterButton({ - label, - isActive, - onClick, -}: { - label: string; - isActive: boolean; - onClick: () => void; -}) { - return ( - - {label} - - ); -} - -const ModelList = () => { - const { data: mainModels } = useGetMainModelsQuery(); - const { colorMode } = useColorMode(); - - const [renderModelList, setRenderModelList] = React.useState(false); - - React.useEffect(() => { - const timer = setTimeout(() => { - setRenderModelList(true); - }, 200); - - return () => clearTimeout(timer); - }, []); - - const [searchText, setSearchText] = useState(''); - const [isSelectedFilter, setIsSelectedFilter] = useState< - 'all' | 'checkpoint' | 'diffusers' - >('all'); - const [_, startTransition] = useTransition(); +type ModelFormat = 'all' | 'checkpoint' | 'diffusers'; +const ModelList = (props: ModelListProps) => { + const { selectedModelId, setSelectedModelId } = props; const { t } = useTranslation(); + const [nameFilter, setNameFilter] = useState(''); + const [modelFormatFilter, setModelFormatFilter] = + useState('all'); - const handleSearchFilter = (e: ChangeEvent) => { - startTransition(() => { - setSearchText(e.target.value); - }); - }; + const { filteredDiffusersModels } = useGetMainModelsQuery(undefined, { + selectFromResult: ({ data }) => ({ + filteredDiffusersModels: modelsFilter(data, 'diffusers', nameFilter), + }), + }); - const renderModelListItems = useMemo(() => { - const ckptModelListItemsToRender: ReactNode[] = []; - const diffusersModelListItemsToRender: ReactNode[] = []; - const filteredModelListItemsToRender: ReactNode[] = []; - const localFilteredModelListItemsToRender: ReactNode[] = []; + const { filteredCheckpointModels } = useGetMainModelsQuery(undefined, { + selectFromResult: ({ data }) => ({ + filteredCheckpointModels: modelsFilter(data, 'checkpoint', nameFilter), + }), + }); - if (!mainModels) return; - - const modelList = mainModels.entities; - - Object.keys(modelList).forEach((model, i) => { - const modelInfo = modelList[model]; - - // If no model info found for a model, ignore it - if (!modelInfo) return; - - if ( - modelInfo.model_name.toLowerCase().includes(searchText.toLowerCase()) - ) { - filteredModelListItemsToRender.push( - - ); - if (modelInfo?.model_format === isSelectedFilter) { - localFilteredModelListItemsToRender.push( - - ); - } - } - - if (modelInfo?.model_format !== 'diffusers') { - ckptModelListItemsToRender.push( - - ); - } else { - diffusersModelListItemsToRender.push( - - ); - } - }); - - return searchText !== '' ? ( - isSelectedFilter === 'all' ? ( - {filteredModelListItemsToRender} - ) : ( - {localFilteredModelListItemsToRender} - ) - ) : ( - - {isSelectedFilter === 'all' && ( - <> - {diffusersModelListItemsToRender.length > 0 && ( - - - {t('modelManager.diffusersModels')} - - {diffusersModelListItemsToRender} - - )} - - {ckptModelListItemsToRender.length > 0 && ( - - - {t('modelManager.checkpointModels')} - - {ckptModelListItemsToRender} - - )} - - )} - - {isSelectedFilter === 'diffusers' && ( - - {diffusersModelListItemsToRender} - - )} - - {isSelectedFilter === 'checkpoint' && ( - - {ckptModelListItemsToRender} - - )} - - ); - }, [mainModels, searchText, t, isSelectedFilter, colorMode]); + const handleSearchFilter = useCallback((e: ChangeEvent) => { + setNameFilter(e.target.value); + }, []); return ( @@ -189,7 +48,6 @@ const ModelList = () => { onChange={handleSearchFilter} label={t('modelManager.search')} /> - { overflow="scroll" paddingInlineEnd={4} > - - setIsSelectedFilter('all')} - isActive={isSelectedFilter === 'all'} - /> - setIsSelectedFilter('diffusers')} - isActive={isSelectedFilter === 'diffusers'} - /> - setIsSelectedFilter('checkpoint')} - isActive={isSelectedFilter === 'checkpoint'} - /> - - - {renderModelList ? ( - renderModelListItems - ) : ( - + setModelFormatFilter('all')} + isChecked={modelFormatFilter === 'all'} + size="sm" > - + {t('modelManager.allModels')} + + setModelFormatFilter('diffusers')} + isChecked={modelFormatFilter === 'diffusers'} + > + {t('modelManager.diffusersModels')} + + setModelFormatFilter('checkpoint')} + isChecked={modelFormatFilter === 'checkpoint'} + > + {t('modelManager.checkpointModels')} + + + + {['all', 'diffusers'].includes(modelFormatFilter) && ( + + + Diffusers + + {filteredDiffusersModels.map((model) => ( + + ))} + + )} + {['all', 'checkpoint'].includes(modelFormatFilter) && ( + + + Checkpoint + + {filteredCheckpointModels.map((model) => ( + + ))} )} @@ -233,3 +115,27 @@ const ModelList = () => { }; export default ModelList; + +const modelsFilter = ( + data: EntityState | undefined, + model_format: ModelFormat, + nameFilter: string +) => { + const filteredModels: MainModelConfigEntity[] = []; + forEach(data?.entities, (model) => { + if (!model) { + return; + } + + const matchesFilter = model.model_name + .toLowerCase() + .includes(nameFilter.toLowerCase()); + + const matchesFormat = model.model_format === model_format; + + if (matchesFilter && matchesFormat) { + filteredModels.push(model); + } + }); + return filteredModels; +}; diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelListItem.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelListItem.tsx index 5a90327aa3..84ed784d1e 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelListItem.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelListItem.tsx @@ -1,115 +1,96 @@ -import { DeleteIcon, EditIcon } from '@chakra-ui/icons'; -import { - Box, - Flex, - Spacer, - Text, - Tooltip, - useColorMode, -} from '@chakra-ui/react'; - -// import { deleteModel, requestModelChange } from 'app/socketio/actions'; -import { RootState } from 'app/store/store'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { DeleteIcon } from '@chakra-ui/icons'; +import { Box, Flex, Spacer, Text, Tooltip } from '@chakra-ui/react'; +import { useAppSelector } from 'app/store/storeHooks'; import IAIAlertDialog from 'common/components/IAIAlertDialog'; +import IAIButton from 'common/components/IAIButton'; import IAIIconButton from 'common/components/IAIIconButton'; import { selectIsBusy } from 'features/system/store/systemSelectors'; -import { setOpenModel } from 'features/system/store/systemSlice'; +import { useCallback } from 'react'; import { useTranslation } from 'react-i18next'; -import { useDeleteMainModelsMutation } from 'services/api/endpoints/models'; -import { BaseModelType } from 'services/api/types'; -import { mode } from 'theme/util/mode'; +import { FaEdit } from 'react-icons/fa'; +import { + MainModelConfigEntity, + useDeleteMainModelsMutation, +} from 'services/api/endpoints/models'; type ModelListItemProps = { - modelKey: string; - name: string; - description: string | undefined; + model: MainModelConfigEntity; + isSelected: boolean; + setSelectedModelId: (v: string | undefined) => void; }; export default function ModelListItem(props: ModelListItemProps) { const isBusy = useAppSelector(selectIsBusy); - - const { colorMode } = useColorMode(); - - const openModel = useAppSelector( - (state: RootState) => state.system.openModel - ); - + const { t } = useTranslation(); const [deleteMainModel] = useDeleteMainModelsMutation(); - const { t } = useTranslation(); + const { model, isSelected, setSelectedModelId } = props; - const dispatch = useAppDispatch(); + const handleSelectModel = useCallback(() => { + setSelectedModelId(model.id); + }, [model.id, setSelectedModelId]); - const { modelKey, name, description } = props; - - const openModelHandler = () => { - dispatch(setOpenModel(modelKey)); - }; - - const handleModelDelete = () => { - const [base_model, _, model_name] = modelKey.split('/'); - deleteMainModel({ - base_model: base_model as BaseModelType, - model_name: model_name, - }); - dispatch(setOpenModel(null)); - }; + const handleModelDelete = useCallback(() => { + deleteMainModel(model); + setSelectedModelId(undefined); + }, [deleteMainModel, model, setSelectedModelId]); return ( - - - - {name} - - - - + + + + + {model.model_name} + + + } + icon={} size="sm" - onClick={openModelHandler} + onClick={handleSelectModel} aria-label={t('accessibility.modifyConfig')} isDisabled={isBusy} + variant="link" /> - } - size="sm" - aria-label={t('modelManager.deleteConfig')} - isDisabled={isBusy} - colorScheme="error" - /> - } - > - -

{t('modelManager.deleteMsg1')}

-

{t('modelManager.deleteMsg2')}

-
-
+ } + aria-label={t('modelManager.deleteConfig')} + isDisabled={isBusy} + colorScheme="error" + /> + } + > + +

{t('modelManager.deleteMsg1')}

+

{t('modelManager.deleteMsg2')}

+
+
); } diff --git a/invokeai/frontend/web/src/services/api/endpoints/models.ts b/invokeai/frontend/web/src/services/api/endpoints/models.ts index 1038a88c09..c86ad91100 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/models.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/models.ts @@ -3,7 +3,9 @@ import { cloneDeep } from 'lodash-es'; import { AnyModelConfig, BaseModelType, + CheckpointModelConfig, ControlNetModelConfig, + DiffusersModelConfig, LoRAModelConfig, MainModelConfig, MergeModelConfig, @@ -14,7 +16,13 @@ import { import { ApiFullTagDescription, LIST_TAG, api } from '..'; import { paths } from '../schema'; -export type MainModelConfigEntity = MainModelConfig & { id: string }; +export type DiffusersModelConfigEntity = DiffusersModelConfig & { id: string }; +export type CheckpointModelConfigEntity = CheckpointModelConfig & { + id: string; +}; +export type MainModelConfigEntity = + | DiffusersModelConfigEntity + | CheckpointModelConfigEntity; export type LoRAModelConfigEntity = LoRAModelConfig & { id: string }; diff --git a/invokeai/frontend/web/src/services/api/types.d.ts b/invokeai/frontend/web/src/services/api/types.d.ts index c2657701e7..fcbbd1a6a0 100644 --- a/invokeai/frontend/web/src/services/api/types.d.ts +++ b/invokeai/frontend/web/src/services/api/types.d.ts @@ -42,11 +42,13 @@ export type ControlNetModelConfig = components['schemas']['ControlNetModelConfig']; export type TextualInversionModelConfig = components['schemas']['TextualInversionModelConfig']; -export type MainModelConfig = - | components['schemas']['StableDiffusion1ModelCheckpointConfig'] +export type DiffusersModelConfig = | components['schemas']['StableDiffusion1ModelDiffusersConfig'] - | components['schemas']['StableDiffusion2ModelCheckpointConfig'] | components['schemas']['StableDiffusion2ModelDiffusersConfig']; +export type CheckpointModelConfig = + | components['schemas']['StableDiffusion1ModelCheckpointConfig'] + | components['schemas']['StableDiffusion2ModelCheckpointConfig']; +export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig; export type AnyModelConfig = | LoRAModelConfig | VaeModelConfig diff --git a/invokeai/frontend/web/src/theme/components/text.ts b/invokeai/frontend/web/src/theme/components/text.ts index 2404bf0594..cccbcf2391 100644 --- a/invokeai/frontend/web/src/theme/components/text.ts +++ b/invokeai/frontend/web/src/theme/components/text.ts @@ -2,7 +2,7 @@ import { defineStyle, defineStyleConfig } from '@chakra-ui/react'; import { mode } from '@chakra-ui/theme-tools'; const subtext = defineStyle((props) => ({ - color: mode('colors.base.500', 'colors.base.400')(props), + color: mode('base.500', 'base.400')(props), })); export const textTheme = defineStyleConfig({