feat(ui): model manager UI tweaks

- Move image display to left
- Move description into model header
- Move model edit & convert buttons to top right of model header
- Tweak styles for model display component
This commit is contained in:
psychedelicious 2024-03-07 13:53:39 +11:00
parent ad70cdfe87
commit ed4e8624dd
7 changed files with 109 additions and 134 deletions

View File

@ -7,7 +7,7 @@ import { Model } from './ModelPanel/Model';
export const ModelPane = () => { export const ModelPane = () => {
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey); const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
return ( return (
<Box layerStyle="first" p={2} borderRadius="base" w="50%" h="full"> <Box layerStyle="first" p={4} borderRadius="base" w="50%" h="full">
{selectedModelKey ? <Model key={selectedModelKey} /> : <InstallModels />} {selectedModelKey ? <Model key={selectedModelKey} /> : <InstallModels />}
</Box> </Box>
); );

View File

@ -1,4 +1,4 @@
import { Box, Button, IconButton, Image } from '@invoke-ai/ui-library'; import { Box, Button, Flex, Icon, IconButton, Image, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import { typedMemo } from 'common/util/typedMemo'; import { typedMemo } from 'common/util/typedMemo';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
@ -95,7 +95,7 @@ const ModelImageUpload = ({ model_key, model_image }: Props) => {
if (image) { if (image) {
return ( return (
<Box position="relative"> <Box position="relative" flexShrink={0}>
<Image <Image
src={image} src={image}
objectFit="cover" objectFit="cover"
@ -107,15 +107,14 @@ const ModelImageUpload = ({ model_key, model_image }: Props) => {
/> />
<IconButton <IconButton
position="absolute" position="absolute"
top="1" insetInlineEnd={0}
right="1" insetBlockStart={0}
onClick={handleResetImage} onClick={handleResetImage}
aria-label={t('modelManager.deleteModelImage')} aria-label={t('modelManager.deleteModelImage')}
tooltip={t('modelManager.deleteModelImage')} tooltip={t('modelManager.deleteModelImage')}
icon={<PiArrowCounterClockwiseBold size={16} />} icon={<PiArrowCounterClockwiseBold />}
size="sm" size="md"
variant="link" variant="ghost"
_hover={{ color: 'base.100' }}
/> />
</Box> </Box>
); );
@ -123,9 +122,21 @@ const ModelImageUpload = ({ model_key, model_image }: Props) => {
return ( return (
<> <>
<Button leftIcon={<PiUploadSimpleBold />} {...getRootProps()} pointerEvents="auto"> <Tooltip label={t('modelManager.uploadImage')}>
{t('modelManager.uploadImage')} <Flex
</Button> as={Button}
w={100}
h={100}
opacity={0.3}
borderRadius="base"
alignItems="center"
justifyContent="center"
flexShrink={0}
{...getRootProps()}
>
<Icon as={PiUploadSimpleBold} w={16} h={16} />
</Flex>
</Tooltip>
<input {...getInputProps()} /> <input {...getInputProps()} />
</> </>
); );

View File

@ -1,12 +1,12 @@
import { Box, Flex, Heading, Tab, TabList, TabPanel, TabPanels, Tabs, Text } from '@invoke-ai/ui-library'; import { Flex, Heading, Text } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query'; import { skipToken } from '@reduxjs/toolkit/query';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { ModelConvertButton } from 'features/modelManagerV2/subpanels/ModelPanel/ModelConvertButton';
import { ModelEditButton } from 'features/modelManagerV2/subpanels/ModelPanel/ModelEditButton';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useGetModelConfigQuery } from 'services/api/endpoints/models'; import { useGetModelConfigQuery } from 'services/api/endpoints/models';
import ModelImageUpload from './Fields/ModelImageUpload'; import ModelImageUpload from './Fields/ModelImageUpload';
import { ModelMetadata } from './Metadata/ModelMetadata';
import { ModelAttrView } from './ModelAttrView';
import { ModelEdit } from './ModelEdit'; import { ModelEdit } from './ModelEdit';
import { ModelView } from './ModelView'; import { ModelView } from './ModelView';
@ -25,38 +25,28 @@ export const Model = () => {
} }
return ( return (
<> <Flex flexDir="column" gap={4}>
<Flex alignItems="center" justifyContent="space-between" gap="4" paddingRight="5"> <Flex alignItems="flex-start" gap={4}>
<Flex flexDir="column" gap={1} p={2}> <ModelImageUpload model_key={selectedModelKey} model_image={data.cover_image} />
<Heading as="h2" fontSize="lg"> <Flex flexDir="column" gap={1} flexGrow={1}>
{data.name} <Flex gap={2} position="relative">
</Heading> <Heading as="h2" fontSize="lg" w="full">
{data.name}
</Heading>
<Flex position="absolute" gap={2} right={0} top={0}>
<ModelEditButton />
<ModelConvertButton modelKey={selectedModelKey} />
</Flex>
</Flex>
{data.source && ( {data.source && (
<Text variant="subtext"> <Text variant="subtext">
{t('modelManager.source')}: {data?.source} {t('modelManager.source')}: {data?.source}
</Text> </Text>
)} )}
<Box mt="4"> <Text noOfLines={3}>{data.description}</Text>
<ModelAttrView label="Description" value={data.description} />
</Box>
</Flex> </Flex>
<ModelImageUpload model_key={selectedModelKey} model_image={data.cover_image} />
</Flex> </Flex>
{selectedModelMode === 'view' ? <ModelView /> : <ModelEdit />}
<Tabs mt="4" h="100%"> </Flex>
<TabList>
<Tab>{t('modelManager.settings')}</Tab>
<Tab>{t('modelManager.metadata')}</Tab>
</TabList>
<TabPanels h="100%">
<TabPanel>{selectedModelMode === 'view' ? <ModelView /> : <ModelEdit />}</TabPanel>
<TabPanel h="full">
<ModelMetadata />
</TabPanel>
</TabPanels>
</Tabs>
</>
); );
}; };

View File

@ -8,42 +8,47 @@ import {
UnorderedList, UnorderedList,
useDisclosure, useDisclosure,
} from '@invoke-ai/ui-library'; } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast'; import { makeToast } from 'features/system/util/makeToast';
import { useCallback } from 'react'; import { useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useConvertModelMutation } from 'services/api/endpoints/models'; import { useConvertModelMutation, useGetModelConfigQuery } from 'services/api/endpoints/models';
import type { CheckpointModelConfig } from 'services/api/types';
interface ModelConvertProps { interface ModelConvertProps {
model: CheckpointModelConfig; modelKey: string | null;
} }
export const ModelConvert = (props: ModelConvertProps) => { export const ModelConvertButton = (props: ModelConvertProps) => {
const { model } = props; const { modelKey } = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const { data } = useGetModelConfigQuery(modelKey ?? skipToken);
const [convertModel, { isLoading }] = useConvertModelMutation(); const [convertModel, { isLoading }] = useConvertModelMutation();
const { isOpen, onOpen, onClose } = useDisclosure(); const { isOpen, onOpen, onClose } = useDisclosure();
const modelConvertHandler = useCallback(() => { const modelConvertHandler = useCallback(() => {
if (!data || isLoading) {
return;
}
dispatch( dispatch(
addToast( addToast(
makeToast({ makeToast({
title: `${t('modelManager.convertingModelBegin')}: ${model.name}`, title: `${t('modelManager.convertingModelBegin')}: ${data?.name}`,
status: 'info', status: 'info',
}) })
) )
); );
convertModel(model.key) convertModel(data?.key)
.unwrap() .unwrap()
.then(() => { .then(() => {
dispatch( dispatch(
addToast( addToast(
makeToast({ makeToast({
title: `${t('modelManager.modelConverted')}: ${model.name}`, title: `${t('modelManager.modelConverted')}: ${data?.name}`,
status: 'success', status: 'success',
}) })
) )
@ -53,13 +58,13 @@ export const ModelConvert = (props: ModelConvertProps) => {
dispatch( dispatch(
addToast( addToast(
makeToast({ makeToast({
title: `${t('modelManager.modelConversionFailed')}: ${model.name}`, title: `${t('modelManager.modelConversionFailed')}: ${data?.name}`,
status: 'error', status: 'error',
}) })
) )
); );
}); });
}, [convertModel, dispatch, model.key, model.name, t]); }, [data, isLoading, dispatch, t, convertModel]);
return ( return (
<> <>
@ -69,11 +74,12 @@ export const ModelConvert = (props: ModelConvertProps) => {
aria-label={t('modelManager.convertToDiffusers')} aria-label={t('modelManager.convertToDiffusers')}
className=" modal-close-btn" className=" modal-close-btn"
isLoading={isLoading} isLoading={isLoading}
flexShrink={0}
> >
🧨 {t('modelManager.convertToDiffusers')} 🧨 {t('modelManager.convertToDiffusers')}
</Button> </Button>
<ConfirmationAlertDialog <ConfirmationAlertDialog
title={`${t('modelManager.convert')} ${model.name}`} title={`${t('modelManager.convert')} ${data?.name}`}
acceptCallback={modelConvertHandler} acceptCallback={modelConvertHandler}
acceptButtonText={`${t('modelManager.convert')}`} acceptButtonText={`${t('modelManager.convert')}`}
isOpen={isOpen} isOpen={isOpen}

View File

@ -104,7 +104,7 @@ export const ModelEdit = () => {
<form onSubmit={handleSubmit(onSubmit)}> <form onSubmit={handleSubmit(onSubmit)}>
<Flex w="full" justifyContent="space-between" gap={4} alignItems="center"> <Flex w="full" justifyContent="space-between" gap={4} alignItems="center">
<FormControl flexDir="column" alignItems="flex-start" gap={1} isInvalid={Boolean(errors.name)}> <FormControl flexDir="column" alignItems="flex-start" gap={1} isInvalid={Boolean(errors.name)}>
<FormLabel hidden={true}>{t('modelManager.modelName')}</FormLabel> <FormLabel>{t('modelManager.modelName')}</FormLabel>
<Input <Input
{...register('name', { {...register('name', {
validate: (value) => (value && value.trim().length > 3) || 'Must be at least 3 characters', validate: (value) => (value && value.trim().length > 3) || 'Must be at least 3 characters',
@ -132,7 +132,7 @@ export const ModelEdit = () => {
<Flex gap="4" alignItems="center"> <Flex gap="4" alignItems="center">
<FormControl flexDir="column" alignItems="flex-start" gap={1}> <FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.description')}</FormLabel> <FormLabel>{t('modelManager.description')}</FormLabel>
<Textarea fontSize="md" resize="none" {...register('description')} /> <Textarea fontSize="md" {...register('description')} />
</FormControl> </FormControl>
</Flex> </Flex>
<Heading as="h3" fontSize="md" mt="4"> <Heading as="h3" fontSize="md" mt="4">

View File

@ -0,0 +1,21 @@
import { Button } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { setSelectedModelMode } from 'features/modelManagerV2/store/modelManagerV2Slice';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { IoPencil } from 'react-icons/io5';
export const ModelEditButton = () => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const handleEditModel = useCallback(() => {
dispatch(setSelectedModelMode('edit'));
}, [dispatch]);
return (
<Button size="sm" leftIcon={<IoPencil />} colorScheme="invokeYellow" onClick={handleEditModel} flexShrink={0}>
{t('modelManager.edit')}
</Button>
);
};

View File

@ -1,127 +1,74 @@
import { Box, Button, Flex, Text } from '@invoke-ai/ui-library'; import { Box, Flex, Text } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query'; import { skipToken } from '@reduxjs/toolkit/query';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { setSelectedModelMode } from 'features/modelManagerV2/store/modelManagerV2Slice'; import { TriggerPhrases } from 'features/modelManagerV2/subpanels/ModelPanel/Metadata/TriggerPhrases';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { IoPencil } from 'react-icons/io5';
import { useGetModelConfigQuery } from 'services/api/endpoints/models'; import { useGetModelConfigQuery } from 'services/api/endpoints/models';
import type {
CheckpointModelConfig,
ControlNetModelConfig,
DiffusersModelConfig,
IPAdapterModelConfig,
LoRAModelConfig,
T2IAdapterModelConfig,
TextualInversionModelConfig,
VAEModelConfig,
} from 'services/api/types';
import { DefaultSettings } from './DefaultSettings'; import { DefaultSettings } from './DefaultSettings';
import { ModelAttrView } from './ModelAttrView'; import { ModelAttrView } from './ModelAttrView';
import { ModelConvert } from './ModelConvert';
export const ModelView = () => { export const ModelView = () => {
const { t } = useTranslation(); const { t } = useTranslation();
const dispatch = useAppDispatch();
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey); const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken); const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
const modelData = useMemo(() => {
if (!data) {
return null;
}
const modelFormat = data.format;
const modelType = data.type;
if (modelType === 'main') {
if (modelFormat === 'diffusers') {
return data as DiffusersModelConfig;
} else if (modelFormat === 'checkpoint') {
return data as CheckpointModelConfig;
}
}
switch (modelType) {
case 'lora':
return data as LoRAModelConfig;
case 'embedding':
return data as TextualInversionModelConfig;
case 't2i_adapter':
return data as T2IAdapterModelConfig;
case 'ip_adapter':
return data as IPAdapterModelConfig;
case 'controlnet':
return data as ControlNetModelConfig;
case 'vae':
return data as VAEModelConfig;
default:
return null;
}
}, [data]);
const handleEditModel = useCallback(() => {
dispatch(setSelectedModelMode('edit'));
}, [dispatch]);
if (isLoading) { if (isLoading) {
return <Text>{t('common.loading')}</Text>; return <Text>{t('common.loading')}</Text>;
} }
if (!modelData) { if (!data) {
return <Text>{t('common.somethingWentWrong')}</Text>; return <Text>{t('common.somethingWentWrong')}</Text>;
} }
return ( return (
<Flex flexDir="column" h="full" gap="2"> <Flex flexDir="column" h="full" gap={4}>
<Box layerStyle="second" borderRadius="base" p={3}> <Box layerStyle="second" borderRadius="base" p={4}>
<Flex gap="2" justifyContent="flex-end" w="full"> <Flex flexDir="column" gap={4}>
<Button size="sm" leftIcon={<IoPencil />} colorScheme="invokeYellow" onClick={handleEditModel}>
{t('modelManager.edit')}
</Button>
{modelData.type === 'main' && modelData.format === 'checkpoint' && <ModelConvert model={modelData} />}
</Flex>
<Flex flexDir="column" gap={3}>
<Flex gap={2}> <Flex gap={2}>
<ModelAttrView label={t('modelManager.baseModel')} value={modelData.base} /> <ModelAttrView label={t('modelManager.baseModel')} value={data.base} />
<ModelAttrView label={t('modelManager.modelType')} value={modelData.type} /> <ModelAttrView label={t('modelManager.modelType')} value={data.type} />
</Flex> </Flex>
<Flex gap={2}> <Flex gap={2}>
<ModelAttrView label={t('common.format')} value={modelData.format} /> <ModelAttrView label={t('common.format')} value={data.format} />
<ModelAttrView label={t('modelManager.path')} value={modelData.path} /> <ModelAttrView label={t('modelManager.path')} value={data.path} />
</Flex> </Flex>
{modelData.type === 'main' && modelData.format === 'diffusers' && modelData.repo_variant && ( {data.type === 'main' && data.format === 'diffusers' && data.repo_variant && (
<Flex gap={2}> <Flex gap={2}>
<ModelAttrView label={t('modelManager.repoVariant')} value={modelData.repo_variant} /> <ModelAttrView label={t('modelManager.repoVariant')} value={data.repo_variant} />
</Flex> </Flex>
)} )}
{modelData.type === 'main' && modelData.format === 'checkpoint' && ( {data.type === 'main' && data.format === 'checkpoint' && (
<> <>
<Flex gap={2}> <Flex gap={2}>
<ModelAttrView label={t('modelManager.pathToConfig')} value={modelData.config_path} /> <ModelAttrView label={t('modelManager.pathToConfig')} value={data.config_path} />
<ModelAttrView label={t('modelManager.variant')} value={modelData.variant} /> <ModelAttrView label={t('modelManager.variant')} value={data.variant} />
</Flex> </Flex>
<Flex gap={2}> <Flex gap={2}>
<ModelAttrView label={t('modelManager.predictionType')} value={modelData.prediction_type} /> <ModelAttrView label={t('modelManager.predictionType')} value={data.prediction_type} />
<ModelAttrView label={t('modelManager.upcastAttention')} value={`${modelData.upcast_attention}`} /> <ModelAttrView label={t('modelManager.upcastAttention')} value={`${data.upcast_attention}`} />
</Flex> </Flex>
</> </>
)} )}
{modelData.type === 'ip_adapter' && ( {data.type === 'ip_adapter' && (
<Flex gap={2}> <Flex gap={2}>
<ModelAttrView label={t('modelManager.imageEncoderModelId')} value={modelData.image_encoder_model_id} /> <ModelAttrView label={t('modelManager.imageEncoderModelId')} value={data.image_encoder_model_id} />
</Flex> </Flex>
)} )}
</Flex> </Flex>
</Box> </Box>
{modelData.type === 'main' && ( {data.type === 'main' && (
<Box layerStyle="second" borderRadius="base" p={3}> <Box layerStyle="second" borderRadius="base" p={4}>
<DefaultSettings /> <DefaultSettings />
</Box> </Box>
)} )}
{data.type === 'lora' && (
<Box layerStyle="second" borderRadius="base" p={4}>
<TriggerPhrases />
</Box>
)}
</Flex> </Flex>
); );
}; };