mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix: Restore Model display and select functionality
This commit is contained in:
parent
b4b760d9e9
commit
e73f774920
@ -1,36 +1,16 @@
|
|||||||
import { Box, Flex, Heading, Spacer, Spinner, Text } from '@chakra-ui/react';
|
import { Box, Flex, Heading, Spacer, Spinner, Text } from '@chakra-ui/react';
|
||||||
import IAIInput from 'common/components/IAIInput';
|
|
||||||
import IAIButton from 'common/components/IAIButton';
|
import IAIButton from 'common/components/IAIButton';
|
||||||
|
import IAIInput from 'common/components/IAIInput';
|
||||||
|
|
||||||
import AddModel from './AddModel';
|
import AddModel from './AddModel';
|
||||||
import ModelListItem from './ModelListItem';
|
|
||||||
import MergeModels from './MergeModels';
|
import MergeModels from './MergeModels';
|
||||||
|
import ModelListItem from './ModelListItem';
|
||||||
|
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
|
||||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
|
||||||
import type { SystemState } from 'features/system/store/systemSlice';
|
|
||||||
import { isEqual, map } from 'lodash-es';
|
|
||||||
|
|
||||||
import React, { useMemo, useState, useTransition } from 'react';
|
|
||||||
import type { ChangeEvent, ReactNode } from 'react';
|
import type { ChangeEvent, ReactNode } from 'react';
|
||||||
|
import React, { useMemo, useState, useTransition } from 'react';
|
||||||
const modelListSelector = createSelector(
|
import { useListModelsQuery } from 'services/api/endpoints/models';
|
||||||
systemSelector,
|
|
||||||
(system: SystemState) => {
|
|
||||||
const models = map(system.model_list, (model, key) => {
|
|
||||||
return { name: key, ...model };
|
|
||||||
});
|
|
||||||
return models;
|
|
||||||
},
|
|
||||||
{
|
|
||||||
memoizeOptions: {
|
|
||||||
resultEqualityCheck: isEqual,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
function ModelFilterButton({
|
function ModelFilterButton({
|
||||||
label,
|
label,
|
||||||
@ -58,7 +38,9 @@ function ModelFilterButton({
|
|||||||
}
|
}
|
||||||
|
|
||||||
const ModelList = () => {
|
const ModelList = () => {
|
||||||
const models = useAppSelector(modelListSelector);
|
const { data: pipelineModels } = useListModelsQuery({
|
||||||
|
model_type: 'pipeline',
|
||||||
|
});
|
||||||
|
|
||||||
const [renderModelList, setRenderModelList] = React.useState<boolean>(false);
|
const [renderModelList, setRenderModelList] = React.useState<boolean>(false);
|
||||||
|
|
||||||
@ -90,43 +72,49 @@ const ModelList = () => {
|
|||||||
const filteredModelListItemsToRender: ReactNode[] = [];
|
const filteredModelListItemsToRender: ReactNode[] = [];
|
||||||
const localFilteredModelListItemsToRender: ReactNode[] = [];
|
const localFilteredModelListItemsToRender: ReactNode[] = [];
|
||||||
|
|
||||||
models.forEach((model, i) => {
|
if (!pipelineModels) return;
|
||||||
if (model.name.toLowerCase().includes(searchText.toLowerCase())) {
|
|
||||||
|
const modelList = pipelineModels.entities;
|
||||||
|
|
||||||
|
Object.keys(modelList).forEach((model, i) => {
|
||||||
|
if (
|
||||||
|
modelList[model].name.toLowerCase().includes(searchText.toLowerCase())
|
||||||
|
) {
|
||||||
filteredModelListItemsToRender.push(
|
filteredModelListItemsToRender.push(
|
||||||
<ModelListItem
|
<ModelListItem
|
||||||
key={i}
|
key={i}
|
||||||
name={model.name}
|
modelKey={model}
|
||||||
status={model.status}
|
name={modelList[model].name}
|
||||||
description={model.description}
|
description={modelList[model].description}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
if (model.format === isSelectedFilter) {
|
if (modelList[model]?.model_format === isSelectedFilter) {
|
||||||
localFilteredModelListItemsToRender.push(
|
localFilteredModelListItemsToRender.push(
|
||||||
<ModelListItem
|
<ModelListItem
|
||||||
key={i}
|
key={i}
|
||||||
name={model.name}
|
modelKey={model}
|
||||||
status={model.status}
|
name={modelList[model].name}
|
||||||
description={model.description}
|
description={modelList[model].description}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (model.format !== 'diffusers') {
|
if (modelList[model]?.model_format !== 'diffusers') {
|
||||||
ckptModelListItemsToRender.push(
|
ckptModelListItemsToRender.push(
|
||||||
<ModelListItem
|
<ModelListItem
|
||||||
key={i}
|
key={i}
|
||||||
name={model.name}
|
modelKey={model}
|
||||||
status={model.status}
|
name={modelList[model].name}
|
||||||
description={model.description}
|
description={modelList[model].description}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
} else {
|
} else {
|
||||||
diffusersModelListItemsToRender.push(
|
diffusersModelListItemsToRender.push(
|
||||||
<ModelListItem
|
<ModelListItem
|
||||||
key={i}
|
key={i}
|
||||||
name={model.name}
|
modelKey={model}
|
||||||
status={model.status}
|
name={modelList[model].name}
|
||||||
description={model.description}
|
description={modelList[model].description}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@ -142,6 +130,23 @@ const ModelList = () => {
|
|||||||
<Flex flexDirection="column" rowGap={6}>
|
<Flex flexDirection="column" rowGap={6}>
|
||||||
{isSelectedFilter === 'all' && (
|
{isSelectedFilter === 'all' && (
|
||||||
<>
|
<>
|
||||||
|
<Box>
|
||||||
|
<Text
|
||||||
|
sx={{
|
||||||
|
fontWeight: '500',
|
||||||
|
py: 2,
|
||||||
|
px: 4,
|
||||||
|
mb: 4,
|
||||||
|
borderRadius: 'base',
|
||||||
|
width: 'max-content',
|
||||||
|
fontSize: 'sm',
|
||||||
|
bg: 'base.750',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{t('modelManager.diffusersModels')}
|
||||||
|
</Text>
|
||||||
|
{diffusersModelListItemsToRender}
|
||||||
|
</Box>
|
||||||
<Box>
|
<Box>
|
||||||
<Text
|
<Text
|
||||||
sx={{
|
sx={{
|
||||||
@ -160,40 +165,23 @@ const ModelList = () => {
|
|||||||
</Text>
|
</Text>
|
||||||
{ckptModelListItemsToRender}
|
{ckptModelListItemsToRender}
|
||||||
</Box>
|
</Box>
|
||||||
<Box>
|
|
||||||
<Text
|
|
||||||
sx={{
|
|
||||||
fontWeight: '500',
|
|
||||||
py: 2,
|
|
||||||
px: 4,
|
|
||||||
mb: 4,
|
|
||||||
borderRadius: 'base',
|
|
||||||
width: 'max-content',
|
|
||||||
fontSize: 'sm',
|
|
||||||
bg: 'base.750',
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
{t('modelManager.diffusersModels')}
|
|
||||||
</Text>
|
|
||||||
{diffusersModelListItemsToRender}
|
|
||||||
</Box>
|
|
||||||
</>
|
</>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
{isSelectedFilter === 'ckpt' && (
|
|
||||||
<Flex flexDirection="column" marginTop={4}>
|
|
||||||
{ckptModelListItemsToRender}
|
|
||||||
</Flex>
|
|
||||||
)}
|
|
||||||
|
|
||||||
{isSelectedFilter === 'diffusers' && (
|
{isSelectedFilter === 'diffusers' && (
|
||||||
<Flex flexDirection="column" marginTop={4}>
|
<Flex flexDirection="column" marginTop={4}>
|
||||||
{diffusersModelListItemsToRender}
|
{diffusersModelListItemsToRender}
|
||||||
</Flex>
|
</Flex>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
|
{isSelectedFilter === 'ckpt' && (
|
||||||
|
<Flex flexDirection="column" marginTop={4}>
|
||||||
|
{ckptModelListItemsToRender}
|
||||||
|
</Flex>
|
||||||
|
)}
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
}, [models, searchText, t, isSelectedFilter]);
|
}, [pipelineModels, searchText, t, isSelectedFilter]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex flexDirection="column" rowGap={4} width="50%" minWidth="50%">
|
<Flex flexDirection="column" rowGap={4} width="50%" minWidth="50%">
|
||||||
@ -211,7 +199,7 @@ const ModelList = () => {
|
|||||||
|
|
||||||
<Flex
|
<Flex
|
||||||
flexDirection="column"
|
flexDirection="column"
|
||||||
gap={1}
|
gap={4}
|
||||||
maxHeight={window.innerHeight - 240}
|
maxHeight={window.innerHeight - 240}
|
||||||
overflow="scroll"
|
overflow="scroll"
|
||||||
paddingInlineEnd={4}
|
paddingInlineEnd={4}
|
||||||
@ -222,16 +210,16 @@ const ModelList = () => {
|
|||||||
onClick={() => setIsSelectedFilter('all')}
|
onClick={() => setIsSelectedFilter('all')}
|
||||||
isActive={isSelectedFilter === 'all'}
|
isActive={isSelectedFilter === 'all'}
|
||||||
/>
|
/>
|
||||||
<ModelFilterButton
|
|
||||||
label={t('modelManager.checkpointModels')}
|
|
||||||
onClick={() => setIsSelectedFilter('ckpt')}
|
|
||||||
isActive={isSelectedFilter === 'ckpt'}
|
|
||||||
/>
|
|
||||||
<ModelFilterButton
|
<ModelFilterButton
|
||||||
label={t('modelManager.diffusersModels')}
|
label={t('modelManager.diffusersModels')}
|
||||||
onClick={() => setIsSelectedFilter('diffusers')}
|
onClick={() => setIsSelectedFilter('diffusers')}
|
||||||
isActive={isSelectedFilter === 'diffusers'}
|
isActive={isSelectedFilter === 'diffusers'}
|
||||||
/>
|
/>
|
||||||
|
<ModelFilterButton
|
||||||
|
label={t('modelManager.checkpointModels')}
|
||||||
|
onClick={() => setIsSelectedFilter('ckpt')}
|
||||||
|
isActive={isSelectedFilter === 'ckpt'}
|
||||||
|
/>
|
||||||
</Flex>
|
</Flex>
|
||||||
|
|
||||||
{renderModelList ? (
|
{renderModelList ? (
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import { DeleteIcon, EditIcon } from '@chakra-ui/icons';
|
import { DeleteIcon, EditIcon } from '@chakra-ui/icons';
|
||||||
import { Box, Button, Flex, Spacer, Text, Tooltip } from '@chakra-ui/react';
|
import { Box, Button, Flex, Spacer, Text, Tooltip } from '@chakra-ui/react';
|
||||||
import { ModelStatus } from 'app/types/invokeai';
|
|
||||||
// import { deleteModel, requestModelChange } from 'app/socketio/actions';
|
// import { deleteModel, requestModelChange } from 'app/socketio/actions';
|
||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
@ -10,9 +10,9 @@ import { setOpenModel } from 'features/system/store/systemSlice';
|
|||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
type ModelListItemProps = {
|
type ModelListItemProps = {
|
||||||
|
modelKey: string;
|
||||||
name: string;
|
name: string;
|
||||||
status: ModelStatus;
|
description: string | undefined;
|
||||||
description: string;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
export default function ModelListItem(props: ModelListItemProps) {
|
export default function ModelListItem(props: ModelListItemProps) {
|
||||||
@ -28,18 +28,18 @@ export default function ModelListItem(props: ModelListItemProps) {
|
|||||||
|
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
const { name, status, description } = props;
|
const { modelKey, name, description } = props;
|
||||||
|
|
||||||
const handleChangeModel = () => {
|
const handleChangeModel = () => {
|
||||||
dispatch(requestModelChange(name));
|
dispatch(requestModelChange(modelKey));
|
||||||
};
|
};
|
||||||
|
|
||||||
const openModelHandler = () => {
|
const openModelHandler = () => {
|
||||||
dispatch(setOpenModel(name));
|
dispatch(setOpenModel(modelKey));
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleModelDelete = () => {
|
const handleModelDelete = () => {
|
||||||
dispatch(deleteModel(name));
|
dispatch(deleteModel(modelKey));
|
||||||
dispatch(setOpenModel(null));
|
dispatch(setOpenModel(null));
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -60,7 +60,7 @@ export default function ModelListItem(props: ModelListItemProps) {
|
|||||||
p={2}
|
p={2}
|
||||||
borderRadius="base"
|
borderRadius="base"
|
||||||
sx={
|
sx={
|
||||||
name === openModel
|
modelKey === openModel
|
||||||
? {
|
? {
|
||||||
bg: 'accent.750',
|
bg: 'accent.750',
|
||||||
_hover: {
|
_hover: {
|
||||||
@ -81,7 +81,6 @@ export default function ModelListItem(props: ModelListItemProps) {
|
|||||||
</Box>
|
</Box>
|
||||||
<Spacer onClick={openModelHandler} cursor="pointer" />
|
<Spacer onClick={openModelHandler} cursor="pointer" />
|
||||||
<Flex gap={2} alignItems="center">
|
<Flex gap={2} alignItems="center">
|
||||||
<Text color={statusTextColor()}>{status}</Text>
|
|
||||||
<Button
|
<Button
|
||||||
size="sm"
|
size="sm"
|
||||||
onClick={handleChangeModel}
|
onClick={handleChangeModel}
|
||||||
|
@ -17,6 +17,7 @@ import { useTranslation } from 'react-i18next';
|
|||||||
|
|
||||||
import type { ReactElement } from 'react';
|
import type { ReactElement } from 'react';
|
||||||
|
|
||||||
|
import { useListModelsQuery } from 'services/api/endpoints/models';
|
||||||
import CheckpointModelEdit from './CheckpointModelEdit';
|
import CheckpointModelEdit from './CheckpointModelEdit';
|
||||||
import DiffusersModelEdit from './DiffusersModelEdit';
|
import DiffusersModelEdit from './DiffusersModelEdit';
|
||||||
import ModelList from './ModelList';
|
import ModelList from './ModelList';
|
||||||
@ -34,9 +35,9 @@ export default function ModelManagerModal({
|
|||||||
onClose: onModelManagerModalClose,
|
onClose: onModelManagerModalClose,
|
||||||
} = useDisclosure();
|
} = useDisclosure();
|
||||||
|
|
||||||
const model_list = useAppSelector(
|
const { data: pipelineModels } = useListModelsQuery({
|
||||||
(state: RootState) => state.system.model_list
|
model_type: 'pipeline',
|
||||||
);
|
});
|
||||||
|
|
||||||
const openModel = useAppSelector(
|
const openModel = useAppSelector(
|
||||||
(state: RootState) => state.system.openModel
|
(state: RootState) => state.system.openModel
|
||||||
@ -61,7 +62,10 @@ export default function ModelManagerModal({
|
|||||||
<ModalBody>
|
<ModalBody>
|
||||||
<Flex width="100%" columnGap={8}>
|
<Flex width="100%" columnGap={8}>
|
||||||
<ModelList />
|
<ModelList />
|
||||||
{openModel && model_list[openModel]['format'] === 'diffusers' ? (
|
{openModel &&
|
||||||
|
pipelineModels &&
|
||||||
|
pipelineModels['entities'][openModel]['model_format'] ===
|
||||||
|
'diffusers' ? (
|
||||||
<DiffusersModelEdit />
|
<DiffusersModelEdit />
|
||||||
) : (
|
) : (
|
||||||
<CheckpointModelEdit />
|
<CheckpointModelEdit />
|
||||||
|
Loading…
Reference in New Issue
Block a user