UI in MM to create trigger phrases

This commit is contained in:
Mary Hipp 2024-02-27 15:33:11 -05:00
parent 4418c118db
commit efb5f2d202
8 changed files with 1791 additions and 1045 deletions

View File

@ -78,6 +78,7 @@
"aboutDesc": "Using Invoke for work? Check out:", "aboutDesc": "Using Invoke for work? Check out:",
"aboutHeading": "Own Your Creative Power", "aboutHeading": "Own Your Creative Power",
"accept": "Accept", "accept": "Accept",
"add": "Add",
"advanced": "Advanced", "advanced": "Advanced",
"advancedOptions": "Advanced Options", "advancedOptions": "Advanced Options",
"ai": "ai", "ai": "ai",
@ -768,6 +769,7 @@
"mergedModelName": "Merged Model Name", "mergedModelName": "Merged Model Name",
"mergedModelSaveLocation": "Save Location", "mergedModelSaveLocation": "Save Location",
"mergeModels": "Merge Models", "mergeModels": "Merge Models",
"metadata": "Metadata",
"model": "Model", "model": "Model",
"modelAdded": "Model Added", "modelAdded": "Model Added",
"modelConversionFailed": "Model Conversion Failed", "modelConversionFailed": "Model Conversion Failed",
@ -839,6 +841,8 @@
"statusConverting": "Converting", "statusConverting": "Converting",
"syncModels": "Sync Models", "syncModels": "Sync Models",
"syncModelsDesc": "If your models are out of sync with the backend, you can refresh them up using this option. This is generally handy in cases where you add models to the InvokeAI root folder or autoimport directory after the application has booted.", "syncModelsDesc": "If your models are out of sync with the backend, you can refresh them up using this option. This is generally handy in cases where you add models to the InvokeAI root folder or autoimport directory after the application has booted.",
"triggerPhrases": "Trigger Phrases",
"typePhraseHere": "Type phrase here",
"upcastAttention": "Upcast Attention", "upcastAttention": "Upcast Attention",
"updateModel": "Update Model", "updateModel": "Update Model",
"useCustomConfig": "Use Custom Config", "useCustomConfig": "Use Custom Config",

View File

@ -8,7 +8,7 @@ export const ModelPane = () => {
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey); const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
return ( return (
<Box layerStyle="first" p={2} borderRadius="base" w="50%" h="full"> <Box layerStyle="first" p={2} borderRadius="base" w="50%" h="full">
{selectedModelKey ? <Model /> : <ImportModels />} {selectedModelKey ? <Model key={selectedModelKey} /> : <ImportModels />}
</Box> </Box>
); );
}; };

View File

@ -0,0 +1,22 @@
import { Flex, Box } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppSelector } from '../../../../../app/store/storeHooks';
import { useGetModelMetadataQuery } from '../../../../../services/api/endpoints/models';
import DataViewer from '../../../../gallery/components/ImageMetadataViewer/DataViewer';
import { TriggerPhrases } from './TriggerPhrases';
export const ModelMetadata = () => {
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
const { data: metadata } = useGetModelMetadataQuery(selectedModelKey ?? skipToken);
return (
<>
<Flex flexDir="column" height="full" gap="3">
<Box layerStyle="second" borderRadius="base" p={3}>
<TriggerPhrases />
</Box>
<DataViewer label="metadata" data={metadata || {}} />
</Flex>
</>
);
};

View File

@ -0,0 +1,106 @@
import {
Button,
Flex,
FormControl,
FormErrorMessage,
Input,
Tag,
TagCloseButton,
TagLabel,
} from '@invoke-ai/ui-library';
import { useState, useMemo, useCallback } from 'react';
import type { ChangeEvent } from 'react';
import { ModelListHeader } from '../../ModelManagerPanel/ModelListHeader';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppSelector } from '../../../../../app/store/storeHooks';
import { useGetModelMetadataQuery, useUpdateModelMetadataMutation } from '../../../../../services/api/endpoints/models';
import { useTranslation } from 'react-i18next';
export const TriggerPhrases = () => {
const { t } = useTranslation();
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
const { data: metadata } = useGetModelMetadataQuery(selectedModelKey ?? skipToken);
const [phrase, setPhrase] = useState('');
const [editModelMetadata, { isLoading }] = useUpdateModelMetadataMutation();
const handlePhraseChange = useCallback((e: ChangeEvent<HTMLInputElement>) => {
setPhrase(e.target.value);
}, []);
const triggerPhrases = useMemo(() => {
return metadata?.trigger_phrases || [];
}, [metadata?.trigger_phrases]);
const errors = useMemo(() => {
const errors = [];
if (phrase.length && triggerPhrases.includes(phrase)) {
errors.push('Phrase is already in list');
}
return errors;
}, [phrase, triggerPhrases]);
const addTriggerPhrase = useCallback(async () => {
if (!selectedModelKey) {
return;
}
if (!phrase.length || triggerPhrases.includes(phrase)) {
return;
}
await editModelMetadata({
key: selectedModelKey,
body: { trigger_phrases: [...triggerPhrases, phrase] },
}).unwrap();
setPhrase('');
}, [editModelMetadata, selectedModelKey, phrase, triggerPhrases]);
const removeTriggerPhrase = useCallback(
async (phraseToRemove: string) => {
if (!selectedModelKey) {
return;
}
const filteredPhrases = triggerPhrases.filter((p) => p !== phraseToRemove);
await editModelMetadata({ key: selectedModelKey, body: { trigger_phrases: filteredPhrases } }).unwrap();
},
[editModelMetadata, selectedModelKey, triggerPhrases]
);
return (
<Flex flexDir="column" w="full" gap="5">
<ModelListHeader title={t('modelManager.triggerPhrases')} />
<form>
<FormControl w="full" isInvalid={Boolean(errors.length)}>
<Flex flexDir="column" w="full">
<Flex gap="3" alignItems="center" w="full">
<Input value={phrase} onChange={handlePhraseChange} placeholder={t('modelManager.typePhraseHere')} />
<Button
type="submit"
onClick={addTriggerPhrase}
isDisabled={Boolean(errors.length)}
isLoading={isLoading}
>
{t('common.add')}
</Button>
</Flex>
{!!errors.length && errors.map((error) => <FormErrorMessage key={error}>{error}</FormErrorMessage>)}
</Flex>
</FormControl>
</form>
<Flex gap="4" flexWrap="wrap" mt="3" mb="3">
{triggerPhrases.map((phrase, index) => (
<Tag size="md" key={index}>
<TagLabel>{phrase}</TagLabel>
<TagCloseButton onClick={removeTriggerPhrase.bind(null, phrase)} isDisabled={isLoading} />
</Tag>
))}
</Flex>
</Flex>
);
};

View File

@ -2,8 +2,57 @@ import { useAppSelector } from 'app/store/storeHooks';
import { ModelEdit } from './ModelEdit'; import { ModelEdit } from './ModelEdit';
import { ModelView } from './ModelView'; import { ModelView } from './ModelView';
import { Tabs, TabList, Tab, TabPanels, TabPanel, Flex, Heading, Text, Box } from '@invoke-ai/ui-library';
import { ModelMetadata } from './Metadata/ModelMetadata';
import { skipToken } from '@reduxjs/toolkit/query';
import { useGetModelConfigQuery } from '../../../../services/api/endpoints/models';
import { ModelAttrView } from './ModelAttrView';
import { useTranslation } from 'react-i18next';
export const Model = () => { export const Model = () => {
const { t } = useTranslation();
const selectedModelMode = useAppSelector((s) => s.modelmanagerV2.selectedModelMode); const selectedModelMode = useAppSelector((s) => s.modelmanagerV2.selectedModelMode);
return selectedModelMode === 'view' ? <ModelView /> : <ModelEdit />; const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
if (isLoading) {
return <Text>{t('common.loading')}</Text>;
}
if (!data) {
return <Text>{t('common.somethingWentWrong')}</Text>;
}
return (
<>
<Flex flexDir="column" gap={1} p={2}>
<Heading as="h2" fontSize="lg">
{data.name}
</Heading>
{data.source && (
<Text variant="subtext">
{t('modelManager.source')}: {data?.source}
</Text>
)}
<Box mt="4">
<ModelAttrView label="Description" value={data.description} />
</Box>
</Flex>
<Tabs mt="4" h="100%">
<TabList>
<Tab>{t('modelManager.settings')}</Tab>
<Tab>{t('modelManager.metadata')}</Tab>
</TabList>
<TabPanels h="100%">
<TabPanel>{selectedModelMode === 'view' ? <ModelView /> : <ModelEdit />}</TabPanel>
<TabPanel h="full">
<ModelMetadata />
</TabPanel>
</TabPanels>
</Tabs>
</>
);
}; };

View File

@ -1,12 +1,11 @@
import { Box, Button, Flex, Heading, Text } from '@invoke-ai/ui-library'; import { Box, Button, Flex, Text } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query'; import { skipToken } from '@reduxjs/toolkit/query';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
import { setSelectedModelMode } from 'features/modelManagerV2/store/modelManagerV2Slice'; import { setSelectedModelMode } from 'features/modelManagerV2/store/modelManagerV2Slice';
import { useCallback, useMemo } from 'react'; import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { IoPencil } from 'react-icons/io5'; import { IoPencil } from 'react-icons/io5';
import { useGetModelConfigQuery, useGetModelMetadataQuery } from 'services/api/endpoints/models'; import { useGetModelConfigQuery } from 'services/api/endpoints/models';
import type { import type {
CheckpointModelConfig, CheckpointModelConfig,
ControlNetModelConfig, ControlNetModelConfig,
@ -26,7 +25,6 @@ export const ModelView = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey); const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken); const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
const { data: metadata } = useGetModelMetadataQuery(selectedModelKey ?? skipToken);
const modelData = useMemo(() => { const modelData = useMemo(() => {
if (!data) { if (!data) {
@ -74,84 +72,52 @@ export const ModelView = () => {
} }
return ( return (
<Flex flexDir="column" h="full"> <Flex flexDir="column" h="full">
<Flex w="full" justifyContent="space-between"> <Box layerStyle="second" borderRadius="base" p={3}>
<Flex flexDir="column" gap={1} p={2}> <Flex gap="2" justifyContent="flex-end" w="full">
<Heading as="h2" fontSize="lg">
{modelData.name}
</Heading>
{modelData.source && (
<Text variant="subtext">
{t('modelManager.source')}: {modelData.source}
</Text>
)}
</Flex>
<Flex gap={2}>
<Button size="sm" leftIcon={<IoPencil />} colorScheme="invokeYellow" onClick={handleEditModel}> <Button size="sm" leftIcon={<IoPencil />} colorScheme="invokeYellow" onClick={handleEditModel}>
{t('modelManager.edit')} {t('modelManager.edit')}
</Button> </Button>
{modelData.type === 'main' && modelData.format === 'checkpoint' && <ModelConvert model={modelData} />} {modelData.type === 'main' && modelData.format === 'checkpoint' && <ModelConvert model={modelData} />}
</Flex> </Flex>
</Flex> <Flex flexDir="column" gap={3}>
<Flex gap={2}>
<Flex flexDir="column" p={2} gap={3}> <ModelAttrView label={t('modelManager.baseModel')} value={modelData.base} />
<Flex> <ModelAttrView label={t('modelManager.modelType')} value={modelData.type} />
<ModelAttrView label="Description" value={modelData.description} /> </Flex>
</Flex> <Flex gap={2}>
<Heading as="h3" fontSize="md" mt="4"> <ModelAttrView label={t('common.format')} value={modelData.format} />
{t('modelManager.modelSettings')} <ModelAttrView label={t('modelManager.path')} value={modelData.path} />
</Heading> </Flex>
<Box layerStyle="second" borderRadius="base" p={3}> {modelData.type === 'main' && (
<Flex flexDir="column" gap={3}> <>
<Flex gap={2}>
<ModelAttrView label={t('modelManager.baseModel')} value={modelData.base} />
<ModelAttrView label={t('modelManager.modelType')} value={modelData.type} />
</Flex>
<Flex gap={2}>
<ModelAttrView label={t('common.format')} value={modelData.format} />
<ModelAttrView label={t('modelManager.path')} value={modelData.path} />
</Flex>
{modelData.type === 'main' && (
<>
<Flex gap={2}>
{modelData.format === 'diffusers' && (
<ModelAttrView label={t('modelManager.repoVariant')} value={modelData.repo_variant} />
)}
{modelData.format === 'checkpoint' && (
<ModelAttrView label={t('modelManager.pathToConfig')} value={modelData.config} />
)}
<ModelAttrView label={t('modelManager.variant')} value={modelData.variant} />
</Flex>
<Flex gap={2}>
<ModelAttrView label={t('modelManager.predictionType')} value={modelData.prediction_type} />
<ModelAttrView label={t('modelManager.upcastAttention')} value={`${modelData.upcast_attention}`} />
</Flex>
<Flex gap={2}>
<ModelAttrView label={t('modelManager.ztsnrTraining')} value={`${modelData.ztsnr_training}`} />
<ModelAttrView label={t('modelManager.vae')} value={modelData.vae} />
</Flex>
</>
)}
{modelData.type === 'ip_adapter' && (
<Flex gap={2}> <Flex gap={2}>
<ModelAttrView label={t('modelManager.imageEncoderModelId')} value={modelData.image_encoder_model_id} /> {modelData.format === 'diffusers' && (
</Flex> <ModelAttrView label={t('modelManager.repoVariant')} value={modelData.repo_variant} />
)} )}
</Flex> {modelData.format === 'checkpoint' && (
</Box> <ModelAttrView label={t('modelManager.pathToConfig')} value={modelData.config} />
</Flex> )}
{metadata && ( <ModelAttrView label={t('modelManager.variant')} value={modelData.variant} />
<> </Flex>
<Heading as="h3" fontSize="md" mt="4"> <Flex gap={2}>
{t('modelManager.modelMetadata')} <ModelAttrView label={t('modelManager.predictionType')} value={modelData.prediction_type} />
</Heading> <ModelAttrView label={t('modelManager.upcastAttention')} value={`${modelData.upcast_attention}`} />
<Flex h="full" w="full" p={2}> </Flex>
<DataViewer label="metadata" data={metadata} /> <Flex gap={2}>
</Flex> <ModelAttrView label={t('modelManager.ztsnrTraining')} value={`${modelData.ztsnr_training}`} />
</> <ModelAttrView label={t('modelManager.vae')} value={modelData.vae} />
)} </Flex>
</>
)}
{modelData.type === 'ip_adapter' && (
<Flex gap={2}>
<ModelAttrView label={t('modelManager.imageEncoderModelId')} value={modelData.image_encoder_model_id} />
</Flex>
)}
</Flex>
</Box>
</Flex> </Flex>
); );
}; };

View File

@ -24,7 +24,15 @@ export type UpdateModelArg = {
body: paths['/api/v2/models/i/{key}']['patch']['requestBody']['content']['application/json']; body: paths['/api/v2/models/i/{key}']['patch']['requestBody']['content']['application/json'];
}; };
export type UpdateModelMetadataArg = {
key: paths['/api/v2/models/i/{key}/metadata']['patch']['parameters']['path']['key'];
body: paths['/api/v2/models/i/{key}/metadata']['patch']['requestBody']['content']['application/json'];
};
type UpdateModelResponse = paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json']; type UpdateModelResponse = paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json'];
type UpdateModelMetadataResponse =
paths['/api/v2/models/i/{key}/metadata']['patch']['responses']['200']['content']['application/json'];
type GetModelConfigResponse = paths['/api/v2/models/i/{key}']['get']['responses']['200']['content']['application/json']; type GetModelConfigResponse = paths['/api/v2/models/i/{key}']['get']['responses']['200']['content']['application/json'];
type GetModelMetadataResponse = type GetModelMetadataResponse =
@ -172,6 +180,16 @@ export const modelsApi = api.injectEndpoints({
}, },
invalidatesTags: ['Model'], invalidatesTags: ['Model'],
}), }),
updateModelMetadata: build.mutation<UpdateModelMetadataResponse, UpdateModelMetadataArg>({
query: ({ key, body }) => {
return {
url: buildModelsUrl(`i/${key}/metadata`),
method: 'PATCH',
body: body,
};
},
invalidatesTags: ['Model'],
}),
installModel: build.mutation<InstallModelResponse, InstallModelArg>({ installModel: build.mutation<InstallModelResponse, InstallModelArg>({
query: ({ source, config, access_token }) => { query: ({ source, config, access_token }) => {
return { return {
@ -351,6 +369,7 @@ export const {
useGetModelMetadataQuery, useGetModelMetadataQuery,
useDeleteModelImportMutation, useDeleteModelImportMutation,
usePruneModelImportsMutation, usePruneModelImportsMutation,
useUpdateModelMetadataMutation,
} = modelsApi; } = modelsApi;
const upsertModelConfigs = ( const upsertModelConfigs = (

File diff suppressed because one or more lines are too long