fix: Restore Model display and select functionality

This commit is contained in:
blessedcoolant 2023-06-26 18:14:44 +12:00 committed by psychedelicious
parent b4b760d9e9
commit e73f774920
3 changed files with 75 additions and 84 deletions

View File

@ -1,36 +1,16 @@
import { Box, Flex, Heading, Spacer, Spinner, Text } from '@chakra-ui/react'; import { Box, Flex, Heading, Spacer, Spinner, Text } from '@chakra-ui/react';
import IAIInput from 'common/components/IAIInput';
import IAIButton from 'common/components/IAIButton'; import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput';
import AddModel from './AddModel'; import AddModel from './AddModel';
import ModelListItem from './ModelListItem';
import MergeModels from './MergeModels'; import MergeModels from './MergeModels';
import ModelListItem from './ModelListItem';
import { useAppSelector } from 'app/store/storeHooks';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { createSelector } from '@reduxjs/toolkit';
import { systemSelector } from 'features/system/store/systemSelectors';
import type { SystemState } from 'features/system/store/systemSlice';
import { isEqual, map } from 'lodash-es';
import React, { useMemo, useState, useTransition } from 'react';
import type { ChangeEvent, ReactNode } from 'react'; import type { ChangeEvent, ReactNode } from 'react';
import React, { useMemo, useState, useTransition } from 'react';
const modelListSelector = createSelector( import { useListModelsQuery } from 'services/api/endpoints/models';
systemSelector,
(system: SystemState) => {
const models = map(system.model_list, (model, key) => {
return { name: key, ...model };
});
return models;
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
function ModelFilterButton({ function ModelFilterButton({
label, label,
@ -58,7 +38,9 @@ function ModelFilterButton({
} }
const ModelList = () => { const ModelList = () => {
const models = useAppSelector(modelListSelector); const { data: pipelineModels } = useListModelsQuery({
model_type: 'pipeline',
});
const [renderModelList, setRenderModelList] = React.useState<boolean>(false); const [renderModelList, setRenderModelList] = React.useState<boolean>(false);
@ -90,43 +72,49 @@ const ModelList = () => {
const filteredModelListItemsToRender: ReactNode[] = []; const filteredModelListItemsToRender: ReactNode[] = [];
const localFilteredModelListItemsToRender: ReactNode[] = []; const localFilteredModelListItemsToRender: ReactNode[] = [];
models.forEach((model, i) => { if (!pipelineModels) return;
if (model.name.toLowerCase().includes(searchText.toLowerCase())) {
const modelList = pipelineModels.entities;
Object.keys(modelList).forEach((model, i) => {
if (
modelList[model].name.toLowerCase().includes(searchText.toLowerCase())
) {
filteredModelListItemsToRender.push( filteredModelListItemsToRender.push(
<ModelListItem <ModelListItem
key={i} key={i}
name={model.name} modelKey={model}
status={model.status} name={modelList[model].name}
description={model.description} description={modelList[model].description}
/> />
); );
if (model.format === isSelectedFilter) { if (modelList[model]?.model_format === isSelectedFilter) {
localFilteredModelListItemsToRender.push( localFilteredModelListItemsToRender.push(
<ModelListItem <ModelListItem
key={i} key={i}
name={model.name} modelKey={model}
status={model.status} name={modelList[model].name}
description={model.description} description={modelList[model].description}
/> />
); );
} }
} }
if (model.format !== 'diffusers') { if (modelList[model]?.model_format !== 'diffusers') {
ckptModelListItemsToRender.push( ckptModelListItemsToRender.push(
<ModelListItem <ModelListItem
key={i} key={i}
name={model.name} modelKey={model}
status={model.status} name={modelList[model].name}
description={model.description} description={modelList[model].description}
/> />
); );
} else { } else {
diffusersModelListItemsToRender.push( diffusersModelListItemsToRender.push(
<ModelListItem <ModelListItem
key={i} key={i}
name={model.name} modelKey={model}
status={model.status} name={modelList[model].name}
description={model.description} description={modelList[model].description}
/> />
); );
} }
@ -142,6 +130,23 @@ const ModelList = () => {
<Flex flexDirection="column" rowGap={6}> <Flex flexDirection="column" rowGap={6}>
{isSelectedFilter === 'all' && ( {isSelectedFilter === 'all' && (
<> <>
<Box>
<Text
sx={{
fontWeight: '500',
py: 2,
px: 4,
mb: 4,
borderRadius: 'base',
width: 'max-content',
fontSize: 'sm',
bg: 'base.750',
}}
>
{t('modelManager.diffusersModels')}
</Text>
{diffusersModelListItemsToRender}
</Box>
<Box> <Box>
<Text <Text
sx={{ sx={{
@ -160,40 +165,23 @@ const ModelList = () => {
</Text> </Text>
{ckptModelListItemsToRender} {ckptModelListItemsToRender}
</Box> </Box>
<Box>
<Text
sx={{
fontWeight: '500',
py: 2,
px: 4,
mb: 4,
borderRadius: 'base',
width: 'max-content',
fontSize: 'sm',
bg: 'base.750',
}}
>
{t('modelManager.diffusersModels')}
</Text>
{diffusersModelListItemsToRender}
</Box>
</> </>
)} )}
{isSelectedFilter === 'ckpt' && (
<Flex flexDirection="column" marginTop={4}>
{ckptModelListItemsToRender}
</Flex>
)}
{isSelectedFilter === 'diffusers' && ( {isSelectedFilter === 'diffusers' && (
<Flex flexDirection="column" marginTop={4}> <Flex flexDirection="column" marginTop={4}>
{diffusersModelListItemsToRender} {diffusersModelListItemsToRender}
</Flex> </Flex>
)} )}
{isSelectedFilter === 'ckpt' && (
<Flex flexDirection="column" marginTop={4}>
{ckptModelListItemsToRender}
</Flex>
)}
</Flex> </Flex>
); );
}, [models, searchText, t, isSelectedFilter]); }, [pipelineModels, searchText, t, isSelectedFilter]);
return ( return (
<Flex flexDirection="column" rowGap={4} width="50%" minWidth="50%"> <Flex flexDirection="column" rowGap={4} width="50%" minWidth="50%">
@ -211,7 +199,7 @@ const ModelList = () => {
<Flex <Flex
flexDirection="column" flexDirection="column"
gap={1} gap={4}
maxHeight={window.innerHeight - 240} maxHeight={window.innerHeight - 240}
overflow="scroll" overflow="scroll"
paddingInlineEnd={4} paddingInlineEnd={4}
@ -222,16 +210,16 @@ const ModelList = () => {
onClick={() => setIsSelectedFilter('all')} onClick={() => setIsSelectedFilter('all')}
isActive={isSelectedFilter === 'all'} isActive={isSelectedFilter === 'all'}
/> />
<ModelFilterButton
label={t('modelManager.checkpointModels')}
onClick={() => setIsSelectedFilter('ckpt')}
isActive={isSelectedFilter === 'ckpt'}
/>
<ModelFilterButton <ModelFilterButton
label={t('modelManager.diffusersModels')} label={t('modelManager.diffusersModels')}
onClick={() => setIsSelectedFilter('diffusers')} onClick={() => setIsSelectedFilter('diffusers')}
isActive={isSelectedFilter === 'diffusers'} isActive={isSelectedFilter === 'diffusers'}
/> />
<ModelFilterButton
label={t('modelManager.checkpointModels')}
onClick={() => setIsSelectedFilter('ckpt')}
isActive={isSelectedFilter === 'ckpt'}
/>
</Flex> </Flex>
{renderModelList ? ( {renderModelList ? (

View File

@ -1,6 +1,6 @@
import { DeleteIcon, EditIcon } from '@chakra-ui/icons'; import { DeleteIcon, EditIcon } from '@chakra-ui/icons';
import { Box, Button, Flex, Spacer, Text, Tooltip } from '@chakra-ui/react'; import { Box, Button, Flex, Spacer, Text, Tooltip } from '@chakra-ui/react';
import { ModelStatus } from 'app/types/invokeai';
// import { deleteModel, requestModelChange } from 'app/socketio/actions'; // import { deleteModel, requestModelChange } from 'app/socketio/actions';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
@ -10,9 +10,9 @@ import { setOpenModel } from 'features/system/store/systemSlice';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
type ModelListItemProps = { type ModelListItemProps = {
modelKey: string;
name: string; name: string;
status: ModelStatus; description: string | undefined;
description: string;
}; };
export default function ModelListItem(props: ModelListItemProps) { export default function ModelListItem(props: ModelListItemProps) {
@ -28,18 +28,18 @@ export default function ModelListItem(props: ModelListItemProps) {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { name, status, description } = props; const { modelKey, name, description } = props;
const handleChangeModel = () => { const handleChangeModel = () => {
dispatch(requestModelChange(name)); dispatch(requestModelChange(modelKey));
}; };
const openModelHandler = () => { const openModelHandler = () => {
dispatch(setOpenModel(name)); dispatch(setOpenModel(modelKey));
}; };
const handleModelDelete = () => { const handleModelDelete = () => {
dispatch(deleteModel(name)); dispatch(deleteModel(modelKey));
dispatch(setOpenModel(null)); dispatch(setOpenModel(null));
}; };
@ -60,7 +60,7 @@ export default function ModelListItem(props: ModelListItemProps) {
p={2} p={2}
borderRadius="base" borderRadius="base"
sx={ sx={
name === openModel modelKey === openModel
? { ? {
bg: 'accent.750', bg: 'accent.750',
_hover: { _hover: {
@ -81,7 +81,6 @@ export default function ModelListItem(props: ModelListItemProps) {
</Box> </Box>
<Spacer onClick={openModelHandler} cursor="pointer" /> <Spacer onClick={openModelHandler} cursor="pointer" />
<Flex gap={2} alignItems="center"> <Flex gap={2} alignItems="center">
<Text color={statusTextColor()}>{status}</Text>
<Button <Button
size="sm" size="sm"
onClick={handleChangeModel} onClick={handleChangeModel}

View File

@ -17,6 +17,7 @@ import { useTranslation } from 'react-i18next';
import type { ReactElement } from 'react'; import type { ReactElement } from 'react';
import { useListModelsQuery } from 'services/api/endpoints/models';
import CheckpointModelEdit from './CheckpointModelEdit'; import CheckpointModelEdit from './CheckpointModelEdit';
import DiffusersModelEdit from './DiffusersModelEdit'; import DiffusersModelEdit from './DiffusersModelEdit';
import ModelList from './ModelList'; import ModelList from './ModelList';
@ -34,9 +35,9 @@ export default function ModelManagerModal({
onClose: onModelManagerModalClose, onClose: onModelManagerModalClose,
} = useDisclosure(); } = useDisclosure();
const model_list = useAppSelector( const { data: pipelineModels } = useListModelsQuery({
(state: RootState) => state.system.model_list model_type: 'pipeline',
); });
const openModel = useAppSelector( const openModel = useAppSelector(
(state: RootState) => state.system.openModel (state: RootState) => state.system.openModel
@ -61,7 +62,10 @@ export default function ModelManagerModal({
<ModalBody> <ModalBody>
<Flex width="100%" columnGap={8}> <Flex width="100%" columnGap={8}>
<ModelList /> <ModelList />
{openModel && model_list[openModel]['format'] === 'diffusers' ? ( {openModel &&
pipelineModels &&
pipelineModels['entities'][openModel]['model_format'] ===
'diffusers' ? (
<DiffusersModelEdit /> <DiffusersModelEdit />
) : ( ) : (
<CheckpointModelEdit /> <CheckpointModelEdit />